--- a/src/HOL/Tools/BNF/bnf_comp.ML Tue Mar 04 18:57:17 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_comp.ML Tue Mar 04 18:57:17 2014 +0100
@@ -23,12 +23,12 @@
val bnf_of_typ: BNF_Def.inline_policy -> (binding -> binding) ->
((string * sort) list list -> (string * sort) list) -> (string * sort) list ->
- (string * sort) list -> typ -> (comp_cache * unfold_set) * Proof.context ->
- (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * Proof.context)
+ (string * sort) list -> typ -> (comp_cache * unfold_set) * local_theory ->
+ (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * local_theory)
val default_comp_sort: (string * sort) list list -> (string * sort) list
val normalize_bnfs: (int -> binding -> binding) -> ''a list list -> ''a list ->
- (''a list list -> ''a list) -> BNF_Def.bnf list -> unfold_set -> Proof.context ->
- (int list list * ''a list) * (BNF_Def.bnf list * (unfold_set * Proof.context))
+ (''a list list -> ''a list) -> BNF_Def.bnf list -> (comp_cache * unfold_set) * local_theory ->
+ (int list list * ''a list) * (BNF_Def.bnf list * ((comp_cache * unfold_set) * local_theory))
type absT_info =
{absT: typ,
@@ -45,7 +45,7 @@
val mk_abs: typ -> term -> term
val mk_rep: typ -> term -> term
val seal_bnf: (binding -> binding) -> unfold_set -> binding -> typ list -> BNF_Def.bnf ->
- Proof.context -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory
+ local_theory -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory
end;
structure BNF_Comp : BNF_COMP =
@@ -61,6 +61,25 @@
type comp_cache = (bnf * (typ list * typ list)) Typtab.table;
+fun key_of_types s Ts = Type (s, Ts);
+fun key_of_typess s = key_of_types s o map (key_of_types "");
+fun typ_of_int n = Type (string_of_int n, []);
+fun typ_of_bnf bnf =
+ key_of_typess "" [[T_of_bnf bnf], lives_of_bnf bnf, sort Term_Ord.typ_ord (deads_of_bnf bnf)];
+
+fun key_of_kill n bnf = key_of_types "k" [typ_of_int n, typ_of_bnf bnf];
+fun key_of_lift n bnf = key_of_types "l" [typ_of_int n, typ_of_bnf bnf];
+fun key_of_permute src dest bnf =
+ key_of_types "p" (map typ_of_int src @ map typ_of_int dest @ [typ_of_bnf bnf]);
+fun key_of_compose oDs Dss Ass outer inners =
+ key_of_types "c" (map (key_of_typess "") [[oDs], Dss, Ass, [map typ_of_bnf (outer :: inners)]]);
+
+fun cache_comp_simple key cache (bnf, (unfold_set, lthy)) =
+ (bnf, ((Typtab.update (key, (bnf, ([], []))) cache, unfold_set), lthy));
+
+fun cache_comp key (bnf_Ds_As, ((cache, unfold_set), lthy)) =
+ (bnf_Ds_As, ((Typtab.update (key, bnf_Ds_As) cache, unfold_set), lthy));
+
(* TODO: Replace by "BNF_Defs.defs list"? *)
type unfold_set = {
map_unfolds: thm list,
@@ -118,7 +137,7 @@
val (oDs, lthy1) = apfst (map TFree)
(Variable.invent_types (replicate odead HOLogic.typeS) lthy);
val (Dss, lthy2) = apfst (map (map TFree))
- (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1);
+ (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1);
val (Ass, lthy3) = apfst (replicate ilive o map TFree)
(Variable.invent_types (replicate ilive HOLogic.typeS) lthy2);
val As = if ilive > 0 then hd Ass else [];
@@ -305,7 +324,8 @@
(* Killing live variables *)
-fun kill_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else
+fun raw_kill_bnf qualify n bnf (accum as (unfold_set, lthy)) =
+ if n = 0 then (bnf, accum) else
let
val b = Binding.suffix_name (mk_killN n) (name_of_bnf bnf);
val live = live_of_bnf bnf;
@@ -394,9 +414,17 @@
(bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
end;
+fun kill_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) =
+ let val key = key_of_kill n bnf in
+ (case Typtab.lookup cache key of
+ SOME (bnf, _) => (bnf, accum)
+ | NONE => cache_comp_simple key cache (raw_kill_bnf qualify n bnf (unfold_set, lthy)))
+ end;
+
(* Adding dummy live variables *)
-fun lift_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else
+fun raw_lift_bnf qualify n bnf (accum as (unfold_set, lthy)) =
+ if n = 0 then (bnf, accum) else
let
val b = Binding.suffix_name (mk_liftN n) (name_of_bnf bnf);
val live = live_of_bnf bnf;
@@ -476,10 +504,17 @@
(bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
end;
+fun lift_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) =
+ let val key = key_of_lift n bnf in
+ (case Typtab.lookup cache key of
+ SOME (bnf, _) => (bnf, accum)
+ | NONE => cache_comp_simple key cache (raw_lift_bnf qualify n bnf (unfold_set, lthy)))
+ end;
+
(* Changing the order of live variables *)
-fun permute_bnf qualify src dest bnf (unfold_set, lthy) =
- if src = dest then (bnf, (unfold_set, lthy)) else
+fun raw_permute_bnf qualify src dest bnf (accum as (unfold_set, lthy)) =
+ if src = dest then (bnf, accum) else
let
val b = Binding.suffix_name (mk_permuteN src dest) (name_of_bnf bnf);
val live = live_of_bnf bnf;
@@ -550,6 +585,13 @@
(bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy'))
end;
+fun permute_bnf qualify src dest bnf (accum as ((cache, unfold_set), lthy)) =
+ let val key = key_of_permute src dest bnf in
+ (case Typtab.lookup cache key of
+ SOME (bnf, _) => (bnf, accum)
+ | NONE => cache_comp_simple key cache (raw_permute_bnf qualify src dest bnf (unfold_set, lthy)))
+ end;
+
(* Composition pipeline *)
fun permute_and_kill qualify n src dest bnf =
@@ -560,17 +602,17 @@
lift_bnf qualify n bnf
#> uncurry (permute_bnf qualify src dest);
-fun normalize_bnfs qualify Ass Ds sort bnfs unfold_set lthy =
+fun normalize_bnfs qualify Ass Ds sort bnfs accum =
let
val before_kill_src = map (fn As => 0 upto (length As - 1)) Ass;
val kill_poss = map (find_indices op = Ds) Ass;
val live_poss = map2 (subtract op =) kill_poss before_kill_src;
val before_kill_dest = map2 append kill_poss live_poss;
val kill_ns = map length kill_poss;
- val (inners', (unfold_set', lthy')) =
+ val (inners', accum') =
fold_map5 (fn i => permute_and_kill (qualify i))
(if length bnfs = 1 then [0] else (1 upto length bnfs))
- kill_ns before_kill_src before_kill_dest bnfs (unfold_set, lthy);
+ kill_ns before_kill_src before_kill_dest bnfs accum;
val Ass' = map2 (map o nth) Ass live_poss;
val As = sort Ass';
@@ -582,27 +624,36 @@
in
((kill_poss, As), fold_map5 (fn i => lift_and_permute (qualify i))
(if length bnfs = 1 then [0] else 1 upto length bnfs)
- lift_ns after_lift_src after_lift_dest inners' (unfold_set', lthy'))
+ lift_ns after_lift_src after_lift_dest inners' accum')
end;
fun default_comp_sort Ass =
Library.sort (Term_Ord.typ_ord o pairself TFree) (fold (fold (insert (op =))) Ass []);
-fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (unfold_set, lthy) =
+fun raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum =
let
val b = name_of_bnf outer;
val Ass = map (map Term.dest_TFree) tfreess;
val Ds = fold (fold Term.add_tfreesT) (oDs :: Dss) [];
- val ((kill_poss, As), (inners', (unfold_set', lthy'))) =
- normalize_bnfs qualify Ass Ds sort inners unfold_set lthy;
+ val ((kill_poss, As), (inners', ((cache', unfold_set'), lthy'))) =
+ normalize_bnfs qualify Ass Ds sort inners accum;
val Ds = oDs @ flat (map3 (append oo map o nth) tfreess kill_poss Dss);
val As = map TFree As;
in
apfst (rpair (Ds, As))
- (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy'))
+ (apsnd (apfst (pair cache'))
+ (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy')))
+ end;
+
+fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (accum as ((cache, _), _)) =
+ let val key = key_of_compose oDs Dss tfreess outer inners in
+ (case Typtab.lookup cache key of
+ SOME bnf_Ds_As => (bnf_Ds_As, accum)
+ | NONE =>
+ cache_comp key (raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum))
end;
(* Hide the type of the bound (optimization) and unfold the definitions (nicer to the user) *)
@@ -791,13 +842,6 @@
((bnf', (all_deads, absT_info)), lthy')
end;
-fun key_of_types Ts = Type ("", Ts);
-val key_of_typess = key_of_types o map key_of_types;
-fun key_of_comp oDs Dss Ass T = key_of_types (map key_of_typess [[oDs], Dss, Ass, [[T]]]);
-
-fun cache_comp key cache (bnf_dead_lives, (unfold_set, lthy)) =
- (bnf_dead_lives, ((Typtab.update (key, bnf_dead_lives) cache, unfold_set), lthy));
-
exception BAD_DEAD of typ * typ;
fun bnf_of_typ _ _ _ _ Ds0 (T as TFree T') accum =
@@ -823,10 +867,10 @@
val deads = deads_of_bnf bnf;
val lives = lives_of_bnf bnf;
val tvars' = Term.add_tvarsT T' [];
- val deads_lives =
+ val Ds_As =
pairself (map (Term.typ_subst_TVars (map fst tvars' ~~ map TFree tfrees)))
(deads, lives);
- in ((bnf, deads_lives), accum) end
+ in ((bnf, Ds_As), accum) end
else
let
val name = Long_Name.base_name C;
@@ -839,18 +883,12 @@
(mk_T_of_bnf (replicate odead (TFree ("dead", []))) (replicate olive dummyT) bnf)));
val oDs = map (nth Ts) oDs_pos;
val Ts' = map (nth Ts) (subtract (op =) oDs_pos (0 upto length Ts - 1));
- val ((inners, (Dss, Ass)), ((cache', unfold_set'), lthy')) =
+ val ((inners, (Dss, Ass)), (accum', lthy')) =
apfst (apsnd split_list o split_list)
(fold_map2 (fn i => bnf_of_typ Smart_Inline (qualify i) sort Xs Ds0)
(if length Ts' = 1 then [0] else (1 upto length Ts')) Ts' accum);
- val key = key_of_comp oDs Dss Ass T;
in
- (case Typtab.lookup cache' key of
- SOME bnf_deads_lives => (bnf_deads_lives, accum)
- | NONE =>
- (unfold_set', lthy')
- |> compose_bnf const_policy qualify sort bnf inners oDs Dss Ass
- |> cache_comp key cache')
+ compose_bnf const_policy qualify sort bnf inners oDs Dss Ass (accum', lthy')
end)
|> tap check_bad_dead
end;
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML Tue Mar 04 18:57:17 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML Tue Mar 04 18:57:17 2014 +0100
@@ -584,7 +584,7 @@
#> Binding.conceal
end;
- val ((bnfs, (deadss, livess)), ((_, unfold_set), lthy)) =
+ val ((bnfs, (deadss, livess)), accum) =
apfst (apsnd split_list o split_list)
(fold_map2 (fn b => bnf_of_typ Smart_Inline (raw_qualify b) fp_sort Xs Ds0) bs rhsXs
((empty_comp_cache, empty_unfolds), lthy));
@@ -598,8 +598,8 @@
val timer = time (timer "Construction of BNFs");
- val ((kill_poss, _), (bnfs', (unfold_set', lthy'))) =
- normalize_bnfs norm_qualify Ass Ds fp_sort bnfs unfold_set lthy;
+ val ((kill_poss, _), (bnfs', ((_, unfold_set'), lthy'))) =
+ normalize_bnfs norm_qualify Ass Ds fp_sort bnfs accum;
val Dss = map3 (append oo map o nth) livess kill_poss deadss;