src/HOL/Tools/BNF/bnf_fp_n2m.ML
changeset 62684 cb20e8828196
parent 62649 d23be25c0835
child 62689 9b8b3db6ac03
--- a/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Mon Mar 21 21:18:08 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Tue Mar 22 07:18:36 2016 +0100
@@ -8,7 +8,7 @@
 signature BNF_FP_N2M =
 sig
   val construct_mutualized_fp: BNF_Util.fp_kind -> int list -> typ list ->
-    BNF_FP_Def_Sugar.fp_sugar list -> binding list -> (string * sort) list ->
+    (int * BNF_FP_Util.fp_result) list -> binding list -> (string * sort) list ->
     typ list * typ list list -> BNF_Def.bnf list -> BNF_Comp.absT_info list -> local_theory ->
     BNF_FP_Util.fp_result * local_theory
 end;
@@ -47,12 +47,10 @@
     Const (@{const_name map_sum}, fT --> gT --> mk_sumT (fAT, gAT) --> mk_sumT (fBT, gBT)) $ f $ g
   end;
 
-fun construct_mutualized_fp fp mutual_cliques fpTs (fp_sugars : fp_sugar list) bs resBs (resDs, Dss)
-    bnfs (absT_infos : absT_info list) lthy =
+fun construct_mutualized_fp fp mutual_cliques fpTs (fp_results : (int * fp_result) list) bs resBs
+    (resDs, Dss) bnfs (absT_infos : absT_info list) lthy =
   let
-    fun of_fp_res get =
-      map (fn {fp_res, fp_res_index, ...} => nth (get fp_res) fp_res_index) fp_sugars;
-
+    fun of_fp_res get = map (uncurry nth o swap o apsnd get) fp_results;
     fun mk_co_algT T U = case_fp fp (T --> U) (U --> T);
     fun co_swap pair = case_fp fp I swap pair;
     val mk_co_comp = HOLogic.mk_comp o co_swap;
@@ -68,13 +66,9 @@
     val dest_co_productT = case_fp fp HOLogic.dest_prodT dest_sumT;
     val rewrite_comp_comp = case_fp fp @{thm rewriteL_comp_comp} @{thm rewriteR_comp_comp};
 
-    val fp_absT_infos = map #absT_info fp_sugars;
+    val fp_absT_infos = of_fp_res #absT_infos;
     val fp_bnfs = of_fp_res #bnfs;
-    val pre_bnfs = map #pre_bnf fp_sugars;
-    val nesting_bnfss =
-      map (fn sugar => #fp_nesting_bnfs sugar @ #live_nesting_bnfs sugar) fp_sugars;
-    val fp_or_nesting_bnfss = fp_bnfs :: nesting_bnfss;
-    val fp_or_nesting_bnfs = distinct (op = o apply2 T_of_bnf) (flat fp_or_nesting_bnfss);
+    val pre_bnfs = of_fp_res #pre_bnfs;
 
     val fp_absTs = map #absT fp_absT_infos;
     val fp_repTs = map #repT fp_absT_infos;
@@ -130,6 +124,15 @@
     val fp_repAs = map2 mk_rep absATs fp_reps;
     val fp_repBs = map2 mk_rep absBTs fp_reps;
 
+    val typ_subst_nonatomic_sorted = fold_rev (typ_subst_nonatomic o single);
+    val sorted_theta = sort (int_ord o apply2 (Term.size_of_typ o fst)) (fpTs ~~ Xs)
+    val sorted_fpTs = map fst sorted_theta;
+
+    val nesting_bnfs = nesting_bnfs lthy
+      [[map (typ_subst_nonatomic_sorted (rev sorted_theta) o range_type o fastype_of) fp_repAs]]
+      allAs;
+    val fp_or_nesting_bnfs = distinct (op = o apply2 T_of_bnf) (fp_bnfs @ nesting_bnfs);
+
     val (((((phis, phis'), pre_phis), xs), ys), names_lthy) = names_lthy
       |> mk_Frees' "R" phiTs
       ||>> mk_Frees "S" pre_phiTs
@@ -138,9 +141,9 @@
 
     val rels =
       let
-        fun find_rel T As Bs = fp_or_nesting_bnfss
-          |> map (filter_out (curry (op = o apply2 name_of_bnf) BNF_Comp.DEADID_bnf))
-          |> get_first (find_first (fn bnf => Type.could_unify (T_of_bnf bnf, T)))
+        fun find_rel T As Bs = fp_or_nesting_bnfs
+          |> filter_out (curry (op = o apply2 name_of_bnf) BNF_Comp.DEADID_bnf)
+          |> find_first (fn bnf => Type.could_unify (T_of_bnf bnf, T))
           |> Option.map (fn bnf =>
             let val live = live_of_bnf bnf;
             in (mk_rel live As Bs (rel_of_bnf bnf), live) end)
@@ -258,9 +261,7 @@
       |> mk_Frees' "s" rec_strTs;
 
     val co_recs = of_fp_res #xtor_co_recs;
-    val ns = map (length o #Ts o #fp_res) fp_sugars;
-
-    val typ_subst_nonatomic_sorted = fold_rev (typ_subst_nonatomic o single);
+    val ns = map (length o #Ts o snd) fp_results;
 
     fun foldT_of_recT recT =
       let
@@ -288,8 +289,7 @@
         val fold_preTs' = mk_fp_absT_repTs (map subst fold_preTs);
 
         val fold_pre_deads_only_Ts =
-          map (typ_subst_nonatomic_sorted (map (rpair dummyT)
-            (As @ sort (int_ord o apply2 Term.size_of_typ) fpTs))) fold_preTs';
+          map (typ_subst_nonatomic_sorted (map (rpair dummyT) (As @ sorted_fpTs))) fold_preTs';
 
         val (mutual_clique, TUs) =
           map_split dest_co_algT (binder_fun_types (foldT_of_recT (fastype_of approx_rec)))
@@ -481,7 +481,8 @@
     (* These results are half broken. This is deliberate. We care only about those fields that are
        used by "primrec", "primcorecursive", and "datatype_compat". *)
     val fp_res =
-      ({Ts = fpTs, bnfs = of_fp_res #bnfs, dtors = dtors, ctors = ctors,
+      ({Ts = fpTs, bnfs = of_fp_res #bnfs, pre_bnfs = bnfs, absT_infos = absT_infos,
+        dtors = dtors, ctors = ctors,
         xtor_un_folds = co_recs (*wrong*), xtor_co_recs = co_recs,
         xtor_co_induct = xtor_co_induct_thm,
         dtor_ctors = of_fp_res #dtor_ctors (*too general types*),