use 'disc_exhausts' property both from types on which 'case's take place and on return type
authorblanchet
Tue, 14 Jan 2014 18:41:24 +0100
changeset 55008 b5b2e193ca33
parent 55007 0c07990363a3
child 55009 d4b69107a86a
use 'disc_exhausts' property both from types on which 'case's take place and on return type
src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Mon Jan 13 20:20:44 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Tue Jan 14 18:41:24 2014 +0100
@@ -100,7 +100,7 @@
 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
 val mk_dnf = mk_disjs o map mk_conjs;
 
-val conjuncts_s = filter_out (curry (op =) @{const True}) o HOLogic.conjuncts;
+val conjuncts_s = filter_out (curry (op aconv) @{const True}) o HOLogic.conjuncts;
 
 fun s_not @{const True} = @{const False}
   | s_not @{const False} = @{const True}
@@ -344,8 +344,8 @@
 
 fun case_thms_of_term ctxt bound_Ts t =
   let val (ctr_sugars, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t () in
-    (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #sel_splits ctr_sugars,
-     maps #sel_split_asms ctr_sugars)
+    (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #disc_exhausts ctr_sugars,
+     maps #sel_splits ctr_sugars, maps #sel_split_asms ctr_sugars)
   end;
 
 fun basic_corec_specs_of ctxt res_T =
@@ -534,7 +534,7 @@
       in find end;
 
     val applied_fun = concl
-      |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
+      |> find_subterm (member (op = o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
       |> the
       handle Option.Option => primcorec_error_eqn "malformed discriminator formula" concl;
     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
@@ -665,7 +665,7 @@
     val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
 
     val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
-        if member ((op =) o apsnd #ctr) basic_ctr_specs ctr then cons (ctr, cs)
+        if member (op = o apsnd #ctr) basic_ctr_specs ctr then cons (ctr, cs)
         else primcorec_error_eqn "not a constructor" ctr) [] rhs' []
       |> AList.group (op =);
 
@@ -918,7 +918,7 @@
 
     val callssss =
       map_filter (try (fn Sel x => x)) eqns_data
-      |> partition_eq ((op =) o pairself #fun_name)
+      |> partition_eq (op = o pairself #fun_name)
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       |> map (flat o snd)
       |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss
@@ -936,16 +936,16 @@
     val ctr_specss = map #ctr_specs corec_specs;
 
     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
-      |> partition_eq ((op =) o pairself #fun_name)
+      |> partition_eq (op = o pairself #fun_name)
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
-      |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
+      |> map (sort (op < o pairself #ctr_no |> make_ord) o flat o snd);
     val _ = disc_eqnss' |> map (fn x =>
-      let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
+      let val d = duplicates (op = o pairself #ctr_no) x in null d orelse
         primcorec_error_eqns "excess discriminator formula in definition"
           (maps (fn t => filter (curry (op =) (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
 
     val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
-      |> partition_eq ((op =) o pairself #fun_name)
+      |> partition_eq (op = o pairself #fun_name)
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       |> map (flat o snd);
 
@@ -987,23 +987,33 @@
     val de_facto_exhaustives =
       map2 (fn b => fn b' => b orelse b') exhaustives syntactic_exhaustives;
 
-    fun map_prove_with_tac tac =
-      map (fn goal => Goal.prove_sorry lthy [] [] goal tac |> Thm.close_derivation);
-
     val nchotomy_goalss =
       map2 (fn false => K [] | true => single o HOLogic.mk_Trueprop o mk_dnf o map #prems)
         de_facto_exhaustives disc_eqnss
       |> list_all_fun_args []
     val nchotomy_taut_thmss =
-      map3 (fn {disc_exhausts, ...} => fn syntactic_exhaustive =>
-          if syntactic_exhaustive then
-            map_prove_with_tac (fn {context = ctxt, ...} =>
-              mk_primcorec_nchotomy_tac ctxt disc_exhausts)
-          else
-            (case tac_opt of
-              SOME tac => map_prove_with_tac tac
-            | NONE => K []))
-        corec_specs syntactic_exhaustives nchotomy_goalss;
+      map5 (fn {disc_exhausts = res_disc_exhausts, ...} => fn arg_Ts =>
+          fn {code_rhs_opt, ...} :: _ => fn [] => K []
+            | [goal] => fn true =>
+              let
+                val (_, _, arg_disc_exhausts, _, _) =
+                  case_thms_of_term lthy arg_Ts (the_default Term.dummy code_rhs_opt);
+              in
+                [Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} =>
+                   mk_primcorec_nchotomy_tac ctxt (res_disc_exhausts @ arg_disc_exhausts))
+                 |> Thm.close_derivation]
+              end
+            | false =>
+              (case tac_opt of
+                SOME tac => [Goal.prove_sorry lthy [] [] goal tac |> Thm.close_derivation]
+              | NONE => []))
+        corec_specs arg_Tss disc_eqnss nchotomy_goalss syntactic_exhaustives;
+
+    val syntactic_exhaustives =
+      map (fn disc_eqns => forall (null o #prems orf is_some o #code_rhs_opt) disc_eqns
+          orelse exists #auto_gen disc_eqns)
+        disc_eqnss;
+
     val goalss = goalss'
       |> (if is_none tac_opt then
           append (map2 (fn true => K [] | false => map (rpair [])) syntactic_exhaustives
@@ -1098,7 +1108,7 @@
               |> HOLogic.mk_Trueprop o HOLogic.mk_eq
               |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
               |> curry Logic.list_all (map dest_Free fun_args);
-            val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term;
+            val (distincts, _, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term;
           in
             mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps
               nested_map_idents nested_map_comps sel_corec k m excludesss
@@ -1115,7 +1125,7 @@
               andalso not (exists (curry (op =) ctr o #ctr) sel_eqns)
             orelse
               filter (curry (op =) ctr o #ctr) sel_eqns
-              |> fst o finds ((op =) o apsnd #sel) sels
+              |> fst o finds (op = o apsnd #sel) sels
               |> exists (null o snd) then
             []
           else
@@ -1134,7 +1144,7 @@
                   SOME rhs => rhs
                 | NONE =>
                   filter (curry (op =) ctr o #ctr) sel_eqns
-                  |> fst o finds ((op =) o apsnd #sel) sels
+                  |> fst o finds (op = o apsnd #sel) sels
                   |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
                   |> curry list_comb ctr)
                 |> curry mk_Trueprop_eq (applied_fun_of fun_name fun_T fun_args)
@@ -1187,7 +1197,7 @@
                               |> Option.map #prems |> the_default [];
                             val t =
                               filter (curry (op =) ctr o #ctr) sel_eqns
-                              |> fst o finds ((op =) o apsnd #sel) sels
+                              |> fst o finds (op = o apsnd #sel) sels
                               |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
                                 #-> abstract)
                               |> curry list_comb ctr;
@@ -1222,7 +1232,7 @@
                     val (raw_goal, goal) = (raw_rhs, rhs)
                       |> pairself (curry mk_Trueprop_eq (applied_fun_of fun_name fun_T fun_args)
                         #> curry Logic.list_all (map dest_Free fun_args));
-                    val (distincts, discIs, sel_splits, sel_split_asms) =
+                    val (distincts, discIs, _, sel_splits, sel_split_asms) =
                       case_thms_of_term lthy bound_Ts raw_rhs;
 
                     val raw_code_thm = mk_primcorec_raw_code_tac lthy distincts discIs sel_splits