run the SMT relevance filter only once, then run the normalization/monomorphization code once _per class_ of SMT solvers
--- a/src/HOL/Tools/SMT/smt_solver.ML Fri Dec 17 12:10:08 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML Fri Dec 17 15:30:43 2010 +0100
@@ -33,7 +33,7 @@
val default_max_relevant: Proof.context -> string -> int
(*filter*)
- type 'a smt_filter_head_result = 'a list * (int option * thm) list *
+ type 'a smt_filter_head_result = ('a list * (int option * thm) list) *
(((int * thm) list * Proof.context) * (int * (int option * thm)) list)
val smt_filter_head: Proof.state ->
('a * (int option * thm)) list -> int -> 'a smt_filter_head_result
@@ -322,7 +322,7 @@
fun mk_result outcome xrules = { outcome = outcome, used_facts = xrules }
-type 'a smt_filter_head_result = 'a list * (int option * thm) list *
+type 'a smt_filter_head_result = ('a list * (int option * thm) list) *
(((int * thm) list * Proof.context) * (int * (int option * thm)) list)
fun smt_filter_head st xwrules i =
@@ -340,7 +340,7 @@
val (xs, wthms) = split_list xwrules
in
- (xs, wthms,
+ ((xs, wthms),
wthms
|> map_index I
|> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
@@ -349,7 +349,7 @@
end
fun smt_filter_tail time_limit run_remote
- (xs, wthms, ((iwthms', ctxt), iwthms)) =
+ ((xs, wthms), ((iwthms', ctxt), iwthms)) =
let
val ctxt = ctxt |> Config.put C.timeout (Time.toReal time_limit)
val xrules = xs ~~ map snd wthms
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML Fri Dec 17 12:10:08 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML Fri Dec 17 15:30:43 2010 +0100
@@ -64,7 +64,7 @@
val {goal, ...} = Proof.goal state
val problem =
{state = state, goal = goal, subgoal = i, subgoal_count = n,
- facts = facts}
+ facts = facts, smt_head = NONE}
val result as {outcome, used_facts, ...} = prover params (K "") problem
in
print silent
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Fri Dec 17 12:10:08 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Fri Dec 17 15:30:43 2010 +0100
@@ -33,14 +33,16 @@
datatype prover_fact =
Untranslated_Fact of (string * locality) * thm |
ATP_Translated_Fact of
- translated_formula option * ((string * locality) * thm)
+ translated_formula option * ((string * locality) * thm) |
+ SMT_Weighted_Fact of (string * locality) * (int option * thm)
type prover_problem =
{state: Proof.state,
goal: thm,
subgoal: int,
subgoal_count: int,
- facts: prover_fact list}
+ facts: prover_fact list,
+ smt_head: (string * locality) SMT_Solver.smt_filter_head_result option}
type prover_result =
{outcome: failure option,
@@ -56,13 +58,9 @@
val smt_iter_time_frac : real Unsynchronized.ref
val smt_iter_min_msecs : int Unsynchronized.ref
val smt_monomorph_limit : int Unsynchronized.ref
- val smt_weights : bool Unsynchronized.ref
- val smt_min_weight : int Unsynchronized.ref
- val smt_max_weight : int Unsynchronized.ref
- val smt_max_index : int Unsynchronized.ref
- val smt_weight_curve : (int -> int) Unsynchronized.ref
val das_Tool : string
+ val select_smt_solver : string -> Proof.context -> Proof.context
val is_smt_prover : Proof.context -> string -> bool
val is_prover_available : Proof.context -> string -> bool
val is_prover_installed : Proof.context -> string -> bool
@@ -76,6 +74,8 @@
val problem_prefix : string Config.T
val measure_run_time : bool Config.T
val untranslated_fact : prover_fact -> (string * locality) * thm
+ val smt_weighted_fact :
+ prover_fact -> (string * locality) * (int option * thm)
val available_provers : Proof.context -> unit
val kill_provers : unit -> unit
val running_provers : unit -> unit
@@ -102,11 +102,16 @@
"Async_Manager". *)
val das_Tool = "Sledgehammer"
+val unremotify = perhaps (try (unprefix remote_prefix))
+
+val select_smt_solver =
+ Context.proof_map o SMT_Config.select_solver o unremotify
+
fun is_smt_prover ctxt name =
let val smts = SMT_Solver.available_solvers_of ctxt in
case try (unprefix remote_prefix) name of
- SOME suffix => member (op =) smts suffix andalso
- SMT_Solver.is_remotely_available ctxt suffix
+ SOME base => member (op =) smts base andalso
+ SMT_Solver.is_remotely_available ctxt base
| NONE => member (op =) smts name
end
@@ -133,8 +138,7 @@
fun default_max_relevant_for_prover ctxt name =
let val thy = ProofContext.theory_of ctxt in
if is_smt_prover ctxt name then
- SMT_Solver.default_max_relevant ctxt
- (perhaps (try (unprefix remote_prefix)) name)
+ SMT_Solver.default_max_relevant ctxt (unremotify name)
else
#default_max_relevant (get_atp thy name)
end
@@ -146,9 +150,11 @@
@{const_name conj}, @{const_name disj}, @{const_name implies},
@{const_name HOL.eq}, @{const_name If}, @{const_name Let}]
-fun is_built_in_const_for_prover ctxt name (s, T) args =
- if is_smt_prover ctxt name then SMT_Builtin.is_builtin_ext ctxt (s, T) args
- else member (op =) atp_irrelevant_consts s
+fun is_built_in_const_for_prover ctxt name =
+ if is_smt_prover ctxt name then
+ ctxt |> select_smt_solver name |> SMT_Builtin.is_builtin_ext
+ else
+ K o member (op =) atp_irrelevant_consts o fst
(* FUDGE *)
val atp_relevance_fudge =
@@ -230,14 +236,17 @@
datatype prover_fact =
Untranslated_Fact of (string * locality) * thm |
- ATP_Translated_Fact of translated_formula option * ((string * locality) * thm)
+ ATP_Translated_Fact of
+ translated_formula option * ((string * locality) * thm) |
+ SMT_Weighted_Fact of (string * locality) * (int option * thm)
type prover_problem =
{state: Proof.state,
goal: thm,
subgoal: int,
subgoal_count: int,
- facts: prover_fact list}
+ facts: prover_fact list,
+ smt_head: (string * locality) SMT_Solver.smt_filter_head_result option}
type prover_result =
{outcome: failure option,
@@ -272,8 +281,12 @@
fun untranslated_fact (Untranslated_Fact p) = p
| untranslated_fact (ATP_Translated_Fact (_, p)) = p
-fun atp_translated_fact ctxt (Untranslated_Fact p) = translate_atp_fact ctxt p
- | atp_translated_fact _ (ATP_Translated_Fact q) = q
+ | untranslated_fact (SMT_Weighted_Fact (info, (_, th))) = (info, th)
+fun atp_translated_fact _ (ATP_Translated_Fact p) = p
+ | atp_translated_fact ctxt fact =
+ translate_atp_fact ctxt (untranslated_fact fact)
+fun smt_weighted_fact (SMT_Weighted_Fact p) = p
+ | smt_weighted_fact fact = untranslated_fact fact |> apsnd (pair NONE)
fun int_opt_add (SOME m) (SOME n) = SOME (m + n)
| int_opt_add _ _ = NONE
@@ -471,9 +484,26 @@
val smt_iter_min_msecs = Unsynchronized.ref 5000
val smt_monomorph_limit = Unsynchronized.ref 4
-fun smt_filter_loop ({debug, verbose, timeout, ...} : params) remote state i =
+fun smt_filter_loop name ({debug, verbose, overlord, timeout, ...} : params)
+ state i smt_head =
let
val ctxt = Proof.context_of state
+ val (remote, base) =
+ case try (unprefix remote_prefix) name of
+ SOME base => (true, base)
+ | NONE => (false, name)
+ val repair_context =
+ select_smt_solver base
+ #> Config.put SMT_Config.verbose debug
+ #> (if overlord then
+ Config.put SMT_Config.debug_files
+ (overlord_file_location_for_prover name
+ |> (fn (path, base) => path ^ "/" ^ base))
+ else
+ I)
+ #> Config.put SMT_Config.monomorph_limit (!smt_monomorph_limit)
+ val state = state |> Proof.map_context repair_context
+
fun iter timeout iter_num outcome0 time_so_far facts =
let
val timer = Timer.startRealTimer ()
@@ -498,7 +528,9 @@
val _ =
if debug then Output.urgent_message "Invoking SMT solver..." else ()
val (outcome, used_facts) =
- SMT_Solver.smt_filter_head state facts i
+ (case (iter_num, smt_head) of
+ (1, SOME head) => head |> apsnd (apfst (apsnd repair_context))
+ | _ => SMT_Solver.smt_filter_head state facts i)
|> SMT_Solver.smt_filter_tail iter_timeout remote
|> (fn {outcome, used_facts} => (outcome, used_facts))
handle exn => if Exn.is_interrupt exn then
@@ -565,53 +597,14 @@
(Config.put Metis_Tactics.verbose debug
#> (fn ctxt => Metis_Tactics.metis_tac ctxt ths)) state i
-val smt_weights = Unsynchronized.ref true
-val smt_weight_min_facts = 20
-
-(* FUDGE *)
-val smt_min_weight = Unsynchronized.ref 0
-val smt_max_weight = Unsynchronized.ref 10
-val smt_max_index = Unsynchronized.ref 200
-val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x)
-
-fun smt_fact_weight j num_facts =
- if !smt_weights andalso num_facts >= smt_weight_min_facts then
- SOME (!smt_max_weight
- - (!smt_max_weight - !smt_min_weight + 1)
- * !smt_weight_curve (Int.max (0, !smt_max_index - j - 1))
- div !smt_weight_curve (!smt_max_index))
- else
- NONE
-
-fun run_smt_solver auto name (params as {debug, verbose, overlord, ...})
- minimize_command
- ({state, subgoal, subgoal_count, facts, ...} : prover_problem) =
+fun run_smt_solver auto name (params as {debug, verbose, ...}) minimize_command
+ ({state, subgoal, subgoal_count, facts, smt_head, ...}
+ : prover_problem) =
let
- val (remote, suffix) =
- case try (unprefix remote_prefix) name of
- SOME suffix => (true, suffix)
- | NONE => (false, name)
- val repair_context =
- Context.proof_map (SMT_Config.select_solver suffix)
- #> Config.put SMT_Config.verbose debug
- #> (if overlord then
- Config.put SMT_Config.debug_files
- (overlord_file_location_for_prover name
- |> (fn (path, base) => path ^ "/" ^ base))
- else
- I)
- #> Config.put SMT_Config.monomorph_limit (!smt_monomorph_limit)
- val state = state |> Proof.map_context repair_context
- val thy = Proof.theory_of state
- val num_facts = length facts
- val facts =
- facts ~~ (0 upto num_facts - 1)
- |> map (fn (fact, j) =>
- fact |> untranslated_fact
- |> apsnd (pair (smt_fact_weight j num_facts)
- o Thm.transfer thy))
+ val ctxt = Proof.context_of state
+ val facts = facts |> map smt_weighted_fact
val {outcome, used_facts, run_time_in_msecs} =
- smt_filter_loop params remote state subgoal facts
+ smt_filter_loop name params state subgoal smt_head facts
val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts)
val outcome = outcome |> Option.map failure_from_smt_failure
val message =
@@ -622,8 +615,10 @@
if can_apply_metis debug state subgoal (map snd used_facts) then
("metis", "")
else
- ("smt", if suffix = SMT_Config.solver_of @{context} then ""
- else "smt_solver = " ^ maybe_quote suffix)
+ let val base = unremotify name in
+ ("smt", if base = SMT_Config.solver_of ctxt then ""
+ else "smt_solver = " ^ maybe_quote base)
+ end
in
try_command_line (proof_banner auto)
(apply_on_subgoal settings subgoal subgoal_count ^
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Dec 17 12:10:08 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Dec 17 15:30:43 2010 +0100
@@ -43,7 +43,7 @@
fun run_prover (params as {debug, blocking, max_relevant, timeout, expect, ...})
auto minimize_command only
- {state, goal, subgoal, subgoal_count, facts} name =
+ {state, goal, subgoal, subgoal_count, facts, smt_head} name =
let
val ctxt = Proof.context_of state
val birth_time = Time.now ()
@@ -56,7 +56,8 @@
val prover = get_prover ctxt auto name
val problem =
{state = state, goal = goal, subgoal = subgoal,
- subgoal_count = subgoal_count, facts = take num_facts facts}
+ subgoal_count = subgoal_count, facts = take num_facts facts,
+ smt_head = smt_head}
fun go () =
let
fun really_go () =
@@ -115,6 +116,36 @@
(false, state))
end
+val smt_weights = Unsynchronized.ref true
+val smt_weight_min_facts = 20
+
+(* FUDGE *)
+val smt_min_weight = Unsynchronized.ref 0
+val smt_max_weight = Unsynchronized.ref 10
+val smt_max_index = Unsynchronized.ref 200
+val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x)
+
+fun smt_fact_weight j num_facts =
+ if !smt_weights andalso num_facts >= smt_weight_min_facts then
+ SOME (!smt_max_weight
+ - (!smt_max_weight - !smt_min_weight + 1)
+ * !smt_weight_curve (Int.max (0, !smt_max_index - j - 1))
+ div !smt_weight_curve (!smt_max_index))
+ else
+ NONE
+
+fun weight_smt_fact thy num_facts (fact, j) =
+ fact |> apsnd (pair (smt_fact_weight j num_facts) o Thm.transfer thy)
+
+fun class_of_smt_solver ctxt name =
+ ctxt |> select_smt_solver name
+ |> SMT_Config.solver_class_of |> SMT_Utils.string_of_class
+
+(* Makes backtraces more transparent and might be more efficient as well. *)
+fun smart_par_list_map _ [] = []
+ | smart_par_list_map f [x] = [f x]
+ | smart_par_list_map f xs = Par_List.map f xs
+
(* FUDGE *)
val auto_max_relevant_divisor = 2
@@ -129,7 +160,10 @@
| n =>
let
val _ = Proof.assert_backward state
+ val state =
+ state |> Proof.map_context (Config.put SMT_Config.verbose debug)
val ctxt = Proof.context_of state
+ val thy = ProofContext.theory_of ctxt
val {facts = chained_ths, goal, ...} = Proof.goal state
val (_, hyp_ts, concl_t) = strip_subgoal goal i
val no_dangerous_types = types_dangerous_types type_sys
@@ -139,60 +173,83 @@
| NONE => ()
val _ = if auto then () else Output.urgent_message "Sledgehammering..."
val (smts, atps) = provers |> List.partition (is_smt_prover ctxt)
- fun run_provers label no_dangerous_types relevance_fudge maybe_translate
- provers (res as (success, state)) =
+ fun run_provers get_facts translate maybe_smt_head provers
+ (res as (success, state)) =
if success orelse null provers then
res
else
let
- val max_max_relevant =
- case max_relevant of
- SOME n => n
- | NONE =>
- 0 |> fold (Integer.max o default_max_relevant_for_prover ctxt)
- provers
- |> auto ? (fn n => n div auto_max_relevant_divisor)
- val is_built_in_const =
- is_built_in_const_for_prover ctxt (hd provers)
- val facts =
- relevant_facts ctxt no_dangerous_types relevance_thresholds
- max_max_relevant is_built_in_const relevance_fudge
- relevance_override chained_ths hyp_ts concl_t
- |> map maybe_translate
+ val facts = get_facts ()
+ val num_facts = length facts
+ val facts = facts ~~ (0 upto num_facts - 1)
+ |> map (translate num_facts)
val problem =
{state = state, goal = goal, subgoal = i, subgoal_count = n,
- facts = facts}
+ facts = facts,
+ smt_head = maybe_smt_head (map smt_weighted_fact facts) i}
val run_prover = run_prover params auto minimize_command only
in
- if debug then
- Output.urgent_message (label ^ plural_s (length provers) ^ ": " ^
- (if null facts then
- "Found no relevant facts."
- else
- "Including (up to) " ^ string_of_int (length facts) ^
- " relevant fact" ^ plural_s (length facts) ^ ":\n" ^
- (facts |> map (fst o fst o untranslated_fact)
- |> space_implode " ") ^ "."))
- else
- ();
if auto then
fold (fn prover => fn (true, state) => (true, state)
| (false, _) => run_prover problem prover)
provers (false, state)
else
provers
- |> (if blocking andalso length provers > 1 then Par_List.map
- else map)
+ |> (if blocking then smart_par_list_map else map)
(run_prover problem)
|> exists fst |> rpair state
end
+ fun get_facts label no_dangerous_types relevance_fudge provers =
+ let
+ val max_max_relevant =
+ case max_relevant of
+ SOME n => n
+ | NONE =>
+ 0 |> fold (Integer.max o default_max_relevant_for_prover ctxt)
+ provers
+ |> auto ? (fn n => n div auto_max_relevant_divisor)
+ val is_built_in_const =
+ is_built_in_const_for_prover ctxt (hd provers)
+ in
+ relevant_facts ctxt no_dangerous_types relevance_thresholds
+ max_max_relevant is_built_in_const relevance_fudge
+ relevance_override chained_ths hyp_ts concl_t
+ |> tap (fn facts =>
+ if debug then
+ label ^ plural_s (length provers) ^ ": " ^
+ (if null facts then
+ "Found no relevant facts."
+ else
+ "Including (up to) " ^ string_of_int (length facts) ^
+ " relevant fact" ^ plural_s (length facts) ^ ":\n" ^
+ (facts |> map (fst o fst) |> space_implode " ") ^ ".")
+ |> Output.urgent_message
+ else
+ ())
+ end
val run_atps =
- run_provers "ATP" no_dangerous_types atp_relevance_fudge
- (ATP_Translated_Fact o translate_atp_fact ctxt) atps
- val run_smts =
- run_provers "SMT solver" true smt_relevance_fudge Untranslated_Fact smts
+ run_provers
+ (get_facts "ATP" no_dangerous_types atp_relevance_fudge o K atps)
+ (ATP_Translated_Fact oo K (translate_atp_fact ctxt o fst))
+ (K (K NONE)) atps
+ fun run_smts (accum as (success, _)) =
+ if success orelse null smts then
+ accum
+ else
+ let
+ val facts = get_facts "SMT solver" true smt_relevance_fudge smts
+ val translate = SMT_Weighted_Fact oo weight_smt_fact thy
+ val maybe_smt_head = try o SMT_Solver.smt_filter_head state
+ in
+ smts |> map (`(class_of_smt_solver ctxt))
+ |> AList.group (op =)
+ |> map (fn (_, smts) => run_provers (K facts) translate
+ maybe_smt_head smts accum)
+ |> exists fst |> rpair state
+ end
fun run_atps_and_smt_solvers () =
- [run_atps, run_smts] |> Par_List.map (fn f => f (false, state) |> K ())
+ [run_atps, run_smts]
+ |> smart_par_list_map (fn f => f (false, state) |> K ())
handle ERROR msg => (Output.urgent_message ("Error: " ^ msg); error msg)
in
(false, state)