# HG changeset patch # User blanchet # Date 1347131066 -7200 # Node ID 0b735fb2602eae9544db08aa4b8e88ab8fd5b1c3 # Parent 262ab1ac38b9fefea164b68897760fba69851e53 generate iter/rec goals diff -r 262ab1ac38b9 -r 0b735fb2602e src/HOL/Codatatype/Tools/bnf_fp_sugar.ML --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200 @@ -63,16 +63,17 @@ fun args_of ((_, args), _) = args; fun ctr_mixfix_of (_, mx) = mx; -fun prepare_datatype prepare_typ gfp specs fake_lthy lthy = +fun prepare_datatype prepare_typ gfp specs fake_lthy no_defs_lthy = let val constrained_As = map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs - |> Library.foldr1 (merge_type_args_constrained lthy); + |> Library.foldr1 (merge_type_args_constrained no_defs_lthy); val As = map fst constrained_As; val As' = map dest_TFree As; val _ = (case duplicates (op =) As of [] => () - | A :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy A))); + | A :: _ => error ("Duplicate type parameter " ^ + quote (Syntax.string_of_typ no_defs_lthy A))); (* TODO: use sort constraints on type args *) @@ -83,7 +84,7 @@ As); val bs = map type_binder_of specs; - val fake_Ts = map mk_fake_T bs; + val fakeTs = map mk_fake_T bs; val mixfixes = map mixfix_of specs; @@ -104,32 +105,30 @@ val _ = (case subtract (op =) As' rhs_As' of [] => () | A' :: _ => error ("Extra type variables on rhs: " ^ - quote (Syntax.string_of_typ lthy (TFree A')))); + quote (Syntax.string_of_typ no_defs_lthy (TFree A')))); val ((Cs, Xs), _) = - lthy + no_defs_lthy |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs |> mk_TFrees N ||>> mk_TFrees N; - fun is_same_fpT (T as Type (s, Us)) (Type (s', Us')) = + fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) = s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^ quote (Syntax.string_of_typ fake_lthy T))) - | is_same_fpT _ _ = false; + | eq_fpT _ _ = false; - fun freeze_fpXs (T as Type (s, Us)) = - (case find_index (is_same_fpT T) fake_Ts of - ~1 => Type (s, map freeze_fpXs Us) - | i => nth Xs i) - | freeze_fpXs T = T; + fun freeze_fp (T as Type (s, Us)) = + (case find_index (eq_fpT T) fakeTs of ~1 => Type (s, map freeze_fp Us) | j => nth Xs j) + | freeze_fp T = T; - val ctr_TsssXs = map (map (map freeze_fpXs)) fake_ctr_Tsss; + val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss; val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs; val eqs = map dest_TFree Xs ~~ sum_prod_TsXs; - val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy') = - fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy; + val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy) = + fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs no_defs_lthy; val timer = time (Timer.startRealTimer ()); @@ -145,8 +144,9 @@ val flds = map (mk_fld As) flds0; val fpTs = map (domain_type o fastype_of) unfs; + val is_fpT = member (op =) fpTs; + val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs; - val ns = map length ctr_Tsss; val mss = map (map length) ctr_Tsss; val Css = map2 replicate ns Cs; @@ -162,8 +162,31 @@ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c end; - val fp_iters = map (mk_iter_or_rec As Cs) fp_iters0; - val fp_recs = map (mk_iter_or_rec As Cs) fp_recs0; + val fp_iters as fp_iter1 :: _ = map (mk_iter_or_rec As Cs) fp_iters0; + val fp_recs as fp_rec1 :: _ = map (mk_iter_or_rec As Cs) fp_recs0; + + val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter1)))); + val y_Tsss = map3 (fn ms => map2 dest_tupleT ms oo dest_sumTN) mss ns fp_y_Ts; + val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css; + + fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) = + if member (op =) Cs U then Us else [T] + | dest_rec_pair T = [T]; + + val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec1)))); + val z_Tssss = + map3 (fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms oo dest_sumTN) mss ns fp_z_Ts; + val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css; + + val ((gss, ysss), _) = + lthy + |> mk_Freess "f" g_Tss + ||>> mk_Freesss "x" y_Tsss; + + val hss = map2 (map2 retype_free) gss h_Tss; + val (zssss, _) = + lthy + |> mk_Freessss "x" z_Tssss; fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf), unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) @@ -178,7 +201,7 @@ val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss; val ((((u, v), fs), xss), _) = - lthy + no_defs_lthy |> yield_singleton (mk_Frees "u") unfT ||>> yield_singleton (mk_Frees "v") fpT ||>> mk_Frees "f" case_Ts @@ -249,39 +272,11 @@ val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs]; - (* (co)iterators, (co)recursors, (co)induction *) - - val is_fpT = member (op =) fpTs; - - fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) = - if member (op =) Cs U then Us else [T] - | dest_rec_pair T = [T]; - fun sugar_datatype no_defs_lthy = let - val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter)))); - val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts; - val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss; - val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css; val iter_T = flat g_Tss ---> fpT --> C; - - val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec)))); - val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts; - val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss; - val z_Tssss = map (map (map dest_rec_pair)) z_Tsss; - val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css; val rec_T = flat h_Tss ---> fpT --> C; - val ((gss, ysss), _) = - no_defs_lthy - |> mk_Freess "f" g_Tss - ||>> mk_Freesss "x" y_Tsss; - - val hss = map2 (map2 retype_free) gss h_Tss; - val (zssss, _) = - no_defs_lthy - |> mk_Freessss "x" z_Tssss; - val iter_binder = Binding.suffix_name ("_" ^ iterN) b; val rec_binder = Binding.suffix_name ("_" ^ recN) b; @@ -313,7 +308,7 @@ val iter = mk_iter_or_rec As Cs' iter0; val recx = mk_iter_or_rec As Cs' rec0; in - ([[ctrs], [[iter]], [[recx]], xss, gss, hss], lthy) + ([[ctrs], [[iter]], [[recx]], xss], lthy) end; fun sugar_codatatype no_defs_lthy = ([], no_defs_lthy); @@ -322,19 +317,30 @@ |> (if gfp then sugar_codatatype else sugar_datatype) end; - fun pour_more_sugar_on_datatypes ([[ctrss], [[iters]], [[recs]], xsss, gsss, hsss], lthy) = + fun pour_more_sugar_on_datatypes ([[ctrss], [[iters]], [[recs]], xsss], lthy) = let val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss; - val giters = map2 (curry flat_list_comb) iters gsss; - val hrecs = map2 (curry flat_list_comb) recs hsss; + val giters = map (fn iter => flat_list_comb (iter, gss)) iters; + val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs; val (iter_thmss, rec_thmss) = let - fun mk_goal_iter_or_rec fc xctr = - mk_Trueprop_eq (fc $ xctr, fc $ xctr); + fun mk_goal_iter_or_rec fss fc xctr f xs xs' = + mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs')); + + fun fix_iter_free (x as Free (_, T)) = + (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x); + fun fix_rec_free (x as Free (_, T)) = + (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]); - val goal_iterss = map2 (fn giter => map (mk_goal_iter_or_rec giter)) giters xctrss; - val goal_recss = map2 (fn hrec => map (mk_goal_iter_or_rec hrec)) hrecs xctrss; + val iter_xsss = map (map (map fix_iter_free)) xsss; + val rec_xsss = map (map (maps fix_rec_free)) xsss; + + val goal_iterss = + map5 (map4 o mk_goal_iter_or_rec gss) giters xctrss gss xsss iter_xsss; + val goal_recss = + map5 (map4 o mk_goal_iter_or_rec hss) hrecs xctrss hss xsss rec_xsss; + val iter_tacss = map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_iterss; (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *) @@ -356,7 +362,7 @@ lthy |> Local_Theory.notes notes |> snd end; - val lthy'' = lthy' + val lthy' = lthy |> fold_map pour_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) @@ -365,7 +371,7 @@ val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^ (if gfp then "co" else "") ^ "datatype")); in - (timer; lthy'') + (timer; lthy') end; fun datatype_cmd info specs lthy =