--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Thu Sep 06 02:56:21 2012 +0200
@@ -52,9 +52,6 @@
fun args_of ((_, args), _) = args;
fun mixfix_of_ctr (_, mx) = mx;
-val uncurry_fs =
- map2 (fn f => fn xs => HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs)));
-
fun prepare_data prepare_typ gfp specs fake_lthy lthy =
let
val constrained_As =
@@ -75,7 +72,7 @@
As);
val bs = map type_binder_of specs;
- val Ts = map mk_T bs;
+ val fp_Ts = map mk_T bs;
val mixfixes = map mixfix_of_typ specs;
@@ -98,35 +95,35 @@
| A' :: _ => error ("Extra type variables on rhs: " ^
quote (Syntax.string_of_typ lthy (TFree A'))));
- val (Bs, C) =
+ val ((Cs, Xs), _) =
lthy
|> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
|> mk_TFrees N
- ||> the_single o fst o mk_TFrees 1;
+ ||>> mk_TFrees N;
- fun is_same_rec (T as Type (s, Us)) (Type (s', Us')) =
+ fun is_same_recT (T as Type (s, Us)) (Type (s', Us')) =
s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
quote (Syntax.string_of_typ fake_lthy T)))
- | is_same_rec _ _ = false
+ | is_same_recT _ _ = false;
- fun freeze_rec (T as Type (s, Us)) =
- (case find_index (is_same_rec T) Ts of
- ~1 => Type (s, map freeze_rec Us)
- | i => nth Bs i)
- | freeze_rec T = T;
+ fun freeze_recXs (T as Type (s, Us)) =
+ (case find_index (is_same_recT T) fp_Ts of
+ ~1 => Type (s, map freeze_recXs Us)
+ | i => nth Xs i)
+ | freeze_recXs T = T;
- val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
- val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
+ val ctr_TsssXs = map (map (map freeze_recXs)) ctr_Tsss;
+ val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
- val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
+ val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
- val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') =
+ val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') =
fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
val timer = time (Timer.startRealTimer ());
- fun mk_unf_or_fld get_foldedT Ts t =
- let val Type (_, Ts0) = get_foldedT (fastype_of t) in
+ fun mk_unf_or_fld get_T Ts t =
+ let val Type (_, Ts0) = get_T (fastype_of t) in
Term.subst_atomic_types (Ts0 ~~ Ts) t
end;
@@ -136,10 +133,23 @@
val unfs = map (mk_unf As) raw_unfs;
val flds = map (mk_fld As) raw_flds;
- fun pour_sugar_on_type (((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject),
- ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
+ fun mk_fp_iter_or_rec Ts Us t =
let
- val n = length ctr_binders;
+ val (binders, body) = strip_type (fastype_of t);
+ val Type (_, Ts0) = if gfp then body else List.last binders;
+ val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
+ in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+ val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
+ val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
+
+ fun pour_sugar_on_type ((((((((((((((b, fp_T), C), fld), unf), fp_iter), fp_rec), fld_unf),
+ unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
+ no_defs_lthy =
+ let
+ val n = length ctr_Tss;
val ks = 1 upto n;
val ms = map length ctr_Tss;
@@ -147,11 +157,11 @@
val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
- val ((((fs, u), v), xss), _) =
+ val ((((u, v), fs), xss), _) =
lthy
- |> mk_Frees "f" case_Ts
- ||>> yield_singleton (mk_Frees "u") unf_T
- ||>> yield_singleton (mk_Frees "v") T
+ |> yield_singleton (mk_Frees "u") unf_T
+ ||>> yield_singleton (mk_Frees "v") fp_T
+ ||>> mk_Frees "f" case_Ts
||>> mk_Freess "x" ctr_Tss;
val ctr_rhss =
@@ -161,7 +171,7 @@
val case_binder = Binding.suffix_name ("_" ^ caseN) b;
val case_rhs =
- fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
+ fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy
|> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
@@ -189,8 +199,8 @@
(mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
in
Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
- mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld)
- (certify lthy unf) fld_unf unf_fld)
+ mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, fp_T])
+ (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
|> Thm.close_derivation
|> Morphism.thm phi
end;
@@ -219,24 +229,30 @@
val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
+ (* (co)iterators, (co)recursors, (co)induction *)
+
+ val is_recT = member (op =) fp_Ts;
+
+ val ns = map length ctr_Tsss;
+ val mss = map (map length) ctr_Tsss;
+ val Css = map2 replicate ns Cs;
+
fun sugar_lfp lthy =
let
-(*###
- val fld_iter = @{term True}; (*###*)
+ val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
+ val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
+ val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
+ val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+ val iter_T = flat g_Tss ---> fp_T --> C;
- val iter_Tss = map (fn Ts => Ts) (*###*) ctr_Tss;
- val iter_Ts = map (fn Ts => Ts ---> C) iter_Tss;
-
- val iter_fs = map2 (fn Free (s, _) => fn T => Free (s, T)) fs iter_Ts
+ val ((gss, ysss), _) =
+ lthy
+ |> mk_Freess "f" g_Tss
+ ||>> apfst (unflat y_Tsss) o mk_Freess "x" (flat y_Tsss);
val iter_rhs =
- fold_rev Term.lambda fs (fld_iter $ mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
-
-
- val uncurried_fs =
- map2 (fn f => fn xs =>
- HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
-*)
+ fold_rev (fold_rev Term.lambda) gss
+ (Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
in
lthy
end;
@@ -248,8 +264,9 @@
end;
val lthy'' =
- fold pour_sugar_on_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
- ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) lthy';
+ fold pour_sugar_on_type (bs ~~ fp_Ts ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
+ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
+ disc_binderss ~~ sel_bindersss) lthy';
val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
(if gfp then "co" else "") ^ "datatype"));
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML Thu Sep 06 02:56:21 2012 +0200
@@ -88,6 +88,11 @@
val mk_sum_case: term -> term -> term
val mk_sum_caseN: term list -> term
+ val dest_sumTN: int -> typ -> typ list
+ val dest_tupleT: int -> typ -> typ list
+
+ val mk_uncurried_fun: term -> term list -> term
+
val mk_Field: term -> term
val mk_union: term * term -> term
@@ -219,6 +224,16 @@
fun mk_sum_caseN [f] = f
| mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
+fun dest_sumTN 1 T = [T]
+ | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
+
+(* TODO: move something like this to "HOLogic"? *)
+fun dest_tupleT 0 @{typ unit} = []
+ | dest_tupleT 1 T = [T]
+ | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
+
+fun mk_uncurried_fun f xs = HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs));
+
fun mk_Field r =
let val T = fst (dest_relT (fastype_of r));
in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp.ML Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML Thu Sep 06 02:56:21 2012 +0200
@@ -11,7 +11,7 @@
sig
val bnf_gfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
BNF_Def.BNF list -> local_theory ->
- (term list * term list * thm list * thm list * thm list) * local_theory
+ (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
end;
structure BNF_GFP : BNF_GFP =
@@ -1965,8 +1965,9 @@
(*transforms defined frees into consts*)
val phi = Proof_Context.export_morphism lthy_old lthy;
- val coiters = map (fst o dest_Const o Morphism.term phi) coiter_frees;
- fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiters (i - 1), Library.foldr (op -->)
+ val coiters = map (Morphism.term phi) coiter_frees;
+ val coiter_names = map (fst o dest_Const) coiters;
+ fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiter_names (i - 1), Library.foldr (op -->)
(map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
val coiter_defs = map ((fn thm => thm RS fun_cong) o Morphism.thm phi) coiter_def_frees;
@@ -2158,8 +2159,9 @@
(*transforms defined frees into consts*)
val phi = Proof_Context.export_morphism lthy_old lthy;
- val corecs = map (fst o dest_Const o Morphism.term phi) corec_frees;
- fun mk_corec ss i = Term.list_comb (Const (nth corecs (i - 1), Library.foldr (op -->)
+ val corecs = map (Morphism.term phi) corec_frees;
+ val corec_names = map (fst o dest_Const) corecs;
+ fun mk_corec ss i = Term.list_comb (Const (nth corec_names (i - 1), Library.foldr (op -->)
(map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
val corec_defs = map (Morphism.thm phi) corec_def_frees;
@@ -2990,7 +2992,7 @@
((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
bs thmss)
in
- ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
+ ((unfs, flds, coiters, corecs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
end;
--- a/src/HOL/Codatatype/Tools/bnf_lfp.ML Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML Thu Sep 06 02:56:21 2012 +0200
@@ -10,7 +10,7 @@
sig
val bnf_lfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
BNF_Def.BNF list -> local_theory ->
- (term list * term list * thm list * thm list * thm list) * local_theory
+ (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
end;
structure BNF_LFP : BNF_LFP =
@@ -1078,8 +1078,9 @@
(*transforms defined frees into consts*)
val phi = Proof_Context.export_morphism lthy_old lthy;
- val iters = map (fst o dest_Const o Morphism.term phi) iter_frees;
- fun mk_iter Ts ss i = Term.list_comb (Const (nth iters (i - 1), Library.foldr (op -->)
+ val iters = map (Morphism.term phi) iter_frees;
+ val iter_names = map (fst o dest_Const) iters;
+ fun mk_iter Ts ss i = Term.list_comb (Const (nth iter_names (i - 1), Library.foldr (op -->)
(map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
val iter_defs = map (Morphism.thm phi) iter_def_frees;
@@ -1239,8 +1240,9 @@
(*transforms defined frees into consts*)
val phi = Proof_Context.export_morphism lthy_old lthy;
- val recs = map (fst o dest_Const o Morphism.term phi) rec_frees;
- fun mk_rec ss i = Term.list_comb (Const (nth recs (i - 1), Library.foldr (op -->)
+ val recs = map (Morphism.term phi) rec_frees;
+ val rec_names = map (fst o dest_Const) recs;
+ fun mk_rec ss i = Term.list_comb (Const (nth rec_names (i - 1), Library.foldr (op -->)
(map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
val rec_defs = map (Morphism.thm phi) rec_def_frees;
@@ -1813,7 +1815,7 @@
((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
bs thmss)
in
- ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
+ ((unfs, flds, iters, recs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
end;
--- a/src/HOL/Codatatype/Tools/bnf_util.ML Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML Thu Sep 06 02:56:21 2012 +0200
@@ -273,8 +273,8 @@
fun mk_Frees x Ts ctxt = mk_fresh_names ctxt (length Ts) x
|>> (fn names => map2 (curry Free) names Ts);
fun mk_Freess x Tss ctxt =
- fold_map2 (fn name => fn Ts => fn ctxt =>
- mk_fresh_names ctxt (length Ts) name) (mk_names (length Tss) x) Tss ctxt
+ fold_map2 (fn name => fn Ts => fn ctxt => mk_fresh_names ctxt (length Ts) name)
+ (mk_names (length Tss) x) Tss ctxt
|>> (fn namess => map2 (map2 (curry Free)) namess Tss);
fun mk_Frees' x Ts ctxt = mk_fresh_names ctxt (length Ts) x
|>> (fn names => `(map Free) (names ~~ Ts));