added weights to SMT problems
authorblanchet
Wed, 15 Dec 2010 16:42:07 +0100
changeset 41168 f6f1ffd51d87
parent 41167 b05014180288
child 41169 95167879f675
added weights to SMT problems
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Dec 15 16:42:06 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Dec 15 16:42:07 2010 +0100
@@ -50,6 +50,18 @@
 
   type prover = params -> minimize_command -> prover_problem -> prover_result
 
+  (* for experimentation purposes -- do not use in production code *)
+  val smt_max_iter : int Unsynchronized.ref
+  val smt_iter_fact_divisor : int Unsynchronized.ref
+  val smt_iter_min_msecs : int Unsynchronized.ref
+  val smt_iter_timeout_divisor : int Unsynchronized.ref
+  val smt_monomorph_limit : int Unsynchronized.ref
+  val smt_weights : bool Unsynchronized.ref
+  val smt_min_weight : int Unsynchronized.ref
+  val smt_max_weight : int Unsynchronized.ref
+  val smt_max_index : int Unsynchronized.ref
+  val smt_weight_curve : (int -> int) Unsynchronized.ref
+
   val das_Tool : string
   val is_smt_prover : Proof.context -> string -> bool
   val is_prover_available : Proof.context -> string -> bool
@@ -269,6 +281,8 @@
 fun overlord_file_location_for_prover prover =
   (getenv "ISABELLE_HOME_USER", "prob_" ^ prover)
 
+val atp_first_iter_frac = 0.67
+
 (* Important messages are important but not so important that users want to see
    them each time. *)
 val important_message_keep_factor = 0.1
@@ -353,11 +367,12 @@
             val run_twice = has_incomplete_mode andalso not auto
             val timer = Timer.startRealTimer ()
             val result =
-              run false (if run_twice then
-                           Time.fromMilliseconds
-                                         (2 * Time.toMilliseconds timeout div 3)
-                         else
-                           timeout)
+              run false
+                 (if run_twice then
+                    seconds (0.001 * atp_first_iter_frac
+                             * Real.fromInt (Time.toMilliseconds timeout))
+                  else
+                    timeout)
               |> run_twice
                  ? (fn (_, msecs0, _, SOME _) =>
                        run true (Time.- (timeout, Timer.checkRealTimer timer))
@@ -437,23 +452,23 @@
   | failure_from_smt_failure _ = UnknownError
 
 (* FUDGE *)
-val smt_max_iter = 8
-val smt_iter_fact_divisor = 2
-val smt_iter_min_msecs = 5000
-val smt_iter_timeout_divisor = 2
-val smt_monomorph_limit = 4
+val smt_max_iter = Unsynchronized.ref 8
+val smt_iter_fact_divisor = Unsynchronized.ref 2
+val smt_iter_min_msecs = Unsynchronized.ref 5000
+val smt_iter_timeout_divisor = Unsynchronized.ref 2
+val smt_monomorph_limit = Unsynchronized.ref 4
 
 fun smt_filter_loop ({verbose, timeout, ...} : params) remote state i =
   let
     val ctxt = Proof.context_of state
-    fun iter timeout iter_num outcome0 msecs_so_far facts =
+    fun iter timeout iter_num outcome0 time_so_far facts =
       let
         val timer = Timer.startRealTimer ()
         val ms = timeout |> Time.toMilliseconds
         val iter_timeout =
-          if iter_num < smt_max_iter then
-            Int.min (ms, Int.max (smt_iter_min_msecs,
-                                  ms div smt_iter_timeout_divisor))
+          if iter_num < !smt_max_iter then
+            Int.min (ms, Int.max (!smt_iter_min_msecs,
+                                  ms div !smt_iter_timeout_divisor))
             |> Time.fromMilliseconds
           else
             timeout
@@ -465,8 +480,10 @@
             |> Output.urgent_message
           else
             ()
-        val {outcome, used_facts, run_time_in_msecs} =
+        val birth = Timer.checkRealTimer timer
+        val {outcome, used_facts, ...} =
           SMT_Solver.smt_filter remote iter_timeout state facts i
+        val death = Timer.checkRealTimer timer
         val _ =
           if verbose andalso is_some outcome then
             "SMT outcome: " ^ SMT_Failure.string_of_failure ctxt (the outcome)
@@ -474,7 +491,7 @@
           else
             ()
         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
-        val msecs_so_far = int_opt_add run_time_in_msecs msecs_so_far
+        val time_so_far = Time.+ (time_so_far, Time.- (death, birth))
         val too_many_facts_perhaps =
           case outcome of
             NONE => false
@@ -493,16 +510,17 @@
           | SOME _ => true
         val timeout = Time.- (timeout, Timer.checkRealTimer timer)
       in
-        if too_many_facts_perhaps andalso iter_num < smt_max_iter andalso
+        if too_many_facts_perhaps andalso iter_num < !smt_max_iter andalso
            num_facts > 0 andalso Time.> (timeout, Time.zeroTime) then
-          let val facts = take (num_facts div smt_iter_fact_divisor) facts in
-            iter timeout (iter_num + 1) outcome0 msecs_so_far facts
+          let val facts = take (num_facts div !smt_iter_fact_divisor) facts in
+            iter timeout (iter_num + 1) outcome0 time_so_far facts
           end
         else
           {outcome = if is_none outcome then NONE else the outcome0,
-           used_facts = used_facts, run_time_in_msecs = msecs_so_far}
+           used_facts = used_facts,
+           run_time_in_msecs = SOME (Time.toMilliseconds time_so_far)}
       end
