src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
changeset 54591 c822230fd22b
parent 54279 3ffb74b52ed6
child 54613 985f8b49c050
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Mon Nov 25 18:18:58 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Mon Nov 25 20:25:22 2013 +0100
@@ -8,10 +8,13 @@
 
 signature BNF_GFP_REC_SUGAR =
 sig
-  val add_primcorecursive_cmd: bool ->
+  datatype primcorec_option =
+    Option_Sequential |
+    Option_Exhaustive
+  val add_primcorecursive_cmd: primcorec_option list ->
     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
     Proof.context -> Proof.state
-  val add_primcorec_cmd: bool ->
+  val add_primcorec_cmd: primcorec_option list ->
     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
     local_theory -> local_theory
 end;
@@ -43,6 +46,10 @@
 fun primcorec_error_eqn str eqn = raise Primcorec_Error (str, [eqn]);
 fun primcorec_error_eqns str eqns = raise Primcorec_Error (str, eqns);
 
+datatype primcorec_option =
+  Option_Sequential |
+  Option_Exhaustive
+
 datatype corec_call =
   Dummy_No_Corec of int |
   No_Corec of int |
@@ -824,7 +831,7 @@
     |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
   end;
 
-fun add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy =
+fun add_primcorec_ursive maybe_tac opts fixes specs maybe_of_specs lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -835,6 +842,9 @@
         [] => ()
       | (b, _) :: _ => primcorec_error ("type of " ^ Binding.print b ^ " contains top sort"));
 
+    val seq = member (op =) opts Option_Sequential;
+    val exhaustive = member (op =) opts Option_Exhaustive;
+
     val fun_names = map Binding.name_of bs;
     val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
@@ -898,6 +908,8 @@
       |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
       |> split_list o map split_list;
 
+    val exhaustive_props = map (mk_disjs o map (mk_conjs o #prems)) disc_eqnss;
+
     fun prove thmss' def_thms' lthy =
       let
         val def_thms = map (snd o snd) def_thms';
@@ -1002,77 +1014,82 @@
 
         fun prove_code disc_eqns sel_eqns ctr_alist ctr_specs =
           let
-            val (fun_name, fun_T, fun_args, maybe_rhs) =
+            val fun_data =
               (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns,
                find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns)
               |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x))
               ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE))
