avoid type inference + tuning
authorblanchet
Mon Sep 10 17:36:02 2012 +0200 (2012-09-10)
changeset 49256df98aeb80a19
parent 49255 2ecc533d6697
child 49257 e9cdacf44cc3
avoid type inference + tuning
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:35:53 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:36:02 2012 +0200
     1.3 @@ -53,11 +53,19 @@
     1.4  val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
     1.5  val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
     1.6  
     1.7 -fun mk_InN_balanced ctxt sum_T Ts t k =
     1.8 +fun mk_InN_balanced sum_T n t k =
     1.9    let
    1.10 -    val u =
    1.11 -      Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} (length Ts) k;
    1.12 -  in singleton (Type_Infer_Context.infer_types ctxt) (Type.constraint sum_T u) end;
    1.13 +    fun repair_types T (Const (s as @{const_name Inl}, _) $ t) = repair_inj_types T s fst t
    1.14 +      | repair_types T (Const (s as @{const_name Inr}, _) $ t) = repair_inj_types T s snd t
    1.15 +      | repair_types _ t = t
    1.16 +    and repair_inj_types T s get t =
    1.17 +      let val T' = get (dest_sumT T) in
    1.18 +        Const (s, T' --> T) $ repair_types T' t
    1.19 +      end;
    1.20 +  in
    1.21 +    Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} n k
    1.22 +    |> repair_types sum_T
    1.23 +  end;
    1.24  
    1.25  val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
    1.26  
    1.27 @@ -231,9 +239,8 @@
    1.28          if member (op =) Cs U then Us else [T]
    1.29        | dest_rec_pair T = [T];
    1.30  
    1.31 -    val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
    1.32 -         (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
    1.33 -          corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
    1.34 +    val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
    1.35 +         (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
    1.36        if lfp then
    1.37          let
    1.38            val y_Tsss =
    1.39 @@ -257,7 +264,7 @@
    1.40              |> mk_Freessss "x" z_Tssss;
    1.41          in
    1.42            (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
    1.43 -           ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
    1.44 +           ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
    1.45          end
    1.46        else
    1.47          let
    1.48 @@ -277,10 +284,10 @@
    1.49                val f_Tsss =
    1.50                  map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
    1.51                val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
    1.52 -            in (f_sum_prod_Ts, f_prod_Tss, f_Tsss, pf_Tss) end;
    1.53 +            in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
    1.54  
    1.55 -          val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
    1.56 -          val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
    1.57 +          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
    1.58 +          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
    1.59  
    1.60            val ((((Free (z, _), cs), pss), gsss), _) =
    1.61              lthy
    1.62 @@ -300,8 +307,8 @@
    1.63              in (pfss, cfsss) end;
    1.64          in
    1.65            ((([], [], []), ([], [], [])),
    1.66 -           ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
    1.67 -            (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
    1.68 +           ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
    1.69 +            (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
    1.70          end;
    1.71  
    1.72      fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
    1.73 @@ -321,10 +328,8 @@
    1.74            ||>> mk_Freess "x" ctr_Tss;
    1.75  
    1.76          val ctr_rhss =
    1.77 -          map2 (fn k => fn xs =>
    1.78 -              fold_rev Term.lambda xs (fld $ mk_InN_balanced no_defs_lthy ctr_sum_prod_T ctr_prod_Ts
    1.79 -                (HOLogic.mk_tuple xs) k))
    1.80 -            ks xss;
    1.81 +          map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
    1.82 +            mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
    1.83  
    1.84          val case_binder = Binding.suffix_name ("_" ^ caseN) b;
    1.85  
    1.86 @@ -429,22 +434,20 @@
    1.87            let
    1.88              val B_to_fpT = C --> fpT;
    1.89  
    1.90 -            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), f_sum_prod_Ts, f_prod_Tss,
    1.91 -                pf_Tss)) =
    1.92 +            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
    1.93                let
    1.94                  val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
    1.95  
    1.96                  val binder = Binding.suffix_name ("_" ^ suf) b;
    1.97  
    1.98 -                fun mk_preds_getters_join c n cps sum_prod_T prod_Ts cfss =
    1.99 +                fun mk_preds_getters_join c n cps sum_prod_T cfss =
   1.100                    Term.lambda c (mk_IfN sum_prod_T cps
   1.101 -                    (map2 (mk_InN_balanced no_defs_lthy sum_prod_T prod_Ts)
   1.102 -                      (map HOLogic.mk_tuple cfss) (1 upto n)));
   1.103 +                    (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
   1.104  
   1.105                  val spec =
   1.106                    mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
   1.107                      Term.list_comb (fp_iter_like,
   1.108 -                      map6 mk_preds_getters_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
   1.109 +                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
   1.110                in (binder, spec) end;
   1.111  
   1.112              val coiter_likes =
   1.113 @@ -550,7 +553,7 @@
   1.114          lthy |> Local_Theory.notes notes |> snd
   1.115        end;
   1.116  
   1.117 -    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
   1.118 +    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, _, ctr_defss, coiter_defs,
   1.119          corec_defs), lthy) =
   1.120        let
   1.121          val z = the_single zs;