src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
author blanchet
Fri May 16 19:13:50 2014 +0200 (2014-05-16)
changeset 56981 3ef45ce002b5
parent 56303 4cc3f4db3447
child 56983 132142089ea6
permissions -rw-r--r--
honor original format of conjecture or hypotheses in Z3-to-Isar proofs
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
     2     Author:     Fabian Immler, TU Muenchen
     3     Author:     Makarius
     4     Author:     Jasmin Blanchette, TU Muenchen
     5 
     6 SMT solvers as Sledgehammer provers.
     7 *)
     8 
     9 signature SLEDGEHAMMER_PROVER_SMT2 =
    10 sig
    11   type stature = ATP_Problem_Generate.stature
    12   type mode = Sledgehammer_Prover.mode
    13   type prover = Sledgehammer_Prover.prover
    14 
    15   val smt2_builtins : bool Config.T
    16   val smt2_triggers : bool Config.T
    17   val smt2_weights : bool Config.T
    18   val smt2_weight_min_facts : int Config.T
    19   val smt2_min_weight : int Config.T
    20   val smt2_max_weight : int Config.T
    21   val smt2_max_weight_index : int Config.T
    22   val smt2_weight_curve : (int -> int) Unsynchronized.ref
    23   val smt2_max_slices : int Config.T
    24   val smt2_slice_fact_frac : real Config.T
    25   val smt2_slice_time_frac : real Config.T
    26   val smt2_slice_min_secs : int Config.T
    27 
    28   val is_smt2_prover : Proof.context -> string -> bool
    29   val run_smt2_solver : mode -> string -> prover
    30 end;
    31 
    32 structure Sledgehammer_Prover_SMT2 : SLEDGEHAMMER_PROVER_SMT2 =
    33 struct
    34 
    35 open ATP_Util
    36 open ATP_Proof
    37 open ATP_Systems
    38 open ATP_Problem_Generate
    39 open ATP_Proof_Reconstruct
    40 open Sledgehammer_Util
    41 open Sledgehammer_Proof_Methods
    42 open Sledgehammer_Isar
    43 open Sledgehammer_Prover
    44 
    45 val smt2_builtins = Attrib.setup_config_bool @{binding sledgehammer_smt2_builtins} (K true)
    46 val smt2_triggers = Attrib.setup_config_bool @{binding sledgehammer_smt2_triggers} (K true)
    47 val smt2_weights = Attrib.setup_config_bool @{binding sledgehammer_smt2_weights} (K true)
    48 val smt2_weight_min_facts =
    49   Attrib.setup_config_int @{binding sledgehammer_smt2_weight_min_facts} (K 20)
    50 
    51 val is_smt2_prover = member (op =) o SMT2_Config.available_solvers_of
    52 
    53 (* FUDGE *)
    54 val smt2_min_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_min_weight} (K 0)
    55 val smt2_max_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight} (K 10)
    56 val smt2_max_weight_index =
    57   Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight_index} (K 200)
    58 val smt2_weight_curve = Unsynchronized.ref (fn x : int => x * x)
    59 
    60 fun smt2_fact_weight ctxt j num_facts =
    61   if Config.get ctxt smt2_weights andalso num_facts >= Config.get ctxt smt2_weight_min_facts then
    62     let
    63       val min = Config.get ctxt smt2_min_weight
    64       val max = Config.get ctxt smt2_max_weight
    65       val max_index = Config.get ctxt smt2_max_weight_index
    66       val curve = !smt2_weight_curve
    67     in
    68       SOME (max - (max - min + 1) * curve (Int.max (0, max_index - j - 1)) div curve max_index)
    69     end
    70   else
    71     NONE
    72 
    73 fun weight_smt2_fact ctxt num_facts ((info, th), j) =
    74   let val thy = Proof_Context.theory_of ctxt in
    75     (info, (smt2_fact_weight ctxt j num_facts, Thm.transfer thy th (* TODO: needed? *)))
    76   end
    77 
    78 (* "SMT2_Failure.Abnormal_Termination" carries the solver's return code. Until these are sorted out
    79    properly in the SMT module, we must interpret these here. *)
    80 val z3_failures =
    81   [(101, OutOfResources),
    82    (103, MalformedInput),
    83    (110, MalformedInput),
    84    (112, TimedOut)]
    85 val unix_failures =
    86   [(138, Crashed),
    87    (139, Crashed)]
    88 val smt2_failures = z3_failures @ unix_failures
    89 
    90 fun failure_of_smt2_failure (SMT2_Failure.Counterexample {is_real_cex, ...}) =
    91     if is_real_cex then Unprovable else GaveUp
    92   | failure_of_smt2_failure SMT2_Failure.Time_Out = TimedOut
    93   | failure_of_smt2_failure (SMT2_Failure.Abnormal_Termination code) =
    94     (case AList.lookup (op =) smt2_failures code of
    95       SOME failure => failure
    96     | NONE => UnknownError ("Abnormal termination with exit code " ^ string_of_int code ^ "."))
    97   | failure_of_smt2_failure SMT2_Failure.Out_Of_Memory = OutOfResources
    98   | failure_of_smt2_failure (SMT2_Failure.Other_Failure s) = UnknownError s
    99 
   100 (* FUDGE *)
   101 val smt2_max_slices = Attrib.setup_config_int @{binding sledgehammer_smt2_max_slices} (K 8)
   102 val smt2_slice_fact_frac =
   103   Attrib.setup_config_real @{binding sledgehammer_smt2_slice_fact_frac} (K 0.667)
   104 val smt2_slice_time_frac =
   105   Attrib.setup_config_real @{binding sledgehammer_smt2_slice_time_frac} (K 0.333)
   106 val smt2_slice_min_secs = Attrib.setup_config_int @{binding sledgehammer_smt2_slice_min_secs} (K 3)
   107 
   108 val is_boring_builtin_typ =
   109   not o exists_subtype (member (op =) [@{typ nat}, @{typ int}, HOLogic.realT])
   110 
   111 fun smt2_filter_loop name ({debug, overlord, max_mono_iters, max_new_mono_instances, timeout, slice,
   112       ...} : params) state goal i =
   113   let
   114     fun repair_context ctxt =
   115       ctxt |> Context.proof_map (SMT2_Config.select_solver name)
   116            |> Config.put SMT2_Config.verbose debug
   117            |> (if overlord then
   118                  Config.put SMT2_Config.debug_files
   119                    (overlord_file_location_of_prover name |> (fn (path, name) => path ^ "/" ^ name))
   120                else
   121                  I)
   122            |> Config.put SMT2_Config.infer_triggers (Config.get ctxt smt2_triggers)
   123            |> not (Config.get ctxt smt2_builtins)
   124               ? (SMT2_Builtin.filter_builtins is_boring_builtin_typ
   125                  #> Config.put SMT2_Systems.z3_extensions false)
   126            |> repair_monomorph_context max_mono_iters default_max_mono_iters max_new_mono_instances
   127                 default_max_new_mono_instances
   128 
   129     val state = Proof.map_context (repair_context) state
   130     val ctxt = Proof.context_of state
   131     val max_slices = if slice then Config.get ctxt smt2_max_slices else 1
   132 
   133     fun do_slice timeout slice outcome0 time_so_far
   134         (weighted_factss as (fact_filter, weighted_facts) :: _) =
   135       let
   136         val timer = Timer.startRealTimer ()
   137         val slice_timeout =
   138           if slice < max_slices then
   139             let val ms = Time.toMilliseconds timeout in
   140               Int.min (ms, Int.max (1000 * Config.get ctxt smt2_slice_min_secs,
   141                 Real.ceil (Config.get ctxt smt2_slice_time_frac * Real.fromInt ms)))
   142               |> Time.fromMilliseconds
   143             end
   144           else
   145             timeout
   146         val num_facts = length weighted_facts
   147         val _ =
   148           if debug then
   149             quote name ^ " slice " ^ string_of_int slice ^ " with " ^ string_of_int num_facts ^
   150             " fact" ^ plural_s num_facts ^ " for " ^ string_of_time slice_timeout
   151             |> Output.urgent_message
   152           else
   153             ()
   154         val birth = Timer.checkRealTimer timer
   155         val _ = if debug then Output.urgent_message "Invoking SMT solver..." else ()
   156 
   157         val filter_result as {outcome, ...} =
   158           SMT2_Solver.smt2_filter ctxt goal weighted_facts i slice_timeout
   159           handle exn =>
   160             if Exn.is_interrupt exn orelse debug then
   161               reraise exn
   162             else
   163               {outcome = SOME (SMT2_Failure.Other_Failure (Runtime.exn_message exn)),
   164                rewrite_rules = [], conjecture_id = ~1, prem_ids = [], helper_ids = [],
   165                fact_ids = [], z3_proof = []}
   166 
   167         val death = Timer.checkRealTimer timer
   168         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
   169         val time_so_far = Time.+ (time_so_far, Time.- (death, birth))
   170         val timeout = Time.- (timeout, Timer.checkRealTimer timer)
   171 
   172         val too_many_facts_perhaps =
   173           (case outcome of
   174             NONE => false
   175           | SOME (SMT2_Failure.Counterexample _) => false
   176           | SOME SMT2_Failure.Time_Out => slice_timeout <> timeout
   177           | SOME (SMT2_Failure.Abnormal_Termination _) => true (* kind of *)
   178           | SOME SMT2_Failure.Out_Of_Memory => true
   179           | SOME (SMT2_Failure.Other_Failure _) => true)
   180       in
   181         if too_many_facts_perhaps andalso slice < max_slices andalso num_facts > 0 andalso
   182            Time.> (timeout, Time.zeroTime) then
   183           let
   184             val new_num_facts =
   185               Real.ceil (Config.get ctxt smt2_slice_fact_frac * Real.fromInt num_facts)
   186             val weighted_factss as (new_fact_filter, _) :: _ =
   187               weighted_factss
   188               |> (fn (x :: xs) => xs @ [x])
   189               |> app_hd (apsnd (take new_num_facts))
   190             val show_filter = fact_filter <> new_fact_filter
   191 
   192             fun num_of_facts fact_filter num_facts =
   193               string_of_int num_facts ^ (if show_filter then " " ^ quote fact_filter else "") ^
   194               " fact" ^ plural_s num_facts
   195 
   196             val _ =
   197               if debug then
   198                 quote name ^ " invoked with " ^
   199                 num_of_facts fact_filter num_facts ^ ": " ^
   200                 string_of_atp_failure (failure_of_smt2_failure (the outcome)) ^
   201                 " Retrying with " ^ num_of_facts new_fact_filter new_num_facts ^
   202                 "..."
   203                 |> Output.urgent_message
   204               else
   205                 ()
   206           in
   207             do_slice timeout (slice + 1) outcome0 time_so_far weighted_factss
   208           end
   209         else
   210           {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
   211            used_from = map (apsnd snd) weighted_facts, run_time = time_so_far}
   212       end
   213   in
   214     do_slice timeout 1 NONE Time.zeroTime
   215   end
   216 
   217 fun run_smt2_solver mode name (params as {debug, verbose, isar_proofs, compress_isar,
   218       try0_isar, smt_proofs, minimize, preplay_timeout, ...})
   219     minimize_command ({state, goal, subgoal, subgoal_count, factss, ...} : prover_problem) =
   220   let
   221     val thy = Proof.theory_of state
   222     val ctxt = Proof.context_of state
   223 
   224     fun weight_facts facts =
   225       let val num_facts = length facts in
   226         map (weight_smt2_fact ctxt num_facts) (facts ~~ (0 upto num_facts - 1))
   227       end
   228 
   229     val weighted_factss = map (apsnd weight_facts) factss
   230     val {outcome, filter_result = {rewrite_rules, conjecture_id, prem_ids, helper_ids, fact_ids,
   231            z3_proof, ...}, used_from, run_time} =
   232       smt2_filter_loop name params state goal subgoal weighted_factss
   233     val used_named_facts = map snd fact_ids
   234     val used_facts = map fst used_named_facts
   235     val outcome = Option.map failure_of_smt2_failure outcome
   236 
   237     val (preplay, message, message_tail) =
   238       (case outcome of
   239         NONE =>
   240         (Lazy.lazy (fn () =>
   241            play_one_line_proof mode debug verbose preplay_timeout used_named_facts state subgoal
   242              SMT2_Method (bunch_of_proof_methods (smt_proofs <> SOME false) false liftingN)),
   243          fn preplay =>
   244             let
   245               val fact_ids =
   246                 map (fn (id, th) => (id, short_thm_name ctxt th)) helper_ids @
   247                 map (fn (id, ((name, _), _)) => (id, name)) fact_ids
   248               val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy rewrite_rules prem_ids conjecture_id
   249                 fact_ids z3_proof
   250               val isar_params =
   251                 K (verbose, (NONE, NONE), preplay_timeout, compress_isar, try0_isar,
   252                    minimize <> SOME false, atp_proof, goal)
   253               val one_line_params =
   254                 (preplay, proof_banner mode name, used_facts,
   255                  choose_minimize_command thy params minimize_command name preplay, subgoal,
   256                  subgoal_count)
   257               val num_chained = length (#facts (Proof.goal state))
   258             in
   259               proof_text ctxt debug isar_proofs smt_proofs isar_params num_chained one_line_params
   260             end,
   261          if verbose then "\nSMT solver real CPU time: " ^ string_of_time run_time ^ "." else "")
   262       | SOME failure =>
   263         (Lazy.value (Metis_Method (NONE, NONE), Play_Failed),
   264          fn _ => string_of_atp_failure failure, ""))
   265   in
   266     {outcome = outcome, used_facts = used_facts, used_from = used_from, run_time = run_time,
   267      preplay = preplay, message = message, message_tail = message_tail}
   268   end
   269 
   270 end;