--- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML Tue Jun 10 11:38:53 2014 +0200
+++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML Tue Jun 10 12:16:22 2014 +0200
@@ -56,19 +56,19 @@
val dest_case: Proof.context -> string -> typ list -> term ->
(ctr_sugar * term list * term list) option
- type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list
+ type ('c, 'a) ctr_spec = (binding * 'c) * 'a list
- val disc_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> binding
- val ctr_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'c
- val args_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'a list
- val sel_defaults_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> (binding * 'v) list
+ val disc_of_ctr_spec: ('c, 'a) ctr_spec -> binding
+ val ctr_of_ctr_spec: ('c, 'a) ctr_spec -> 'c
+ val args_of_ctr_spec: ('c, 'a) ctr_spec -> 'a list
val free_constructors: ({prems: thm list, context: Proof.context} -> tactic) list list ->
- ((bool * bool) * binding) * (term, binding, term) ctr_spec list -> local_theory ->
+ (((bool * bool) * binding) * (term, binding) ctr_spec list) * term list -> local_theory ->
ctr_sugar * local_theory
val parse_bound_term: (binding * string) parser
val parse_ctr_options: (bool * bool) parser
- val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a, string) ctr_spec parser
+ val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a) ctr_spec parser
+ val parse_sel_default_eqs: string list parser
end;
structure Ctr_Sugar : CTR_SUGAR =
@@ -313,24 +313,43 @@
| _ => NONE)
| _ => NONE);
-fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
+fun const_or_free_name (Const (s, _)) = Long_Name.base_name s
+ | const_or_free_name (Free (s, _)) = s
+ | const_or_free_name t = raise TERM ("const_or_free_name", [t])
-type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list;
+fun extract_sel_default ctxt t =
+ let
+ fun malformed () =
+ error ("Malformed selector default value equation: " ^ Syntax.string_of_term ctxt t);
-fun disc_of_ctr_spec (((disc, _), _), _) = disc;
-fun ctr_of_ctr_spec (((_, ctr), _), _) = ctr;
-fun args_of_ctr_spec ((_, args), _) = args;
-fun sel_defaults_of_ctr_spec (_, ds) = ds;
+ val ((sel, (ctr, vars)), rhs) =
+ fst (Term.replace_dummy_patterns (Syntax.check_term ctxt t) 0)
+ |> HOLogic.dest_eq
+ |>> (Term.dest_comb
+ #>> const_or_free_name
+ ##> (Term.strip_comb #>> (Term.dest_Const #> fst)))
+ handle TERM _ => malformed ();
+ in
+ if forall (is_Free orf is_Var) vars andalso not (has_duplicates (op aconv) vars) then
+ ((ctr, sel), fold_rev Term.lambda vars rhs)
+ else
+ malformed ()
+ end;
-fun prepare_free_constructors prep_term (((discs_sels, no_code), raw_case_binding), ctr_specs)
- no_defs_lthy =
+type ('c, 'a) ctr_spec = (binding * 'c) * 'a list;
+
+fun disc_of_ctr_spec ((disc, _), _) = disc;
+fun ctr_of_ctr_spec ((_, ctr), _) = ctr;
+fun args_of_ctr_spec (_, args) = args;
+
+fun prepare_free_constructors prep_term
+ ((((discs_sels, no_code), raw_case_binding), ctr_specs), sel_default_eqs) no_defs_lthy =
let
(* TODO: sanity checks on arguments *)
val raw_ctrs = map ctr_of_ctr_spec ctr_specs;
val raw_disc_bindings = map disc_of_ctr_spec ctr_specs;
val raw_sel_bindingss = map args_of_ctr_spec ctr_specs;
- val raw_sel_defaultss = map sel_defaults_of_ctr_spec ctr_specs;
val n = length raw_ctrs;
val ks = 1 upto n;
@@ -338,7 +357,6 @@
val _ = if n > 0 then () else error "No constructors specified";
val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
- val sel_defaultss = map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss;
val Type (fcT_name, As0) = body_type (fastype_of (hd ctrs0));
val fc_b_name = Long_Name.base_name fcT_name;
@@ -424,8 +442,8 @@
(* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides
nicer names). Consider removing. *)
- val eta_fs = map2 eta_expand_arg xss xfs;
- val eta_gs = map2 eta_expand_arg xss xgs;
+ val eta_fs = map2 (fold_rev Term.lambda) xss xfs;
+ val eta_gs = map2 (fold_rev Term.lambda) xss xgs;
val case_binding =
qualify false
@@ -484,13 +502,38 @@
val no_discs_sels =
not discs_sels andalso
forall (forall Binding.is_empty) (raw_disc_bindings :: raw_sel_bindingss) andalso
- forall null raw_sel_defaultss;
+ null sel_default_eqs;
val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') =
if no_discs_sels then
(true, [], [], [], [], [], lthy')
else
let
+ val sel_bindings = flat sel_bindingss;
+ val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
+ val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
+
+ val sel_binding_index =
+ if all_sels_distinct then 1 upto length sel_bindings
+ else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
+
+ val all_proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
+ val sel_infos =
+ AList.group (op =) (sel_binding_index ~~ all_proto_sels)
+ |> sort (int_ord o pairself fst)
+ |> map snd |> curry (op ~~) uniq_sel_bindings;
+ val sel_bindings = map fst sel_infos;
+ val sel_Ts = map (curry (op -->) fcT o fastype_of o snd o snd o hd o snd) sel_infos;
+
+ val sel_default_lthy = no_defs_lthy
+ |> Proof_Context.allow_dummies
+ |> Proof_Context.add_fixes
+ (map2 (fn b => fn T => (b, SOME T, NoSyn)) sel_bindings sel_Ts)
+ |> snd;
+
+ val sel_defaults =
+ map (extract_sel_default sel_default_lthy o prep_term sel_default_lthy) sel_default_eqs;
+
fun disc_free b = Free (Binding.name_of b, mk_pred1T fcT);
fun disc_spec b exist_xs_u_eq_ctr = mk_Trueprop_eq (disc_free b $ u, exist_xs_u_eq_ctr);
@@ -499,48 +542,33 @@
Term.lambda u (alternate_disc_lhs (K o rapp u o disc_free) (3 - k));
fun mk_sel_case_args b proto_sels T =
- map2 (fn Ts => fn k =>
+ map3 (fn Const (c, _) => fn Ts => fn k =>
(case AList.lookup (op =) proto_sels k of
NONE =>
- (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
- NONE => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
- | SOME t => t |> Type.constraint (Ts ---> T) |> Syntax.check_term lthy)
- | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
+ (case filter (curry (op =) (c, Binding.name_of b) o fst) sel_defaults of
+ [] => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
+ | [(_, t)] => t
+ | _ => error "Multiple default values for selector/constructor pair")
+ | SOME (xs, x) => fold_rev Term.lambda xs x)) ctrs ctr_Tss ks;
fun sel_spec b proto_sels =
let
val _ =
(case duplicates (op =) (map fst proto_sels) of
k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^
- " for constructor " ^
- quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
+ " for constructor " ^ quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
| [] => ())
val T =
(case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of
[T] => T
| T :: T' :: _ => error ("Inconsistent range type for selector " ^
- quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^ " vs. "
- ^ quote (Syntax.string_of_typ lthy T')));
+ quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^
+ " vs. " ^ quote (Syntax.string_of_typ lthy T')));
in
mk_Trueprop_eq (Free (Binding.name_of b, fcT --> T) $ u,
Term.list_comb (mk_case As T case0, mk_sel_case_args b proto_sels T) $ u)
end;
- val sel_bindings = flat sel_bindingss;
- val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
- val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
-
- val sel_binding_index =
- if all_sels_distinct then 1 upto length sel_bindings
- else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
-
- val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
- val sel_infos =
- AList.group (op =) (sel_binding_index ~~ proto_sels)
- |> sort (int_ord o pairself fst)
- |> map snd |> curry (op ~~) uniq_sel_bindings;
- val sel_bindings = map fst sel_infos;
-
fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss;
val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
@@ -733,7 +761,7 @@
| _ => false);
val all_sel_thms =
- (if all_sels_distinct andalso forall null sel_defaultss then
+ (if all_sels_distinct andalso null sel_default_eqs then
flat sel_thmss
else
map_product (fn s => fn (xs', c) => make_sel_thm xs' c s) sel_defs
@@ -1020,19 +1048,17 @@
>> (fn js => (member (op =) js 0, member (op =) js 1)))
(false, false);
-val parse_defaults =
- @{keyword "("} |-- Parse.reserved "defaults" |-- Scan.repeat parse_bound_term --| @{keyword ")"};
-
fun parse_ctr_spec parse_ctr parse_arg =
- parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg --
- Scan.optional parse_defaults [];
+ parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg;
val parse_ctr_specs = Parse.enum1 "|" (parse_ctr_spec Parse.term Parse.binding);
+val parse_sel_default_eqs = Scan.optional (@{keyword "where"} |-- Parse.enum1 "|" Parse.prop) [];
val _ =
Outer_Syntax.local_theory_to_proof @{command_spec "free_constructors"}
"register an existing freely generated type's constructors"
(parse_ctr_options -- Parse.binding --| @{keyword "for"} -- parse_ctr_specs
+ -- parse_sel_default_eqs
>> free_constructors_cmd);
val _ = Context.>> (Context.map_theory Ctr_Sugar_Interpretation.init);