# HG changeset patch # User blanchet # Date 1346892981 -7200 # Node ID 6d29d2db5f88c211e53348eb2e51c9293c60b05f # Parent eab51f249c704605f1bd39972e44dc4387495828 construct high-level iterator RHS diff -r eab51f249c70 -r 6d29d2db5f88 src/HOL/Codatatype/Tools/bnf_fp_sugar.ML --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Thu Sep 06 01:37:24 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Thu Sep 06 02:56:21 2012 +0200 @@ -52,9 +52,6 @@ fun args_of ((_, args), _) = args; fun mixfix_of_ctr (_, mx) = mx; -val uncurry_fs = - map2 (fn f => fn xs => HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))); - fun prepare_data prepare_typ gfp specs fake_lthy lthy = let val constrained_As = @@ -75,7 +72,7 @@ As); val bs = map type_binder_of specs; - val Ts = map mk_T bs; + val fp_Ts = map mk_T bs; val mixfixes = map mixfix_of_typ specs; @@ -98,35 +95,35 @@ | A' :: _ => error ("Extra type variables on rhs: " ^ quote (Syntax.string_of_typ lthy (TFree A')))); - val (Bs, C) = + val ((Cs, Xs), _) = lthy |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs |> mk_TFrees N - ||> the_single o fst o mk_TFrees 1; + ||>> mk_TFrees N; - fun is_same_rec (T as Type (s, Us)) (Type (s', Us')) = + fun is_same_recT (T as Type (s, Us)) (Type (s', Us')) = s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^ quote (Syntax.string_of_typ fake_lthy T))) - | is_same_rec _ _ = false + | is_same_recT _ _ = false; - fun freeze_rec (T as Type (s, Us)) = - (case find_index (is_same_rec T) Ts of - ~1 => Type (s, map freeze_rec Us) - | i => nth Bs i) - | freeze_rec T = T; + fun freeze_recXs (T as Type (s, Us)) = + (case find_index (is_same_recT T) fp_Ts of + ~1 => Type (s, map freeze_recXs Us) + | i => nth Xs i) + | freeze_recXs T = T; - val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss; - val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs; + val ctr_TsssXs = map (map (map freeze_recXs)) ctr_Tsss; + val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs; - val eqs = map dest_TFree Bs ~~ sum_prod_TsBs; + val eqs = map dest_TFree Xs ~~ sum_prod_TsXs; - val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') = + val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') = fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy; val timer = time (Timer.startRealTimer ()); - fun mk_unf_or_fld get_foldedT Ts t = - let val Type (_, Ts0) = get_foldedT (fastype_of t) in + fun mk_unf_or_fld get_T Ts t = + let val Type (_, Ts0) = get_T (fastype_of t) in Term.subst_atomic_types (Ts0 ~~ Ts) t end; @@ -136,10 +133,23 @@ val unfs = map (mk_unf As) raw_unfs; val flds = map (mk_fld As) raw_flds; - fun pour_sugar_on_type (((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject), - ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) no_defs_lthy = + fun mk_fp_iter_or_rec Ts Us t = let - val n = length ctr_binders; + val (binders, body) = strip_type (fastype_of t); + val Type (_, Ts0) = if gfp then body else List.last binders; + val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders)); + in + Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t + end; + + val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters; + val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs; + + fun pour_sugar_on_type ((((((((((((((b, fp_T), C), fld), unf), fp_iter), fp_rec), fld_unf), + unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) + no_defs_lthy = + let + val n = length ctr_Tss; val ks = 1 upto n; val ms = map length ctr_Tss; @@ -147,11 +157,11 @@ val prod_Ts = map HOLogic.mk_tupleT ctr_Tss; val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss; - val ((((fs, u), v), xss), _) = + val ((((u, v), fs), xss), _) = lthy - |> mk_Frees "f" case_Ts - ||>> yield_singleton (mk_Frees "u") unf_T - ||>> yield_singleton (mk_Frees "v") T + |> yield_singleton (mk_Frees "u") unf_T + ||>> yield_singleton (mk_Frees "v") fp_T + ||>> mk_Frees "f" case_Ts ||>> mk_Freess "x" ctr_Tss; val ctr_rhss = @@ -161,7 +171,7 @@ val case_binder = Binding.suffix_name ("_" ^ caseN) b; val case_rhs = - fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v)); + fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v)); val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs => @@ -189,8 +199,8 @@ (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u))); in Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} => - mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld) - (certify lthy unf) fld_unf unf_fld) + mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, fp_T]) + (certify lthy fld) (certify lthy unf) fld_unf unf_fld) |> Thm.close_derivation |> Morphism.thm phi end; @@ -219,24 +229,30 @@ val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs]; + (* (co)iterators, (co)recursors, (co)induction *) + + val is_recT = member (op =) fp_Ts; + + val ns = map length ctr_Tsss; + val mss = map (map length) ctr_Tsss; + val Css = map2 replicate ns Cs; + fun sugar_lfp lthy = let -(*### - val fld_iter = @{term True}; (*###*) + val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter)))); + val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts; + val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss; + val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css; + val iter_T = flat g_Tss ---> fp_T --> C; - val iter_Tss = map (fn Ts => Ts) (*###*) ctr_Tss; - val iter_Ts = map (fn Ts => Ts ---> C) iter_Tss; - - val iter_fs = map2 (fn Free (s, _) => fn T => Free (s, T)) fs iter_Ts + val ((gss, ysss), _) = + lthy + |> mk_Freess "f" g_Tss + ||>> apfst (unflat y_Tsss) o mk_Freess "x" (flat y_Tsss); val iter_rhs = - fold_rev Term.lambda fs (fld_iter $ mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v)); - - - val uncurried_fs = - map2 (fn f => fn xs => - HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss; -*) + fold_rev (fold_rev Term.lambda) gss + (Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss)); in lthy end; @@ -248,8 +264,9 @@ end; val lthy'' = - fold pour_sugar_on_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ - ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) lthy'; + fold pour_sugar_on_type (bs ~~ fp_Ts ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~ + fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ + disc_binderss ~~ sel_bindersss) lthy'; val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^ (if gfp then "co" else "") ^ "datatype")); diff -r eab51f249c70 -r 6d29d2db5f88 src/HOL/Codatatype/Tools/bnf_fp_util.ML --- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML Thu Sep 06 01:37:24 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML Thu Sep 06 02:56:21 2012 +0200 @@ -88,6 +88,11 @@ val mk_sum_case: term -> term -> term val mk_sum_caseN: term list -> term + val dest_sumTN: int -> typ -> typ list + val dest_tupleT: int -> typ -> typ list + + val mk_uncurried_fun: term -> term list -> term + val mk_Field: term -> term val mk_union: term * term -> term @@ -219,6 +224,16 @@ fun mk_sum_caseN [f] = f | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs); +fun dest_sumTN 1 T = [T] + | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T'; + +(* TODO: move something like this to "HOLogic"? *) +fun dest_tupleT 0 @{typ unit} = [] + | dest_tupleT 1 T = [T] + | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T'; + +fun mk_uncurried_fun f xs = HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs)); + fun mk_Field r = let val T = fst (dest_relT (fastype_of r)); in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end; diff -r eab51f249c70 -r 6d29d2db5f88 src/HOL/Codatatype/Tools/bnf_gfp.ML --- a/src/HOL/Codatatype/Tools/bnf_gfp.ML Thu Sep 06 01:37:24 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML Thu Sep 06 02:56:21 2012 +0200 @@ -11,7 +11,7 @@ sig val bnf_gfp: binding list -> mixfix list -> (string * sort) list -> typ list list -> BNF_Def.BNF list -> local_theory -> - (term list * term list * thm list * thm list * thm list) * local_theory + (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory end; structure BNF_GFP : BNF_GFP = @@ -1965,8 +1965,9 @@ (*transforms defined frees into consts*) val phi = Proof_Context.export_morphism lthy_old lthy; - val coiters = map (fst o dest_Const o Morphism.term phi) coiter_frees; - fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiters (i - 1), Library.foldr (op -->) + val coiters = map (Morphism.term phi) coiter_frees; + val coiter_names = map (fst o dest_Const) coiters; + fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiter_names (i - 1), Library.foldr (op -->) (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss); val coiter_defs = map ((fn thm => thm RS fun_cong) o Morphism.thm phi) coiter_def_frees; @@ -2158,8 +2159,9 @@ (*transforms defined frees into consts*) val phi = Proof_Context.export_morphism lthy_old lthy; - val corecs = map (fst o dest_Const o Morphism.term phi) corec_frees; - fun mk_corec ss i = Term.list_comb (Const (nth corecs (i - 1), Library.foldr (op -->) + val corecs = map (Morphism.term phi) corec_frees; + val corec_names = map (fst o dest_Const) corecs; + fun mk_corec ss i = Term.list_comb (Const (nth corec_names (i - 1), Library.foldr (op -->) (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss); val corec_defs = map (Morphism.thm phi) corec_def_frees; @@ -2990,7 +2992,7 @@ ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])])) bs thmss) in - ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms), + ((unfs, flds, coiters, corecs, unf_fld_thms, fld_unf_thms, fld_inject_thms), lthy |> Local_Theory.notes (common_notes @ notes) |> snd) end; diff -r eab51f249c70 -r 6d29d2db5f88 src/HOL/Codatatype/Tools/bnf_lfp.ML --- a/src/HOL/Codatatype/Tools/bnf_lfp.ML Thu Sep 06 01:37:24 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML Thu Sep 06 02:56:21 2012 +0200 @@ -10,7 +10,7 @@ sig val bnf_lfp: binding list -> mixfix list -> (string * sort) list -> typ list list -> BNF_Def.BNF list -> local_theory -> - (term list * term list * thm list * thm list * thm list) * local_theory + (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory end; structure BNF_LFP : BNF_LFP = @@ -1078,8 +1078,9 @@ (*transforms defined frees into consts*) val phi = Proof_Context.export_morphism lthy_old lthy; - val iters = map (fst o dest_Const o Morphism.term phi) iter_frees; - fun mk_iter Ts ss i = Term.list_comb (Const (nth iters (i - 1), Library.foldr (op -->) + val iters = map (Morphism.term phi) iter_frees; + val iter_names = map (fst o dest_Const) iters; + fun mk_iter Ts ss i = Term.list_comb (Const (nth iter_names (i - 1), Library.foldr (op -->) (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss); val iter_defs = map (Morphism.thm phi) iter_def_frees; @@ -1239,8 +1240,9 @@ (*transforms defined frees into consts*) val phi = Proof_Context.export_morphism lthy_old lthy; - val recs = map (fst o dest_Const o Morphism.term phi) rec_frees; - fun mk_rec ss i = Term.list_comb (Const (nth recs (i - 1), Library.foldr (op -->) + val recs = map (Morphism.term phi) rec_frees; + val rec_names = map (fst o dest_Const) recs; + fun mk_rec ss i = Term.list_comb (Const (nth rec_names (i - 1), Library.foldr (op -->) (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss); val rec_defs = map (Morphism.thm phi) rec_def_frees; @@ -1813,7 +1815,7 @@ ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])])) bs thmss) in - ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms), + ((unfs, flds, iters, recs, unf_fld_thms, fld_unf_thms, fld_inject_thms), lthy |> Local_Theory.notes (common_notes @ notes) |> snd) end; diff -r eab51f249c70 -r 6d29d2db5f88 src/HOL/Codatatype/Tools/bnf_util.ML --- a/src/HOL/Codatatype/Tools/bnf_util.ML Thu Sep 06 01:37:24 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_util.ML Thu Sep 06 02:56:21 2012 +0200 @@ -273,8 +273,8 @@ fun mk_Frees x Ts ctxt = mk_fresh_names ctxt (length Ts) x |>> (fn names => map2 (curry Free) names Ts); fun mk_Freess x Tss ctxt = - fold_map2 (fn name => fn Ts => fn ctxt => - mk_fresh_names ctxt (length Ts) name) (mk_names (length Tss) x) Tss ctxt + fold_map2 (fn name => fn Ts => fn ctxt => mk_fresh_names ctxt (length Ts) name) + (mk_names (length Tss) x) Tss ctxt |>> (fn namess => map2 (map2 (curry Free)) namess Tss); fun mk_Frees' x Ts ctxt = mk_fresh_names ctxt (length Ts) x |>> (fn names => `(map Free) (names ~~ Ts));