--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
@@ -15,6 +15,7 @@
open BNF_Util
open BNF_Wrap
+open BNF_Def
open BNF_FP_Util
open BNF_LFP
open BNF_GFP
@@ -26,7 +27,14 @@
val itersN = "iters";
val recsN = "recs";
-fun split_list7 xs = (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs);
+fun split_list8 xs =
+ (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
+
+fun typ_subst inst (T as Type (s, Ts)) =
+ (case AList.lookup (op =) inst T of
+ NONE => Type (s, map (typ_subst inst) Ts)
+ | SOME T' => T')
+ | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
fun retype_free (Free (s, _)) T = Free (s, T);
@@ -37,6 +45,8 @@
fun mk_uncurried2_fun f xss =
mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
+fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
+
fun popescu_zip [] [fs] = fs
| popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
@@ -160,14 +170,14 @@
val mss = map (map length) ctr_Tsss;
val Css = map2 replicate ns Cs;
- fun mk_iter_like Ts Us c =
+ fun mk_iter_like Ts Us t =
let
- val (binders, body) = strip_type (fastype_of c);
+ val (binders, body) = strip_type (fastype_of t);
val (f_Us, prebody) = split_last binders;
val Type (_, Ts0) = if lfp then prebody else body;
val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
in
- Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
end;
val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
@@ -359,7 +369,7 @@
val iter = mk_iter_like As Cs iter0;
val recx = mk_iter_like As Cs rec0;
in
- ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
+ ((ctrs, iter, recx, v, xss, ctr_defs, iter_def, rec_def), lthy)
end;
fun some_gfp_sugar no_defs_lthy =
@@ -402,14 +412,19 @@
val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
in
- ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
+ ((ctrs, coiter, corec, v, xss, ctr_defs, coiter_def, corec_def), lthy)
end;
in
wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
|> (if lfp then some_lfp_sugar else some_gfp_sugar)
end;
- fun pour_more_sugar_on_lfps ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
+ fun mk_map Ts Us t =
+ let val (Type (_, Ts0), Type (_, Us0)) = strip_type (fastype_of t) |>> List.last in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+ fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
lthy) =
let
val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
@@ -422,13 +437,40 @@
fold_rev (fold_rev Logic.all) (xs :: fss)
(mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
- fun repair_iter_call (x as Free (_, T)) =
- (case find_index (curry (op =) T) fpTs of ~1 => x | j => nth giters j $ x);
+ fun build_iter_like fiter_likes maybe_tick =
+ let
+ fun build (T, U) =
+ if T = U then
+ Const (@{const_name id}, T --> T)
+ else
+ (case (find_index (curry (op =) T) fpTs, (T, U)) of
+ (~1, (Type (s, Ts), Type (_, Us))) =>
+ let
+ val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+ val mapx = mk_map Ts Us map0;
+ val TUs = map dest_funT (fst (split_last (binder_types (fastype_of mapx))));
+ val args = map build TUs;
+ in Term.list_comb (mapx, args) end
+ | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
+ in build end;
+
+ fun mk_U maybe_prodT =
+ typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
+
+ fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
+ if member (op =) fpTs T then
+ maybe_cons x [build_iter_like fiter_likes (K I) (T, mk_U (K I) T) $ x]
+ else if exists_subtype (member (op =) fpTs) T then
+ [build_iter_like fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+ else
+ [x];
+
fun repair_rec_call (x as Free (_, T)) =
(case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
- val gxsss = map (map (map repair_iter_call)) xsss;
- val hxsss = map (map (maps repair_rec_call)) xsss;
+ val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
+ val hxsss =
+ map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
@@ -455,8 +497,8 @@
lthy |> Local_Theory.notes notes |> snd
end;
- fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, xsss, ctr_defss, coiter_defs, corec_defs),
- lthy) =
+ fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
+ corec_defs), lthy) =
let
val gcoiters = map (lists_bmoc pgss) coiters;
val hcorecs = map (lists_bmoc phss) corecs;
@@ -505,7 +547,7 @@
|> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
- |>> split_list7
+ |>> split_list8
|> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^