more caching in composition pipeline
authorblanchet
Tue, 04 Mar 2014 18:57:17 +0100
changeset 55904 0ef30d52c5e4
parent 55903 461a75d00323
child 55905 91d5085ad928
more caching in composition pipeline
src/HOL/Tools/BNF/bnf_comp.ML
src/HOL/Tools/BNF/bnf_fp_util.ML
--- a/src/HOL/Tools/BNF/bnf_comp.ML	Tue Mar 04 18:57:17 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_comp.ML	Tue Mar 04 18:57:17 2014 +0100
@@ -23,12 +23,12 @@
 
   val bnf_of_typ: BNF_Def.inline_policy -> (binding -> binding) ->
     ((string * sort) list list -> (string * sort) list) -> (string * sort) list ->
-    (string * sort) list -> typ -> (comp_cache * unfold_set) * Proof.context ->
-    (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * Proof.context)
+    (string * sort) list -> typ -> (comp_cache * unfold_set) * local_theory ->
+    (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * local_theory)
   val default_comp_sort: (string * sort) list list -> (string * sort) list
   val normalize_bnfs: (int -> binding -> binding) -> ''a list list -> ''a list ->
-    (''a list list -> ''a list) -> BNF_Def.bnf list -> unfold_set -> Proof.context ->
-    (int list list * ''a list) * (BNF_Def.bnf list * (unfold_set * Proof.context))
+    (''a list list -> ''a list) -> BNF_Def.bnf list -> (comp_cache * unfold_set) * local_theory ->
+    (int list list * ''a list) * (BNF_Def.bnf list * ((comp_cache * unfold_set) * local_theory))
 
   type absT_info =
     {absT: typ,
@@ -45,7 +45,7 @@
   val mk_abs: typ -> term -> term
   val mk_rep: typ -> term -> term
   val seal_bnf: (binding -> binding) -> unfold_set -> binding -> typ list -> BNF_Def.bnf ->
-    Proof.context -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory
+    local_theory -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory
 end;
 
 structure BNF_Comp : BNF_COMP =
@@ -61,6 +61,25 @@
 
 type comp_cache = (bnf * (typ list * typ list)) Typtab.table;
 
+fun key_of_types s Ts = Type (s, Ts);
+fun key_of_typess s = key_of_types s o map (key_of_types "");
+fun typ_of_int n = Type (string_of_int n, []);
+fun typ_of_bnf bnf =
+  key_of_typess "" [[T_of_bnf bnf], lives_of_bnf bnf, sort Term_Ord.typ_ord (deads_of_bnf bnf)];
+
+fun key_of_kill n bnf = key_of_types "k" [typ_of_int n, typ_of_bnf bnf];
+fun key_of_lift n bnf = key_of_types "l" [typ_of_int n, typ_of_bnf bnf];
+fun key_of_permute src dest bnf =
+  key_of_types "p" (map typ_of_int src @ map typ_of_int dest @ [typ_of_bnf bnf]);
+fun key_of_compose oDs Dss Ass outer inners =
+  key_of_types "c" (map (key_of_typess "") [[oDs], Dss, Ass, [map typ_of_bnf (outer :: inners)]]);
+
+fun cache_comp_simple key cache (bnf, (unfold_set, lthy)) =
+  (bnf, ((Typtab.update (key, (bnf, ([], []))) cache, unfold_set), lthy));
+
+fun cache_comp key (bnf_Ds_As, ((cache, unfold_set), lthy)) =
+  (bnf_Ds_As, ((Typtab.update (key, bnf_Ds_As) cache, unfold_set), lthy));
+
 (* TODO: Replace by "BNF_Defs.defs list"? *)
 type unfold_set = {
   map_unfolds: thm list,
@@ -118,7 +137,7 @@
     val (oDs, lthy1) = apfst (map TFree)
       (Variable.invent_types (replicate odead HOLogic.typeS) lthy);
     val (Dss, lthy2) = apfst (map (map TFree))
-        (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1);
+      (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1);
     val (Ass, lthy3) = apfst (replicate ilive o map TFree)
       (Variable.invent_types (replicate ilive HOLogic.typeS) lthy2);
     val As = if ilive > 0 then hd Ass else [];
@@ -305,7 +324,8 @@
 
 (* Killing live variables *)
 
-fun kill_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else
+fun raw_kill_bnf qualify n bnf (accum as (unfold_set, lthy)) =
+  if n = 0 then (bnf, accum) else
   let
     val b = Binding.suffix_name (mk_killN n) (name_of_bnf bnf);
     val live = live_of_bnf bnf;
@@ -394,9 +414,17 @@
     (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
   end;
 
+fun kill_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) =
+  let val key = key_of_kill n bnf in
+    (case Typtab.lookup cache key of
+      SOME (bnf, _) => (bnf, accum)
+    | NONE => cache_comp_simple key cache (raw_kill_bnf qualify n bnf (unfold_set, lthy)))
+  end;
+
 (* Adding dummy live variables *)
 
-fun lift_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else
+fun raw_lift_bnf qualify n bnf (accum as (unfold_set, lthy)) =
+  if n = 0 then (bnf, accum) else
   let
     val b = Binding.suffix_name (mk_liftN n) (name_of_bnf bnf);
     val live = live_of_bnf bnf;
@@ -476,10 +504,17 @@
     (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
   end;
 
+fun lift_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) =
+  let val key = key_of_lift n bnf in
+    (case Typtab.lookup cache key of
+      SOME (bnf, _) => (bnf, accum)
+    | NONE => cache_comp_simple key cache (raw_lift_bnf qualify n bnf (unfold_set, lthy)))
+  end;
+
 (* Changing the order of live variables *)
 
-fun permute_bnf qualify src dest bnf (unfold_set, lthy) =
-  if src = dest then (bnf, (unfold_set, lthy)) else
+fun raw_permute_bnf qualify src dest bnf (accum as (unfold_set, lthy)) =
+  if src = dest then (bnf, accum) else
   let
     val b = Binding.suffix_name (mk_permuteN src dest) (name_of_bnf bnf);
     val live = live_of_bnf bnf;
@@ -550,6 +585,13 @@
     (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
   end;
 
+fun permute_bnf qualify src dest bnf (accum as ((cache, unfold_set), lthy)) =
+  let val key = key_of_permute src dest bnf in
+    (case Typtab.lookup cache key of
+      SOME (bnf, _) => (bnf, accum)
+    | NONE => cache_comp_simple key cache (raw_permute_bnf qualify src dest bnf (unfold_set, lthy)))
+  end;
+
 (* Composition pipeline *)
 
 fun permute_and_kill qualify n src dest bnf =
@@ -560,17 +602,17 @@
   lift_bnf qualify n bnf
   #> uncurry (permute_bnf qualify src dest);
 
-fun normalize_bnfs qualify Ass Ds sort bnfs unfold_set lthy =
+fun normalize_bnfs qualify Ass Ds sort bnfs accum =
   let
     val before_kill_src = map (fn As => 0 upto (length As - 1)) Ass;
     val kill_poss = map (find_indices op = Ds) Ass;
     val live_poss = map2 (subtract op =) kill_poss before_kill_src;
     val before_kill_dest = map2 append kill_poss live_poss;
     val kill_ns = map length kill_poss;
-    val (inners', (unfold_set', lthy')) =
+    val (inners', accum') =
       fold_map5 (fn i => permute_and_kill (qualify i))
         (if length bnfs = 1 then [0] else (1 upto length bnfs))
-        kill_ns before_kill_src before_kill_dest bnfs (unfold_set, lthy);
+        kill_ns before_kill_src before_kill_dest bnfs accum;
 
     val Ass' = map2 (map o nth) Ass live_poss;
     val As = sort Ass';
@@ -582,27 +624,36 @@
   in
     ((kill_poss, As), fold_map5 (fn i => lift_and_permute (qualify i))
       (if length bnfs = 1 then [0] else 1 upto length bnfs)
-      lift_ns after_lift_src after_lift_dest inners' (unfold_set', lthy'))
+      lift_ns after_lift_src after_lift_dest inners' accum')
   end;
 
 fun default_comp_sort Ass =
   Library.sort (Term_Ord.typ_ord o pairself TFree) (fold (fold (insert (op =))) Ass []);
 
-fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (unfold_set, lthy) =
+fun raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum =
   let
     val b = name_of_bnf outer;
 
     val Ass = map (map Term.dest_TFree) tfreess;
     val Ds = fold (fold Term.add_tfreesT) (oDs :: Dss) [];
 
-    val ((kill_poss, As), (inners', (unfold_set', lthy'))) =
-      normalize_bnfs qualify Ass Ds sort inners unfold_set lthy;
+    val ((kill_poss, As), (inners', ((cache', unfold_set'), lthy'))) =
+      normalize_bnfs qualify Ass Ds sort inners accum;
 
     val Ds = oDs @ flat (map3 (append oo map o nth) tfreess kill_poss Dss);
     val As = map TFree As;
   in
     apfst (rpair (Ds, As))
-      (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy'))
+      (apsnd (apfst (pair cache'))
+        (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy')))
+  end;
+
+fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (accum as ((cache, _), _)) =
+  let val key = key_of_compose oDs Dss tfreess outer inners in
+    (case Typtab.lookup cache key of
+      SOME bnf_Ds_As => (bnf_Ds_As, accum)
+    | NONE =>
+      cache_comp key (raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum))
   end;
 
 (* Hide the type of the bound (optimization) and unfold the definitions (nicer to the user) *)
@@ -791,13 +842,6 @@
     ((bnf', (all_deads, absT_info)), lthy')
   end;
 
-fun key_of_types Ts = Type ("", Ts);
-val key_of_typess = key_of_types o map key_of_types;
-fun key_of_comp oDs Dss Ass T = key_of_types (map key_of_typess [[oDs], Dss, Ass, [[T]]]);
-
-fun cache_comp key cache (bnf_dead_lives, (unfold_set, lthy)) =
-  (bnf_dead_lives, ((Typtab.update (key, bnf_dead_lives) cache, unfold_set), lthy));
-
 exception BAD_DEAD of typ * typ;
 
 fun bnf_of_typ _ _ _ _ Ds0 (T as TFree T') accum =
@@ -823,10 +867,10 @@
             val deads = deads_of_bnf bnf;
             val lives = lives_of_bnf bnf;
             val tvars' = Term.add_tvarsT T' [];
-            val deads_lives =
+            val Ds_As =
               pairself (map (Term.typ_subst_TVars (map fst tvars' ~~ map TFree tfrees)))
                 (deads, lives);
-          in ((bnf, deads_lives), accum) end
+          in ((bnf, Ds_As), accum) end
         else
           let
             val name = Long_Name.base_name C;
@@ -839,18 +883,12 @@
               (mk_T_of_bnf (replicate odead (TFree ("dead", []))) (replicate olive dummyT) bnf)));
             val oDs = map (nth Ts) oDs_pos;
             val Ts' = map (nth Ts) (subtract (op =) oDs_pos (0 upto length Ts - 1));
-            val ((inners, (Dss, Ass)), ((cache', unfold_set'), lthy')) =
+            val ((inners, (Dss, Ass)), (accum', lthy')) =
               apfst (apsnd split_list o split_list)
                 (fold_map2 (fn i => bnf_of_typ Smart_Inline (qualify i) sort Xs Ds0)
                 (if length Ts' = 1 then [0] else (1 upto length Ts')) Ts' accum);
-            val key = key_of_comp oDs Dss Ass T;
           in
-            (case Typtab.lookup cache' key of
-              SOME bnf_deads_lives => (bnf_deads_lives, accum)
-            | NONE =>
-              (unfold_set', lthy')
-              |> compose_bnf const_policy qualify sort bnf inners oDs Dss Ass
-              |> cache_comp key cache')
+            compose_bnf const_policy qualify sort bnf inners oDs Dss Ass (accum', lthy')
           end)
       |> tap check_bad_dead
     end;
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Tue Mar 04 18:57:17 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Tue Mar 04 18:57:17 2014 +0100
@@ -584,7 +584,7 @@
         #> Binding.conceal
       end;
 
-    val ((bnfs, (deadss, livess)), ((_, unfold_set), lthy)) =
+    val ((bnfs, (deadss, livess)), accum) =
       apfst (apsnd split_list o split_list)
         (fold_map2 (fn b => bnf_of_typ Smart_Inline (raw_qualify b) fp_sort Xs Ds0) bs rhsXs
           ((empty_comp_cache, empty_unfolds), lthy));
@@ -598,8 +598,8 @@
 
     val timer = time (timer "Construction of BNFs");
 
-    val ((kill_poss, _), (bnfs', (unfold_set', lthy'))) =
-      normalize_bnfs norm_qualify Ass Ds fp_sort bnfs unfold_set lthy;
+    val ((kill_poss, _), (bnfs', ((_, unfold_set'), lthy'))) =
+      normalize_bnfs norm_qualify Ass Ds fp_sort bnfs accum;
 
     val Dss = map3 (append oo map o nth) livess kill_poss deadss;