allow default values for selectors in low-level "wrap_data" command
authorblanchet
Tue, 11 Sep 2012 16:08:55 +0200
changeset 49280 52413dc96326
parent 49279 2fcfc11374ed
child 49281 3d87f4fd0d50
allow default values for selectors in low-level "wrap_data" command
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 16:08:27 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 16:08:55 2012 +0200
@@ -480,7 +480,7 @@
               corec_def), lthy)
           end;
       in
-        wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_binders, sel_binderss)) lthy'
+        wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_binders, (sel_binderss, []))) lthy'
         |> (if lfp then some_lfp_sugar else some_gfp_sugar)
       end;
 
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 11 16:08:27 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 11 16:08:55 2012 +0200
@@ -11,7 +11,8 @@
   val mk_half_pairss: 'a list -> ('a * 'a) list list
   val mk_ctr: typ list -> term -> term
   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
-    ((bool * term list) * term) * (binding list * binding list list) -> local_theory ->
+    ((bool * term list) * term) *
+      (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
     (term list list * thm list * thm list list) * local_theory
   val parse_wrap_options: bool parser
 end;
@@ -56,8 +57,7 @@
 
 fun mk_half_pairss ys = mk_half_pairss' [[]] ys;
 
-(* TODO: provide a way to have a different default value, e.g. "tl Nil = Nil" *)
-fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
+fun mk_undefined T = Const (@{const_name undefined}, T);
 
 fun mk_ctr Ts ctr =
   let val Type (_, Ts0) = body_type (fastype_of ctr) in
@@ -67,27 +67,29 @@
 fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
 
 fun name_of_ctr c =
-  case head_of c of
+  (case head_of c of
     Const (s, _) => s
   | Free (s, _) => s
-  | _ => error "Cannot extract name of constructor";
+  | _ => error "Cannot extract name of constructor");
 
 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
-    (raw_disc_binders, raw_sel_binderss)) no_defs_lthy =
+    (raw_disc_binders, (raw_sel_binderss, raw_sel_defaultss))) no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
     (* TODO: attributes (simp, case_names, etc.) *)
     (* TODO: case syntax *)
     (* TODO: integration with function package ("size") *)
 
-    val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
-    val case0 = prep_term no_defs_lthy raw_case;
-
-    val n = length ctrs0;
+    val n = length raw_ctrs;
     val ks = 1 upto n;
 
     val _ = if n > 0 then () else error "No constructors specified";
 
+    val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
+    val case0 = prep_term no_defs_lthy raw_case;
+    val sel_defaultss =
+      pad_list [] n (map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss);
+
     val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
     val b = Binding.qualified_name fpT_name;
 
@@ -194,11 +196,17 @@
 
           fun alternate_disc k = Term.lambda v (alternate_disc_lhs (K o disc_free) (3 - k));
 
-          fun mk_sel_case_args proto_sels T =
-            map2 (fn Ts => fn i =>
-              case AList.lookup (op =) proto_sels i of
-                NONE => mk_undef T Ts
-              | SOME (xs, x) => fold_rev Term.lambda xs x) ctr_Tss ks;
+          fun mk_sel_case_args b proto_sels T =
+            map2 (fn Ts => fn k =>
+              (case AList.lookup (op =) proto_sels k of
+                NONE =>
+                let val def_T = Ts ---> T in
+                  (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
+                    NONE => mk_undefined def_T
+                  | SOME t => fold_rev (fn T => Term.lambda (Free (Name.uu, T))) Ts
+                      (Term.subst_atomic_types [(fastype_of t, T)] t))
+                end
+              | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
 
           fun sel_spec b proto_sels =
             let
@@ -216,7 +224,7 @@
                     " vs. " ^ quote (Syntax.string_of_typ no_defs_lthy T')));
             in
               mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
-                Term.list_comb (mk_case As T, mk_sel_case_args proto_sels T) $ v)
+                Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v)
             end;
 
           val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
@@ -537,11 +545,16 @@
 
 fun wrap_datatype tacss = (fn (goalss, after_qed, lthy) =>
   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
-  |> (fn thms => after_qed thms lthy)) oo
-  prepare_wrap_datatype (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
+  |> (fn thms => after_qed thms lthy)) oo prepare_wrap_datatype (K I);
+
+fun parse_bracket_list parser = @{keyword "["} |-- Parse.list parser --|  @{keyword "]"};
 
-val parse_bindings = @{keyword "["} |-- Parse.list Parse.binding --| @{keyword "]"};
-val parse_bindingss = @{keyword "["} |-- Parse.list parse_bindings --| @{keyword "]"};
+val parse_bindings = parse_bracket_list Parse.binding;
+val parse_bindingss = parse_bracket_list parse_bindings;
+
+val parse_bound_term = (Parse.binding --| @{keyword ":"}) -- Parse.term;
+val parse_bound_terms = parse_bracket_list parse_bound_term;
+val parse_bound_termss = parse_bracket_list parse_bound_terms;
 
 val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
   Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
@@ -553,7 +566,8 @@
 val _ =
   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
     ((parse_wrap_options -- (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) --
-      Parse.term -- Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))
+      Parse.term -- Scan.optional (parse_bindings -- Scan.optional (parse_bindingss --
+        Scan.optional parse_bound_termss []) ([], [])) ([], ([], [])))
      >> wrap_datatype_cmd);
 
 end;