avoid type inference + tuning
authorblanchet
Mon, 10 Sep 2012 17:36:02 +0200
changeset 49256 df98aeb80a19
parent 49255 2ecc533d6697
child 49257 e9cdacf44cc3
avoid type inference + tuning
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:35:53 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:36:02 2012 +0200
@@ -53,11 +53,19 @@
 val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
 val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
 
-fun mk_InN_balanced ctxt sum_T Ts t k =
+fun mk_InN_balanced sum_T n t k =
   let
-    val u =
-      Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} (length Ts) k;
-  in singleton (Type_Infer_Context.infer_types ctxt) (Type.constraint sum_T u) end;
+    fun repair_types T (Const (s as @{const_name Inl}, _) $ t) = repair_inj_types T s fst t
+      | repair_types T (Const (s as @{const_name Inr}, _) $ t) = repair_inj_types T s snd t
+      | repair_types _ t = t
+    and repair_inj_types T s get t =
+      let val T' = get (dest_sumT T) in
+        Const (s, T' --> T) $ repair_types T' t
+      end;
+  in
+    Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} n k
+    |> repair_types sum_T
+  end;
 
 val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
 
@@ -231,9 +239,8 @@
         if member (op =) Cs U then Us else [T]
       | dest_rec_pair T = [T];
 
-    val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
-         (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
-          corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
+    val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
+         (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
       if lfp then
         let
           val y_Tsss =
@@ -257,7 +264,7 @@
             |> mk_Freessss "x" z_Tssss;
         in
           (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
-           ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
+           ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
         end
       else
         let
@@ -277,10 +284,10 @@
               val f_Tsss =
                 map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
               val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
-            in (f_sum_prod_Ts, f_prod_Tss, f_Tsss, pf_Tss) end;
+            in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
 
-          val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
-          val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
+          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
+          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
 
           val ((((Free (z, _), cs), pss), gsss), _) =
             lthy
@@ -300,8 +307,8 @@
             in (pfss, cfsss) end;
         in
           ((([], [], []), ([], [], [])),
-           ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
-            (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
+           ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
+            (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
         end;
 
     fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
@@ -321,10 +328,8 @@
           ||>> mk_Freess "x" ctr_Tss;
 
         val ctr_rhss =
-          map2 (fn k => fn xs =>
-              fold_rev Term.lambda xs (fld $ mk_InN_balanced no_defs_lthy ctr_sum_prod_T ctr_prod_Ts
-                (HOLogic.mk_tuple xs) k))
-            ks xss;
+          map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
+            mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
 
         val case_binder = Binding.suffix_name ("_" ^ caseN) b;
 
@@ -429,22 +434,20 @@
           let
             val B_to_fpT = C --> fpT;
 
-            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), f_sum_prod_Ts, f_prod_Tss,
-                pf_Tss)) =
+            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
               let
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
 
                 val binder = Binding.suffix_name ("_" ^ suf) b;
 
-                fun mk_preds_getters_join c n cps sum_prod_T prod_Ts cfss =
+                fun mk_preds_getters_join c n cps sum_prod_T cfss =
                   Term.lambda c (mk_IfN sum_prod_T cps
-                    (map2 (mk_InN_balanced no_defs_lthy sum_prod_T prod_Ts)
-                      (map HOLogic.mk_tuple cfss) (1 upto n)));
+                    (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
 
                 val spec =
                   mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
                     Term.list_comb (fp_iter_like,
-                      map6 mk_preds_getters_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
+                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
               in (binder, spec) end;
 
             val coiter_likes =
@@ -550,7 +553,7 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
-    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
+    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, _, ctr_defss, coiter_defs,
         corec_defs), lthy) =
       let
         val z = the_single zs;