src/HOL/BNF/Tools/bnf_ctr_sugar.ML
changeset 52968 2b430bbb5a1a
parent 52965 0cd62cb233e0
child 52969 f2df0730f8ac
--- a/src/HOL/BNF/Tools/bnf_ctr_sugar.ML	Mon Aug 12 15:25:16 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_ctr_sugar.ML	Mon Aug 12 15:25:17 2013 +0200
@@ -39,7 +39,7 @@
   val name_of_disc: term -> string
 
   val wrap_free_constructors: ({prems: thm list, context: Proof.context} -> tactic) list list ->
-    (((bool * bool) * term list) * term) *
+    (((bool * bool) * term list) * binding) *
       (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
     ctr_sugar * local_theory
   val parse_wrap_free_constructors_options: (bool * bool) parser
@@ -177,8 +177,8 @@
 
 fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
 
-fun prepare_wrap_free_constructors prep_term ((((no_dests, rep_compat), raw_ctrs), raw_case),
-    (raw_disc_bindings, (raw_sel_bindingss, raw_sel_defaultss))) no_defs_lthy =
+fun prepare_wrap_free_constructors prep_term ((((no_dests, rep_compat), raw_ctrs),
+    raw_case_binding), (raw_disc_bindings, (raw_sel_bindingss, raw_sel_defaultss))) no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
 
@@ -188,13 +188,10 @@
     val _ = if n > 0 then () else error "No constructors specified";
 
     val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
-    val case0 = prep_term no_defs_lthy raw_case;
     val sel_defaultss =
       pad_list [] n (map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss);
 
-    val case0T = fastype_of case0;
-    val Type (dataT_name, As0) =
-      domain_type (snd (strip_typeN (num_binder_types case0T - 1) case0T));
+    val Type (dataT_name, As0) = body_type (fastype_of (hd ctrs0));
     val data_b = Binding.qualified_name dataT_name;
     val data_b_name = Binding.name_of data_b;
 
@@ -256,15 +253,15 @@
           else
             sel)) (1 upto m) o pad_list Binding.empty m) ctrs0 ms;
 
-    val casex = mk_case As B case0;
     val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
 
