# HG changeset patch # User blanchet # Date 1292596243 -3600 # Node ID 8edeb1dbbc767b91d8e9e63f2b313dc69ec70222 # Parent 7d07736aaaf6448a24af97d7b3bf7340e6087ce2 run the SMT relevance filter only once, then run the normalization/monomorphization code once _per class_ of SMT solvers diff -r 7d07736aaaf6 -r 8edeb1dbbc76 src/HOL/Tools/SMT/smt_solver.ML --- a/src/HOL/Tools/SMT/smt_solver.ML Fri Dec 17 12:10:08 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_solver.ML Fri Dec 17 15:30:43 2010 +0100 @@ -33,7 +33,7 @@ val default_max_relevant: Proof.context -> string -> int (*filter*) - type 'a smt_filter_head_result = 'a list * (int option * thm) list * + type 'a smt_filter_head_result = ('a list * (int option * thm) list) * (((int * thm) list * Proof.context) * (int * (int option * thm)) list) val smt_filter_head: Proof.state -> ('a * (int option * thm)) list -> int -> 'a smt_filter_head_result @@ -322,7 +322,7 @@ fun mk_result outcome xrules = { outcome = outcome, used_facts = xrules } -type 'a smt_filter_head_result = 'a list * (int option * thm) list * +type 'a smt_filter_head_result = ('a list * (int option * thm) list) * (((int * thm) list * Proof.context) * (int * (int option * thm)) list) fun smt_filter_head st xwrules i = @@ -340,7 +340,7 @@ val (xs, wthms) = split_list xwrules in - (xs, wthms, + ((xs, wthms), wthms |> map_index I |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts)) @@ -349,7 +349,7 @@ end fun smt_filter_tail time_limit run_remote - (xs, wthms, ((iwthms', ctxt), iwthms)) = + ((xs, wthms), ((iwthms', ctxt), iwthms)) = let val ctxt = ctxt |> Config.put C.timeout (Time.toReal time_limit) val xrules = xs ~~ map snd wthms diff -r 7d07736aaaf6 -r 8edeb1dbbc76 src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML Fri Dec 17 12:10:08 2010 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML Fri Dec 17 15:30:43 2010 +0100 @@ -64,7 +64,7 @@ val {goal, ...} = Proof.goal state val problem = {state = state, goal = goal, subgoal = i, subgoal_count = n, - facts = facts} + facts = facts, smt_head = NONE} val result as {outcome, used_facts, ...} = prover params (K "") problem in print silent diff -r 7d07736aaaf6 -r 8edeb1dbbc76 src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Fri Dec 17 12:10:08 2010 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Fri Dec 17 15:30:43 2010 +0100 @@ -33,14 +33,16 @@ datatype prover_fact = Untranslated_Fact of (string * locality) * thm | ATP_Translated_Fact of - translated_formula option * ((string * locality) * thm) + translated_formula option * ((string * locality) * thm) | + SMT_Weighted_Fact of (string * locality) * (int option * thm) type prover_problem = {state: Proof.state, goal: thm, subgoal: int, subgoal_count: int, - facts: prover_fact list} + facts: prover_fact list, + smt_head: (string * locality) SMT_Solver.smt_filter_head_result option} type prover_result = {outcome: failure option, @@ -56,13 +58,9 @@ val smt_iter_time_frac : real Unsynchronized.ref val smt_iter_min_msecs : int Unsynchronized.ref val smt_monomorph_limit : int Unsynchronized.ref - val smt_weights : bool Unsynchronized.ref - val smt_min_weight : int Unsynchronized.ref - val smt_max_weight : int Unsynchronized.ref - val smt_max_index : int Unsynchronized.ref - val smt_weight_curve : (int -> int) Unsynchronized.ref val das_Tool : string + val select_smt_solver : string -> Proof.context -> Proof.context val is_smt_prover : Proof.context -> string -> bool val is_prover_available : Proof.context -> string -> bool val is_prover_installed : Proof.context -> string -> bool @@ -76,6 +74,8 @@ val problem_prefix : string Config.T val measure_run_time : bool Config.T val untranslated_fact : prover_fact -> (string * locality) * thm + val smt_weighted_fact : + prover_fact -> (string * locality) * (int option * thm) val available_provers : Proof.context -> unit val kill_provers : unit -> unit val running_provers : unit -> unit @@ -102,11 +102,16 @@ "Async_Manager". *) val das_Tool = "Sledgehammer" +val unremotify = perhaps (try (unprefix remote_prefix)) + +val select_smt_solver = + Context.proof_map o SMT_Config.select_solver o unremotify + fun is_smt_prover ctxt name = let val smts = SMT_Solver.available_solvers_of ctxt in case try (unprefix remote_prefix) name of - SOME suffix => member (op =) smts suffix andalso - SMT_Solver.is_remotely_available ctxt suffix + SOME base => member (op =) smts base andalso + SMT_Solver.is_remotely_available ctxt base | NONE => member (op =) smts name end @@ -133,8 +138,7 @@ fun default_max_relevant_for_prover ctxt name = let val thy = ProofContext.theory_of ctxt in if is_smt_prover ctxt name then - SMT_Solver.default_max_relevant ctxt - (perhaps (try (unprefix remote_prefix)) name) + SMT_Solver.default_max_relevant ctxt (unremotify name) else #default_max_relevant (get_atp thy name) end @@ -146,9 +150,11 @@ @{const_name conj}, @{const_name disj}, @{const_name implies}, @{const_name HOL.eq}, @{const_name If}, @{const_name Let}] -fun is_built_in_const_for_prover ctxt name (s, T) args = - if is_smt_prover ctxt name then SMT_Builtin.is_builtin_ext ctxt (s, T) args - else member (op =) atp_irrelevant_consts s +fun is_built_in_const_for_prover ctxt name = + if is_smt_prover ctxt name then + ctxt |> select_smt_solver name |> SMT_Builtin.is_builtin_ext + else + K o member (op =) atp_irrelevant_consts o fst (* FUDGE *) val atp_relevance_fudge = @@ -230,14 +236,17 @@ datatype prover_fact = Untranslated_Fact of (string * locality) * thm | - ATP_Translated_Fact of translated_formula option * ((string * locality) * thm) + ATP_Translated_Fact of + translated_formula option * ((string * locality) * thm) | + SMT_Weighted_Fact of (string * locality) * (int option * thm) type prover_problem = {state: Proof.state, goal: thm, subgoal: int, subgoal_count: int, - facts: prover_fact list} + facts: prover_fact list, + smt_head: (string * locality) SMT_Solver.smt_filter_head_result option} type prover_result = {outcome: failure option, @@ -272,8 +281,12 @@ fun untranslated_fact (Untranslated_Fact p) = p | untranslated_fact (ATP_Translated_Fact (_, p)) = p -fun atp_translated_fact ctxt (Untranslated_Fact p) = translate_atp_fact ctxt p - | atp_translated_fact _ (ATP_Translated_Fact q) = q + | untranslated_fact (SMT_Weighted_Fact (info, (_, th))) = (info, th) +fun atp_translated_fact _ (ATP_Translated_Fact p) = p + | atp_translated_fact ctxt fact = + translate_atp_fact ctxt (untranslated_fact fact) +fun smt_weighted_fact (SMT_Weighted_Fact p) = p + | smt_weighted_fact fact = untranslated_fact fact |> apsnd (pair NONE) fun int_opt_add (SOME m) (SOME n) = SOME (m + n) | int_opt_add _ _ = NONE @@ -471,9 +484,26 @@ val smt_iter_min_msecs = Unsynchronized.ref 5000 val smt_monomorph_limit = Unsynchronized.ref 4 -fun smt_filter_loop ({debug, verbose, timeout, ...} : params) remote state i = +fun smt_filter_loop name ({debug, verbose, overlord, timeout, ...} : params) + state i smt_head = let val ctxt = Proof.context_of state + val (remote, base) = + case try (unprefix remote_prefix) name of + SOME base => (true, base) + | NONE => (false, name) + val repair_context = + select_smt_solver base + #> Config.put SMT_Config.verbose debug + #> (if overlord then + Config.put SMT_Config.debug_files + (overlord_file_location_for_prover name + |> (fn (path, base) => path ^ "/" ^ base)) + else + I) + #> Config.put SMT_Config.monomorph_limit (!smt_monomorph_limit) + val state = state |> Proof.map_context repair_context + fun iter timeout iter_num outcome0 time_so_far facts = let val timer = Timer.startRealTimer () @@ -498,7 +528,9 @@ val _ = if debug then Output.urgent_message "Invoking SMT solver..." else () val (outcome, used_facts) = - SMT_Solver.smt_filter_head state facts i + (case (iter_num, smt_head) of + (1, SOME head) => head |> apsnd (apfst (apsnd repair_context)) + | _ => SMT_Solver.smt_filter_head state facts i) |> SMT_Solver.smt_filter_tail iter_timeout remote |> (fn {outcome, used_facts} => (outcome, used_facts)) handle exn => if Exn.is_interrupt exn then @@ -565,53 +597,14 @@ (Config.put Metis_Tactics.verbose debug #> (fn ctxt => Metis_Tactics.metis_tac ctxt ths)) state i -val smt_weights = Unsynchronized.ref true -val smt_weight_min_facts = 20 - -(* FUDGE *) -val smt_min_weight = Unsynchronized.ref 0 -val smt_max_weight = Unsynchronized.ref 10 -val smt_max_index = Unsynchronized.ref 200 -val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x) - -fun smt_fact_weight j num_facts = - if !smt_weights andalso num_facts >= smt_weight_min_facts then - SOME (!smt_max_weight - - (!smt_max_weight - !smt_min_weight + 1) - * !smt_weight_curve (Int.max (0, !smt_max_index - j - 1)) - div !smt_weight_curve (!smt_max_index)) - else - NONE - -fun run_smt_solver auto name (params as {debug, verbose, overlord, ...}) - minimize_command - ({state, subgoal, subgoal_count, facts, ...} : prover_problem) = +fun run_smt_solver auto name (params as {debug, verbose, ...}) minimize_command + ({state, subgoal, subgoal_count, facts, smt_head, ...} + : prover_problem) = let - val (remote, suffix) = - case try (unprefix remote_prefix) name of - SOME suffix => (true, suffix) - | NONE => (false, name) - val repair_context = - Context.proof_map (SMT_Config.select_solver suffix) - #> Config.put SMT_Config.verbose debug - #> (if overlord then - Config.put SMT_Config.debug_files - (overlord_file_location_for_prover name - |> (fn (path, base) => path ^ "/" ^ base)) - else - I) - #> Config.put SMT_Config.monomorph_limit (!smt_monomorph_limit) - val state = state |> Proof.map_context repair_context - val thy = Proof.theory_of state - val num_facts = length facts - val facts = - facts ~~ (0 upto num_facts - 1) - |> map (fn (fact, j) => - fact |> untranslated_fact - |> apsnd (pair (smt_fact_weight j num_facts) - o Thm.transfer thy)) + val ctxt = Proof.context_of state + val facts = facts |> map smt_weighted_fact val {outcome, used_facts, run_time_in_msecs} = - smt_filter_loop params remote state subgoal facts + smt_filter_loop name params state subgoal smt_head facts val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts) val outcome = outcome |> Option.map failure_from_smt_failure val message = @@ -622,8 +615,10 @@ if can_apply_metis debug state subgoal (map snd used_facts) then ("metis", "") else - ("smt", if suffix = SMT_Config.solver_of @{context} then "" - else "smt_solver = " ^ maybe_quote suffix) + let val base = unremotify name in + ("smt", if base = SMT_Config.solver_of ctxt then "" + else "smt_solver = " ^ maybe_quote base) + end in try_command_line (proof_banner auto) (apply_on_subgoal settings subgoal subgoal_count ^ diff -r 7d07736aaaf6 -r 8edeb1dbbc76 src/HOL/Tools/Sledgehammer/sledgehammer_run.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Dec 17 12:10:08 2010 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Dec 17 15:30:43 2010 +0100 @@ -43,7 +43,7 @@ fun run_prover (params as {debug, blocking, max_relevant, timeout, expect, ...}) auto minimize_command only - {state, goal, subgoal, subgoal_count, facts} name = + {state, goal, subgoal, subgoal_count, facts, smt_head} name = let val ctxt = Proof.context_of state val birth_time = Time.now () @@ -56,7 +56,8 @@ val prover = get_prover ctxt auto name val problem = {state = state, goal = goal, subgoal = subgoal, - subgoal_count = subgoal_count, facts = take num_facts facts} + subgoal_count = subgoal_count, facts = take num_facts facts, + smt_head = smt_head} fun go () = let fun really_go () = @@ -115,6 +116,36 @@ (false, state)) end +val smt_weights = Unsynchronized.ref true +val smt_weight_min_facts = 20 + +(* FUDGE *) +val smt_min_weight = Unsynchronized.ref 0 +val smt_max_weight = Unsynchronized.ref 10 +val smt_max_index = Unsynchronized.ref 200 +val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x) + +fun smt_fact_weight j num_facts = + if !smt_weights andalso num_facts >= smt_weight_min_facts then + SOME (!smt_max_weight + - (!smt_max_weight - !smt_min_weight + 1) + * !smt_weight_curve (Int.max (0, !smt_max_index - j - 1)) + div !smt_weight_curve (!smt_max_index)) + else + NONE + +fun weight_smt_fact thy num_facts (fact, j) = + fact |> apsnd (pair (smt_fact_weight j num_facts) o Thm.transfer thy) + +fun class_of_smt_solver ctxt name = + ctxt |> select_smt_solver name + |> SMT_Config.solver_class_of |> SMT_Utils.string_of_class + +(* Makes backtraces more transparent and might be more efficient as well. *) +fun smart_par_list_map _ [] = [] + | smart_par_list_map f [x] = [f x] + | smart_par_list_map f xs = Par_List.map f xs + (* FUDGE *) val auto_max_relevant_divisor = 2 @@ -129,7 +160,10 @@ | n => let val _ = Proof.assert_backward state + val state = + state |> Proof.map_context (Config.put SMT_Config.verbose debug) val ctxt = Proof.context_of state + val thy = ProofContext.theory_of ctxt val {facts = chained_ths, goal, ...} = Proof.goal state val (_, hyp_ts, concl_t) = strip_subgoal goal i val no_dangerous_types = types_dangerous_types type_sys @@ -139,60 +173,83 @@ | NONE => () val _ = if auto then () else Output.urgent_message "Sledgehammering..." val (smts, atps) = provers |> List.partition (is_smt_prover ctxt) - fun run_provers label no_dangerous_types relevance_fudge maybe_translate - provers (res as (success, state)) = + fun run_provers get_facts translate maybe_smt_head provers + (res as (success, state)) = if success orelse null provers then res else let - val max_max_relevant = - case max_relevant of - SOME n => n - | NONE => - 0 |> fold (Integer.max o default_max_relevant_for_prover ctxt) - provers - |> auto ? (fn n => n div auto_max_relevant_divisor) - val is_built_in_const = - is_built_in_const_for_prover ctxt (hd provers) - val facts = - relevant_facts ctxt no_dangerous_types relevance_thresholds - max_max_relevant is_built_in_const relevance_fudge - relevance_override chained_ths hyp_ts concl_t - |> map maybe_translate + val facts = get_facts () + val num_facts = length facts + val facts = facts ~~ (0 upto num_facts - 1) + |> map (translate num_facts) val problem = {state = state, goal = goal, subgoal = i, subgoal_count = n, - facts = facts} + facts = facts, + smt_head = maybe_smt_head (map smt_weighted_fact facts) i} val run_prover = run_prover params auto minimize_command only in - if debug then - Output.urgent_message (label ^ plural_s (length provers) ^ ": " ^ - (if null facts then - "Found no relevant facts." - else - "Including (up to) " ^ string_of_int (length facts) ^ - " relevant fact" ^ plural_s (length facts) ^ ":\n" ^ - (facts |> map (fst o fst o untranslated_fact) - |> space_implode " ") ^ ".")) - else - (); if auto then fold (fn prover => fn (true, state) => (true, state) | (false, _) => run_prover problem prover) provers (false, state) else provers - |> (if blocking andalso length provers > 1 then Par_List.map - else map) + |> (if blocking then smart_par_list_map else map) (run_prover problem) |> exists fst |> rpair state end + fun get_facts label no_dangerous_types relevance_fudge provers = + let + val max_max_relevant = + case max_relevant of + SOME n => n + | NONE => + 0 |> fold (Integer.max o default_max_relevant_for_prover ctxt) + provers + |> auto ? (fn n => n div auto_max_relevant_divisor) + val is_built_in_const = + is_built_in_const_for_prover ctxt (hd provers) + in + relevant_facts ctxt no_dangerous_types relevance_thresholds + max_max_relevant is_built_in_const relevance_fudge + relevance_override chained_ths hyp_ts concl_t + |> tap (fn facts => + if debug then + label ^ plural_s (length provers) ^ ": " ^ + (if null facts then + "Found no relevant facts." + else + "Including (up to) " ^ string_of_int (length facts) ^ + " relevant fact" ^ plural_s (length facts) ^ ":\n" ^ + (facts |> map (fst o fst) |> space_implode " ") ^ ".") + |> Output.urgent_message + else + ()) + end val run_atps = - run_provers "ATP" no_dangerous_types atp_relevance_fudge - (ATP_Translated_Fact o translate_atp_fact ctxt) atps - val run_smts = - run_provers "SMT solver" true smt_relevance_fudge Untranslated_Fact smts + run_provers + (get_facts "ATP" no_dangerous_types atp_relevance_fudge o K atps) + (ATP_Translated_Fact oo K (translate_atp_fact ctxt o fst)) + (K (K NONE)) atps + fun run_smts (accum as (success, _)) = + if success orelse null smts then + accum + else + let + val facts = get_facts "SMT solver" true smt_relevance_fudge smts + val translate = SMT_Weighted_Fact oo weight_smt_fact thy + val maybe_smt_head = try o SMT_Solver.smt_filter_head state + in + smts |> map (`(class_of_smt_solver ctxt)) + |> AList.group (op =) + |> map (fn (_, smts) => run_provers (K facts) translate + maybe_smt_head smts accum) + |> exists fst |> rpair state + end fun run_atps_and_smt_solvers () = - [run_atps, run_smts] |> Par_List.map (fn f => f (false, state) |> K ()) + [run_atps, run_smts] + |> smart_par_list_map (fn f => f (false, state) |> K ()) handle ERROR msg => (Output.urgent_message ("Error: " ^ msg); error msg) in (false, state)