refactoring mode inference so that the theory is not changed in the mode inference procedure
authorbulwahn
Fri, 10 Sep 2010 10:59:07 +0200
changeset 39273 92aa2a0f7399
parent 39272 0b61951d2682
child 39274 b17ffa965223
refactoring mode inference so that the theory is not changed in the mode inference procedure
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Fri Sep 10 10:21:25 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Fri Sep 10 10:59:07 2010 +0200
@@ -73,9 +73,12 @@
   type 'a pred_mode_table = (string * ((bool * mode) * 'a) list) list
 
   val infer_modes : 
-    mode_analysis_options -> options -> compilation -> (string * typ) list -> (string * mode list) list ->
-      string list -> (string *  (Term.term list * Predicate_Compile_Aux.indprem list) list) list ->
-      theory -> ((moded_clause list pred_mode_table * string list) * theory)
+    mode_analysis_options -> options ->
+     (string -> Predicate_Compile_Aux.mode list) * (string -> Predicate_Compile_Aux.mode list)
+       * (string -> Predicate_Compile_Aux.mode -> bool) -> Proof.context -> (string * typ) list ->
+      (string * mode list) list ->
+      string list -> (string * (Term.term list * Predicate_Compile_Aux.indprem list) list) list ->
+      ((moded_clause list pred_mode_table * (string * mode list) list) * string list)
 end;
 
 structure Predicate_Compile_Core : PREDICATE_COMPILE_CORE =
@@ -1498,13 +1501,13 @@
       cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map string_of_ext_mode ms)) modes))
   else ()
 
-fun infer_modes mode_analysis_options options compilation preds all_modes param_vs clauses thy =
+fun infer_modes mode_analysis_options options (lookup_mode, lookup_neg_mode, needs_random) ctxt
+  preds all_modes param_vs clauses =
   let
-    val ctxt = ProofContext.init_global thy  
     val collect_errors = false
     fun appair f (x1, x2) (y1, y2) = (f x1 y1, f x2 y2)
     fun add_needs_random s (false, m) = ((false, m), false)
-      | add_needs_random s (true, m) = ((true, m), needs_random ctxt s m)
+      | add_needs_random s (true, m) = ((true, m), needs_random s m)
     fun add_polarity_and_random_bit s b ms = map (fn m => add_needs_random s (b, m)) ms
     val prednames = map fst preds
     (* extramodes contains all modes of all constants, should we only use the necessary ones
@@ -1516,18 +1519,16 @@
       | predname_of _ = I
     val relevant_prednames = fold (fn (_, clauses') =>
       fold (fn (_, ps) => fold Term.add_const_names (map dest_indprem ps)) clauses') clauses []
+      |> filter_out (fn name => member (op =) prednames name)
     val extra_modes =
       if #infer_pos_and_neg_modes mode_analysis_options then
         let
           val pos_extra_modes =
-            map_filter (fn name => Option.map (pair name) (try (modes_of compilation ctxt) name))
+            map_filter (fn name => Option.map (pair name) (try lookup_mode name))
             relevant_prednames
-            |> filter_out (fn (name, _) => member (op =) prednames name)
           val neg_extra_modes =
-          map_filter (fn name => Option.map (pair name)
-            (try (modes_of (negative_compilation_of compilation) ctxt) name))
-            relevant_prednames
-            |> filter_out (fn (name, _) => member (op =) prednames name)
+            map_filter (fn name => Option.map (pair name) (try lookup_neg_mode name))
+              relevant_prednames
         in
           map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms)
                 @ add_polarity_and_random_bit s false (the (AList.lookup (op =) neg_extra_modes s))))
@@ -1535,9 +1536,8 @@
         end
       else
         map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms)))
-          (map_filter (fn name => Option.map (pair name) (try (modes_of compilation ctxt) name))
-            relevant_prednames
-          |> filter_out (fn (name, _) => member (op =) prednames name))
+          (map_filter (fn name => Option.map (pair name) (try lookup_mode name))
+            relevant_prednames)
     val _ = print_extra_modes options extra_modes
     val start_modes =
       if #infer_pos_and_neg_modes mode_analysis_options then
@@ -1559,11 +1559,11 @@
           (fixp (fn modes => map fst (iteration modes)) start_modes, []))
     val moded_clauses = map (get_modes_pred' mode_analysis_options ctxt param_vs clauses
       (modes @ extra_modes)) modes
-    val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then
-      set_needs_random s (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms) else I)
-      modes thy
+    val need_random = fold (fn (s, ms) => if member (op =) (map fst preds) s then
+      cons (s, (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms)) else I)
+      modes []
   in
-    ((moded_clauses, errors), thy')
+    ((moded_clauses, need_random), errors)
   end;
 
 (* term construction *)
@@ -2845,10 +2845,13 @@
     val (preds, all_vs, param_vs, all_modes, clauses) =
       prepare_intrs options compilation thy prednames (maps (intros_of ctxt) prednames)
     val _ = print_step options "Infering modes..."
-    val ((moded_clauses, errors), thy') =
+    val (lookup_mode, lookup_neg_mode, needs_random) = (modes_of compilation ctxt,
+      modes_of (negative_compilation_of compilation) ctxt, needs_random ctxt)
+    val ((moded_clauses, needs_random), errors) =
       Output.cond_timeit (!Quickcheck.timing) "Infering modes"
       (fn _ => infer_modes mode_analysis_options
-        options compilation preds all_modes param_vs clauses thy)
+        options (lookup_mode, lookup_neg_mode, needs_random) ctxt preds all_modes param_vs clauses)
+    val thy' = fold (fn (s, ms) => set_needs_random s ms) needs_random thy
     val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
     val _ = check_expected_modes preds options modes
     val _ = check_proposed_modes preds options modes errors