intermediate definitions and caching in n2m to keep terms small
authortraytel
Fri, 15 Apr 2016 21:33:47 +0200
changeset 63046 8053ef5a0174
parent 63045 c50c764aab10
child 63047 2146553e96c6
intermediate definitions and caching in n2m to keep terms small
src/HOL/BNF_Fixpoint_Base.thy
src/HOL/Tools/BNF/bnf_fp_n2m.ML
src/HOL/Tools/BNF/bnf_fp_n2m_tactics.ML
--- a/src/HOL/BNF_Fixpoint_Base.thy	Thu Apr 14 20:29:42 2016 +0200
+++ b/src/HOL/BNF_Fixpoint_Base.thy	Fri Apr 15 21:33:47 2016 +0200
@@ -284,6 +284,9 @@
   "rel_fun (rel_fun R S) (rel_fun (rel_fun R T) (rel_fun R (rel_prod S T))) BNF_Def.convol BNF_Def.convol"
   unfolding rel_fun_def convol_def by auto
 
+lemma Let_const: "Let x (\<lambda>_. c) = c"
+  unfolding Let_def ..
+
 ML_file "Tools/BNF/bnf_fp_util_tactics.ML"
 ML_file "Tools/BNF/bnf_fp_util.ML"
 ML_file "Tools/BNF/bnf_fp_def_sugar_tactics.ML"
--- a/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Thu Apr 14 20:29:42 2016 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Fri Apr 15 21:33:47 2016 +0200
@@ -39,6 +39,26 @@
     |> Thm.close_derivation
   end;
 
