split "smt_filter" into head and tail
authorblanchet
Fri, 17 Dec 2010 12:02:46 +0100
changeset 41239 d6e804ff29c3
parent 41238 78e4508d2e54
child 41240 5965c8c97210
split "smt_filter" into head and tail
src/HOL/Tools/SMT/smt_solver.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
--- a/src/HOL/Tools/SMT/smt_solver.ML	Fri Dec 17 12:01:49 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML	Fri Dec 17 12:02:46 2010 +0100
@@ -33,10 +33,12 @@
   val default_max_relevant: Proof.context -> string -> int
 
   (*filter*)
-  val smt_filter: bool -> Time.time -> Proof.state ->
-    ('a * (int option * thm)) list -> int ->
-    {outcome: SMT_Failure.failure option, used_facts: ('a * thm) list,
-    run_time_in_msecs: int option}
+  type 'a smt_filter_head_result = 'a list * (int option * thm) list *
+    (((int * thm) list * Proof.context) * (int * (int option * thm)) list)
+  val smt_filter_head: Time.time -> Proof.state ->
+    ('a * (int option * thm)) list -> int -> 'a smt_filter_head_result
+  val smt_filter_tail: bool -> 'a smt_filter_head_result ->
+    {outcome: SMT_Failure.failure option, used_facts: ('a * thm) list}
 
   (*tactic*)
   val smt_tac': bool -> Proof.context -> thm list -> int -> Tactical.tactic
@@ -212,16 +214,19 @@
     int list * thm,
   default_max_relevant: int }
 
-fun gen_solver name (info : solver_info) rm ctxt iwthms =
+fun gen_solver_head ctxt iwthms =
+  SMT_Normalize.normalize ctxt iwthms
+  |> rpair ctxt
+  |-> SMT_Monomorph.monomorph
+
+fun gen_solver_tail (name, info : solver_info) rm (iwthms', ctxt) iwthms =
   let
     val {env_var, is_remote, options, reconstruct, ...} = info
     val cmd = (rm, env_var, is_remote, name)
   in
-    SMT_Normalize.normalize ctxt iwthms
-    |> rpair ctxt
-    |-> SMT_Monomorph.monomorph
-    |> (fn (iwthms', ctxt') => invoke name cmd options iwthms' ctxt'
-    |> reconstruct ctxt')
+    (iwthms', ctxt)
+    |-> invoke name cmd options
+    |> reconstruct ctxt
     |> (fn (idxs, thm) => thm
     |> tap (fn _ => trace_assumptions ctxt iwthms idxs)
     |> pair idxs)
@@ -284,9 +289,8 @@
   in (name, get_info ctxt name) end
 
 val solver_name_of = fst o name_and_solver_of
-fun solver_of ctxt =
-  let val (name, raw_solver) = name_and_solver_of ctxt
-  in gen_solver name raw_solver end
+fun solver_of ctxt rm ctxt' =
+  `(gen_solver_head ctxt') #-> gen_solver_tail (name_and_solver_of ctxt) rm
 
 val available_solvers_of = Symtab.keys o Solvers.get o Context.Proof
 
@@ -306,19 +310,22 @@
   | TVar (_, []) => true
   | _ => false))
 
-fun smt_solver rm ctxt iwthms =
-  (* without this test, we would run into problems when atomizing the rules: *)
+(* without this test, we would run into problems when atomizing the rules: *)
+fun check_topsort iwthms =
   if exists (has_topsort o Thm.prop_of o snd o snd) iwthms then
     raise SMT_Failure.SMT (SMT_Failure.Other_Failure ("proof state " ^
       "contains the universal sort {}"))
-  else solver_of ctxt rm ctxt iwthms
+  else
+    ()
 
 val cnot = Thm.cterm_of @{theory} @{const Not}
 
-fun mk_result outcome xrules =
-  { outcome = outcome, used_facts = xrules, run_time_in_msecs = NONE }
+fun mk_result outcome xrules = { outcome = outcome, used_facts = xrules }
 
-fun smt_filter run_remote time_limit st xwrules i =
+type 'a smt_filter_head_result = 'a list * (int option * thm) list *
+  (((int * thm) list * Proof.context) * (int * (int option * thm)) list)
+
+fun smt_filter_head time_limit st xwrules i =
   let
     val ctxt =
       Proof.context_of st
@@ -333,20 +340,25 @@
     val cprop = negate (Thm.rhs_of (SMT_Normalize.atomize_conv ctxt' concl))
 
     val (xs, wthms) = split_list xwrules
-    val xrules = xs ~~ map snd wthms
   in
-    wthms
-    |> map_index I
-    |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
-    |> smt_solver (SOME run_remote) ctxt'
+    (xs, wthms,
+     wthms
+     |> map_index I
+     |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
+     |> tap check_topsort
+     |> `(gen_solver_head ctxt'))
+  end
+
+fun smt_filter_tail run_remote (xs, wthms, head as ((_, ctxt), _)) =
+  let val xrules = xs ~~ map snd wthms in
+    head
+    |-> gen_solver_tail (name_and_solver_of ctxt) (SOME run_remote)
     |> distinct (op =) o fst
     |> map_filter (try (nth xrules))
     |> (if solver_name_of ctxt = "z3" (* FIXME *) then I else K xrules)
     |> mk_result NONE
   end
   handle SMT_Failure.SMT fail => mk_result (SOME fail) []
-  (* FIXME: measure runtime *)
-
 
 
 (* SMT tactic *)
@@ -356,7 +368,7 @@
   THEN' Tactic.rtac @{thm ccontr}
   THEN' SUBPROOF (fn {context=ctxt', prems, ...} =>
     let
-      fun solve iwthms = snd (smt_solver NONE ctxt' iwthms)
+      val solve = snd o solver_of ctxt' NONE ctxt' o tap check_topsort
       val tag = "Solver " ^ C.solver_of ctxt' ^ ": "
       val str_of = prefix tag o SMT_Failure.string_of_failure ctxt'
       fun safe_solve iwthms =
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Fri Dec 17 12:01:49 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Fri Dec 17 12:02:46 2010 +0100
@@ -498,8 +498,9 @@
         val _ =
           if debug then Output.urgent_message "Invoking SMT solver..." else ()
         val (outcome, used_facts) =
-          SMT_Solver.smt_filter remote iter_timeout state facts i
-          |> (fn {outcome, used_facts, ...} => (outcome, used_facts))
+          SMT_Solver.smt_filter_head iter_timeout state facts i
+          |> SMT_Solver.smt_filter_tail remote
+          |> (fn {outcome, used_facts} => (outcome, used_facts))
           handle exn => if Exn.is_interrupt exn then
                           reraise exn
                         else