src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49200 73f9aede57a4
parent 49199 7c9a3c67c55d
child 49201 c69c2c18dccb
equal deleted inserted replaced
49199:7c9a3c67c55d 49200:73f9aede57a4
    20 open BNF_FP_Sugar_Tactics
    20 open BNF_FP_Sugar_Tactics
    21 
    21 
    22 val caseN = "case";
    22 val caseN = "case";
    23 
    23 
    24 fun retype_free (Free (s, _)) T = Free (s, T);
    24 fun retype_free (Free (s, _)) T = Free (s, T);
       
    25 
       
    26 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
       
    27 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
       
    28 fun mk_doubly_uncurried_fun f xss =
       
    29   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
    25 
    30 
    26 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    31 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    27 
    32 
    28 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    33 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    29   if T = T' then
    34   if T = T' then
   240 
   245 
   241         val ns = map length ctr_Tsss;
   246         val ns = map length ctr_Tsss;
   242         val mss = map (map length) ctr_Tsss;
   247         val mss = map (map length) ctr_Tsss;
   243         val Css = map2 replicate ns Cs;
   248         val Css = map2 replicate ns Cs;
   244 
   249 
       
   250         fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
       
   251             if member (op =) Cs U then Us else [T]
       
   252           | dest_rec_pair T = [T];
       
   253 
   245         fun sugar_datatype no_defs_lthy =
   254         fun sugar_datatype no_defs_lthy =
   246           let
   255           let
   247             val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
   256             val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
   248             val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
   257             val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
   249             val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
   258             val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
   251             val iter_T = flat g_Tss ---> fp_T --> C;
   260             val iter_T = flat g_Tss ---> fp_T --> C;
   252 
   261 
   253             val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
   262             val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
   254             val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
   263             val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
   255             val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
   264             val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
   256             val h_Tss = map2 (map2 (curry (op --->))) z_Tsss Css;
   265             val z_Tssss = map (map (map dest_rec_pair)) z_Tsss;
       
   266             val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
   257             val rec_T = flat h_Tss ---> fp_T --> C;
   267             val rec_T = flat h_Tss ---> fp_T --> C;
   258 
   268 
   259             val ((gss, ysss), _) =
   269             val ((gss, ysss), _) =
   260               no_defs_lthy
   270               no_defs_lthy
   261               |> mk_Freess "f" g_Tss
   271               |> mk_Freess "f" g_Tss
   262               ||>> mk_Freesss "x" y_Tsss;
   272               ||>> mk_Freesss "x" y_Tsss;
   263 
   273 
   264             val hss = map2 (map2 retype_free) gss h_Tss;
   274             val hss = map2 (map2 retype_free) gss h_Tss;
   265             val (zsss, _) =
   275             val (zssss, _) =
   266               no_defs_lthy
   276               no_defs_lthy
   267               |> mk_Freesss "x" z_Tsss;
   277               |> mk_Freessss "x" z_Tssss;
   268 
   278 
   269             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
   279             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
   270             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
   280             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
   271 
   281 
   272             val iter_free = Free (Binding.name_of iter_binder, iter_T);
   282             val iter_free = Free (Binding.name_of iter_binder, iter_T);
   275             val iter_spec =
   285             val iter_spec =
   276               mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss iter_free,
   286               mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss iter_free,
   277                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
   287                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
   278             val rec_spec =
   288             val rec_spec =
   279               mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free,
   289               mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free,
   280                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) hss zsss));
   290                 Term.list_comb (fp_rec,
       
   291                   map2 (mk_sum_caseN oo map2 mk_doubly_uncurried_fun) hss zssss));
   281 
   292 
   282             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
   293             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
   283               |> apfst split_list o fold_map (fn (b, spec) =>
   294               |> apfst split_list o fold_map (fn (b, spec) =>
   284                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   295                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   285                 #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)]
   296                 #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)]