simplified N2M code now that 'fold' is no longer used by the sugar layer + use right context in all 'force_typ' calls
authorblanchet
Mon, 03 Mar 2014 23:05:49 +0100
changeset 55894 8f3fe443948a
parent 55893 aed17a173d16
child 55895 74a2758dcbae
child 55896 c78575827f38
simplified N2M code now that 'fold' is no longer used by the sugar layer + use right context in all 'force_typ' calls
src/HOL/Tools/BNF/bnf_fp_n2m.ML
--- a/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Mon Mar 03 23:05:30 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Mon Mar 03 23:05:49 2014 +0100
@@ -228,13 +228,11 @@
     val fold_preTs = map2 (fn Ds => mk_T_of_bnf Ds allAs) Dss bnfs;
     val rec_preTs = map (Term.typ_subst_atomic rec_theta) fold_preTs;
 
-    val fold_strTs = map2 mk_co_algT fold_preTs Xs;
     val rec_strTs = map2 mk_co_algT rec_preTs Xs;
     val resTs = map2 mk_co_algT fpTs Xs;
 
-    val (((fold_strs, fold_strs'), (rec_strs, rec_strs')), names_lthy) = names_lthy
-      |> mk_Frees' "s" fold_strTs
-      ||>> mk_Frees' "s" rec_strTs;
+    val ((rec_strs, rec_strs'), names_lthy) = names_lthy
+      |> mk_Frees' "s" rec_strTs;
 
     val co_folds = of_fp_res #xtor_co_folds;
     val co_recs = of_fp_res #xtor_co_recs;
@@ -246,13 +244,12 @@
 
     val typ_subst_nonatomic_sorted = fold_rev (typ_subst_nonatomic o single);
 
-    fun force_iter is_rec i TU TU_rec raw_fold raw_rec =
+    fun force_rec i TU TU_rec raw_fold raw_rec =
       let
         val thy = Proof_Context.theory_of lthy;
 
         val approx_fold = raw_fold
-          |> force_typ names_lthy
-            (replicate (nth ns i) dummyT ---> (if is_rec then TU_rec else TU));
+          |> force_typ names_lthy (replicate (nth ns i) dummyT ---> TU_rec);
         val subst = Term.typ_subst_atomic fold_thetaAs;
 
         fun mk_fp_absT_repT fp_repT fp_absT = mk_absT thy fp_repT fp_absT ooo mk_repT;
@@ -271,31 +268,28 @@
         val js = find_indices Type.could_unify TUs cands;
         val Tpats = map (fn j => mk_co_algT (nth fold_pre_deads_only_Ts j) (nth Xs j)) js;
       in
-        force_typ names_lthy (Tpats ---> TU) (if is_rec then raw_rec else raw_fold)
+        force_typ names_lthy (Tpats ---> TU) raw_rec
       end;
 
     fun mk_co_comp_abs_rep fp_absT absT fp_abs fp_rep abs rep t =
       fp_case fp (HOLogic.mk_comp (HOLogic.mk_comp (t, mk_abs absT abs), mk_rep fp_absT fp_rep))
         (HOLogic.mk_comp (mk_abs fp_absT fp_abs, HOLogic.mk_comp (mk_rep absT rep, t)));
 
-    fun mk_iter b_opt is_rec iters lthy TU =
+    fun mk_rec b_opt recs lthy TU =
       let
         val thy = Proof_Context.theory_of lthy;
 
         val x = co_alg_argT TU;
         val i = find_index (fn T => x = T) Xs;
-        val TUiter =
-          (case find_first (fn f => body_fun_type (fastype_of f) = TU) iters of
+        val TUrec =
+          (case find_first (fn f => body_fun_type (fastype_of f) = TU) recs of
             NONE => 
-              force_iter is_rec i
-                (TU |> (is_none b_opt andalso not is_rec) ? substT (fpTs ~~ Xs))
+              force_rec i TU
                 (TU |> is_none b_opt ? substT (map2 mk_co_productT fpTs Xs ~~ Xs))
                 (nth co_folds i) (nth co_recs i)
           | SOME f => f);
 
-        val TUs = binder_fun_types (fastype_of TUiter);
-        val iter_preTs = if is_rec then rec_preTs else fold_preTs;
-        val iter_strs = if is_rec then rec_strs else fold_strs;
+        val TUs = binder_fun_types (fastype_of TUrec);
 
         fun mk_s TU' =
           let
@@ -310,8 +304,8 @@
             val sF' =
               mk_absT_fp_repT (nth repTs' i) (nth absTs' i) (nth fp_absTs i) (nth fp_repTs i) sF
                 handle Term.TYPE _ => sF;
-            val F = nth iter_preTs i;
-            val s = nth iter_strs i;
+            val F = nth rec_preTs i;
+            val s = nth rec_strs i;
           in
             if sF = F then s
             else if sF' = F then mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep s
@@ -327,67 +321,53 @@
                 fun mk_smap_arg TU =
                   (if domain_type TU = range_type TU then
                     HOLogic.id_const (domain_type TU)
-                  else if is_rec then
+                  else
                     let
                       val (TY, (U, X)) = TU |> dest_co_algT ||> dest_co_productT;
                       val T = mk_co_algT TY U;
                     in
-                      (case try (force_typ lthy T o build_map lthy co_proj1_const o dest_funT) T of
+                      (case try (force_typ names_lthy T o build_map lthy co_proj1_const o dest_funT) T of
                         SOME f => mk_co_product f
-                          (fst (fst (mk_iter NONE is_rec iters lthy (mk_co_algT TY X))))
+                          (fst (fst (mk_rec NONE recs lthy (mk_co_algT TY X))))
                       | NONE => mk_map_co_product
                           (build_map lthy co_proj1_const
                             (dest_funT (mk_co_algT (dest_co_productT TY |> fst) U)))
                           (HOLogic.id_const X))
-                    end
-                  else
-                    fst (fst (mk_iter NONE is_rec iters lthy TU)))
+                    end)
                 val smap_args = map mk_smap_arg smap_argTs;
               in
                 mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep
                   (mk_co_comp (s, Term.list_comb (smap, smap_args)))
               end
           end;
-        val t = Term.list_comb (TUiter, map mk_s TUs);
+        val t = Term.list_comb (TUrec, map mk_s TUs);
       in
         (case b_opt of
           NONE => ((t, Drule.dummy_thm), lthy)
         | SOME b => Local_Theory.define ((b, NoSyn), ((Binding.conceal (Thm.def_binding b), []),
-            fold_rev Term.absfree (if is_rec then rec_strs' else fold_strs') t)) lthy |>> apsnd snd)
+            fold_rev Term.absfree rec_strs' t)) lthy |>> apsnd snd)
       end;
 
-    fun mk_iters is_rec name lthy =
-      fold2 (fn TU => fn b => fn ((iters, defs), lthy) =>
-        mk_iter (SOME b) is_rec iters lthy TU |>> (fn (f, d) => (f :: iters, d :: defs)))
-      resTs (map (Binding.suffix_name ("_" ^ name)) bs) (([], []), lthy)
+    val recN = fp_case fp ctor_recN dtor_corecN;
+    fun mk_recs lthy =
+      fold2 (fn TU => fn b => fn ((recs, defs), lthy) =>
+        mk_rec (SOME b) recs lthy TU |>> (fn (f, d) => (f :: recs, d :: defs)))
+      resTs (map (Binding.suffix_name ("_" ^ recN)) bs) (([], []), lthy)
       |>> apfst rev o apsnd rev;
-    val foldN = fp_case fp ctor_foldN dtor_unfoldN;
-    val recN = fp_case fp ctor_recN dtor_corecN;
-    val (((raw_un_folds, raw_un_fold_defs), (raw_co_recs, raw_co_rec_defs)), (lthy, raw_lthy)) =
-      lthy
-      |> mk_iters false foldN
-      ||>> mk_iters true recN
+    val ((raw_co_recs, raw_co_rec_defs), (lthy, raw_lthy)) = lthy
+      |> mk_recs
       ||> `Local_Theory.restore;
 
     val phi = Proof_Context.export_morphism raw_lthy lthy;
 
-    val un_folds = map (Morphism.term phi) raw_un_folds;
     val co_recs = map (Morphism.term phi) raw_co_recs;
 
-    val fp_fold_o_maps = of_fp_res #xtor_co_fold_o_map_thms;
     val fp_rec_o_maps = of_fp_res #xtor_co_rec_o_map_thms;
 
-    val (xtor_un_fold_thms, xtor_co_rec_thms) =
+    val xtor_co_rec_thms =
       let
-        val folds = map (fn f => Term.list_comb (f, fold_strs)) raw_un_folds;
         val recs = map (fn r => Term.list_comb (r, rec_strs)) raw_co_recs;
-        val fold_mapTs = co_swap (As @ fpTs, As @ Xs);
         val rec_mapTs = co_swap (As @ fpTs, As @ map2 mk_co_productT fpTs Xs);
-        val pre_fold_maps =
-          map2 (fn Ds => fn bnf =>
-            Term.list_comb (uncurry (mk_map_of_bnf Ds) fold_mapTs bnf,
-              map HOLogic.id_const As @ folds))
-          Dss bnfs;
         val pre_rec_maps =
           map2 (fn Ds => fn bnf =>
             Term.list_comb (uncurry (mk_map_of_bnf Ds) rec_mapTs bnf,
@@ -404,26 +384,13 @@
                 fp_abs fp_rep abs rep rhs)
           end;
 
-        val fold_goals =
-          map8 mk_goals folds xtors fold_strs pre_fold_maps fp_abss fp_reps abss reps;
-        val rec_goals = map8 mk_goals recs xtors rec_strs pre_rec_maps fp_abss fp_reps abss reps;
-
-        fun mk_thms ss goals tac =
-          Library.foldr1 HOLogic.mk_conj goals
-          |> HOLogic.mk_Trueprop
-          |> fold_rev Logic.all ss
-          |> (fn goal => Goal.prove_sorry raw_lthy [] [] goal tac)
-          |> Thm.close_derivation
-          |> Morphism.thm phi
-          |> split_conj_thm
-          |> map (fn thm => thm RS @{thm comp_eq_dest});
+        val goals = map8 mk_goals recs xtors rec_strs pre_rec_maps fp_abss fp_reps abss reps;
 
         val pre_map_defs = no_refl (map map_def_of_bnf bnfs);
         val fp_pre_map_defs = no_refl (map map_def_of_bnf pre_bnfs);
 
         val unfold_map = map (unfold_thms lthy (id_apply :: pre_map_defs));
 
-        val fp_xtor_un_folds = map (mk_pointfree lthy) (of_fp_res #xtor_co_fold_thms);
         val fp_xtor_co_recs = map (mk_pointfree lthy) (of_fp_res #xtor_co_rec_thms);
 
         val fold_thms = fp_case fp @{thm comp_assoc} @{thm comp_assoc[symmetric]} ::
@@ -450,21 +417,26 @@
         val fp_Rep_o_Abss = map mk_Rep_o_Abs fp_type_definitions;
         val Rep_o_Abss = map mk_Rep_o_Abs type_definitions;
 
-        fun mk_tac defs o_map_thms xtor_thms thms {context = ctxt, prems = _} =
-          unfold_thms_tac ctxt (flat [thms, defs, pre_map_defs, fp_pre_map_defs,
-            xtor_thms, o_map_thms, map_thms, fp_Rep_o_Abss, Rep_o_Abss]) THEN
+        fun tac {context = ctxt, prems = _} =
+          unfold_thms_tac ctxt (flat [rec_thms, raw_co_rec_defs, pre_map_defs, fp_pre_map_defs,
+            fp_xtor_co_recs, fp_rec_o_maps, map_thms, fp_Rep_o_Abss, Rep_o_Abss]) THEN
           CONJ_WRAP (K (HEADGOAL (rtac refl))) bnfs;
-
-        val fold_tac = mk_tac raw_un_fold_defs fp_fold_o_maps fp_xtor_un_folds fold_thms;
-        val rec_tac = mk_tac raw_co_rec_defs fp_rec_o_maps fp_xtor_co_recs rec_thms;
       in
-        (mk_thms fold_strs fold_goals fold_tac, mk_thms rec_strs rec_goals rec_tac)
+        Library.foldr1 HOLogic.mk_conj goals
+        |> HOLogic.mk_Trueprop
+        |> fold_rev Logic.all rec_strs
+        |> (fn goal => Goal.prove_sorry raw_lthy [] [] goal tac)
+        |> Thm.close_derivation
+        |> Morphism.thm phi
+        |> split_conj_thm
+        |> map (fn thm => thm RS @{thm comp_eq_dest})
       end;
 
     (* These results are half broken. This is deliberate. We care only about those fields that are
        used by "primrec", "primcorecursive", and "datatype_compat". *)
     val fp_res =
-      ({Ts = fpTs, bnfs = of_fp_res #bnfs, dtors = dtors, ctors = ctors, xtor_co_folds = un_folds,
+      ({Ts = fpTs, bnfs = of_fp_res #bnfs, dtors = dtors, ctors = ctors,
+        xtor_co_folds = co_recs (*theorems about wrong constants*),
         xtor_co_recs = co_recs, xtor_co_induct = xtor_co_induct_thm,
         dtor_ctors = of_fp_res #dtor_ctors (*too general types*),
         ctor_dtors = of_fp_res #ctor_dtors (*too general types*),
@@ -473,9 +445,9 @@
         xtor_map_thms = of_fp_res #xtor_map_thms (*too general types and terms*),
         xtor_set_thmss = of_fp_res #xtor_set_thmss (*too general types and terms*),
         xtor_rel_thms = of_fp_res #xtor_rel_thms (*too general types and terms*),
-        xtor_co_fold_thms = xtor_un_fold_thms,
+        xtor_co_fold_thms = xtor_co_rec_thms (*theorems about wrong constants*),
         xtor_co_rec_thms = xtor_co_rec_thms,
-        xtor_co_fold_o_map_thms = fp_fold_o_maps (*theorems about old constants*),
+        xtor_co_fold_o_map_thms = fp_rec_o_maps (*theorems about wrong, old constants*),
         xtor_co_rec_o_map_thms = fp_rec_o_maps (*theorems about old constants*),
         rel_xtor_co_induct_thm = rel_xtor_co_induct_thm}
        |> morph_fp_result (Morphism.term_morphism "BNF" (singleton (Variable.polymorphic lthy))));