-    val (((((((xss, xss'), yss), fs), gs), [u', v']), (p, p')), names_lthy) = no_defs_lthy |>
+    val ((((((((xss, xss'), yss), fs), gs), [u', v']), [w]), (p, p')), names_lthy) = no_defs_lthy |>
       mk_Freess' "x" ctr_Tss
       ||>> mk_Freess "y" ctr_Tss
       ||>> mk_Frees "f" case_Ts
       ||>> mk_Frees "g" case_Ts
       ||>> (apfst (map (rpair dataT)) oo Variable.variant_fixes) [data_b_name, data_b_name ^ "'"]
+      ||>> mk_Frees "z" [B]
       ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
 
     val u = Free u';
@@ -277,16 +274,43 @@
     val xfs = map2 (curry Term.list_comb) fs xss;
     val xgs = map2 (curry Term.list_comb) gs xss;
 
+    (* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides
+       nicer names). Consider removing. *)
+    val eta_fs = map2 eta_expand_arg xss xfs;
+    val eta_gs = map2 eta_expand_arg xss xgs;
+
+    val case_binding =
+      qualify false
+        (if Binding.is_empty raw_case_binding orelse
+            Binding.eq_name (raw_case_binding, standard_binding) then
+           Binding.suffix_name ("_" ^ caseN) data_b
+         else
+           raw_case_binding);
+
+    fun mk_case_disj xctr xf xs =
+      list_exists_free xs (HOLogic.mk_conj (HOLogic.mk_eq (u, xctr), HOLogic.mk_eq (w, xf)));
+
+    val case_rhs =
+      fold_rev (fold_rev Term.lambda) [fs, [u]]
+        (Const (@{const_name The}, (B --> HOLogic.boolT) --> B) $
+           Term.lambda w (Library.foldr1 HOLogic.mk_disj (map3 mk_case_disj xctrs xfs xss)));
+
+    val ((raw_case, (_, raw_case_def)), (lthy', lthy)) = no_defs_lthy
+      |> Local_Theory.define ((case_binding, NoSyn), ((Thm.def_binding case_binding, []), case_rhs))
+      ||> `Local_Theory.restore;
+
+    val phi = Proof_Context.export_morphism lthy lthy';
+
+    val case_def = Morphism.thm phi raw_case_def;
+
+    val case0 = Morphism.term phi raw_case;
+    val casex = mk_case As B case0;
+
     val fcase = Term.list_comb (casex, fs);
 
     val ufcase = fcase $ u;
     val vfcase = fcase $ v;
 
-    (* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides
-       nicer names). Consider removing. *)
-    val eta_fs = map2 eta_expand_arg xss xfs;
-    val eta_gs = map2 eta_expand_arg xss xgs;
-
     val eta_fcase = Term.list_comb (casex, eta_fs);
     val eta_gcase = Term.list_comb (casex, eta_gs);
 
@@ -312,7 +336,7 @@
     val (all_sels_distinct, discs, selss, udiscs, uselss, vdiscs, vselss, disc_defs, sel_defs,
          sel_defss, lthy') =
       if no_dests then
-        (true, [], [], [], [], [], [], [], [], [], no_defs_lthy)
+        (true, [], [], [], [], [], [], [], [], [], lthy)
       else
         let
           fun disc_free b = Free (Binding.name_of b, mk_pred1T dataT);
@@ -328,7 +352,7 @@
                 NONE =>
                 (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
                   NONE => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T)
-                | SOME t => t |> Type.constraint (Ts ---> T) |> Syntax.check_term no_defs_lthy)
+                | SOME t => t |> Type.constraint (Ts ---> T) |> Syntax.check_term lthy)
               | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
 
           fun sel_spec b proto_sels =
@@ -337,14 +361,14 @@
                 (case duplicates (op =) (map fst proto_sels) of
                    k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^
                      " for constructor " ^
-                     quote (Syntax.string_of_term no_defs_lthy (nth ctrs (k - 1))))
+                     quote (Syntax.string_of_term lthy (nth ctrs (k - 1))))
                  | [] => ())
               val T =
                 (case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of
                   [T] => T
                 | T :: T' :: _ => error ("Inconsistent range type for selector " ^
-                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ no_defs_lthy T) ^
-                    " vs. " ^ quote (Syntax.string_of_typ no_defs_lthy T')));
+                    quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^ " vs. "
+                    ^ quote (Syntax.string_of_typ lthy T')));
             in
               mk_Trueprop_eq (Free (Binding.name_of b, dataT --> T) $ u,
                 Term.list_comb (mk_case As T case0, mk_sel_case_args b proto_sels T) $ u)
@@ -368,7 +392,7 @@
           fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss;
 
           val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
-            no_defs_lthy
+            lthy
             |> apfst split_list o fold_map3 (fn k => fn exist_xs_u_eq_ctr => fn b =>
                 if Binding.is_empty b then
                   if n = 1 then pair (Term.lambda u (mk_uu_eq ()), unique_disc_no_def)
@@ -432,16 +456,11 @@
         map (map mk_goal) (mk_half_pairss (`I (xss ~~ xctrs)))
       end;
 
-    val cases_goal =
-      map3 (fn xs => fn xctr => fn xf =>
-        fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (fcase $ xctr, xf))) xss xctrs xfs;
-
-    val goalss = [exhaust_goal] :: inject_goalss @ half_distinct_goalss @ [cases_goal];
+    val goalss = [exhaust_goal] :: inject_goalss @ half_distinct_goalss;
 
     fun after_qed thmss lthy =
       let
-        val ([exhaust_thm], (inject_thmss, (half_distinct_thmss, [case_thms]))) =
-          (hd thmss, apsnd (chop (n * n)) (chop n (tl thmss)));
+        val ([exhaust_thm], (inject_thmss, half_distinct_thmss)) = (hd thmss, chop n (tl thmss));
 
         val inject_thms = flat inject_thmss;
 
@@ -470,6 +489,19 @@
             |> Thm.close_derivation
           end;
 
+        val case_thms =
+          let
+            val goals =
+              map3 (fn xctr => fn xf => fn xs =>
+                fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (fcase $ xctr, xf))) xctrs xfs xss;
+          in
+            map4 (fn k => fn goal => fn injects => fn distinctss =>
+                Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} =>
+                  mk_case_tac ctxt n k case_def injects distinctss)
+                |> Thm.close_derivation)
+              ks goals inject_thmss distinct_thmsss
+          end;
+
         val (all_sel_thms, sel_thmss, disc_thmss, disc_thms, discI_thms, disc_exclude_thms,
              disc_exhaust_thms, collapse_thms, expand_thms, case_conv_thms) =
           if no_dests then
@@ -767,7 +799,7 @@
     "wrap an existing freely generated type's constructors"
     ((parse_wrap_free_constructors_options -- (@{keyword "["} |-- Parse.list Parse.term --|
         @{keyword "]"}) --
-      Parse.term -- Scan.optional (parse_bindings -- Scan.optional (parse_bindingss --
+      parse_binding -- Scan.optional (parse_bindings -- Scan.optional (parse_bindingss --
         Scan.optional parse_bound_termss []) ([], [])) ([], ([], [])))
      >> wrap_free_constructors_cmd);