src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
author blanchet
Thu Mar 13 14:48:20 2014 +0100 (2014-03-13 ago)
changeset 56104 fd6e132ee4fb
parent 56099 bc036c1cf111
child 56128 c106ac2ff76d
permissions -rw-r--r--
correctly reconstruct helper facts (e.g. 'nat_int') in Isar proofs
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
     2     Author:     Fabian Immler, TU Muenchen
     3     Author:     Makarius
     4     Author:     Jasmin Blanchette, TU Muenchen
     5 
     6 SMT solvers as Sledgehammer provers.
     7 *)
     8 
     9 signature SLEDGEHAMMER_PROVER_SMT2 =
    10 sig
    11   type stature = ATP_Problem_Generate.stature
    12   type mode = Sledgehammer_Prover.mode
    13   type prover = Sledgehammer_Prover.prover
    14 
    15   val smt2_builtins : bool Config.T
    16   val smt2_triggers : bool Config.T
    17   val smt2_weights : bool Config.T
    18   val smt2_weight_min_facts : int Config.T
    19   val smt2_min_weight : int Config.T
    20   val smt2_max_weight : int Config.T
    21   val smt2_max_weight_index : int Config.T
    22   val smt2_weight_curve : (int -> int) Unsynchronized.ref
    23   val smt2_max_slices : int Config.T
    24   val smt2_slice_fact_frac : real Config.T
    25   val smt2_slice_time_frac : real Config.T
    26   val smt2_slice_min_secs : int Config.T
    27 
    28   val is_smt2_prover : Proof.context -> string -> bool
    29   val run_smt2_solver : mode -> string -> prover
    30 end;
    31 
    32 structure Sledgehammer_Prover_SMT2 : SLEDGEHAMMER_PROVER_SMT2 =
    33 struct
    34 
    35 open ATP_Util
    36 open ATP_Proof
    37 open ATP_Systems
    38 open ATP_Problem_Generate
    39 open ATP_Proof_Reconstruct
    40 open Sledgehammer_Util
    41 open Sledgehammer_Proof_Methods
    42 open Sledgehammer_Isar
    43 open Sledgehammer_Prover
    44 
    45 val smt2_builtins = Attrib.setup_config_bool @{binding sledgehammer_smt2_builtins} (K true)
    46 val smt2_triggers = Attrib.setup_config_bool @{binding sledgehammer_smt2_triggers} (K true)
    47 val smt2_weights = Attrib.setup_config_bool @{binding sledgehammer_smt2_weights} (K true)
    48 val smt2_weight_min_facts =
    49   Attrib.setup_config_int @{binding sledgehammer_smt2_weight_min_facts} (K 20)
    50 
    51 fun is_smt2_prover ctxt = member (op =) (SMT2_Solver.available_solvers_of ctxt)
    52 
    53 (* FUDGE *)
    54 val smt2_min_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_min_weight} (K 0)
    55 val smt2_max_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight} (K 10)
    56 val smt2_max_weight_index =
    57   Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight_index} (K 200)
    58 val smt2_weight_curve = Unsynchronized.ref (fn x : int => x * x)
    59 
    60 fun smt2_fact_weight ctxt j num_facts =
    61   if Config.get ctxt smt2_weights andalso num_facts >= Config.get ctxt smt2_weight_min_facts then
    62     let
    63       val min = Config.get ctxt smt2_min_weight
    64       val max = Config.get ctxt smt2_max_weight
    65       val max_index = Config.get ctxt smt2_max_weight_index
    66       val curve = !smt2_weight_curve
    67     in
    68       SOME (max - (max - min + 1) * curve (Int.max (0, max_index - j - 1)) div curve max_index)
    69     end
    70   else
    71     NONE
    72 
    73 fun weight_smt2_fact ctxt num_facts ((info, th), j) =
    74   let val thy = Proof_Context.theory_of ctxt in
    75     (info, (smt2_fact_weight ctxt j num_facts, Thm.transfer thy th (* TODO: needed? *)))
    76   end
    77 
    78 (* "SMT2_Failure.Abnormal_Termination" carries the solver's return code. Until these are sorted out
    79    properly in the SMT module, we must interpret these here. *)
    80 val z3_failures =
    81   [(101, OutOfResources),
    82    (103, MalformedInput),
    83    (110, MalformedInput),
    84    (112, TimedOut)]
    85 val unix_failures =
    86   [(138, Crashed),
    87    (139, Crashed)]
    88 val smt2_failures = z3_failures @ unix_failures
    89 
    90 fun failure_of_smt2_failure (SMT2_Failure.Counterexample {is_real_cex, ...}) =
    91     if is_real_cex then Unprovable else GaveUp
    92   | failure_of_smt2_failure SMT2_Failure.Time_Out = TimedOut
    93   | failure_of_smt2_failure (SMT2_Failure.Abnormal_Termination code) =
    94     (case AList.lookup (op =) smt2_failures code of
    95       SOME failure => failure
    96     | NONE => UnknownError ("Abnormal termination with exit code " ^ string_of_int code ^ "."))
    97   | failure_of_smt2_failure SMT2_Failure.Out_Of_Memory = OutOfResources
    98   | failure_of_smt2_failure (SMT2_Failure.Other_Failure s) = UnknownError s
    99 
   100 (* FUDGE *)
   101 val smt2_max_slices = Attrib.setup_config_int @{binding sledgehammer_smt2_max_slices} (K 8)
   102 val smt2_slice_fact_frac =
   103   Attrib.setup_config_real @{binding sledgehammer_smt2_slice_fact_frac} (K 0.667)
   104 val smt2_slice_time_frac =
   105   Attrib.setup_config_real @{binding sledgehammer_smt2_slice_time_frac} (K 0.333)
   106 val smt2_slice_min_secs = Attrib.setup_config_int @{binding sledgehammer_smt2_slice_min_secs} (K 3)
   107 
   108 val is_boring_builtin_typ =
   109   not o exists_subtype (member (op =) [@{typ nat}, @{typ int}, HOLogic.realT])
   110 
   111 fun smt2_filter_loop name ({debug, overlord, max_mono_iters, max_new_mono_instances, timeout, slice,
   112       ...} : params) state goal i =
   113   let
   114     fun repair_context ctxt =
   115       ctxt |> Context.proof_map (SMT2_Config.select_solver name)
   116            |> Config.put SMT2_Config.verbose debug
   117            |> (if overlord then
   118                  Config.put SMT2_Config.debug_files
   119                    (overlord_file_location_of_prover name |> (fn (path, name) => path ^ "/" ^ name))
   120                else
   121                  I)
   122            |> Config.put SMT2_Config.infer_triggers (Config.get ctxt smt2_triggers)
   123            |> not (Config.get ctxt smt2_builtins)
   124               ? (SMT2_Builtin.filter_builtins is_boring_builtin_typ
   125                  #> Config.put SMT2_Systems.z3_extensions false)
   126            |> repair_monomorph_context max_mono_iters default_max_mono_iters max_new_mono_instances
   127                 default_max_new_mono_instances
   128 
   129     val state = Proof.map_context (repair_context) state
   130     val ctxt = Proof.context_of state
   131     val max_slices = if slice then Config.get ctxt smt2_max_slices else 1
   132 
   133     fun do_slice timeout slice outcome0 time_so_far
   134         (weighted_factss as (fact_filter, weighted_facts) :: _) =
   135       let
   136         val timer = Timer.startRealTimer ()
   137         val slice_timeout =
   138           if slice < max_slices then
   139             let val ms = Time.toMilliseconds timeout in
   140               Int.min (ms, Int.max (1000 * Config.get ctxt smt2_slice_min_secs,
   141                 Real.ceil (Config.get ctxt smt2_slice_time_frac * Real.fromInt ms)))
   142               |> Time.fromMilliseconds
   143             end
   144           else
   145             timeout
   146         val num_facts = length weighted_facts
   147         val _ =
   148           if debug then
   149             quote name ^ " slice " ^ string_of_int slice ^ " with " ^ string_of_int num_facts ^
   150             " fact" ^ plural_s num_facts ^ " for " ^ string_of_time slice_timeout
   151             |> Output.urgent_message
   152           else
   153             ()
   154         val birth = Timer.checkRealTimer timer
   155         val _ = if debug then Output.urgent_message "Invoking SMT solver..." else ()
   156 
   157         val filter_result as {outcome, ...} =
   158           SMT2_Solver.smt2_filter ctxt goal weighted_facts i slice_timeout
   159           handle exn =>
   160             if Exn.is_interrupt exn orelse debug then
   161               reraise exn
   162             else
   163               {outcome = SOME (SMT2_Failure.Other_Failure (ML_Compiler.exn_message exn)),
   164                conjecture_id = ~1, helper_ids = [], fact_ids = [], z3_proof = []}
   165 
   166         val death = Timer.checkRealTimer timer
   167         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
   168         val time_so_far = Time.+ (time_so_far, Time.- (death, birth))
   169         val timeout = Time.- (timeout, Timer.checkRealTimer timer)
   170 
   171         val too_many_facts_perhaps =
   172           (case outcome of
   173             NONE => false
   174           | SOME (SMT2_Failure.Counterexample _) => false
   175           | SOME SMT2_Failure.Time_Out => slice_timeout <> timeout
   176           | SOME (SMT2_Failure.Abnormal_Termination _) => true (* kind of *)
   177           | SOME SMT2_Failure.Out_Of_Memory => true
   178           | SOME (SMT2_Failure.Other_Failure _) => true)
   179       in
   180         if too_many_facts_perhaps andalso slice < max_slices andalso num_facts > 0 andalso
   181            Time.> (timeout, Time.zeroTime) then
   182           let
   183             val new_num_facts =
   184               Real.ceil (Config.get ctxt smt2_slice_fact_frac * Real.fromInt num_facts)
   185             val weighted_factss as (new_fact_filter, _) :: _ =
   186               weighted_factss
   187               |> (fn (x :: xs) => xs @ [x])
   188               |> app_hd (apsnd (take new_num_facts))
   189             val show_filter = fact_filter <> new_fact_filter
   190 
   191             fun num_of_facts fact_filter num_facts =
   192               string_of_int num_facts ^ (if show_filter then " " ^ quote fact_filter else "") ^
   193               " fact" ^ plural_s num_facts
   194 
   195             val _ =
   196               if debug then
   197                 quote name ^ " invoked with " ^
   198                 num_of_facts fact_filter num_facts ^ ": " ^
   199                 string_of_atp_failure (failure_of_smt2_failure (the outcome)) ^
   200                 " Retrying with " ^ num_of_facts new_fact_filter new_num_facts ^
   201                 "..."
   202                 |> Output.urgent_message
   203               else
   204                 ()
   205           in
   206             do_slice timeout (slice + 1) outcome0 time_so_far weighted_factss
   207           end
   208         else
   209           {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
   210            used_from = map (apsnd snd) weighted_facts, run_time = time_so_far}
   211       end
   212   in
   213     do_slice timeout 1 NONE Time.zeroTime
   214   end
   215 
   216 fun run_smt2_solver mode name (params as {debug, verbose, isar_proofs, compress_isar,
   217       try0_isar, smt_proofs, minimize, preplay_timeout, ...})
   218     minimize_command ({state, goal, subgoal, subgoal_count, factss, ...} : prover_problem) =
   219   let
   220     val thy = Proof.theory_of state
   221     val ctxt = Proof.context_of state
   222 
   223     fun weight_facts facts =
   224       let val num_facts = length facts in
   225         map (weight_smt2_fact ctxt num_facts) (facts ~~ (0 upto num_facts - 1))
   226       end
   227 
   228     val weighted_factss = map (apsnd weight_facts) factss
   229     val {outcome, filter_result = {conjecture_id, helper_ids, fact_ids, z3_proof, ...},
   230          used_from, run_time} = smt2_filter_loop name params state goal subgoal weighted_factss
   231     val used_named_facts = map snd fact_ids
   232     val used_facts = map fst used_named_facts
   233     val outcome = Option.map failure_of_smt2_failure outcome
   234 
   235     val (preplay, message, message_tail) =
   236       (case outcome of
   237         NONE =>
   238         (Lazy.lazy (fn () =>
   239            play_one_line_proof mode debug verbose preplay_timeout used_named_facts state subgoal
   240              SMT2_Method (bunch_of_proof_methods (smt_proofs <> SOME false) false liftingN)),
   241          fn preplay =>
   242             let
   243               val fact_ids =
   244                 map (fn (id, th) => (id, short_thm_name ctxt th)) helper_ids @
   245                 map (fn (id, ((name, _), _)) => (id, name)) fact_ids
   246               val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy conjecture_id fact_ids z3_proof
   247               val isar_params =
   248                 K (verbose, (NONE, NONE), preplay_timeout, compress_isar, try0_isar,
   249                    minimize <> SOME false, atp_proof, goal)
   250               val one_line_params =
   251                 (preplay, proof_banner mode name, used_facts,
   252                  choose_minimize_command thy params minimize_command name preplay, subgoal,
   253                  subgoal_count)
   254               val num_chained = length (#facts (Proof.goal state))
   255             in
   256               proof_text ctxt debug isar_proofs smt_proofs isar_params num_chained one_line_params
   257             end,
   258          if verbose then "\nSMT solver real CPU time: " ^ string_of_time run_time ^ "." else "")
   259       | SOME failure =>
   260         (Lazy.value (Metis_Method (NONE, NONE), Play_Failed),
   261          fn _ => string_of_atp_failure failure, ""))
   262   in
   263     {outcome = outcome, used_facts = used_facts, used_from = used_from, run_time = run_time,
   264      preplay = preplay, message = message, message_tail = message_tail}
   265   end
   266 
   267 end;