changed discriminator default: avoid mixing ctor and dtor views
authorblanchet
Fri, 26 Apr 2013 11:04:45 +0200
changeset 51787 1267c28c7bdd
parent 51786 61ed47755088
child 51788 5fe72280a49f
changed discriminator default: avoid mixing ctor and dtor views
src/HOL/BNF/Tools/bnf_def.ML
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_util.ML
src/HOL/BNF/Tools/bnf_wrap.ML
--- a/src/HOL/BNF/Tools/bnf_def.ML	Fri Apr 26 09:53:11 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_def.ML	Fri Apr 26 11:04:45 2013 +0200
@@ -1320,7 +1320,7 @@
 
 val _ =
   Outer_Syntax.local_theory_to_proof @{command_spec "bnf_def"} "define a BNF for an existing type"
-    ((parse_opt_binding_colon -- Parse.term --
+    ((parse_opt_binding_colon Binding.empty -- Parse.term --
        (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) -- Parse.term --
        (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) -- Scan.option Parse.term)
        >> bnf_def_cmd);
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Apr 26 09:53:11 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Apr 26 11:04:45 2013 +0200
@@ -1224,7 +1224,8 @@
 val parse_type_arg_constrained =
   Parse.type_ident -- Scan.option (@{keyword "::"} |-- Parse.!!! Parse.sort)
 
-val parse_type_arg_named_constrained = parse_opt_binding_colon -- parse_type_arg_constrained
+val parse_type_arg_named_constrained =
+  parse_opt_binding_colon Binding.empty -- parse_type_arg_constrained
 
 val parse_type_args_named_constrained =
   parse_type_arg_constrained >> (single o pair Binding.empty) ||
@@ -1248,7 +1249,7 @@
   Scan.succeed no_map_rel;
 
 val parse_ctr_spec =
-  parse_opt_binding_colon -- Parse.binding -- Scan.repeat parse_ctr_arg --
+  parse_opt_binding_colon smart_binding -- Parse.binding -- Scan.repeat parse_ctr_arg --
   Scan.optional parse_defaults [] -- Parse.opt_mixfix;
 
 val parse_spec =
--- a/src/HOL/BNF/Tools/bnf_util.ML	Fri Apr 26 09:53:11 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_util.ML	Fri Apr 26 11:04:45 2013 +0200
@@ -159,8 +159,11 @@
   val certifyT: Proof.context -> typ -> ctyp
   val certify: Proof.context -> term -> cterm
 
+  val standard_binding: binding
+  val smart_binding: binding
+  val binding_eq: binding * binding -> bool
   val parse_binding_colon: Token.T list -> binding * Token.T list
-  val parse_opt_binding_colon: Token.T list -> binding * Token.T list
+  val parse_opt_binding_colon: binding -> Token.T list -> binding * Token.T list
 
   val typedef: binding * (string * sort) list * mixfix -> term ->
     (binding * binding) option -> tactic -> local_theory -> (string * Typedef.info) * local_theory
@@ -302,8 +305,16 @@
 fun certify ctxt = Thm.cterm_of (Proof_Context.theory_of ctxt);
 fun certifyT ctxt = Thm.ctyp_of (Proof_Context.theory_of ctxt);
 
+val binding_eq = (op =) o pairself Binding.dest
+
+(* The standard binding stands for a name generated following the canonical convention (e.g.
+   "is_Nil" from "Nil"). The smart binding is either the standard binding or no binding at all,
+   depending on the context. *)
+val standard_binding = @{binding _};
+val smart_binding = Binding.conceal standard_binding;
+
 val parse_binding_colon = Parse.binding --| @{keyword ":"};
-val parse_opt_binding_colon = Scan.optional parse_binding_colon Binding.empty;
+val parse_opt_binding_colon = Scan.optional parse_binding_colon;
 
 (*TODO: is this really different from Typedef.add_typedef_global?*)
 fun typedef typ set opt_morphs tac lthy =
--- a/src/HOL/BNF/Tools/bnf_wrap.ML	Fri Apr 26 09:53:11 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_wrap.ML	Fri Apr 26 11:04:45 2013 +0200
@@ -58,8 +58,6 @@
 val split_asmN = "split_asm";
 val weak_case_cong_thmsN = "weak_case_cong";
 
-val std_binding = @{binding _};
-
 val induct_simp_attrs = @{attributes [induct_simp]};
 val cong_attrs = @{attributes [cong]};
 val iff_attrs = @{attributes [iff]};
@@ -159,39 +157,47 @@
 
     val ms = map length ctr_Tss;
 
-    val raw_disc_bindings' = pad_list Binding.empty n raw_disc_bindings;
+    val raw_disc_bindings' = pad_list smart_binding n raw_disc_bindings;
 
     fun can_really_rely_on_disc k =
-      not (Binding.eq_name (nth raw_disc_bindings' (k - 1), Binding.empty)) orelse
-      nth ms (k - 1) = 0;
+      not (Binding.is_empty (nth raw_disc_bindings' (k - 1))) orelse nth ms (k - 1) = 0;
     fun can_rely_on_disc k =
       can_really_rely_on_disc k orelse (k = 1 andalso not (can_really_rely_on_disc 2));
     fun can_omit_disc_binding k m =
-      n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (3 - k));
+      m = 0 orelse n = 1 orelse (n = 2 andalso can_rely_on_disc (3 - k));
 
-    val std_disc_binding = qualify false o Binding.name o prefix isN o base_name_of_ctr;
+    fun should_really_rely_on_disc k =
+      let val b = nth raw_disc_bindings' (k - 1) in
+        not (Binding.is_empty b orelse binding_eq (b, smart_binding))
+      end;
+    fun should_rely_on_disc k =
+      should_really_rely_on_disc k orelse (k = 1 andalso not (should_really_rely_on_disc 2));
+    fun should_omit_disc_binding k =
+      n = 1 orelse (n = 2 andalso should_rely_on_disc (3 - k));
+
+    val standard_disc_binding = qualify false o Binding.name o prefix isN o base_name_of_ctr;
 
     val disc_bindings =
       raw_disc_bindings'
       |> map4 (fn k => fn m => fn ctr => fn disc =>
         Option.map (qualify false)
-          (if Binding.eq_name (disc, Binding.empty) then
-             if can_omit_disc_binding k m then NONE else SOME (std_disc_binding ctr)
-           else if Binding.eq_name (disc, std_binding) then
-             SOME (std_disc_binding ctr)
+          (if Binding.is_empty disc then
+             if can_omit_disc_binding k m then NONE else error "Cannot omit discriminator name"
+           else if binding_eq (disc, smart_binding) then
+             if should_omit_disc_binding k then NONE else SOME (standard_disc_binding ctr)
+           else if binding_eq (disc, standard_binding) then
+             SOME (standard_disc_binding ctr)
            else
              SOME disc)) ks ms ctrs0;
 
-    val no_discs = map is_none disc_bindings;
-
-    fun std_sel_binding m l = Binding.name o mk_unN m l o base_name_of_ctr;
+    fun standard_sel_binding m l = Binding.name o mk_unN m l o base_name_of_ctr;
 
     val sel_bindingss =
       pad_list [] n raw_sel_bindingss
       |> map3 (fn ctr => fn m => map2 (fn l => fn sel =>
         qualify false
-          (if Binding.eq_name (sel, Binding.empty) orelse Binding.eq_name (sel, std_binding) then
-            std_sel_binding m l ctr
+          (if Binding.is_empty sel orelse binding_eq (sel, standard_binding) then
+            standard_sel_binding m l ctr
           else
             sel)) (1 upto m) o pad_list Binding.empty m) ctrs0 ms;
 
@@ -310,9 +316,14 @@
             no_defs_lthy
             |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_u_eq_ctr =>
               fn NONE =>
-                 if n = 1 then pair (Term.lambda u (mk_uu_eq ()), unique_disc_no_def)
-                 else if m = 0 then pair (Term.lambda u exist_xs_u_eq_ctr, refl)
-                 else pair (alternate_disc k, alternate_disc_no_def)
+                 if n = 1 then
+                   pair (Term.lambda u (mk_uu_eq ()), unique_disc_no_def)
+                 else if n = 2 andalso should_omit_disc_binding k then
+                   pair (alternate_disc k, alternate_disc_no_def)
+                 else if m = 0 then
+                   pair (Term.lambda u exist_xs_u_eq_ctr, refl)
+                 else
+                   raise Fail "missing discriminator"
                | SOME b => Specification.definition (SOME (b, NONE, NoSyn),
                    ((Thm.def_binding b, []), disc_spec b exist_xs_u_eq_ctr)) #>> apsnd snd)
               ks ms exist_xs_u_eq_ctrs disc_bindings
@@ -482,7 +493,7 @@
                   map3 mk_thms discI_thms not_discI_thms distinct_thmsss' |> `transpose
                 end;
 
-              val disc_thms = flat (map2 (fn true => K [] | false => I) no_discs disc_thmss);
+              val disc_thms = flat (map2 (fn NONE => K [] | SOME _ => I) disc_bindings disc_thmss);
 
               val (disc_exclude_thms, (disc_exclude_thmsss', disc_exclude_thmsss)) =
                 let