src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
changeset 57200 aab87ffa60cc
parent 57094 589ec121ce1a
child 57260 8747af0d1012
--- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Tue Jun 10 11:38:53 2014 +0200
+++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Tue Jun 10 12:16:22 2014 +0200
@@ -56,19 +56,19 @@
   val dest_case: Proof.context -> string -> typ list -> term ->
     (ctr_sugar * term list * term list) option
 
-  type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list
+  type ('c, 'a) ctr_spec = (binding * 'c) * 'a list
 
-  val disc_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> binding
-  val ctr_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'c
-  val args_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'a list
-  val sel_defaults_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> (binding * 'v) list
+  val disc_of_ctr_spec: ('c, 'a) ctr_spec -> binding
+  val ctr_of_ctr_spec: ('c, 'a) ctr_spec -> 'c
+  val args_of_ctr_spec: ('c, 'a) ctr_spec -> 'a list
 
   val free_constructors: ({prems: thm list, context: Proof.context} -> tactic) list list ->
-    ((bool * bool) * binding) * (term, binding, term) ctr_spec list -> local_theory ->
+    (((bool * bool) * binding) * (term, binding) ctr_spec list) * term list -> local_theory ->
     ctr_sugar * local_theory
   val parse_bound_term: (binding * string) parser
   val parse_ctr_options: (bool * bool) parser
-  val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a, string) ctr_spec parser
+  val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a) ctr_spec parser
+  val parse_sel_default_eqs: string list parser
 end;
 
 structure Ctr_Sugar : CTR_SUGAR =
@@ -313,24 +313,43 @@
     | _ => NONE)
   | _ => NONE);
 
-fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
+fun const_or_free_name (Const (s, _)) = Long_Name.base_name s
+  | const_or_free_name (Free (s, _)) = s
+  | const_or_free_name t = raise TERM ("const_or_free_name", [t])
 
-type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list;
+fun extract_sel_default ctxt t =
+  let
+    fun malformed () =
+      error ("Malformed selector default value equation: " ^ Syntax.string_of_term ctxt t);
 
