src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
changeset 57165 7b1bf424ec5f
parent 57159 24cbdebba35a
child 57243 8c261f0a9b32
equal deleted inserted replaced
57164:eb5f27ec3987 57165:7b1bf424ec5f
    12   type mode = Sledgehammer_Prover.mode
    12   type mode = Sledgehammer_Prover.mode
    13   type prover = Sledgehammer_Prover.prover
    13   type prover = Sledgehammer_Prover.prover
    14 
    14 
    15   val smt2_builtins : bool Config.T
    15   val smt2_builtins : bool Config.T
    16   val smt2_triggers : bool Config.T
    16   val smt2_triggers : bool Config.T
    17   val smt2_weights : bool Config.T
       
    18   val smt2_weight_min_facts : int Config.T
       
    19   val smt2_min_weight : int Config.T
       
    20   val smt2_max_weight : int Config.T
       
    21   val smt2_max_weight_index : int Config.T
       
    22   val smt2_weight_curve : (int -> int) Unsynchronized.ref
       
    23   val smt2_max_slices : int Config.T
    17   val smt2_max_slices : int Config.T
    24   val smt2_slice_fact_frac : real Config.T
    18   val smt2_slice_fact_frac : real Config.T
    25   val smt2_slice_time_frac : real Config.T
    19   val smt2_slice_time_frac : real Config.T
    26   val smt2_slice_min_secs : int Config.T
    20   val smt2_slice_min_secs : int Config.T
    27 
    21 
    42 open Sledgehammer_Isar
    36 open Sledgehammer_Isar
    43 open Sledgehammer_Prover
    37 open Sledgehammer_Prover
    44 
    38 
    45 val smt2_builtins = Attrib.setup_config_bool @{binding sledgehammer_smt2_builtins} (K true)
    39 val smt2_builtins = Attrib.setup_config_bool @{binding sledgehammer_smt2_builtins} (K true)
    46 val smt2_triggers = Attrib.setup_config_bool @{binding sledgehammer_smt2_triggers} (K true)
    40 val smt2_triggers = Attrib.setup_config_bool @{binding sledgehammer_smt2_triggers} (K true)
    47 val smt2_weights = Attrib.setup_config_bool @{binding sledgehammer_smt2_weights} (K true)
       
    48 val smt2_weight_min_facts =
       
    49   Attrib.setup_config_int @{binding sledgehammer_smt2_weight_min_facts} (K 20)
       
    50 
    41 
    51 val is_smt2_prover = member (op =) o SMT2_Config.available_solvers_of
    42 val is_smt2_prover = member (op =) o SMT2_Config.available_solvers_of
    52 
       
    53 (* FUDGE *)
       
    54 val smt2_min_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_min_weight} (K 0)
       
    55 val smt2_max_weight = Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight} (K 10)
       
    56 val smt2_max_weight_index =
       
    57   Attrib.setup_config_int @{binding sledgehammer_smt2_max_weight_index} (K 200)
       
    58 val smt2_weight_curve = Unsynchronized.ref (fn x : int => x * x)
       
    59 
       
    60 fun smt2_fact_weight ctxt j num_facts =
       
    61   if Config.get ctxt smt2_weights andalso num_facts >= Config.get ctxt smt2_weight_min_facts then
       
    62     let
       
    63       val min = Config.get ctxt smt2_min_weight
       
    64       val max = Config.get ctxt smt2_max_weight
       
    65       val max_index = Config.get ctxt smt2_max_weight_index
       
    66       val curve = !smt2_weight_curve
       
    67     in
       
    68       SOME (max - (max - min + 1) * curve (Int.max (0, max_index - j - 1)) div curve max_index)
       
    69     end
       
    70   else
       
    71     NONE
       
    72 
       
    73 fun weight_smt2_fact ctxt num_facts ((info, th), j) =
       
    74   (info, (smt2_fact_weight ctxt j num_facts, th))
       
    75 
    43 
    76 (* "SMT2_Failure.Abnormal_Termination" carries the solver's return code. Until these are sorted out
    44 (* "SMT2_Failure.Abnormal_Termination" carries the solver's return code. Until these are sorted out
    77    properly in the SMT module, we must interpret these here. *)
    45    properly in the SMT module, we must interpret these here. *)
    78 val z3_failures =
    46 val z3_failures =
    79   [(101, OutOfResources),
    47   [(101, OutOfResources),
   126 
    94 
   127     val state = Proof.map_context (repair_context) state
    95     val state = Proof.map_context (repair_context) state
   128     val ctxt = Proof.context_of state
    96     val ctxt = Proof.context_of state
   129     val max_slices = if slice then Config.get ctxt smt2_max_slices else 1
    97     val max_slices = if slice then Config.get ctxt smt2_max_slices else 1
   130 
    98 
   131     fun do_slice timeout slice outcome0 time_so_far
    99     fun do_slice timeout slice outcome0 time_so_far (factss as (fact_filter, facts) :: _) =
   132         (weighted_factss as (fact_filter, weighted_facts) :: _) =
       
   133       let
   100       let
   134         val timer = Timer.startRealTimer ()
   101         val timer = Timer.startRealTimer ()
   135         val slice_timeout =
   102         val slice_timeout =
   136           if slice < max_slices then
   103           if slice < max_slices then
   137             let val ms = Time.toMilliseconds timeout in
   104             let val ms = Time.toMilliseconds timeout in
   139                 Real.ceil (Config.get ctxt smt2_slice_time_frac * Real.fromInt ms)))
   106                 Real.ceil (Config.get ctxt smt2_slice_time_frac * Real.fromInt ms)))
   140               |> Time.fromMilliseconds
   107               |> Time.fromMilliseconds
   141             end
   108             end
   142           else
   109           else
   143             timeout
   110             timeout
   144         val num_facts = length weighted_facts
   111         val num_facts = length facts
   145         val _ =
   112         val _ =
   146           if debug then
   113           if debug then
   147             quote name ^ " slice " ^ string_of_int slice ^ " with " ^ string_of_int num_facts ^
   114             quote name ^ " slice " ^ string_of_int slice ^ " with " ^ string_of_int num_facts ^
   148             " fact" ^ plural_s num_facts ^ " for " ^ string_of_time slice_timeout
   115             " fact" ^ plural_s num_facts ^ " for " ^ string_of_time slice_timeout
   149             |> Output.urgent_message
   116             |> Output.urgent_message
   150           else
   117           else
   151             ()
   118             ()
   152         val birth = Timer.checkRealTimer timer
   119         val birth = Timer.checkRealTimer timer
   153 
   120 
   154         val filter_result as {outcome, ...} =
   121         val filter_result as {outcome, ...} =
   155           SMT2_Solver.smt2_filter ctxt goal weighted_facts i slice_timeout
   122           SMT2_Solver.smt2_filter ctxt goal facts i slice_timeout
   156           handle exn =>
   123           handle exn =>
   157             if Exn.is_interrupt exn orelse debug then
   124             if Exn.is_interrupt exn orelse debug then
   158               reraise exn
   125               reraise exn
   159             else
   126             else
   160               {outcome = SOME (SMT2_Failure.Other_Failure (Runtime.exn_message exn)),
   127               {outcome = SOME (SMT2_Failure.Other_Failure (Runtime.exn_message exn)),
   177         if too_many_facts_perhaps andalso slice < max_slices andalso num_facts > 0 andalso
   144         if too_many_facts_perhaps andalso slice < max_slices andalso num_facts > 0 andalso
   178            Time.> (timeout, Time.zeroTime) then
   145            Time.> (timeout, Time.zeroTime) then
   179           let
   146           let
   180             val new_num_facts =
   147             val new_num_facts =
   181               Real.ceil (Config.get ctxt smt2_slice_fact_frac * Real.fromInt num_facts)
   148               Real.ceil (Config.get ctxt smt2_slice_fact_frac * Real.fromInt num_facts)
   182             val weighted_factss as (new_fact_filter, _) :: _ =
   149             val factss as (new_fact_filter, _) :: _ =
   183               weighted_factss
   150               factss
   184               |> (fn (x :: xs) => xs @ [x])
   151               |> (fn (x :: xs) => xs @ [x])
   185               |> app_hd (apsnd (take new_num_facts))
   152               |> app_hd (apsnd (take new_num_facts))
   186             val show_filter = fact_filter <> new_fact_filter
   153             val show_filter = fact_filter <> new_fact_filter
   187 
   154 
   188             fun num_of_facts fact_filter num_facts =
   155             fun num_of_facts fact_filter num_facts =
   198                 "..."
   165                 "..."
   199                 |> Output.urgent_message
   166                 |> Output.urgent_message
   200               else
   167               else
   201                 ()
   168                 ()
   202           in
   169           in
   203             do_slice timeout (slice + 1) outcome0 time_so_far weighted_factss
   170             do_slice timeout (slice + 1) outcome0 time_so_far factss
   204           end
   171           end
   205         else
   172         else
   206           {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
   173           {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
   207            used_from = map (apsnd snd) weighted_facts, run_time = time_so_far}
   174            used_from = facts, run_time = time_so_far}
   208       end
   175       end
   209   in
   176   in
   210     do_slice timeout 1 NONE Time.zeroTime
   177     do_slice timeout 1 NONE Time.zeroTime
   211   end
   178   end
   212 
   179 
   215     minimize_command ({state, goal, subgoal, subgoal_count, factss, ...} : prover_problem) =
   182     minimize_command ({state, goal, subgoal, subgoal_count, factss, ...} : prover_problem) =
   216   let
   183   let
   217     val thy = Proof.theory_of state
   184     val thy = Proof.theory_of state
   218     val ctxt = Proof.context_of state
   185     val ctxt = Proof.context_of state
   219 
   186 
   220     val (_, hyp_ts, concl_t) = strip_subgoal goal subgoal ctxt
       
   221 
       
   222     fun weight_facts facts =
       
   223       let val num_facts = length facts in
       
   224         map (weight_smt2_fact ctxt num_facts) (facts ~~ (0 upto num_facts - 1))
       
   225       end
       
   226 
       
   227     val weighted_factss = map (apsnd weight_facts) factss
       
   228     val {outcome, filter_result = {fact_ids, atp_proof, ...}, used_from, run_time} =
   187     val {outcome, filter_result = {fact_ids, atp_proof, ...}, used_from, run_time} =
   229       smt2_filter_loop name params state goal subgoal weighted_factss
   188       smt2_filter_loop name params state goal subgoal factss
   230     val used_named_facts = map snd fact_ids
   189     val used_named_facts = map snd fact_ids
   231     val used_facts = map fst used_named_facts
   190     val used_facts = map fst used_named_facts
   232     val outcome = Option.map failure_of_smt2_failure outcome
   191     val outcome = Option.map failure_of_smt2_failure outcome
   233 
   192 
   234     val (preplay, message, message_tail) =
   193     val (preplay, message, message_tail) =