src/HOL/Tools/Sledgehammer/sledgehammer_run.ML
author blanchet
Wed Dec 12 21:48:29 2012 +0100 (2012-12-12 ago)
changeset 50510 7e4f2f8d9b50
parent 50201 c26369c9eda6
child 50557 31313171deb5
permissions -rw-r--r--
export a pair of ML functions
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_run.ML
     2     Author:     Fabian Immler, TU Muenchen
     3     Author:     Makarius
     4     Author:     Jasmin Blanchette, TU Muenchen
     5 
     6 Sledgehammer's heart.
     7 *)
     8 
     9 signature SLEDGEHAMMER_RUN =
    10 sig
    11   type fact_override = Sledgehammer_Fact.fact_override
    12   type minimize_command = Sledgehammer_Reconstruct.minimize_command
    13   type mode = Sledgehammer_Provers.mode
    14   type params = Sledgehammer_Provers.params
    15 
    16   val someN : string
    17   val noneN : string
    18   val timeoutN : string
    19   val unknownN : string
    20   val run_sledgehammer :
    21     params -> mode -> int -> fact_override
    22     -> ((string * string list) list -> string -> minimize_command)
    23     -> Proof.state -> bool * (string * Proof.state)
    24 end;
    25 
    26 structure Sledgehammer_Run : SLEDGEHAMMER_RUN =
    27 struct
    28 
    29 open ATP_Util
    30 open ATP_Problem_Generate
    31 open ATP_Proof_Reconstruct
    32 open Sledgehammer_Util
    33 open Sledgehammer_Fact
    34 open Sledgehammer_Provers
    35 open Sledgehammer_Minimize
    36 open Sledgehammer_MaSh
    37 
    38 val someN = "some"
    39 val noneN = "none"
    40 val timeoutN = "timeout"
    41 val unknownN = "unknown"
    42 
    43 val ordered_outcome_codes = [someN, unknownN, timeoutN, noneN]
    44 
    45 fun max_outcome_code codes =
    46   NONE
    47   |> fold (fn candidate =>
    48               fn accum as SOME _ => accum
    49                | NONE => if member (op =) codes candidate then SOME candidate
    50                          else NONE)
    51           ordered_outcome_codes
    52   |> the_default unknownN
    53 
    54 fun prover_description ctxt ({verbose, blocking, ...} : params) name num_facts i
    55                        n goal =
    56   (quote name,
    57    (if verbose then
    58       " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts
    59     else
    60       "") ^
    61    " on " ^ (if n = 1 then "goal" else "subgoal " ^ string_of_int i) ^
    62    (if blocking then "."
    63     else "\n" ^ Syntax.string_of_term ctxt (Thm.term_of (Thm.cprem_of goal i))))
    64 
    65 fun launch_prover (params as {debug, verbose, blocking, max_facts, slice,
    66                               timeout, expect, ...})
    67                   mode minimize_command only learn
    68                   {state, goal, subgoal, subgoal_count, facts} name =
    69   let
    70     val ctxt = Proof.context_of state
    71     val hard_timeout = Time.+ (timeout, timeout)
    72     val birth_time = Time.now ()
    73     val death_time = Time.+ (birth_time, hard_timeout)
    74     val max_facts =
    75       max_facts |> the_default (default_max_facts_for_prover ctxt slice name)
    76     val num_facts = length facts |> not only ? Integer.min max_facts
    77     fun desc () =
    78       prover_description ctxt params name num_facts subgoal subgoal_count goal
    79     val problem =
    80       {state = state, goal = goal, subgoal = subgoal,
    81        subgoal_count = subgoal_count,
    82        facts = facts
    83                |> not (Sledgehammer_Provers.is_ho_atp ctxt name)
    84                   ? filter_out (curry (op =) Induction o snd o snd o fst
    85                                 o untranslated_fact)
    86                |> take num_facts}
    87     fun print_used_facts used_facts =
    88       tag_list 1 facts
    89       |> map (fn (j, fact) => fact |> untranslated_fact |> apsnd (K j))
    90       |> filter_used_facts false used_facts
    91       |> map (fn ((name, _), j) => name ^ "@" ^ string_of_int j)
    92       |> commas
    93       |> enclose ("Fact" ^ plural_s (length facts) ^ " in " ^ quote name ^
    94                   " proof (of " ^ string_of_int (length facts) ^ "): ") "."
    95       |> Output.urgent_message
    96     fun really_go () =
    97       problem
    98       |> get_minimizing_prover ctxt mode learn name params minimize_command
    99       |> verbose ? tap (fn {outcome = NONE, used_facts as _ :: _, ...} =>
   100                            print_used_facts used_facts
   101                          | _ => ())
   102       |> (fn {outcome, preplay, message, message_tail, ...} =>
   103              (if outcome = SOME ATP_Proof.TimedOut then timeoutN
   104               else if is_some outcome then noneN
   105               else someN, fn () => message (preplay ()) ^ message_tail))
   106     fun go () =
   107       let
   108         val (outcome_code, message) =
   109           if debug then
   110             really_go ()
   111           else
   112             (really_go ()
   113              handle ERROR msg => (unknownN, fn () => "Error: " ^ msg ^ "\n")
   114                   | exn =>
   115                     if Exn.is_interrupt exn then
   116                       reraise exn
   117                     else
   118                       (unknownN, fn () => "Internal error:\n" ^
   119                                           ML_Compiler.exn_message exn ^ "\n"))
   120         val _ =
   121           (* The "expect" argument is deliberately ignored if the prover is
   122              missing so that the "Metis_Examples" can be processed on any
   123              machine. *)
   124           if expect = "" orelse outcome_code = expect orelse
   125              not (is_prover_installed ctxt name) then
   126             ()
   127           else if blocking then
   128             error ("Unexpected outcome: " ^ quote outcome_code ^ ".")
   129           else
   130             warning ("Unexpected outcome: " ^ quote outcome_code ^ ".");
   131       in (outcome_code, message) end
   132   in
   133     if mode = Auto_Try then
   134       let val (outcome_code, message) = TimeLimit.timeLimit timeout go () in
   135         (outcome_code,
   136          state
   137          |> outcome_code = someN
   138             ? Proof.goal_message (fn () =>
   139                   [Pretty.str "",
   140                    Pretty.mark Markup.intensify (Pretty.str (message ()))]
   141                   |> Pretty.chunks))
   142       end
   143     else if blocking then
   144       let
   145         val (outcome_code, message) = TimeLimit.timeLimit hard_timeout go ()
   146       in
   147         (if outcome_code = someN orelse mode = Normal then
   148            quote name ^ ": " ^ message ()
   149          else
   150            "")
   151         |> Async_Manager.break_into_chunks
   152         |> List.app Output.urgent_message;
   153         (outcome_code, state)
   154       end
   155     else
   156       (Async_Manager.launch SledgehammerN birth_time death_time (desc ())
   157                             ((fn (outcome_code, message) =>
   158                                  (verbose orelse outcome_code = someN,
   159                                   message ())) o go);
   160        (unknownN, state))
   161   end
   162 
   163 fun class_of_smt_solver ctxt name =
   164   ctxt |> select_smt_solver name
   165        |> SMT_Config.solver_class_of |> SMT_Utils.string_of_class
   166 
   167 val auto_try_max_facts_divisor = 2 (* FUDGE *)
   168 
   169 fun run_sledgehammer (params as {debug, verbose, blocking, provers, max_facts,
   170                                  slice, ...})
   171         mode i (fact_override as {only, ...}) minimize_command state =
   172   if null provers then
   173     error "No prover is set."
   174   else case subgoal_count state of
   175     0 => (Output.urgent_message "No subgoal!"; (false, (noneN, state)))
   176   | n =>
   177     let
   178       val _ = Proof.assert_backward state
   179       val print = if mode = Normal then Output.urgent_message else K ()
   180       val state =
   181         state |> Proof.map_context (Config.put SMT_Config.verbose debug)
   182       val ctxt = Proof.context_of state
   183       val {facts = chained, goal, ...} = Proof.goal state
   184       val (_, hyp_ts, concl_t) = strip_subgoal ctxt goal i
   185       val ho_atp = exists (Sledgehammer_Provers.is_ho_atp ctxt) provers
   186       val reserved = reserved_isar_keyword_table ()
   187       val css = clasimpset_rule_table_of ctxt
   188       val all_facts =
   189         nearly_all_facts ctxt ho_atp fact_override reserved css chained hyp_ts
   190                          concl_t
   191       val _ = () |> not blocking ? kill_provers
   192       val _ = case find_first (not o is_prover_supported ctxt) provers of
   193                 SOME name => error ("No such prover: " ^ name ^ ".")
   194               | NONE => ()
   195       val _ = print "Sledgehammering..."
   196       val (smts, (ueq_atps, full_atps)) =
   197         provers |> List.partition (is_smt_prover ctxt)
   198                 ||> List.partition (is_unit_equational_atp ctxt)
   199       fun launch_provers state get_facts translate provers =
   200         let
   201           val facts = get_facts ()
   202           val num_facts = length facts
   203           val facts = facts ~~ (0 upto num_facts - 1)
   204                       |> map (translate num_facts)
   205           val problem =
   206             {state = state, goal = goal, subgoal = i, subgoal_count = n,
   207              facts = facts}
   208           fun learn prover =
   209             mash_learn_proof ctxt params prover (prop_of goal) all_facts
   210           val launch = launch_prover params mode minimize_command only learn
   211         in
   212           if mode = Auto_Try orelse mode = Try then
   213             (unknownN, state)
   214             |> fold (fn prover => fn accum as (outcome_code, _) =>
   215                         if outcome_code = someN then accum
   216                         else launch problem prover)
   217                     provers
   218           else
   219             provers
   220             |> (if blocking then Par_List.map else map)
   221                    (launch problem #> fst)
   222             |> max_outcome_code |> rpair state
   223         end
   224       fun get_facts label is_appropriate_prop provers =
   225         let
   226           val max_max_facts =
   227             case max_facts of
   228               SOME n => n
   229             | NONE =>
   230               0 |> fold (Integer.max o default_max_facts_for_prover ctxt slice)
   231                         provers
   232                 |> mode = Auto_Try ? (fn n => n div auto_try_max_facts_divisor)
   233         in
   234           all_facts
   235           |> (case is_appropriate_prop of
   236                 SOME is_app => filter (is_app o prop_of o snd)
   237               | NONE => I)
   238           |> relevant_facts ctxt params (hd provers) max_max_facts fact_override
   239                             hyp_ts concl_t
   240           |> map (apfst (apfst (fn name => name ())))
   241           |> tap (fn facts =>
   242                      if verbose then
   243                        label ^ plural_s (length provers) ^ ": " ^
   244                        (if null facts then
   245                           "Found no relevant facts."
   246                         else
   247                           "Including (up to) " ^ string_of_int (length facts) ^
   248                           " relevant fact" ^ plural_s (length facts) ^ ":\n" ^
   249                           (facts |> map (fst o fst) |> space_implode " ") ^ ".")
   250                        |> print
   251                      else
   252                        ())
   253         end
   254       fun launch_atps label is_appropriate_prop atps accum =
   255         if null atps then
   256           accum
   257         else if is_some is_appropriate_prop andalso
   258                 not (the is_appropriate_prop concl_t) then
   259           (if verbose orelse length atps = length provers then
   260              "Goal outside the scope of " ^
   261              space_implode " " (serial_commas "and" (map quote atps)) ^ "."
   262              |> Output.urgent_message
   263            else
   264              ();
   265            accum)
   266         else
   267           launch_provers state (get_facts label is_appropriate_prop o K atps)
   268                          (K (Untranslated_Fact o fst)) atps
   269       fun launch_smts accum =
   270         if null smts then
   271           accum
   272         else
   273           let
   274             val facts = get_facts "SMT solver" NONE smts
   275             val weight = SMT_Weighted_Fact oo weight_smt_fact ctxt
   276           in
   277             smts |> map (`(class_of_smt_solver ctxt))
   278                  |> AList.group (op =)
   279                  |> map (snd #> launch_provers state (K facts) weight #> fst)
   280                  |> max_outcome_code |> rpair state
   281           end
   282       val launch_full_atps = launch_atps "ATP" NONE full_atps
   283       val launch_ueq_atps =
   284         launch_atps "Unit equational provers" (SOME is_unit_equality) ueq_atps
   285       fun launch_atps_and_smt_solvers () =
   286         [launch_full_atps, launch_smts, launch_ueq_atps]
   287         |> Par_List.map (fn f => ignore (f (unknownN, state)))
   288         handle ERROR msg => (print ("Error: " ^ msg); error msg)
   289       fun maybe f (accum as (outcome_code, _)) =
   290         accum |> (mode = Normal orelse outcome_code <> someN) ? f
   291     in
   292       (unknownN, state)
   293       |> (if blocking then
   294             launch_full_atps
   295             #> mode <> Auto_Try ? (maybe launch_ueq_atps #> maybe launch_smts)
   296           else
   297             (fn p => Future.fork (tap launch_atps_and_smt_solvers) |> K p))
   298       handle TimeLimit.TimeOut =>
   299              (print "Sledgehammer ran out of time."; (unknownN, state))
   300     end
   301     |> `(fn (outcome_code, _) => outcome_code = someN)
   302 
   303 end;