src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
changeset 57200 aab87ffa60cc
parent 57094 589ec121ce1a
child 57260 8747af0d1012
     1.1 --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Tue Jun 10 11:38:53 2014 +0200
     1.2 +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Tue Jun 10 12:16:22 2014 +0200
     1.3 @@ -56,19 +56,19 @@
     1.4    val dest_case: Proof.context -> string -> typ list -> term ->
     1.5      (ctr_sugar * term list * term list) option
     1.6  
     1.7 -  type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list
     1.8 +  type ('c, 'a) ctr_spec = (binding * 'c) * 'a list
     1.9  
    1.10 -  val disc_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> binding
    1.11 -  val ctr_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'c
    1.12 -  val args_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> 'a list
    1.13 -  val sel_defaults_of_ctr_spec: ('c, 'a, 'v) ctr_spec -> (binding * 'v) list
    1.14 +  val disc_of_ctr_spec: ('c, 'a) ctr_spec -> binding
    1.15 +  val ctr_of_ctr_spec: ('c, 'a) ctr_spec -> 'c
    1.16 +  val args_of_ctr_spec: ('c, 'a) ctr_spec -> 'a list
    1.17  
    1.18    val free_constructors: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    1.19 -    ((bool * bool) * binding) * (term, binding, term) ctr_spec list -> local_theory ->
    1.20 +    (((bool * bool) * binding) * (term, binding) ctr_spec list) * term list -> local_theory ->
    1.21      ctr_sugar * local_theory
    1.22    val parse_bound_term: (binding * string) parser
    1.23    val parse_ctr_options: (bool * bool) parser
    1.24 -  val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a, string) ctr_spec parser
    1.25 +  val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a) ctr_spec parser
    1.26 +  val parse_sel_default_eqs: string list parser
    1.27  end;
    1.28  
    1.29  structure Ctr_Sugar : CTR_SUGAR =
    1.30 @@ -313,24 +313,43 @@
    1.31      | _ => NONE)
    1.32    | _ => NONE);
    1.33  
    1.34 -fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
    1.35 +fun const_or_free_name (Const (s, _)) = Long_Name.base_name s
    1.36 +  | const_or_free_name (Free (s, _)) = s
    1.37 +  | const_or_free_name t = raise TERM ("const_or_free_name", [t])
    1.38  
    1.39 -type ('c, 'a, 'v) ctr_spec = ((binding * 'c) * 'a list) * (binding * 'v) list;
    1.40 +fun extract_sel_default ctxt t =
    1.41 +  let
    1.42 +    fun malformed () =
    1.43 +      error ("Malformed selector default value equation: " ^ Syntax.string_of_term ctxt t);
    1.44  
    1.45 -fun disc_of_ctr_spec (((disc, _), _), _) = disc;
    1.46 -fun ctr_of_ctr_spec (((_, ctr), _), _) = ctr;
    1.47 -fun args_of_ctr_spec ((_, args), _) = args;
    1.48 -fun sel_defaults_of_ctr_spec (_, ds) = ds;
    1.49 +    val ((sel, (ctr, vars)), rhs) =
    1.50 +      fst (Term.replace_dummy_patterns (Syntax.check_term ctxt t) 0)
    1.51 +      |> HOLogic.dest_eq
    1.52 +      |>> (Term.dest_comb
    1.53 +        #>> const_or_free_name
    1.54 +        ##> (Term.strip_comb #>> (Term.dest_Const #> fst)))
    1.55 +      handle TERM _ => malformed ();
    1.56 +  in
    1.57 +    if forall (is_Free orf is_Var) vars andalso not (has_duplicates (op aconv) vars) then
    1.58 +      ((ctr, sel), fold_rev Term.lambda vars rhs)
    1.59 +    else
    1.60 +      malformed ()
    1.61 +  end;
    1.62  
    1.63 -fun prepare_free_constructors prep_term (((discs_sels, no_code), raw_case_binding), ctr_specs)
    1.64 -    no_defs_lthy =
    1.65 +type ('c, 'a) ctr_spec = (binding * 'c) * 'a list;
    1.66 +
    1.67 +fun disc_of_ctr_spec ((disc, _), _) = disc;
    1.68 +fun ctr_of_ctr_spec ((_, ctr), _) = ctr;
    1.69 +fun args_of_ctr_spec (_, args) = args;
    1.70 +
    1.71 +fun prepare_free_constructors prep_term
    1.72 +    ((((discs_sels, no_code), raw_case_binding), ctr_specs), sel_default_eqs) no_defs_lthy =
    1.73    let
    1.74      (* TODO: sanity checks on arguments *)
    1.75  
    1.76      val raw_ctrs = map ctr_of_ctr_spec ctr_specs;
    1.77      val raw_disc_bindings = map disc_of_ctr_spec ctr_specs;
    1.78      val raw_sel_bindingss = map args_of_ctr_spec ctr_specs;
    1.79 -    val raw_sel_defaultss = map sel_defaults_of_ctr_spec ctr_specs;
    1.80  
    1.81      val n = length raw_ctrs;
    1.82      val ks = 1 upto n;
    1.83 @@ -338,7 +357,6 @@
    1.84      val _ = if n > 0 then () else error "No constructors specified";
    1.85  
    1.86      val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
    1.87 -    val sel_defaultss = map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss;
    1.88  
    1.89      val Type (fcT_name, As0) = body_type (fastype_of (hd ctrs0));
    1.90      val fc_b_name = Long_Name.base_name fcT_name;
    1.91 @@ -424,8 +442,8 @@
    1.92  
    1.93      (* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides
    1.94         nicer names). Consider removing. *)
    1.95 -    val eta_fs = map2 eta_expand_arg xss xfs;
    1.96 -    val eta_gs = map2 eta_expand_arg xss xgs;
    1.97 +    val eta_fs = map2 (fold_rev Term.lambda) xss xfs;
    1.98 +    val eta_gs = map2 (fold_rev Term.lambda) xss xgs;
    1.99  
   1.100      val case_binding =
   1.101        qualify false
   1.102 @@ -484,13 +502,38 @@
   1.103      val no_discs_sels =
   1.104        not discs_sels andalso
   1.105        forall (forall Binding.is_empty) (raw_disc_bindings :: raw_sel_bindingss) andalso
   1.106 -      forall null raw_sel_defaultss;
   1.107 +      null sel_default_eqs;
   1.108  
   1.109      val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') =
   1.110        if no_discs_sels then
   1.111          (true, [], [], [], [], [], lthy')
   1.112        else
   1.113          let
   1.114 +          val sel_bindings = flat sel_bindingss;
   1.115 +          val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
   1.116 +          val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
   1.117 +
   1.118 +          val sel_binding_index =
   1.119 +            if all_sels_distinct then 1 upto length sel_bindings
   1.120 +            else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
   1.121 +
   1.122 +          val all_proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
   1.123 +          val sel_infos =
   1.124 +            AList.group (op =) (sel_binding_index ~~ all_proto_sels)
   1.125 +            |> sort (int_ord o pairself fst)
   1.126 +            |> map snd |> curry (op ~~) uniq_sel_bindings;
   1.127 +          val sel_bindings = map fst sel_infos;
   1.128 +          val sel_Ts = map (curry (op -->) fcT o fastype_of o snd o snd o hd o snd) sel_infos;
   1.129 +
   1.130 +          val sel_default_lthy = no_defs_lthy
   1.131 +            |> Proof_Context.allow_dummies
   1.132 +            |> Proof_Context.add_fixes
   1.133 +              (map2 (fn b => fn T => (b, SOME T, NoSyn)) sel_bindings sel_Ts)
   1.134 +            |> snd;
   1.135 +
   1.136 +          val sel_defaults =
   1.137 +            map (extract_sel_default sel_default_lthy o prep_term sel_default_lthy) sel_default_eqs;
   1.138 +
   1.139            fun disc_free b = Free (Binding.name_of b, mk_pred1T fcT);
   1.140  
   1.141            fun disc_spec b exist_xs_u_eq_ctr = mk_Trueprop_eq (disc_free b $ u, exist_xs_u_eq_ctr);
   1.142 @@ -499,48 +542,33 @@
   1.143              Term.lambda u (alternate_disc_lhs (K o rapp u o disc_free) (3 - k));
   1.144  
   1.145            fun mk_sel_case_args b proto_sels T =
   1.146 -            map2 (fn Ts => fn k =>
   1.147 +            map3 (fn Const (c, _) => fn Ts => fn k =>
   1.148                (case AList.lookup (op =) proto_sels k of
   1.149                  NONE =>
   1.150 -                (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
   1.151 -                  NONE => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
   1.152 -                | SOME t => t |> Type.constraint (Ts ---> T) |> Syntax.check_term lthy)
   1.153 -              | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
   1.154 +                (case filter (curry (op =) (c, Binding.name_of b) o fst) sel_defaults of
   1.155 +                  [] => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
   1.156 +                | [(_, t)] => t
   1.157 +                | _ => error "Multiple default values for selector/constructor pair")
   1.158 +              | SOME (xs, x) => fold_rev Term.lambda xs x)) ctrs ctr_Tss ks;
   1.159  
   1.160            fun sel_spec b proto_sels =
   1.161              let
   1.162                val _ =
   1.163                  (case duplicates (op =) (map fst proto_sels) of
   1.164                     k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^
   1.165 -                     " for constructor " ^
   1.166 -                     quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
   1.167 +                     " for constructor " ^ quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
   1.168                   | [] => ())
   1.169                val T =
   1.170                  (case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of
   1.171                    [T] => T
   1.172                  | T :: T' :: _ => error ("Inconsistent range type for selector " ^
   1.173 -                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^ " vs. "
   1.174 -                    ^ quote (Syntax.string_of_typ lthy T')));
   1.175 +                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^
   1.176 +                    " vs. " ^ quote (Syntax.string_of_typ lthy T')));
   1.177              in
   1.178                mk_Trueprop_eq (Free (Binding.name_of b, fcT --> T) $ u,
   1.179                  Term.list_comb (mk_case As T case0, mk_sel_case_args b proto_sels T) $ u)
   1.180              end;
   1.181  
   1.182 -          val sel_bindings = flat sel_bindingss;
   1.183 -          val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
   1.184 -          val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
   1.185 -
   1.186 -          val sel_binding_index =
   1.187 -            if all_sels_distinct then 1 upto length sel_bindings
   1.188 -            else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
   1.189 -
   1.190 -          val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
   1.191 -          val sel_infos =
   1.192 -            AList.group (op =) (sel_binding_index ~~ proto_sels)
   1.193 -            |> sort (int_ord o pairself fst)
   1.194 -            |> map snd |> curry (op ~~) uniq_sel_bindings;
   1.195 -          val sel_bindings = map fst sel_infos;
   1.196 -
   1.197            fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss;
   1.198  
   1.199            val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
   1.200 @@ -733,7 +761,7 @@
   1.201                  | _ => false);
   1.202  
   1.203                val all_sel_thms =
   1.204 -                (if all_sels_distinct andalso forall null sel_defaultss then
   1.205 +                (if all_sels_distinct andalso null sel_default_eqs then
   1.206                     flat sel_thmss
   1.207                   else
   1.208                     map_product (fn s => fn (xs', c) => make_sel_thm xs' c s) sel_defs
   1.209 @@ -1020,19 +1048,17 @@
   1.210        >> (fn js => (member (op =) js 0, member (op =) js 1)))
   1.211      (false, false);
   1.212  
   1.213 -val parse_defaults =
   1.214 -  @{keyword "("} |-- Parse.reserved "defaults" |-- Scan.repeat parse_bound_term --| @{keyword ")"};
   1.215 -
   1.216  fun parse_ctr_spec parse_ctr parse_arg =
   1.217 -  parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg --
   1.218 -  Scan.optional parse_defaults [];
   1.219 +  parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg;
   1.220  
   1.221  val parse_ctr_specs = Parse.enum1 "|" (parse_ctr_spec Parse.term Parse.binding);
   1.222 +val parse_sel_default_eqs = Scan.optional (@{keyword "where"} |-- Parse.enum1 "|" Parse.prop) [];
   1.223  
   1.224  val _ =
   1.225    Outer_Syntax.local_theory_to_proof @{command_spec "free_constructors"}
   1.226      "register an existing freely generated type's constructors"
   1.227      (parse_ctr_options -- Parse.binding --| @{keyword "for"} -- parse_ctr_specs
   1.228 +       -- parse_sel_default_eqs
   1.229       >> free_constructors_cmd);
   1.230  
   1.231  val _ = Context.>> (Context.map_theory Ctr_Sugar_Interpretation.init);