--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 23:26:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 23:27:19 2012 +0200
@@ -7,7 +7,10 @@
signature BNF_FP_SUGAR =
sig
- (* TODO: programmatic interface *)
+ val datatyp: bool ->
+ bool * ((((typ * sort) list * binding) * mixfix) * ((((binding * binding) *
+ (binding * typ) list) * (binding * term) list) * mixfix) list) list ->
+ local_theory -> local_theory
end;
structure BNF_FP_Sugar : BNF_FP_SUGAR =
@@ -44,9 +47,11 @@
| SOME T' => T')
| typ_subst inst T = the_default T (AList.lookup (op =) inst T);
-fun retype_free (Free (s, _)) T = Free (s, T);
+fun resort_tfree S (TFree (s, _)) = TFree (s, S);
-val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
+fun retype_free T (Free (s, _)) = Free (s, T);
+
+val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
fun mk_predT T = T --> HOLogic.boolT;
@@ -66,23 +71,10 @@
fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
-fun merge_type_arg_constrained ctxt (T, c) (T', c') =
- if T = T' then
- (case (c, c') of
- (_, NONE) => (T, c)
- | (NONE, _) => (T, c')
- | _ =>
- if c = c' then
- (T, c)
- else
- error ("Inconsistent sort constraints for type variable " ^
- quote (Syntax.string_of_typ ctxt T)))
- else
- cannot_merge_types ();
+fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
-fun merge_type_args_constrained ctxt (cAs, cAs') =
- if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
- else cannot_merge_types ();
+fun merge_type_args (As, As') =
+ if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
fun type_args_constrained_of (((cAs, _), _), _) = cAs;
val type_args_of = map fst o type_args_constrained_of;
@@ -96,31 +88,45 @@
fun defaults_of ((_, ds), _) = ds;
fun ctr_mixfix_of (_, mx) = mx;
-fun prepare_datatype prepare_typ prepare_term lfp (no_dests, specs) fake_lthy no_defs_lthy =
+fun define_datatype prepare_constraint prepare_typ prepare_term lfp (no_dests, specs)
+ no_defs_lthy0 =
let
+ (* TODO: sanity checks on arguments *)
+
val _ = if not lfp andalso no_dests then error "Cannot define destructor-less codatatypes"
else ();
- val constrained_As =
- map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
- |> Library.foldr1 (merge_type_args_constrained no_defs_lthy);
- val As = map fst constrained_As;
- val As' = map dest_TFree As;
+ val N = length specs;
+
+ fun prepare_type_arg (ty, c) =
+ let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
+ TFree (s, prepare_constraint no_defs_lthy0 c)
+ end;
+
+ val Ass0 = map (map prepare_type_arg o type_args_constrained_of) specs;
+ val unsorted_Ass0 = map (map (resort_tfree HOLogic.typeS)) Ass0;
+ val unsorted_As = Library.foldr1 merge_type_args unsorted_Ass0;
- val _ = (case duplicates (op =) As of [] => ()
- | A :: _ => error ("Duplicate type parameter " ^
- quote (Syntax.string_of_typ no_defs_lthy A)));
+ val ((Bs, Cs), no_defs_lthy) =
+ no_defs_lthy0
+ |> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
+ |> mk_TFrees N
+ ||>> mk_TFrees N;
- (* TODO: use sort constraints on type args *)
-
- val N = length specs;
+ (* TODO: cleaner handling of fake contexts, without "background_theory" *)
+ (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
+ locale and shadows an existing global type*)
+ val fake_thy =
+ Theory.copy #> fold (fn spec => perhaps (try (Sign.add_type no_defs_lthy
+ (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
+ val fake_lthy = Proof_Context.background_theory fake_thy no_defs_lthy;
fun mk_fake_T b =
Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
- As);
+ unsorted_As);
val bs = map type_binder_of specs;
- val fakeTs = map mk_fake_T bs;
+ val fake_Ts = map mk_fake_T bs;
val mixfixes = map mixfix_of specs;
@@ -135,39 +141,41 @@
val ctr_mixfixess = map (map ctr_mixfix_of) ctr_specss;
val sel_bindersss = map (map (map fst)) ctr_argsss;
- val fake_ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
-
+ val fake_ctr_Tsss0 = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
val raw_sel_defaultsss = map (map defaults_of) ctr_specss;
+ val (Ass as As :: _) :: fake_ctr_Tsss =
+ burrow (burrow (Syntax.check_typs fake_lthy)) (Ass0 :: fake_ctr_Tsss0);
+
+ val _ = (case duplicates (op =) unsorted_As of [] => ()
+ | A :: _ => error ("Duplicate type parameter " ^
+ quote (Syntax.string_of_typ no_defs_lthy A)));
+
val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
- val _ = (case subtract (op =) As' rhs_As' of
+ val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
[] => ()
| A' :: _ => error ("Extra type variables on rhs: " ^
quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
- val ((Cs, Xs), _) =
- no_defs_lthy
- |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
- |> mk_TFrees N
- ||>> mk_TFrees N;
-
fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
quote (Syntax.string_of_typ fake_lthy T)))
| eq_fpT _ _ = false;
fun freeze_fp (T as Type (s, Us)) =
- (case find_index (eq_fpT T) fakeTs of ~1 => Type (s, map freeze_fp Us) | j => nth Xs j)
+ (case find_index (eq_fpT T) fake_Ts of ~1 => Type (s, map freeze_fp Us) | j => nth Bs j)
| freeze_fp T = T;
- val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
- val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
+ val ctr_TsssBs = map (map (map freeze_fp)) fake_ctr_Tsss;
+ val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
- val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
+ val fp_eqs =
+ map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
fp_iter_thms, fp_rec_thms), lthy)) =
- fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes As' eqs no_defs_lthy;
+ fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes (map dest_TFree unsorted_As) fp_eqs
+ no_defs_lthy0;
val add_nested_bnf_names =
let
@@ -179,7 +187,7 @@
in snd oo add end;
val nested_bnfs =
- map_filter (bnf_of lthy) (fold (fold (fold add_nested_bnf_names)) ctr_TsssXs []);
+ map_filter (bnf_of lthy) (fold (fold (fold add_nested_bnf_names)) ctr_TsssBs []);
val timer = time (Timer.startRealTimer ());
@@ -196,7 +204,7 @@
val fpTs = map (domain_type o fastype_of) unfs;
- val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
+ val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
val ns = map length ctr_Tsss;
val kss = map (fn n => 1 upto n) ns;
val mss = map (map length) ctr_Tsss;
@@ -242,8 +250,8 @@
dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
- val hss = map2 (map2 retype_free) gss h_Tss;
- val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
+ val hss = map2 (map2 retype_free) h_Tss gss;
+ val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
val (zssss_tl, _) =
lthy
|> mk_Freessss "y" (map (map (map tl)) z_Tssss);
@@ -293,7 +301,7 @@
val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
- val hssss_hd = map2 (map2 (map2 (fn [g] => fn T :: _ => retype_free g T))) gssss h_Tssss;
+ val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
val ((sssss, hssss_tl), _) =
lthy
|> mk_Freessss "q" s_Tssss
@@ -685,18 +693,9 @@
(timer; lthy')
end;
-fun datatype_cmd lfp (bundle as (_, specs)) lthy =
- let
- (* TODO: cleaner handling of fake contexts, without "background_theory" *)
- (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
- locale and shadows an existing global type*)
- val fake_thy = Theory.copy
- #> fold (fn spec => perhaps (try (Sign.add_type lthy
- (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
- val fake_lthy = Proof_Context.background_theory fake_thy lthy;
- in
- prepare_datatype Syntax.read_typ Syntax.read_term lfp bundle fake_lthy lthy
- end;
+val datatyp = define_datatype (K I) (K I) (K I);
+
+val datatype_cmd = define_datatype Typedecl.read_constraint Syntax.parse_typ Syntax.read_term;
val parse_opt_binding_colon = Scan.optional (Parse.binding --| @{keyword ":"}) no_binder
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML Tue Sep 11 23:26:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML Tue Sep 11 23:27:19 2012 +0200
@@ -96,7 +96,7 @@
val (As, B) =
no_defs_lthy
- |> mk_TFrees (length As0)
+ |> mk_TFrees' (map Type.sort_of_atyp As0)
||> the_single o fst o mk_TFrees 1;
val fpT = Type (fpT_name, As);
@@ -572,6 +572,10 @@
map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
|> (fn thms => after_qed thms lthy)) oo prepare_wrap_datatype (K I);
+val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
+ Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
+ prepare_wrap_datatype Syntax.read_term;
+
fun parse_bracket_list parser = @{keyword "["} |-- Parse.list parser --| @{keyword "]"};
val parse_bindings = parse_bracket_list Parse.binding;
@@ -581,10 +585,6 @@
val parse_bound_terms = parse_bracket_list parse_bound_term;
val parse_bound_termss = parse_bracket_list parse_bound_terms;
-val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
- Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
- prepare_wrap_datatype Syntax.read_term;
-
val parse_wrap_options =
Scan.optional (@{keyword "("} |-- (@{keyword "no_dests"} >> K true) --| @{keyword ")"}) false;