src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53822 6304b12c7627
parent 53811 2967fa35d89e
child 53830 ed2eb7df2aac
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 24 14:07:23 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 24 15:16:59 2013 +0200
@@ -12,6 +12,9 @@
   val add_primcorecursive_cmd: bool ->
     (binding * string option * mixfix) list * (Attrib.binding * string) list -> Proof.context ->
     Proof.state
+  val add_primcorec_cmd: bool ->
+    (binding * string option * mixfix) list * (Attrib.binding * string) list -> local_theory ->
+    local_theory
 end;
 
 structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
@@ -387,7 +390,7 @@
     if null eqns
     then error ("primrec_new error:\n  " ^ str)
     else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
-      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
+      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
 
 
 
@@ -401,6 +404,7 @@
   ctr_no: int, (*###*)
   disc: term,
   prems: term list,
+  auto_gen: bool,
   user_eqn: term
 };
 type co_eqn_data_sel = {
@@ -469,6 +473,7 @@
       ctr_no = ctr_no,
       disc = #disc (nth ctr_specs ctr_no),
       prems = real_prems,
+      auto_gen = catch_all,
       user_eqn = user_eqn
     }, matchedsss')
   end;
@@ -649,10 +654,11 @@
 
     val exclss' =
       disc_eqnss
-      |> map (map (fn {fun_args, ctr_no, prems, ...} => (fun_args, ctr_no, prems))
+      |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x))
         #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs [])
         #> maps (uncurry (map o pair)
-          #> map (fn ((fun_args, c, x), (_, c', y)) => ((c, c'), (x, s_not (mk_conjs y)))
+          #> map (fn ((fun_args, c, x, a), (_, c', y, a')) =>
+              ((c, c', a orelse a'), (x, s_not (mk_conjs y)))
             ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop
             ||> Logic.list_implies
             ||> curry Logic.list_all (map dest_Free fun_args))))
@@ -680,12 +686,13 @@
         ctr_no = n,
         disc = #disc (nth ctr_specs n),
         prems = maps (invert_prems o #prems) disc_eqns,
+        auto_gen = true,
         user_eqn = undef_const};
     in
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
     end;
 
-fun add_primcorec sequential fixes specs lthy =
+fun add_primcorec simple sequential fixes specs lthy =
   let
     val (bs, mxs) = map_split (apfst fst) fixes;
     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
@@ -723,10 +730,16 @@
     val (defs, exclss') =
       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
 
-    (* try to prove (automatically generated) tautologies by ourselves *)
-    val exclss'' = exclss'
-      |> map (map (apsnd
-        (`(try (fn t => Goal.prove lthy [] [] t (mk_primcorec_assumption_tac lthy |> K))))));
+    fun prove_excl_tac (c, c', a) =
+      if a orelse c = c' orelse sequential then SOME (K (mk_primcorec_assumption_tac lthy))
+      else if simple then SOME (K (auto_tac lthy))
+      else NONE;
+
+val _ = tracing ("exclusiveness properties:\n    \<cdot> " ^
+ space_implode "\n    \<cdot> " (maps (map (Syntax.string_of_term lthy o snd)) exclss'));
+
+    val exclss'' = exclss' |> map (map (fn (idx, t) =>
+      (idx, (Option.map (Goal.prove lthy [] [] t) (prove_excl_tac idx), t))));
     val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss'';
     val (obligation_idxss, obligationss) = exclss''
       |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
@@ -739,7 +752,7 @@
         val exclss' = map (op ~~) (obligation_idxss ~~ thmss');
         fun mk_exclsss excls n =
           (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1))
-          |-> fold (fn ((c, c'), thm) => nth_map c (nth_map c' (K [thm])));
+          |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm])));
         val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
           |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
 
@@ -799,8 +812,6 @@
 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true)
           then [] else
             let
-val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
-val _ = tracing (the_default "no disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
               val (fun_name, fun_T, fun_args, prems) =
                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
@@ -818,11 +829,6 @@
                 |> curry Logic.list_all (map dest_Free fun_args);
               val maybe_disc_thm = AList.lookup (op =) disc_alist disc;
               val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist);
-val _ = tracing ("t = " ^ Syntax.string_of_term lthy t);
-val _ = tracing ("m = " ^ @{make_string} m);
-val _ = tracing ("collapse = " ^ @{make_string} collapse);
-val _ = tracing ("maybe_disc_thm = " ^ @{make_string} maybe_disc_thm);
-val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms);
             in
               mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
               |> K |> Goal.prove lthy [] [] t
@@ -838,9 +844,7 @@
           (map #ctr_specs corec_specs);
 
         val safess = map (map (K false)) ctr_thmss; (* FIXME: "true" for non-corecursive theorems *)
-        val safe_ctr_thmss =
-          map2 (map_filter (fn (safe, thm) => if safe then SOME thm else NONE) oo curry (op ~~))
-            safess ctr_thmss;
+        val safe_ctr_thmss = map (map snd o filter fst o (op ~~)) (safess ~~ ctr_thmss);
 
         fun mk_simp_thms disc_thms sel_thms ctr_thms = disc_thms @ sel_thms @ ctr_thms;
 
@@ -875,24 +879,38 @@
       in
         lthy |> Local_Theory.notes (anonymous_notes @ notes @ common_notes) |> snd
       end;
+
+    fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss';
+
+    val _ = if not simple orelse forall null obligationss then () else
+      primrec_error "need exclusiveness proofs - use primcorecursive instead of primcorec";
   in
-    lthy'
-    |> Proof.theorem NONE (curry (op #->) (fold_map Local_Theory.define defs) o prove) obligationss
-    |> Proof.refine (Method.primitive_text I)
-    |> Seq.hd
-  end
+    if simple then
+      lthy'
+      |> after_qed (map (fn [] => []) obligationss)
+      |> pair NONE o SOME
+    else
+      lthy'
+      |> Proof.theorem NONE after_qed obligationss
+      |> Proof.refine (Method.primitive_text I)
+      |> Seq.hd
+      |> rpair NONE o SOME
+  end;
 
-fun add_primcorecursive_cmd seq (raw_fixes, raw_specs) lthy =
+fun add_primcorec_ursive_cmd simple seq (raw_fixes, raw_specs) lthy =
   let
     val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
   in
-    add_primcorec seq fixes specs lthy
+    add_primcorec simple seq fixes specs lthy
     handle ERROR str => primrec_error str
   end
   handle Primrec_Error (str, eqns) =>
     if null eqns
     then error ("primcorec error:\n  " ^ str)
     else error ("primcorec error:\n  " ^ str ^ "\nin\n  " ^
-      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
+      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
+
+val add_primcorecursive_cmd = (the o fst) ooo add_primcorec_ursive_cmd false;
+val add_primcorec_cmd = (the o snd) ooo add_primcorec_ursive_cmd true;
 
 end;