-              |> the o merge_options;
-
-            val bound_Ts = List.rev (map fastype_of fun_args);
+              |> merge_options;
+          in
+            (case fun_data of
+              NONE => []
+            | SOME (fun_name, fun_T, fun_args, maybe_rhs) =>
+              let
+                val bound_Ts = List.rev (map fastype_of fun_args);
 
-            val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0));
-            val maybe_rhs_info =
-              (case maybe_rhs of
-                SOME rhs =>
-                let
-                  val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs;
-                  val cond_ctrs =
-                    fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs [];
-                  val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs;
-                in SOME (rhs, raw_rhs, ctr_thms) end
-              | NONE =>
-                let
-                  fun prove_code_ctr {ctr, sels, ...} =
-                    if not (exists (equal ctr o fst) ctr_alist) then NONE else
-                      let
-                        val prems = find_first (equal ctr o #ctr) disc_eqns
-                          |> Option.map #prems |> the_default [];
-                        val t =
-                          filter (equal ctr o #ctr) sel_eqns
-                          |> fst o finds ((op =) o apsnd #sel) sels
-                          |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
-                            #-> abstract)
-                          |> curry list_comb ctr;
-                      in
-                        SOME (prems, t)
-                      end;
-                  val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
-                in
-                  if exists is_none maybe_ctr_conds_argss then NONE else
+                val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0));
+                val maybe_rhs_info =
+                  (case maybe_rhs of
+                    SOME rhs =>
+                    let
+                      val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs;
+                      val cond_ctrs =
+                        fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs [];
+                      val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs;
+                    in SOME (rhs, raw_rhs, ctr_thms) end
+                  | NONE =>
                     let
-                      val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t)
-                        maybe_ctr_conds_argss
-                        (Const (@{const_name Code.abort}, @{typ String.literal} -->
-                            (@{typ unit} --> body_type fun_T) --> body_type fun_T) $
-                          HOLogic.mk_literal fun_name $
-                          absdummy @{typ unit} (incr_boundvars 1 lhs));
-                    in SOME (rhs, rhs, map snd ctr_alist) end
-                end);
-          in
-            (case maybe_rhs_info of
-              NONE => []
-            | SOME (rhs, raw_rhs, ctr_thms) =>
-              let
-                val ms = map (Logic.count_prems o prop_of) ctr_thms;
-                val (raw_t, t) = (raw_rhs, rhs)
-                  |> pairself
-                    (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
-                      map Bound (length fun_args - 1 downto 0)))
-                    #> HOLogic.mk_Trueprop
-                    #> curry Logic.list_all (map dest_Free fun_args));
-                val (distincts, discIs, sel_splits, sel_split_asms) =
-                  case_thms_of_term lthy bound_Ts raw_rhs;
+                      fun prove_code_ctr {ctr, sels, ...} =
+                        if not (exists (equal ctr o fst) ctr_alist) then NONE else
+                          let
+                            val prems = find_first (equal ctr o #ctr) disc_eqns
+                              |> Option.map #prems |> the_default [];
+                            val t =
+                              filter (equal ctr o #ctr) sel_eqns
+                              |> fst o finds ((op =) o apsnd #sel) sels
+                              |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
+                                #-> abstract)
+                              |> curry list_comb ctr;
+                          in
+                            SOME (prems, t)
+                          end;
+                      val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
+                    in
+                      if exists is_none maybe_ctr_conds_argss then NONE else
+                        let
+                          val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t)
+                            maybe_ctr_conds_argss
+                            (Const (@{const_name Code.abort}, @{typ String.literal} -->
+                                (@{typ unit} --> body_type fun_T) --> body_type fun_T) $
+                              HOLogic.mk_literal fun_name $
+                              absdummy @{typ unit} (incr_boundvars 1 lhs));
+                        in SOME (rhs, rhs, map snd ctr_alist) end
+                    end);
+              in
+                (case maybe_rhs_info of
+                  NONE => []
+                | SOME (rhs, raw_rhs, ctr_thms) =>
+                  let
+                    val ms = map (Logic.count_prems o prop_of) ctr_thms;
+                    val (raw_t, t) = (raw_rhs, rhs)
+                      |> pairself
+                        (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
+                          map Bound (length fun_args - 1 downto 0)))
+                        #> HOLogic.mk_Trueprop
+                        #> curry Logic.list_all (map dest_Free fun_args));
+                    val (distincts, discIs, sel_splits, sel_split_asms) =
+                      case_thms_of_term lthy bound_Ts raw_rhs;
 
-                val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits
-                    sel_split_asms ms ctr_thms
-                  |> K |> Goal.prove lthy [] [] raw_t
-                  |> Thm.close_derivation;
-              in
-                mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm
-                |> K |> Goal.prove lthy [] [] t
-                |> Thm.close_derivation
-                |> single
+                    val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits
+                        sel_split_asms ms ctr_thms
+                      |> K |> Goal.prove lthy [] [] raw_t
+                      |> Thm.close_derivation;
+                  in
+                    mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm
+                    |> K |> Goal.prove lthy [] [] t
+                    |> Thm.close_derivation
+                    |> single
+                  end)
               end)
           end;
 
@@ -1120,13 +1137,13 @@
     (goalss, after_qed, lthy')
   end;
 
-fun add_primcorec_ursive_cmd maybe_tac seq (raw_fixes, raw_specs') lthy =
+fun add_primcorec_ursive_cmd maybe_tac opts (raw_fixes, raw_specs') lthy =
   let
     val (raw_specs, maybe_of_specs) =
       split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy));
     val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy;
   in
-    add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy
+    add_primcorec_ursive maybe_tac opts fixes specs maybe_of_specs lthy
     handle ERROR str => primcorec_error str
   end
   handle Primcorec_Error (str, eqns) =>