src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
changeset 54910 0ec2cccbf8ad
parent 54909 63db983c6953
child 54911 6a6980245ce0
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
@@ -809,30 +809,32 @@
 
 fun mk_actual_disc_eqns fun_binding arg_Ts exhaustive ({ctr_specs, ...} : corec_spec)
     (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) =
-  if exhaustive orelse length disc_eqns <> length ctr_specs - 1 then
-    disc_eqns
-  else
-    let
-      val n = 0 upto length ctr_specs
-        |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
-      val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
-        |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
-      val maybe_sel_eqn = find_first (equal (Binding.name_of fun_binding) o #fun_name) sel_eqns;
-      val extra_disc_eqn = {
-        fun_name = Binding.name_of fun_binding,
-        fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
-        fun_args = fun_args,
-        ctr = #ctr (nth ctr_specs n),
-        ctr_no = n,
-        disc = #disc (nth ctr_specs n),
-        prems = maps (s_not_conj o #prems) disc_eqns,
-        auto_gen = true,
-        maybe_ctr_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
-        maybe_code_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
-        user_eqn = undef_const};
-    in
-      chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
-    end;
+  let val num_disc_eqns = length disc_eqns in
+    if num_disc_eqns < length ctr_specs - 1 andalso num_disc_eqns > 1 then
+      disc_eqns
+    else
+      let
+        val n = 0 upto length ctr_specs
+          |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
+        val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
+          |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
+        val maybe_sel_eqn = find_first (equal (Binding.name_of fun_binding) o #fun_name) sel_eqns;
+        val extra_disc_eqn = {
+          fun_name = Binding.name_of fun_binding,
+          fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
+          fun_args = fun_args,
+          ctr = #ctr (nth ctr_specs n),
+          ctr_no = n,
+          disc = #disc (nth ctr_specs n),
+          prems = maps (s_not_conj o #prems) disc_eqns,
+          auto_gen = true,
+          maybe_ctr_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
+          maybe_code_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
+          user_eqn = undef_const};
+      in
+        chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
+      end
+  end;
 
 fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
   let
@@ -931,7 +933,8 @@
       |> split_list o map split_list;
 
     val list_all_fun_args =
-      map2 ((fn {fun_args, ...} => map (curry Logic.list_all (map dest_Free fun_args))) o hd)
+      map2 (fn [] => I
+          | {fun_args, ...} :: _ => map (curry Logic.list_all (map dest_Free fun_args)))
         disc_eqnss;
 
     val syntactic_exhaustives =
@@ -1180,14 +1183,16 @@
               end)
           end;
 
-        val disc_alists = map3 (maps oo prove_disc) corec_specs excludessss disc_eqnss;
+        val disc_alistss = map3 (map oo prove_disc) corec_specs excludessss disc_eqnss;
+        val disc_alists = map flat disc_alistss;
         val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss excludessss sel_eqnss;
-        val disc_thmss = map (map snd) disc_alists;
+        val disc_thmsss = map (map (map snd)) disc_alistss;
+        val disc_thmss = map flat disc_thmsss;
         val sel_thmss = map (map snd) sel_alists;
 
-        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms discs
+        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss disc_thms
             ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
-          if null discs orelse null exhaust_thms then
+          if null disc_thms orelse null exhaust_thms then
             []
           else
             let
@@ -1200,15 +1205,15 @@
               if prems = [@{term False}] then
                 []
               else
-                mk_primcorec_disc_iff_tac lthy (ctr_no + 1) (the_single exhaust_thms) discs
-                  (flat disc_excludess)
+                mk_primcorec_disc_iff_tac lthy (the_single exhaust_thms) (the_single disc_thms)
+                  (flat disc_thmss) (flat disc_excludess)
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> single
             end;
 
-        val disc_iff_thmss = map4 (maps ooo prove_disc_iff) corec_specs exhaust_thmss disc_thmss
-          disc_eqnss;
+        val disc_iff_thmss = map5 (flat ooo map2 ooo prove_disc_iff) corec_specs exhaust_thmss
+          disc_thmsss disc_thmsss disc_eqnss;
 
         val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
           ctr_specss;