tuned ML interface
authorblanchet
Thu, 13 Mar 2014 13:18:13 +0100
changeset 56082 ffd99d397a9f
parent 56081 72fad75baf7e
child 56083 b5d1d9c60341
tuned ML interface
src/HOL/Tools/SMT2/smt2_solver.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
--- a/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 13:18:13 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 13:18:13 2014 +0100
@@ -30,12 +30,8 @@
   val default_max_relevant: Proof.context -> string -> int
 
   (*filter*)
-  type 'a smt2_filter_data =
-    ('a * thm) list * ((int * thm) list * Proof.context)
-  val smt2_filter_preprocess: Proof.context -> thm list -> thm ->
-    ('a * (int option * thm)) list -> int -> 'a smt2_filter_data
-  val smt2_filter_apply: Time.time -> 'a smt2_filter_data ->
-    {outcome: SMT2_Failure.failure option, used_facts: ('a * thm) list}
+  val smt2_filter: Proof.context -> thm list -> thm -> ('a * (int option * thm)) list -> int ->
+    Time.time -> {outcome: SMT2_Failure.failure option, used_facts: ('a * thm) list}
 
   (*tactic*)
   val smt2_tac: Proof.context -> thm list -> int -> tactic
@@ -225,19 +221,15 @@
   let val name = solver_name_of ctxt
   in (name, get_info ctxt name) end
 
-fun gen_preprocess ctxt iwthms = SMT2_Normalize.normalize iwthms ctxt
-
-fun gen_apply (ithms, ctxt) =
-  let val (name, {command, replay, ...}) = name_and_info_of ctxt
+fun apply_solver ctxt0 iwthms =
+  let
+    val (ithms, ctxt) = SMT2_Normalize.normalize iwthms ctxt0
+    val (name, {command, replay, ...}) = name_and_info_of ctxt
   in
-    (ithms, ctxt)
-    |-> invoke name command
+    invoke name command ithms ctxt
     |> replay ctxt
-    |>> distinct (op =)
   end
 
-fun apply_solver ctxt = gen_apply o gen_preprocess ctxt
-
 val default_max_relevant = #default_max_relevant oo get_info
 
 val supports_filter = #supports_filter o snd o name_and_info_of 
@@ -266,43 +258,33 @@
 
 fun mk_result outcome xrules = { outcome = outcome, used_facts = xrules }
 
-type 'a smt2_filter_data = ('a * thm) list * ((int * thm) list * Proof.context)
-
-fun smt2_filter_preprocess ctxt facts goal xwthms i =
+fun smt2_filter ctxt facts goal xwthms i time_limit =
   let
     val ctxt =
       ctxt
       |> Config.put SMT2_Config.oracle false
       |> Config.put SMT2_Config.filter_only_facts true
+      |> Config.put SMT2_Config.timeout (Time.toReal time_limit)
 
-    val ({context=ctxt', prems, concl, ...}, _) = Subgoal.focus ctxt i goal
+    val ({context=ctxt, prems, concl, ...}, _) = Subgoal.focus ctxt i goal
     fun negate ct = Thm.dest_comb ct ||> Thm.apply cnot |-> Thm.apply
     val cprop =
-      (case try negate (Thm.rhs_of (SMT2_Normalize.atomize_conv ctxt' concl)) of
+      (case try negate (Thm.rhs_of (SMT2_Normalize.atomize_conv ctxt concl)) of
         SOME ct => ct
       | NONE => raise SMT2_Failure.SMT (SMT2_Failure.Other_Failure (
           "goal is not a HOL term")))
+
+    val xthms = map (apsnd snd) xwthms
+
+    fun filter_thms false = K xthms
+      | filter_thms true = map_filter (try (nth xthms)) o fst
   in
     map snd xwthms
     |> map_index I
     |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
-    |> check_topsorts ctxt'
-    |> gen_preprocess ctxt'
-    |> pair (map (apsnd snd) xwthms)
-  end
-
-fun smt2_filter_apply time_limit (xthms, (ithms, ctxt)) =
-  let
-    val ctxt' =
-      ctxt
-      |> Config.put SMT2_Config.timeout (Time.toReal time_limit)
-
-    fun filter_thms false = K xthms
-      | filter_thms true = map_filter (try (nth xthms)) o fst
-  in
-    (ithms, ctxt')
-    |> gen_apply
-    |> filter_thms (supports_filter ctxt')
+    |> check_topsorts ctxt
+    |> apply_solver ctxt
+    |> filter_thms (supports_filter ctxt)
     |> mk_result NONE
   end
   handle SMT2_Failure.SMT fail => mk_result (SOME fail) []
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 13:18:13 2014 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 13:18:13 2014 +0100
@@ -154,8 +154,7 @@
         val _ = if debug then Output.urgent_message "Invoking SMT solver..." else ()
 
         val (outcome, used_facts) =
-          SMT2_Solver.smt2_filter_preprocess ctxt [] goal weighted_facts i
-          |> SMT2_Solver.smt2_filter_apply slice_timeout
+          SMT2_Solver.smt2_filter ctxt [] goal weighted_facts i slice_timeout
           |> (fn {outcome, used_facts} => (outcome, used_facts))
           handle exn =>
             if Exn.is_interrupt exn then reraise exn