+val cacheN = "cache"
+fun mk_cacheN i = cacheN ^ string_of_int i ^ "_";
+val cache_threshold = Attrib.setup_config_int @{binding bnf_n2m_cache_threshold} (K 200);
+type cache = int * (term * thm) Typtab.table
+val empty_cache = (0, Typtab.empty)
+fun update_cache b0 TU t (cache as (i, tab), lthy) =
+  if size_of_term t < Config.get lthy cache_threshold then (t, (cache, lthy))
+  else
+    let
+      val b = Binding.prefix_name (mk_cacheN i) b0;
+      val ((c, thm), lthy') =
+        Local_Theory.define ((b, NoSyn), ((Binding.concealed (Thm.def_binding b), []), t)) lthy
+        |>> apsnd snd;
+    in
+      (c, ((i + 1, Typtab.update (TU, (c, thm)) tab), lthy'))
+    end;
+
+fun lookup_cache (SOME _) _ _ = NONE
+  | lookup_cache NONE TU ((_, tab), _) = Typtab.lookup tab TU |> Option.map fst;
+
 fun construct_mutualized_fp fp mutual_cliques fpTs indexed_fp_ress bs resBs (resDs, Dss) bnfs
     (absT_infos : absT_info list) lthy =
   let
@@ -254,9 +274,6 @@
     val fold_strTs = map2 mk_co_algT fold_preTs Xs;
     val resTs = map2 mk_co_algT fpTs Xs;
 
-    val ((fold_strs, fold_strs'), names_lthy) = names_lthy
-      |> mk_Frees' "s" fold_strTs;
-
     val fp_un_folds = of_fp_res #xtor_un_folds;
     val ns = map (length o #Ts o snd) indexed_fp_ress;
 
@@ -294,86 +311,109 @@
       case_fp fp (HOLogic.mk_comp (HOLogic.mk_comp (t, mk_abs absT abs), mk_rep fp_absT fp_rep))
         (HOLogic.mk_comp (mk_abs fp_absT fp_abs, HOLogic.mk_comp (mk_rep absT rep, t)));
 
-    fun mk_un_fold b_opt un_folds lthy TU =
-      let
-        val thy = Proof_Context.theory_of lthy;
-
-        val x = co_alg_argT TU;
-        val i = find_index (fn T => x = T) Xs;
-        val TUfold =
-          (case find_first (fn f => body_fun_type (fastype_of f) = TU) un_folds of
-            NONE => force_fold i TU (nth fp_un_folds i)
-          | SOME f => f);
-
-        val TUs = binder_fun_types (fastype_of TUfold);
-
-        fun mk_s TU' =
-          let
-            fun mk_absT_fp_repT repT absT = mk_absT thy repT absT ooo mk_repT;
+    val thy = Proof_Context.theory_of lthy;
+    fun mk_absT_fp_repT repT absT = mk_absT thy repT absT ooo mk_repT;
 
-            val i = find_index (fn T => co_alg_argT TU' = T) Xs;
-            val fp_abs = nth fp_abss i;
-            val fp_rep = nth fp_reps i;
-            val abs = nth abss i;
-            val rep = nth reps i;
-            val sF = co_alg_funT TU';
-            val sF' =
-              mk_absT_fp_repT (nth repTs' i) (nth absTs' i) (nth fp_absTs i) (nth fp_repTs i) sF
-                handle Term.TYPE _ => sF;
-            val F = nth fold_preTs i;
-            val s = nth fold_strs i;
-          in
-            if sF = F then s
-            else if sF' = F then mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep s
-            else
-              let
-                val smapT = replicate live dummyT ---> mk_co_algT sF' F;
-                fun hidden_to_unit t =
-                  Term.subst_TVars (map (rpair HOLogic.unitT) (Term.add_tvar_names t [])) t;
-                val smap = map_of_bnf (nth bnfs i)
-                  |> force_typ names_lthy smapT
-                  |> hidden_to_unit;
-                val smap_argTs = strip_typeN live (fastype_of smap) |> fst;
-                fun mk_smap_arg T_to_U =
-                  (if domain_type T_to_U = range_type T_to_U then
-                    HOLogic.id_const (domain_type T_to_U)
-                  else
-                    fst (fst (mk_un_fold NONE un_folds lthy T_to_U)));
-                val smap_args = map mk_smap_arg smap_argTs;
-              in
-                mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep
-                  (mk_co_comp s (Term.list_comb (smap, smap_args)))
-              end
-          end;
-        val t = Term.list_comb (TUfold, map mk_s TUs);
+    fun mk_un_fold b_opt ss un_folds cache_lthy TU =
+      (case lookup_cache b_opt TU cache_lthy of
+        SOME t => ((t, Drule.dummy_thm), cache_lthy)
+      | NONE =>
+        let
+          val x = co_alg_argT TU;
+          val i = find_index (fn T => x = T) Xs;
+          val TUfold =
+            (case find_first (fn f => body_fun_type (fastype_of f) = TU) un_folds of
+              NONE => force_fold i TU (nth fp_un_folds i)
+            | SOME f => f);
+  
+          val TUs = binder_fun_types (fastype_of TUfold);
+  
+          fun mk_s TU' cache_lthy =
+            let
+              val i = find_index (fn T => co_alg_argT TU' = T) Xs;
+              val fp_abs = nth fp_abss i;
+              val fp_rep = nth fp_reps i;
+              val abs = nth abss i;
+              val rep = nth reps i;
+              val sF = co_alg_funT TU';
+              val sF' =
+                mk_absT_fp_repT (nth repTs' i) (nth absTs' i) (nth fp_absTs i) (nth fp_repTs i) sF
+                  handle Term.TYPE _ => sF;
+              val F = nth fold_preTs i;
+              val s = nth ss i;
+            in
+              if sF = F then (s, cache_lthy)
+              else if sF' = F then (mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep s, cache_lthy)
+              else
+                let
+                  val smapT = replicate live dummyT ---> mk_co_algT sF' F;
+                  fun hidden_to_unit t =
+                    Term.subst_TVars (map (rpair HOLogic.unitT) (Term.add_tvar_names t [])) t;
+                  val smap = map_of_bnf (nth bnfs i)
+                    |> force_typ names_lthy smapT
+                    |> hidden_to_unit;
+                  val smap_argTs = strip_typeN live (fastype_of smap) |> fst;
+                  fun mk_smap_arg T_to_U cache_lthy =
+                    (if domain_type T_to_U = range_type T_to_U then
+                      (HOLogic.id_const (domain_type T_to_U), cache_lthy)
+                    else
+                      mk_un_fold NONE ss un_folds cache_lthy T_to_U |>> fst);
+                  val (smap_args, cache_lthy') = fold_map mk_smap_arg smap_argTs cache_lthy;
+                in
+                  (mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep
+                    (mk_co_comp s (Term.list_comb (smap, smap_args))), cache_lthy')
+                end
+            end;
+          val (args, cache_lthy) = fold_map mk_s TUs cache_lthy;
+          val t = Term.list_comb (TUfold, args);
+        in
+          (case b_opt of
+            NONE => update_cache b TU t cache_lthy |>> rpair Drule.dummy_thm
+          | SOME b => cache_lthy
+             |-> (fn cache =>
+               let
+                 val S = HOLogic.mk_tupleT fold_strTs;
+                 val s = HOLogic.mk_tuple ss;
+                 val u = Const (@{const_name Let}, S --> (S --> TU) --> TU) $ s $ absdummy S t;
+               in
+                 Local_Theory.define ((b, NoSyn), ((Binding.concealed (Thm.def_binding b), []), u))
+                 #>> apsnd snd ##> pair cache
+               end))
+        end);
+
+    val un_foldN = case_fp fp ctor_foldN dtor_unfoldN;
+    fun mk_un_folds (ss_names, lthy) =
+      let val ss = map2 (curry Free) ss_names fold_strTs;
       in
-        (case b_opt of
-          NONE => ((t, Drule.dummy_thm), lthy)
-        | SOME b => Local_Theory.define ((b, NoSyn), ((Binding.concealed (Thm.def_binding b), []),
-            fold_rev Term.absfree fold_strs' t)) lthy |>> apsnd snd)
+        fold2 (fn TU => fn b => fn ((un_folds, defs), cache_lthy) =>
+          mk_un_fold (SOME b) (map2 (curry Free) ss_names fold_strTs) un_folds cache_lthy TU
+          |>> (fn (f, d) => (f :: un_folds, d :: defs)))
+        resTs (map (Binding.suffix_name ("_" ^ un_foldN)) bs) (([], []), (empty_cache, lthy))
+        |>> map_prod rev rev
+        |>> pair ss
       end;
+    val ((ss, (un_folds, un_fold_defs0)), (cache, (lthy, raw_lthy))) = lthy
+      |> Local_Theory.open_target |> snd
+      |> Variable.add_fixes (mk_names n "s")
+      |> mk_un_folds
+      ||> apsnd (`(Local_Theory.close_target));
 
-    val foldN = case_fp fp ctor_foldN dtor_unfoldN;
-    fun mk_un_folds lthy =
-      fold2 (fn TU => fn b => fn ((un_folds, defs), lthy) =>
-        mk_un_fold (SOME b) un_folds lthy TU |>> (fn (f, d) => (f :: un_folds, d :: defs)))
-      resTs (map (Binding.suffix_name ("_" ^ foldN)) bs) (([], []), lthy)
-      |>> map_prod rev rev;
-    val ((raw_xtor_un_folds, raw_xtor_un_fold_defs), (lthy, raw_lthy)) = lthy
-      |> Local_Theory.open_target |> snd
-      |> mk_un_folds
-      ||> `Local_Theory.close_target;
+    val un_fold_defs = map (unfold_thms raw_lthy @{thms Let_const}) un_fold_defs0;
+
+    val cache_defs = snd cache |> Typtab.dest |> map (snd o snd);
 
     val phi = Proof_Context.export_morphism raw_lthy lthy;
 
-    val xtor_un_folds = map (Morphism.term phi) raw_xtor_un_folds;
-    val xtor_un_fold_defs = map (Morphism.thm phi) raw_xtor_un_fold_defs;
-    val xtor_un_folds' = map2 (fn raw => fn t => Const (fst (dest_Const t), fastype_of raw)) raw_xtor_un_folds xtor_un_folds;
+    val xtor_un_folds = map (head_of o Morphism.term phi) un_folds;
+    val xtor_un_fold_defs = map (Drule.abs_def o Morphism.thm phi) un_fold_defs;
+    val xtor_cache_defs = map (Drule.abs_def o Morphism.thm phi) cache_defs;
+    val xtor_un_folds' = map2 (fn raw => fn t =>
+        Const (fst (dest_Const t), fold_strTs ---> fastype_of raw))
+      un_folds xtor_un_folds;
 
     val fp_un_fold_o_maps = of_fp_res #xtor_un_fold_o_maps
       |> maps (fn thm => [thm, thm RS rewrite_comp_comp]);
 
-    val un_folds = map (fn fold => Term.list_comb (fold, fold_strs)) raw_xtor_un_folds;
     val fold_mapTs = co_swap (As @ fpTs, As @ Xs);
     val pre_fold_maps = @{map 2} (fn Ds => uncurry (mk_map_of_bnf Ds) fold_mapTs) Dss bnfs
     fun mk_pre_fold_maps fs =
@@ -398,8 +438,8 @@
     val eq_thm_prop_untyped = Term.aconv_untyped o apply2 Thm.full_prop_of;
 
     val map_thms = no_refl (maps (fn bnf =>
-       let val map_comp0 = map_comp0_of_bnf bnf RS sym
-       in [map_comp0, map_comp0 RS rewrite_comp_comp, map_id0_of_bnf bnf] end)
+        let val map_comp0 = map_comp0_of_bnf bnf RS sym
+        in [map_comp0, map_comp0 RS rewrite_comp_comp, map_id0_of_bnf bnf] end)
       fp_or_nesting_bnfs) @
       remove eq_thm_prop_untyped (case_fp fp @{thm comp_assoc[symmetric]} @{thm comp_assoc})
       (map2 (fn thm => fn bnf =>
@@ -421,18 +461,18 @@
                 fp_abs fp_rep abs rep rhs)
           end;
 
-        val goals = @{map 8} mk_goals un_folds xtors fold_strs pre_fold_maps fp_abss fp_reps abss reps;
+        val goals =
+          @{map 8} mk_goals un_folds xtors ss pre_fold_maps fp_abss fp_reps abss reps;
 
         val fp_un_folds = map (mk_pointfree lthy) (of_fp_res #xtor_un_fold_thms);
 
-        val simps = flat [simp_thms, raw_xtor_un_fold_defs, map_defs, fp_un_folds,
+        val simps = flat [simp_thms, un_fold_defs, map_defs, fp_un_folds,
           fp_un_fold_o_maps, map_thms, Rep_o_Abss];
       in
         Library.foldr1 HOLogic.mk_conj goals
         |> HOLogic.mk_Trueprop
-        |> fold_rev Logic.all fold_strs
         |> (fn goal => Goal.prove_sorry raw_lthy [] [] goal
-          (fn {context = ctxt, prems = _} => mk_xtor_un_fold_tac ctxt n simps))
+          (fn {context = ctxt, prems = _} => mk_xtor_un_fold_tac ctxt n simps cache_defs))
         |> Thm.close_derivation
         |> Morphism.thm phi
         |> split_conj_thm
@@ -454,7 +494,7 @@
               mk_co_comp_abs_rep (co_alg_funT (fastype_of lhs)) (co_alg_funT (fastype_of rhs))
                 fp_abs fp_rep abs rep rhs)
           end;
-        val prems = @{map 8} mk_prem fs fold_strs fold_maps xtors fp_abss fp_reps abss reps;
+        val prems = @{map 8} mk_prem fs ss fold_maps xtors fp_abss fp_reps abss reps;
         val concl = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj
           (map2 (curry HOLogic.mk_eq) fs un_folds));
         val vars = Variable.add_free_names raw_lthy concl [];
@@ -470,8 +510,8 @@
             |> unfold_thms lthy (pre_map_defs @ simp_thms)) nesting_bnfs;
       in
         Goal.prove_sorry raw_lthy vars prems concl
-          (mk_xtor_un_fold_unique_tac fp raw_xtor_un_fold_defs map_arg_congs xtor_un_fold_o_maps
-            Rep_o_Abss fp_un_fold_uniques simp_thms map_thms map_defs)
+          (mk_xtor_un_fold_unique_tac fp un_fold_defs map_arg_congs xtor_un_fold_o_maps
+            Rep_o_Abss fp_un_fold_uniques simp_thms map_thms map_defs cache_defs)
           |> Thm.close_derivation
           |> case_fp fp I (fn thm => thm OF replicate n sym)
           |> Morphism.thm phi
@@ -502,7 +542,7 @@
 
     fun tac {context = ctxt, prems = _} =
       mk_xtor_un_fold_transfer_tac ctxt n xtor_un_fold_defs rel_defs fp_un_fold_transfers
-        map_transfers Abs_transfers fp_or_nesting_rel_eqs;
+        map_transfers Abs_transfers fp_or_nesting_rel_eqs xtor_cache_defs;
 
     val xtor_un_fold_transfer_thms =
       mk_xtor_co_iter_transfer_thms fp pre_rels XYphis XYphis rels ABphis
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_tactics.ML	Thu Apr 14 20:29:42 2016 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_tactics.ML	Fri Apr 15 21:33:47 2016 +0200
@@ -10,11 +10,11 @@
   val mk_rel_xtor_co_induct_tac: BNF_Util.fp_kind -> thm list -> thm list -> thm list ->
     thm list -> thm list -> thm list -> {prems: thm list, context: Proof.context} -> tactic
   val mk_xtor_un_fold_unique_tac: BNF_Util.fp_kind -> thm list -> thm list -> thm list ->
-    thm list -> thm list -> thm list -> thm list -> thm list ->
+    thm list -> thm list -> thm list -> thm list -> thm list -> thm list ->
     {context: Proof.context, prems: thm list} -> tactic
-  val mk_xtor_un_fold_tac: Proof.context -> int -> thm list -> tactic
+  val mk_xtor_un_fold_tac: Proof.context -> int -> thm list -> thm list -> tactic
   val mk_xtor_un_fold_transfer_tac: Proof.context -> int -> thm list -> thm list -> thm list ->
-    thm list -> thm list -> thm list -> tactic
+    thm list -> thm list -> thm list -> thm list -> tactic
 end;
 
 structure BNF_FP_N2M_Tactics : BNF_FP_N2M_TACTICS =
@@ -26,6 +26,10 @@
 
 val vimage2p_unfolds = o_apply :: @{thms vimage2p_def};
 
+fun unfold_at_most_once_tac ctxt thms = 
+  CONVERSION (Conv.top_sweep_conv (K (Conv.rewrs_conv thms)) ctxt);
+val unfold_once_tac = CHANGED ooo unfold_at_most_once_tac;
+
 fun mk_rel_xtor_co_induct_tac fp abs_inverses co_inducts0 rel_defs rel_monos nesting_rel_eqs0
   nesting_rel_monos0 {context = ctxt, prems = raw_C_IHs} =
   let
@@ -61,7 +65,7 @@
   end;
 
 fun mk_xtor_un_fold_unique_tac fp xtor_un_fold_defs map_arg_congs xtor_un_fold_o_maps
-   Rep_o_Abss fp_un_fold_uniques simp_thms map_thms map_defs {context = ctxt, prems} =
+   Rep_o_Abss fp_un_fold_uniques simp_thms map_thms map_defs cache_defs {context = ctxt, prems} =
   unfold_thms_tac ctxt xtor_un_fold_defs THEN
   HEADGOAL (REPEAT_DETERM o FIRST' [hyp_subst_tac_thin true ctxt, rtac ctxt refl,
     rtac ctxt conjI,
@@ -70,14 +74,16 @@
     resolve_tac ctxt map_arg_congs,
     resolve_tac ctxt fp_un_fold_uniques THEN_ALL_NEW case_fp fp (K all_tac) (rtac ctxt sym),
     SELECT_GOAL (CHANGED (unfold_thms_tac ctxt (flat [simp_thms, map_thms, map_defs,
-      xtor_un_fold_defs, xtor_un_fold_o_maps, Rep_o_Abss, prems])))]);
+      xtor_un_fold_defs, xtor_un_fold_o_maps, Rep_o_Abss, prems]))),
+    unfold_once_tac ctxt cache_defs]);
 
-fun mk_xtor_un_fold_tac ctxt n simps =
-  unfold_thms_tac ctxt simps THEN
+fun mk_xtor_un_fold_tac ctxt n simps cache_defs =
+  REPEAT_DETERM (CHANGED (unfold_thms_tac ctxt simps) ORELSE
+    CHANGED (ALLGOALS (unfold_at_most_once_tac ctxt cache_defs))) THEN
   CONJ_WRAP (K (HEADGOAL (rtac ctxt refl))) (1 upto n);
 
 fun mk_xtor_un_fold_transfer_tac ctxt n xtor_un_fold_defs rel_defs fp_un_fold_transfers
-    map_transfers Abs_transfers fp_or_nesting_rel_eqs =
+    map_transfers Abs_transfers fp_or_nesting_rel_eqs cache_defs =
   let
     val unfold = SELECT_GOAL (unfold_thms_tac ctxt fp_or_nesting_rel_eqs);
   in
@@ -87,6 +93,7 @@
         REPEAT_DETERM o (FIRST' [assume_tac ctxt, rtac ctxt @{thm id_transfer},
             rtac ctxt @{thm rel_funD[OF rel_funD[OF comp_transfer]]},
             resolve_tac ctxt fp_un_fold_transfers, resolve_tac ctxt map_transfers,
+            unfold_once_tac ctxt cache_defs,
             resolve_tac ctxt Abs_transfers, rtac ctxt @{thm vimage2p_rel_fun},
             unfold THEN' rtac ctxt @{thm vimage2p_rel_fun}])])
       fp_un_fold_transfers)