-fun disc_of_ctr_spec (((disc, _), _), _) = disc;
-fun ctr_of_ctr_spec (((_, ctr), _), _) = ctr;
-fun args_of_ctr_spec ((_, args), _) = args;
-fun sel_defaults_of_ctr_spec (_, ds) = ds;
+    val ((sel, (ctr, vars)), rhs) =
+      fst (Term.replace_dummy_patterns (Syntax.check_term ctxt t) 0)
+      |> HOLogic.dest_eq
+      |>> (Term.dest_comb
+        #>> const_or_free_name
+        ##> (Term.strip_comb #>> (Term.dest_Const #> fst)))
+      handle TERM _ => malformed ();
+  in
+    if forall (is_Free orf is_Var) vars andalso not (has_duplicates (op aconv) vars) then
+      ((ctr, sel), fold_rev Term.lambda vars rhs)
+    else
+      malformed ()
+  end;
 
-fun prepare_free_constructors prep_term (((discs_sels, no_code), raw_case_binding), ctr_specs)
-    no_defs_lthy =
+type ('c, 'a) ctr_spec = (binding * 'c) * 'a list;
+
+fun disc_of_ctr_spec ((disc, _), _) = disc;
+fun ctr_of_ctr_spec ((_, ctr), _) = ctr;
+fun args_of_ctr_spec (_, args) = args;
+
+fun prepare_free_constructors prep_term
+    ((((discs_sels, no_code), raw_case_binding), ctr_specs), sel_default_eqs) no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
 
     val raw_ctrs = map ctr_of_ctr_spec ctr_specs;
     val raw_disc_bindings = map disc_of_ctr_spec ctr_specs;
     val raw_sel_bindingss = map args_of_ctr_spec ctr_specs;
-    val raw_sel_defaultss = map sel_defaults_of_ctr_spec ctr_specs;
 
     val n = length raw_ctrs;
     val ks = 1 upto n;
@@ -338,7 +357,6 @@
     val _ = if n > 0 then () else error "No constructors specified";
 
     val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
-    val sel_defaultss = map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss;
 
     val Type (fcT_name, As0) = body_type (fastype_of (hd ctrs0));
     val fc_b_name = Long_Name.base_name fcT_name;
@@ -424,8 +442,8 @@
 
     (* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides
        nicer names). Consider removing. *)
-    val eta_fs = map2 eta_expand_arg xss xfs;
-    val eta_gs = map2 eta_expand_arg xss xgs;
+    val eta_fs = map2 (fold_rev Term.lambda) xss xfs;
+    val eta_gs = map2 (fold_rev Term.lambda) xss xgs;
 
     val case_binding =
       qualify false
@@ -484,13 +502,38 @@
     val no_discs_sels =
       not discs_sels andalso
       forall (forall Binding.is_empty) (raw_disc_bindings :: raw_sel_bindingss) andalso
-      forall null raw_sel_defaultss;
+      null sel_default_eqs;
 
     val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') =
       if no_discs_sels then
         (true, [], [], [], [], [], lthy')
       else
         let
+          val sel_bindings = flat sel_bindingss;
+          val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
+          val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
+
+          val sel_binding_index =
+            if all_sels_distinct then 1 upto length sel_bindings
+            else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
+
+          val all_proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
+          val sel_infos =
+            AList.group (op =) (sel_binding_index ~~ all_proto_sels)
+            |> sort (int_ord o pairself fst)
+            |> map snd |> curry (op ~~) uniq_sel_bindings;
+          val sel_bindings = map fst sel_infos;
+          val sel_Ts = map (curry (op -->) fcT o fastype_of o snd o snd o hd o snd) sel_infos;
+
+          val sel_default_lthy = no_defs_lthy
+            |> Proof_Context.allow_dummies
+            |> Proof_Context.add_fixes
+              (map2 (fn b => fn T => (b, SOME T, NoSyn)) sel_bindings sel_Ts)
+            |> snd;
+
+          val sel_defaults =
+            map (extract_sel_default sel_default_lthy o prep_term sel_default_lthy) sel_default_eqs;
+
           fun disc_free b = Free (Binding.name_of b, mk_pred1T fcT);
 
           fun disc_spec b exist_xs_u_eq_ctr = mk_Trueprop_eq (disc_free b $ u, exist_xs_u_eq_ctr);
@@ -499,48 +542,33 @@
             Term.lambda u (alternate_disc_lhs (K o rapp u o disc_free) (3 - k));
 
           fun mk_sel_case_args b proto_sels T =
-            map2 (fn Ts => fn k =>
+            map3 (fn Const (c, _) => fn Ts => fn k =>
               (case AList.lookup (op =) proto_sels k of
                 NONE =>
-                (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
-                  NONE => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
-                | SOME t => t |> Type.constraint (Ts ---> T) |> Syntax.check_term lthy)
-              | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
+                (case filter (curry (op =) (c, Binding.name_of b) o fst) sel_defaults of
+                  [] => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
+                | [(_, t)] => t
+                | _ => error "Multiple default values for selector/constructor pair")
+              | SOME (xs, x) => fold_rev Term.lambda xs x)) ctrs ctr_Tss ks;
 
           fun sel_spec b proto_sels =
             let
               val _ =
                 (case duplicates (op =) (map fst proto_sels) of
                    k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^
-                     " for constructor " ^
-                     quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
+                     " for constructor " ^ quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
                  | [] => ())
               val T =
                 (case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of
                   [T] => T
                 | T :: T' :: _ => error ("Inconsistent range type for selector " ^
-                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^ " vs. "
-                    ^ quote (Syntax.string_of_typ lthy T')));
+                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^
+                    " vs. " ^ quote (Syntax.string_of_typ lthy T')));
             in
               mk_Trueprop_eq (Free (Binding.name_of b, fcT --> T) $ u,
                 Term.list_comb (mk_case As T case0, mk_sel_case_args b proto_sels T) $ u)
             end;
 
-          val sel_bindings = flat sel_bindingss;
-          val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
-          val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
-
-          val sel_binding_index =
-            if all_sels_distinct then 1 upto length sel_bindings
-            else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
-
-          val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
-          val sel_infos =
-            AList.group (op =) (sel_binding_index ~~ proto_sels)
-            |> sort (int_ord o pairself fst)
-            |> map snd |> curry (op ~~) uniq_sel_bindings;
-          val sel_bindings = map fst sel_infos;
-
           fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss;
 
           val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
@@ -733,7 +761,7 @@
                 | _ => false);
 
               val all_sel_thms =
-                (if all_sels_distinct andalso forall null sel_defaultss then
+                (if all_sels_distinct andalso null sel_default_eqs then
                    flat sel_thmss
                  else
                    map_product (fn s => fn (xs', c) => make_sel_thm xs' c s) sel_defs
@@ -1020,19 +1048,17 @@
       >> (fn js => (member (op =) js 0, member (op =) js 1)))
     (false, false);
 
-val parse_defaults =
-  @{keyword "("} |-- Parse.reserved "defaults" |-- Scan.repeat parse_bound_term --| @{keyword ")"};
-
 fun parse_ctr_spec parse_ctr parse_arg =
-  parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg --
-  Scan.optional parse_defaults [];
+  parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg;
 
 val parse_ctr_specs = Parse.enum1 "|" (parse_ctr_spec Parse.term Parse.binding);
+val parse_sel_default_eqs = Scan.optional (@{keyword "where"} |-- Parse.enum1 "|" Parse.prop) [];
 
 val _ =
   Outer_Syntax.local_theory_to_proof @{command_spec "free_constructors"}
     "register an existing freely generated type's constructors"
     (parse_ctr_options -- Parse.binding --| @{keyword "for"} -- parse_ctr_specs
+       -- parse_sel_default_eqs
      >> free_constructors_cmd);
 
 val _ = Context.>> (Context.map_theory Ctr_Sugar_Interpretation.init);