# HG changeset patch # User blanchet # Date 1393955837 -3600 # Node ID 0ef30d52c5e4f874c13b0146fed978c4db9d3e90 # Parent 461a75d0032347c3d56e88ff0aed3056786bf8c7 more caching in composition pipeline diff -r 461a75d00323 -r 0ef30d52c5e4 src/HOL/Tools/BNF/bnf_comp.ML --- a/src/HOL/Tools/BNF/bnf_comp.ML Tue Mar 04 18:57:17 2014 +0100 +++ b/src/HOL/Tools/BNF/bnf_comp.ML Tue Mar 04 18:57:17 2014 +0100 @@ -23,12 +23,12 @@ val bnf_of_typ: BNF_Def.inline_policy -> (binding -> binding) -> ((string * sort) list list -> (string * sort) list) -> (string * sort) list -> - (string * sort) list -> typ -> (comp_cache * unfold_set) * Proof.context -> - (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * Proof.context) + (string * sort) list -> typ -> (comp_cache * unfold_set) * local_theory -> + (BNF_Def.bnf * (typ list * typ list)) * ((comp_cache * unfold_set) * local_theory) val default_comp_sort: (string * sort) list list -> (string * sort) list val normalize_bnfs: (int -> binding -> binding) -> ''a list list -> ''a list -> - (''a list list -> ''a list) -> BNF_Def.bnf list -> unfold_set -> Proof.context -> - (int list list * ''a list) * (BNF_Def.bnf list * (unfold_set * Proof.context)) + (''a list list -> ''a list) -> BNF_Def.bnf list -> (comp_cache * unfold_set) * local_theory -> + (int list list * ''a list) * (BNF_Def.bnf list * ((comp_cache * unfold_set) * local_theory)) type absT_info = {absT: typ, @@ -45,7 +45,7 @@ val mk_abs: typ -> term -> term val mk_rep: typ -> term -> term val seal_bnf: (binding -> binding) -> unfold_set -> binding -> typ list -> BNF_Def.bnf -> - Proof.context -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory + local_theory -> (BNF_Def.bnf * (typ list * absT_info)) * local_theory end; structure BNF_Comp : BNF_COMP = @@ -61,6 +61,25 @@ type comp_cache = (bnf * (typ list * typ list)) Typtab.table; +fun key_of_types s Ts = Type (s, Ts); +fun key_of_typess s = key_of_types s o map (key_of_types ""); +fun typ_of_int n = Type (string_of_int n, []); +fun typ_of_bnf bnf = + key_of_typess "" [[T_of_bnf bnf], lives_of_bnf bnf, sort Term_Ord.typ_ord (deads_of_bnf bnf)]; + +fun key_of_kill n bnf = key_of_types "k" [typ_of_int n, typ_of_bnf bnf]; +fun key_of_lift n bnf = key_of_types "l" [typ_of_int n, typ_of_bnf bnf]; +fun key_of_permute src dest bnf = + key_of_types "p" (map typ_of_int src @ map typ_of_int dest @ [typ_of_bnf bnf]); +fun key_of_compose oDs Dss Ass outer inners = + key_of_types "c" (map (key_of_typess "") [[oDs], Dss, Ass, [map typ_of_bnf (outer :: inners)]]); + +fun cache_comp_simple key cache (bnf, (unfold_set, lthy)) = + (bnf, ((Typtab.update (key, (bnf, ([], []))) cache, unfold_set), lthy)); + +fun cache_comp key (bnf_Ds_As, ((cache, unfold_set), lthy)) = + (bnf_Ds_As, ((Typtab.update (key, bnf_Ds_As) cache, unfold_set), lthy)); + (* TODO: Replace by "BNF_Defs.defs list"? *) type unfold_set = { map_unfolds: thm list, @@ -118,7 +137,7 @@ val (oDs, lthy1) = apfst (map TFree) (Variable.invent_types (replicate odead HOLogic.typeS) lthy); val (Dss, lthy2) = apfst (map (map TFree)) - (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1); + (fold_map Variable.invent_types (map (fn n => replicate n HOLogic.typeS) ideads) lthy1); val (Ass, lthy3) = apfst (replicate ilive o map TFree) (Variable.invent_types (replicate ilive HOLogic.typeS) lthy2); val As = if ilive > 0 then hd Ass else []; @@ -305,7 +324,8 @@ (* Killing live variables *) -fun kill_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else +fun raw_kill_bnf qualify n bnf (accum as (unfold_set, lthy)) = + if n = 0 then (bnf, accum) else let val b = Binding.suffix_name (mk_killN n) (name_of_bnf bnf); val live = live_of_bnf bnf; @@ -394,9 +414,17 @@ (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy')) end; +fun kill_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) = + let val key = key_of_kill n bnf in + (case Typtab.lookup cache key of + SOME (bnf, _) => (bnf, accum) + | NONE => cache_comp_simple key cache (raw_kill_bnf qualify n bnf (unfold_set, lthy))) + end; + (* Adding dummy live variables *) -fun lift_bnf qualify n bnf (unfold_set, lthy) = if n = 0 then (bnf, (unfold_set, lthy)) else +fun raw_lift_bnf qualify n bnf (accum as (unfold_set, lthy)) = + if n = 0 then (bnf, accum) else let val b = Binding.suffix_name (mk_liftN n) (name_of_bnf bnf); val live = live_of_bnf bnf; @@ -476,10 +504,17 @@ (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy')) end; +fun lift_bnf qualify n bnf (accum as ((cache, unfold_set), lthy)) = + let val key = key_of_lift n bnf in + (case Typtab.lookup cache key of + SOME (bnf, _) => (bnf, accum) + | NONE => cache_comp_simple key cache (raw_lift_bnf qualify n bnf (unfold_set, lthy))) + end; + (* Changing the order of live variables *) -fun permute_bnf qualify src dest bnf (unfold_set, lthy) = - if src = dest then (bnf, (unfold_set, lthy)) else +fun raw_permute_bnf qualify src dest bnf (accum as (unfold_set, lthy)) = + if src = dest then (bnf, accum) else let val b = Binding.suffix_name (mk_permuteN src dest) (name_of_bnf bnf); val live = live_of_bnf bnf; @@ -550,6 +585,13 @@ (bnf', (add_bnf_to_unfolds bnf' unfold_set, lthy')) end; +fun permute_bnf qualify src dest bnf (accum as ((cache, unfold_set), lthy)) = + let val key = key_of_permute src dest bnf in + (case Typtab.lookup cache key of + SOME (bnf, _) => (bnf, accum) + | NONE => cache_comp_simple key cache (raw_permute_bnf qualify src dest bnf (unfold_set, lthy))) + end; + (* Composition pipeline *) fun permute_and_kill qualify n src dest bnf = @@ -560,17 +602,17 @@ lift_bnf qualify n bnf #> uncurry (permute_bnf qualify src dest); -fun normalize_bnfs qualify Ass Ds sort bnfs unfold_set lthy = +fun normalize_bnfs qualify Ass Ds sort bnfs accum = let val before_kill_src = map (fn As => 0 upto (length As - 1)) Ass; val kill_poss = map (find_indices op = Ds) Ass; val live_poss = map2 (subtract op =) kill_poss before_kill_src; val before_kill_dest = map2 append kill_poss live_poss; val kill_ns = map length kill_poss; - val (inners', (unfold_set', lthy')) = + val (inners', accum') = fold_map5 (fn i => permute_and_kill (qualify i)) (if length bnfs = 1 then [0] else (1 upto length bnfs)) - kill_ns before_kill_src before_kill_dest bnfs (unfold_set, lthy); + kill_ns before_kill_src before_kill_dest bnfs accum; val Ass' = map2 (map o nth) Ass live_poss; val As = sort Ass'; @@ -582,27 +624,36 @@ in ((kill_poss, As), fold_map5 (fn i => lift_and_permute (qualify i)) (if length bnfs = 1 then [0] else 1 upto length bnfs) - lift_ns after_lift_src after_lift_dest inners' (unfold_set', lthy')) + lift_ns after_lift_src after_lift_dest inners' accum') end; fun default_comp_sort Ass = Library.sort (Term_Ord.typ_ord o pairself TFree) (fold (fold (insert (op =))) Ass []); -fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (unfold_set, lthy) = +fun raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum = let val b = name_of_bnf outer; val Ass = map (map Term.dest_TFree) tfreess; val Ds = fold (fold Term.add_tfreesT) (oDs :: Dss) []; - val ((kill_poss, As), (inners', (unfold_set', lthy'))) = - normalize_bnfs qualify Ass Ds sort inners unfold_set lthy; + val ((kill_poss, As), (inners', ((cache', unfold_set'), lthy'))) = + normalize_bnfs qualify Ass Ds sort inners accum; val Ds = oDs @ flat (map3 (append oo map o nth) tfreess kill_poss Dss); val As = map TFree As; in apfst (rpair (Ds, As)) - (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy')) + (apsnd (apfst (pair cache')) + (clean_compose_bnf const_policy (qualify 0) b outer inners' (unfold_set', lthy'))) + end; + +fun compose_bnf const_policy qualify sort outer inners oDs Dss tfreess (accum as ((cache, _), _)) = + let val key = key_of_compose oDs Dss tfreess outer inners in + (case Typtab.lookup cache key of + SOME bnf_Ds_As => (bnf_Ds_As, accum) + | NONE => + cache_comp key (raw_compose_bnf const_policy qualify sort outer inners oDs Dss tfreess accum)) end; (* Hide the type of the bound (optimization) and unfold the definitions (nicer to the user) *) @@ -791,13 +842,6 @@ ((bnf', (all_deads, absT_info)), lthy') end; -fun key_of_types Ts = Type ("", Ts); -val key_of_typess = key_of_types o map key_of_types; -fun key_of_comp oDs Dss Ass T = key_of_types (map key_of_typess [[oDs], Dss, Ass, [[T]]]); - -fun cache_comp key cache (bnf_dead_lives, (unfold_set, lthy)) = - (bnf_dead_lives, ((Typtab.update (key, bnf_dead_lives) cache, unfold_set), lthy)); - exception BAD_DEAD of typ * typ; fun bnf_of_typ _ _ _ _ Ds0 (T as TFree T') accum = @@ -823,10 +867,10 @@ val deads = deads_of_bnf bnf; val lives = lives_of_bnf bnf; val tvars' = Term.add_tvarsT T' []; - val deads_lives = + val Ds_As = pairself (map (Term.typ_subst_TVars (map fst tvars' ~~ map TFree tfrees))) (deads, lives); - in ((bnf, deads_lives), accum) end + in ((bnf, Ds_As), accum) end else let val name = Long_Name.base_name C; @@ -839,18 +883,12 @@ (mk_T_of_bnf (replicate odead (TFree ("dead", []))) (replicate olive dummyT) bnf))); val oDs = map (nth Ts) oDs_pos; val Ts' = map (nth Ts) (subtract (op =) oDs_pos (0 upto length Ts - 1)); - val ((inners, (Dss, Ass)), ((cache', unfold_set'), lthy')) = + val ((inners, (Dss, Ass)), (accum', lthy')) = apfst (apsnd split_list o split_list) (fold_map2 (fn i => bnf_of_typ Smart_Inline (qualify i) sort Xs Ds0) (if length Ts' = 1 then [0] else (1 upto length Ts')) Ts' accum); - val key = key_of_comp oDs Dss Ass T; in - (case Typtab.lookup cache' key of - SOME bnf_deads_lives => (bnf_deads_lives, accum) - | NONE => - (unfold_set', lthy') - |> compose_bnf const_policy qualify sort bnf inners oDs Dss Ass - |> cache_comp key cache') + compose_bnf const_policy qualify sort bnf inners oDs Dss Ass (accum', lthy') end) |> tap check_bad_dead end; diff -r 461a75d00323 -r 0ef30d52c5e4 src/HOL/Tools/BNF/bnf_fp_util.ML --- a/src/HOL/Tools/BNF/bnf_fp_util.ML Tue Mar 04 18:57:17 2014 +0100 +++ b/src/HOL/Tools/BNF/bnf_fp_util.ML Tue Mar 04 18:57:17 2014 +0100 @@ -584,7 +584,7 @@ #> Binding.conceal end; - val ((bnfs, (deadss, livess)), ((_, unfold_set), lthy)) = + val ((bnfs, (deadss, livess)), accum) = apfst (apsnd split_list o split_list) (fold_map2 (fn b => bnf_of_typ Smart_Inline (raw_qualify b) fp_sort Xs Ds0) bs rhsXs ((empty_comp_cache, empty_unfolds), lthy)); @@ -598,8 +598,8 @@ val timer = time (timer "Construction of BNFs"); - val ((kill_poss, _), (bnfs', (unfold_set', lthy'))) = - normalize_bnfs norm_qualify Ass Ds fp_sort bnfs unfold_set lthy; + val ((kill_poss, _), (bnfs', ((_, unfold_set'), lthy'))) = + normalize_bnfs norm_qualify Ass Ds fp_sort bnfs accum; val Dss = map3 (append oo map o nth) livess kill_poss deadss;