-  in iter timeout 1 NONE (SOME 0) end
+  in iter timeout 1 NONE Time.zeroTime end
 
 (* taken from "Mirabelle" and generalized *)
 fun can_apply timeout tac state i =
@@ -522,7 +540,26 @@
             (Config.put Metis_Tactics.verbose debug
              #> (fn ctxt => Metis_Tactics.metis_tac ctxt ths)) state i
 
-fun run_smt_solver auto name (params as {debug, overlord, ...}) minimize_command
+val smt_weights = Unsynchronized.ref true
+val smt_weight_min_facts = 20
+
+(* FUDGE *)
+val smt_min_weight = Unsynchronized.ref 0
+val smt_max_weight = Unsynchronized.ref 10
+val smt_max_index = Unsynchronized.ref 200
+val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x)
+
+fun smt_fact_weight j num_facts =
+  if !smt_weights andalso num_facts >= smt_weight_min_facts then
+    SOME (!smt_max_weight
+          - (!smt_max_weight - !smt_min_weight)
+            * !smt_weight_curve (!smt_max_index - j)
+            div !smt_weight_curve (!smt_max_index))
+  else
+    NONE
+
+fun run_smt_solver auto name (params as {debug, verbose, overlord, ...})
+        minimize_command
         ({state, subgoal, subgoal_count, facts, ...} : prover_problem) =
   let
     val (remote, suffix) =
@@ -538,10 +575,16 @@
                         |> (fn (path, base) => path ^ "/" ^ base))
           else
             I)
-      #> Config.put SMT_Config.monomorph_limit smt_monomorph_limit
+      #> Config.put SMT_Config.monomorph_limit (!smt_monomorph_limit)
     val state = state |> Proof.map_context repair_context
     val thy = Proof.theory_of state
-    val facts = facts |> map (apsnd (pair NONE o Thm.transfer thy) o untranslated_fact)
+    val num_facts = length facts
+    val facts =
+      facts ~~ (0 upto num_facts - 1)
+      |> map (fn (fact, j) =>
+                 fact |> untranslated_fact
+                      |> apsnd (pair (smt_fact_weight j num_facts)
+                                o Thm.transfer thy))
     val {outcome, used_facts, run_time_in_msecs} =
       smt_filter_loop params remote state subgoal facts
     val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts)
@@ -561,7 +604,13 @@
               (apply_on_subgoal settings subgoal subgoal_count ^
                command_call method (map fst other_lemmas)) ^
           minimize_line minimize_command
-                        (map fst (other_lemmas @ chained_lemmas))
+                        (map fst (other_lemmas @ chained_lemmas)) ^
+          (if verbose then
+             "\nSMT solver real CPU time: " ^
+             string_from_time (Time.fromMilliseconds (the run_time_in_msecs)) ^
+             "."
+           else
+             "")
         end
       | SOME failure => string_for_failure "SMT solver" failure
   in