internally allow different values for 'exhaustive' for different constructors
authorblanchet
Thu, 02 Jan 2014 09:50:22 +0100
changeset 54903 c664bd02bf94
parent 54902 a9291e4d2366
child 54904 5d965f17b0e4
internally allow different values for 'exhaustive' for different constructors
src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
@@ -858,8 +858,6 @@
 
     val actual_nn = length bs;
 
-    val exhaustive = member (op =) opts Exhaustive_Option; (*###*)
-
     val sequentials = replicate actual_nn (member (op =) opts Sequential_Option);
     val exhaustives = replicate actual_nn (member (op =) opts Exhaustive_Option);
 
@@ -929,23 +927,26 @@
       |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
       |> split_list o map split_list;
 
+    val de_facto_exhaustives =
+      map2 (fn exhaustive => fn disc_eqns =>
+          exhaustive orelse forall (null o #prems) disc_eqns orelse exists #auto_gen disc_eqns)
+        exhaustives disc_eqnss;
+      
     val list_all_fun_args =
-      map2 ((fn {fun_args, ...} => curry Logic.list_all (map dest_Free fun_args)) o hd) disc_eqnss;
+      map2 ((fn {fun_args, ...} => map (curry Logic.list_all (map dest_Free fun_args))) o hd)
+        disc_eqnss;
 
-    val nchotomy_goals =
-      if exhaustive then
-        map (HOLogic.mk_Trueprop o mk_dnf o map #prems) disc_eqnss |> list_all_fun_args
-      else
-        [];
-    val nchotomy_taut_thms =
-      if exhaustive andalso is_some maybe_tac then
-        map (fn goal => Goal.prove lthy [] [] goal (the maybe_tac) |> Thm.close_derivation)
-          nchotomy_goals
-      else
-        [];
-    val goalss =
-      if exhaustive andalso is_none maybe_tac then map (rpair []) nchotomy_goals :: goalss'
-      else goalss';
+    val nchotomy_goalss =
+      map2 (fn false => K []
+        | true => single o HOLogic.mk_Trueprop o mk_dnf o map #prems) exhaustives disc_eqnss
+      |> list_all_fun_args
+    val nchotomy_taut_thmss =
+      (case maybe_tac of
+        SOME tac => map (map (fn goal => Goal.prove lthy [] [] goal tac |> Thm.close_derivation))
+          nchotomy_goalss
+      | NONE => []);
+    val goalss = goalss'
+      |> (if is_none maybe_tac then append (map (map (rpair [])) nchotomy_goalss) else I);
 
     val p = Var (("P", 0), HOLogic.boolT); (* safe since there are no other variables around *)
 
@@ -955,23 +956,21 @@
       let
         val def_thms = map (snd o snd) def_thms';
 
-        val maybe_nchotomy_thms =
-          if exhaustive then map SOME (if is_none maybe_tac then hd thmss'' else nchotomy_taut_thms)
-          else map (K NONE) def_thms;
-        val exclude_thmss = if exhaustive andalso is_none maybe_tac then tl thmss'' else thmss'';
+        val nchotomy_thmss =
+          if is_none maybe_tac then take actual_nn thmss'' else nchotomy_taut_thmss;
+        val exclude_thmss = thmss'' |> is_none maybe_tac ? drop actual_nn;
 
-        val maybe_exhaust_thms =
-          if exhaustive then
-            map (mk_imp_p o map (mk_imp_p o map HOLogic.mk_Trueprop o #prems)) disc_eqnss
-            |> list_all_fun_args
-            |> map3 (fn disc_eqns => fn SOME nchotomy_thm => fn goal =>
-                mk_primcorec_exhaust_tac (length disc_eqns) nchotomy_thm
-                |> K |> Goal.prove lthy [] [] goal
-                |> Thm.close_derivation
-                |> SOME)
-              disc_eqnss maybe_nchotomy_thms
-          else
-            map (K NONE) def_thms;
+        val exhaust_thmss =
+          map2 (fn false => K []
+              | true => single o mk_imp_p o map (mk_imp_p o map HOLogic.mk_Trueprop o #prems))
+            exhaustives disc_eqnss
+          |> list_all_fun_args
+          |> map3 (fn disc_eqns => fn [] => K []
+              | [nchotomy_thm] => fn [goal] =>
+                [mk_primcorec_exhaust_tac (length disc_eqns) nchotomy_thm
+                 |> K |> Goal.prove lthy [] [] goal
+                 |> Thm.close_derivation])
+            disc_eqnss nchotomy_thmss;
 
         val excludess' = map (op ~~) (goal_idxss ~~ exclude_thmss);
         fun mk_excludesss excludes n =
@@ -1075,7 +1074,7 @@
                 |> single
             end;
 
-        fun prove_code disc_eqns sel_eqns maybe_nchotomy ctr_alist ctr_specs =
+        fun prove_code de_facto_exhaustive disc_eqns sel_eqns nchotomys ctr_alist ctr_specs =
           let
             val maybe_fun_data =
               (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns,
@@ -1117,14 +1116,10 @@
                             SOME (prems, t)
                           end;
                       val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
-                      fun is_syntactically_exhaustive () =
-                        forall null (map_filter (try (fst o the)) maybe_ctr_conds_argss)
-                        orelse forall is_some maybe_ctr_conds_argss
-                          andalso exists #auto_gen disc_eqns
                     in
                       let
                         val rhs =
-                          (if exhaustive orelse is_syntactically_exhaustive () then
+                          (if de_facto_exhaustive then
                              split_last (map_filter I maybe_ctr_conds_argss) ||> snd
                            else
                              Const (@{const_name Code.abort}, @{typ String.literal} -->
@@ -1148,7 +1143,7 @@
                       case_thms_of_term lthy bound_Ts raw_rhs;
 
                     val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs
-                        sel_splits sel_split_asms ms ctr_thms maybe_nchotomy
+                        sel_splits sel_split_asms ms ctr_thms (try the_single nchotomys)
                       |> K |> Goal.prove lthy [] [] raw_goal
                       |> Thm.close_derivation;
                   in
@@ -1165,9 +1160,9 @@
         val disc_thmss = map (map snd) disc_alists;
         val sel_thmss = map (map snd) sel_alists;
 
-        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) maybe_exhaust_thm discs
+        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms discs
             ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
-          if null discs orelse is_none maybe_exhaust_thm then
+          if null discs orelse null exhaust_thms then
             []
           else
             let
@@ -1181,22 +1176,22 @@
               if prems = [@{term False}] then
                 []
               else
-                mk_primcorec_disc_iff_tac lthy (ctr_no + 1) (the maybe_exhaust_thm) discs
+                mk_primcorec_disc_iff_tac lthy (ctr_no + 1) (the_single exhaust_thms) discs
                   disc_excludess
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> single
             end;
 
-        val disc_iff_thmss = map4 (maps ooo prove_disc_iff) corec_specs maybe_exhaust_thms
-          disc_thmss disc_eqnss;
+        val disc_iff_thmss = map4 (maps ooo prove_disc_iff) corec_specs exhaust_thmss disc_thmss
+          disc_eqnss;
 
         val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
           ctr_specss;
         val ctr_thmss = map (map snd) ctr_alists;
 
-        val code_thmss = map5 prove_code disc_eqnss sel_eqnss maybe_nchotomy_thms ctr_alists
-          ctr_specss;
+        val code_thmss = map6 prove_code de_facto_exhaustives disc_eqnss sel_eqnss nchotomy_thmss
+          ctr_alists ctr_specss;
 
         val simp_thmss = map2 append disc_thmss sel_thmss
 
@@ -1209,8 +1204,8 @@
            (discN, disc_thmss, simp_attrs),
            (disc_iffN, disc_iff_thmss, []),
            (excludeN, exclude_thmss, []),
-           (exhaustN, map the_list maybe_exhaust_thms, []),
-           (nchotomyN, map the_list maybe_nchotomy_thms, []),
+           (exhaustN, exhaust_thmss, []),
+           (nchotomyN, nchotomy_thmss, []),
            (selN, sel_thmss, simp_attrs),
            (simpsN, simp_thmss, []),
            (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])]