src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
changeset 54256 4843082be7ef
parent 54255 4f7c016d5bc6
child 54265 3e1d230f1c00
--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Tue Nov 05 05:48:08 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Tue Nov 05 05:48:08 2013 +0100
@@ -37,6 +37,32 @@
 
 val n2mN = "n2m_"
 
+type n2m_sugar = fp_sugar list * (lfp_sugar_thms option * gfp_sugar_thms option);
+
+structure Data = Generic_Data
+(
+  type T = n2m_sugar Typtab.table;
+  val empty = Typtab.empty;
+  val extend = I;
+  val merge = Typtab.merge (eq_fst (eq_list eq_fp_sugar));
+);
+
+fun morph_n2m_sugar phi (fp_sugars, (lfp_sugar_thms_opt, gfp_sugar_thms_opt)) =
+  (map (morph_fp_sugar phi) fp_sugars,
+   (Option.map (morph_lfp_sugar_thms phi) lfp_sugar_thms_opt,
+    Option.map (morph_gfp_sugar_thms phi) gfp_sugar_thms_opt));
+
+val transfer_n2m_sugar =
+  morph_n2m_sugar o Morphism.thm_morphism o Thm.transfer o Proof_Context.theory_of;
+
+fun n2m_sugar_of ctxt =
+  Typtab.lookup (Data.get (Context.Proof ctxt))
+  #> Option.map (transfer_n2m_sugar ctxt);
+
+fun register_n2m_sugar key n2m_sugar =
+  Local_Theory.declaration {syntax = false, pervasive = false}
+    (fn phi => Data.map (Typtab.default (key, morph_n2m_sugar phi n2m_sugar)));
+
 fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
   | unfold_let (Const (@{const_name prod_case}, _) $ t) =
     (case unfold_let t of
@@ -93,6 +119,9 @@
       case f x of SOME y => (y :: ys, (x :: good, bad)) | NONE => (ys, (good, x :: bad)))
     xs ([], ([], []));
 
+fun key_of_fp_eqs fp fpTs fp_eqs =
+  Type (fp_case fp "l" "g", fpTs @ maps (fn (z, T) => [TFree z, T]) fp_eqs);
+
 (* TODO: test with sort constraints on As *)
 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
    as deads? *)
@@ -168,75 +197,82 @@
       val mss = map (map length) ctr_Tsss;
 
       val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
-
-      val base_fp_names = Name.variant_list [] fp_b_names;
-      val fp_bs = map2 (fn b_name => fn base_fp_name =>
-          Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
-        b_names base_fp_names;
+      val key = key_of_fp_eqs fp fpTs fp_eqs;
+    in
+      (case n2m_sugar_of no_defs_lthy key of
+        SOME n2m_sugar => (n2m_sugar, no_defs_lthy)
+      | NONE =>
+        let
+          val base_fp_names = Name.variant_list [] fp_b_names;
+          val fp_bs = map2 (fn b_name => fn base_fp_name =>
+              Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
+            b_names base_fp_names;
 
-      val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
-             dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
-        fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
+          val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
+                 dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
+            fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
 
-      val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
-      val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
-
-      val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
-        mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
+          val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
+          val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
 
-      fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
+          val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
+            mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
+
+          fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
 
-      val ((co_iterss, co_iter_defss), lthy) =
-        fold_map2 (fn b =>
-          (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
-           else define_coiters [unfoldN, corecN] (the coiters_args_types))
-            (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
-        |>> split_list;
+          val ((co_iterss, co_iter_defss), lthy) =
+            fold_map2 (fn b =>
+              (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
+               else define_coiters [unfoldN, corecN] (the coiters_args_types))
+                (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
+            |>> split_list;
 
-      val rho = tvar_subst thy Ts fpTs;
-      val ctr_sugar_phi =
-        Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
-          (Morphism.term_morphism (Term.subst_TVars rho));
-      val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
-
-      val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
+          val rho = tvar_subst thy Ts fpTs;
+          val ctr_sugar_phi =
+            Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
+              (Morphism.term_morphism (Term.subst_TVars rho));
+          val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
 
-      val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
-            sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
-        if fp = Least_FP then
-          derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
-            xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
-            co_iterss co_iter_defss lthy
-          |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
-            ([induct], fold_thmss, rec_thmss, [], [], [], []))
-          ||> (fn info => (SOME info, NONE))
-        else
-          derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
-            dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs ctrXs_Tsss kss mss ns
-            ctr_defss ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy)
-            lthy
-          |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
-                  (disc_unfold_thmss, disc_corec_thmss, _), _,
-                  (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
-            (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
-             disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
-          ||> (fn info => (NONE, SOME info));
+          val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
 
-      val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
+          val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
+                sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
+            if fp = Least_FP then
+              derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
+                xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
+                co_iterss co_iter_defss lthy
+              |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
+                ([induct], fold_thmss, rec_thmss, [], [], [], []))
+              ||> (fn info => (SOME info, NONE))
+            else
+              derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types)
+                xtor_co_induct dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs
+                ctrXs_Tsss kss mss ns ctr_defss ctr_sugars co_iterss co_iter_defss
+                (Proof_Context.export lthy no_defs_lthy) lthy
+              |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
+                      (disc_unfold_thmss, disc_corec_thmss, _), _,
+                      (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
+                (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
+                 disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
+              ||> (fn info => (NONE, SOME info));
 
-      fun mk_target_fp_sugar (kk, T) =
-        {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
-         nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
-         ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
-         co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
-         disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
-         sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
-        |> morph_fp_sugar phi;
-    in
-      ((map_index mk_target_fp_sugar fpTs, fp_sugar_thms), lthy)
+          val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
+
+          fun mk_target_fp_sugar (kk, T) =
+            {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
+             nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
+             ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
+             co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
+             disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
+             sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
+            |> morph_fp_sugar phi;
+
+          val n2m_sugar = (map_index mk_target_fp_sugar fpTs, fp_sugar_thms);
+        in
+          (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
+        end)
     end
   else
-    (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
     ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
 
 fun indexify_callsss fp_sugar callsss =
@@ -299,7 +335,7 @@
     val Ts = actual_Ts @ missing_Ts;
 
     fun generalize_simple_type T (seen, lthy) =
-      mk_TFrees 1 lthy |> (fn ([U], lthy) => (U, ((T, U) :: seen, lthy)));
+      variant_tfrees ["aa"] lthy |> (fn ([U], lthy) => (U, ((T, U) :: seen, lthy)));
 
     fun generalize_type T (seen_lthy as (seen, _)) =
       (case AList.lookup (op =) seen T of