src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53722 e176d6d3345f
parent 53720 03fac7082137
child 53725 9e64151359e8
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 18 18:11:32 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Sep 19 00:32:33 2013 +0200
@@ -661,15 +661,17 @@
     |> rpair exclss'
   end;
 
-fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} disc_eqns =
+fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} sel_eqns disc_eqns =
   if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
     let
       val n = 0 upto length ctr_specs
         |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
+      val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
+        |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
       val extra_disc_eqn = {
         fun_name = Binding.name_of fun_binding,
         fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
-        fun_args = the_default (map (curry Free Name.uu) arg_Ts) (try (#fun_args o hd) disc_eqns),
+        fun_args = fun_args,
         ctr = #ctr (nth ctr_specs n),
         ctr_no = n,
         disc = #disc (nth ctr_specs n),
@@ -718,7 +720,7 @@
 
     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val arg_Tss = map (binder_types o snd o fst) fixes;
-    val disc_eqnss = map4 mk_real_disc_eqns bs arg_Tss corec_specs disc_eqnss';
+    val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
     val (defs, exclss') =
       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
 
@@ -744,13 +746,12 @@
 
         fun prove_disc {ctr_specs, ...} exclsss
             {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} =
-          if user_eqn = undef_const then [] else
+          if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
             let
-              val disc_corec = nth ctr_specs ctr_no |> #disc_corec;
+              val {disc_corec, ...} = nth ctr_specs ctr_no;
               val k = 1 + ctr_no;
               val m = length prems;
               val t =
-                (* FIXME use applied_fun from dissect_\<dots> instead? *)
                 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
                 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
                 |> HOLogic.mk_Trueprop
@@ -790,22 +791,24 @@
 
         fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
             {ctr, disc, sels, collapse, ...} =
+let val _ = tracing ("disc = " ^ @{make_string} disc); in
           if not (exists (equal ctr o #ctr) disc_eqns)
-andalso (warning ("no disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true)
-            orelse (* don't try to prove theorems where some sel_eqns are missing *)
+              andalso not (exists (equal ctr o #ctr) sel_eqns)
+andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true)
+            orelse (* don't try to prove theorems when some sel_eqns are missing *)
               filter (equal ctr o #ctr) sel_eqns
               |> fst o finds ((op =) o apsnd #sel) sels
               |> exists (null o snd)
 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true)
-            orelse
-              #user_eqn (the (find_first (equal ctr o #ctr) disc_eqns)) = undef_const
-andalso (warning ("auto-generated disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true)
           then [] else
             let
 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
-val _ = tracing ("disc = " ^ Syntax.string_of_term lthy (#disc (the (find_first (equal ctr o #ctr) disc_eqns))));
-              val {fun_name, fun_T, fun_args, prems, ...} =
-                the (find_first (equal ctr o #ctr) disc_eqns);
+val _ = tracing (the_default "NO disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
+              val (fun_name, fun_T, fun_args, prems) =
+                (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
+                |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
+                ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
+                |> the o merge_options;
               val m = length prems;
               val t = sel_eqns
                 |> fst o finds ((op =) o apsnd #sel) sels
@@ -816,17 +819,19 @@
                 |> HOLogic.mk_Trueprop
                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
                 |> curry Logic.list_all (map dest_Free fun_args);
-              val disc_thm = the_default TrueI (AList.lookup (op =) disc_thms disc);
+              val maybe_disc_thm = AList.lookup (op =) disc_thms disc;
               val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms');
 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t);
 val _ = tracing ("m = " ^ @{make_string} m);
 val _ = tracing ("collapse = " ^ @{make_string} collapse);
-val _ = tracing ("disc_thm = " ^ @{make_string} disc_thm);
+val _ = tracing ("maybe_disc_thm = " ^ @{make_string} maybe_disc_thm);
 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms);
             in
-              mk_primcorec_ctr_of_dtr_tac lthy m collapse disc_thm sel_thms
+              mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
               |> K |> Goal.prove lthy [] [] t
               |> single
+(*handle ERROR x => (warning x; []))*)
+end
           end;
 
         val (disc_notes, disc_thmss) =