src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33132 07efd452a698
parent 33131 cef39362ce56
child 33133 2eb7dfcf3bc3
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -7,8 +7,8 @@
 signature PREDICATE_COMPILE_CORE =
 sig
   val setup: theory -> theory
-  val code_pred: Predicate_Compile_Aux.options -> int list list option -> string -> Proof.context -> Proof.state
-  val code_pred_cmd: Predicate_Compile_Aux.options -> int list list option -> string -> Proof.context -> Proof.state
+  val code_pred: Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
+  val code_pred_cmd: Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
   type smode = (int * int list option) list
   type mode = smode option list * smode
   datatype tmode = Mode of mode * smode * tmode option list;
@@ -39,7 +39,7 @@
   val mk_casesrule : Proof.context -> int -> thm list -> term
   val analyze_compr: theory -> term -> term
   val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
-  val add_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
+  val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
   val code_pred_intros_attrib : attribute
   (* used by Quickcheck_Generator *) 
   (*val funT_of : mode -> typ -> typ
@@ -90,8 +90,8 @@
   val rpred_compfuns : compilation_funs
   val dest_funT : typ -> typ * typ
  (* val depending_preds_of : theory -> thm list -> string list *)
-  val add_quickcheck_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
-  val add_sizelim_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
+  val add_quickcheck_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
+  val add_sizelim_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
   val is_inductive_predicate : theory -> string -> bool
   val terms_vs : term list -> string list
   val subsets : int -> int -> int list list
@@ -398,9 +398,10 @@
      
 (* diagnostic display functions *)
 
-fun print_modes modes = tracing ("Inferred modes:\n" ^
-  cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
-    string_of_mode ms)) modes));
+fun print_modes modes =
+  tracing ("Inferred modes:\n" ^
+    cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
+      string_of_mode ms)) modes));
 
 fun print_pred_mode_table string_of_entry thy pred_mode_table =
   let
@@ -482,6 +483,19 @@
     fold print (all_modes_of thy) ()
   end
 
+(* validity checks *)
+
+fun check_expected_modes (options : Predicate_Compile_Aux.options) modes =
+  case expected_modes options of
+    SOME (s, ms) => (case AList.lookup (op =) modes s of
+      SOME modes =>
+        if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
+          error ("expected modes were not inferred:"
+            ^ "infered modes for " ^ s ^ ": " ^ commas (map (string_of_smode o snd) modes))
+        else ()
+      | NONE => ())
+  | NONE => ()
+
 (* importing introduction rules *)   
 
 fun unify_consts thy cs intr_ts =
@@ -1470,7 +1484,7 @@
                      NONE => in_ts
                    | SOME size_t => in_ts @ [size_t]
                    val u = lift_pred compfuns
-                     (list_comb (compile_expr NONE size thy (mode, t), args))                     
+                     (list_comb (compile_expr NONE size thy (mode, t), args))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -2305,22 +2319,18 @@
 
 (** main function of predicate compiler **)
 
-fun add_equations_of steps options expected_modes prednames thy =
+fun add_equations_of steps options prednames thy =
   let
     val _ = print_step options ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
-    val _ = Output.tracing (commas (map (Display.string_of_thm_global thy) (maps (intros_of thy) prednames)))
+    val _ = tracing (commas (map (Display.string_of_thm_global thy) (maps (intros_of thy) prednames)))
       (*val _ = check_intros_elim_match thy prednames*)
       (*val _ = map (check_format_of_intro_rule thy) (maps (intros_of thy) prednames)*)
     val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
       prepare_intrs thy prednames (maps (intros_of thy) prednames)
     val _ = print_step options "Infering modes..."
     val moded_clauses = #infer_modes steps options thy extra_modes all_modes param_vs clauses 
-    val modes : (string * ((int * int list option) list option list * (int * int list option) list) list) list = map (fn (p, mps) => (p, map fst mps)) moded_clauses
-    val all_smodes : (((int * int list option) list) list) list = map (map snd) (map snd modes)
-    val _ = case expected_modes of
-      SOME ms => if not (forall (fn smodes => eq_set (map (map (rpair NONE)) ms, smodes)) all_smodes) then
-      error ("expected modes were not inferred - " ^ commas (map string_of_smode (flat all_smodes))) else ()
-      | NONE => ()
+    val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
+    val _ = check_expected_modes options modes
     val _ = print_modes modes
     val _ = print_moded_clauses thy moded_clauses
     val _ = print_step options "Defining executable functions..."
@@ -2359,7 +2369,7 @@
 
 fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, [])) 
   
-fun gen_add_equations steps options expected_modes names thy =
+fun gen_add_equations steps options names thy =
   let
     val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
       |> Theory.checkpoint;
@@ -2369,7 +2379,7 @@
     val thy'' = fold_rev
       (fn preds => fn thy =>
         if #are_not_defined steps thy preds then
-          add_equations_of steps options expected_modes preds thy else thy)
+          add_equations_of steps options preds thy else thy)
       scc thy' |> Theory.checkpoint
   in thy'' end
 
@@ -2417,15 +2427,11 @@
 val setup = PredData.put (Graph.empty) #>
   Attrib.setup @{binding code_pred_intros} (Scan.succeed (attrib add_intro))
     "adding alternative introduction rules for code generation of inductive predicates"
-(*  Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib)
-    "adding alternative elimination rules for code generation of inductive predicates";
-    *)
   (*FIXME name discrepancy in attribs and ML code*)
   (*FIXME intros should be better named intro*)
-  (*FIXME why distinguished attribute for cases?*)
 
 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *)
-fun generic_code_pred prep_const options modes raw_const lthy =
+fun generic_code_pred prep_const options raw_const lthy =
   let
     val thy = ProofContext.theory_of lthy
     val const = prep_const thy raw_const
@@ -2455,9 +2461,12 @@
       in
         goal_ctxt |> LocalTheory.theory (fold set_elim global_thms #>
           (if is_rpred options then
-            (add_equations options NONE [const] #>
-             add_sizelim_equations options NONE [const] #> add_quickcheck_equations options NONE [const])
-           else add_equations options modes [const]))
+            (add_equations options [const] #>
+             add_sizelim_equations options [const] #> add_quickcheck_equations options [const])
+           else if is_sizelim options then
+             add_sizelim_equations options [const]
+           else
+             add_equations options [const]))
       end  
   in
     Proof.theorem_i NONE after_qed (map (single o (rpair [])) cases_rules) lthy''