src/HOL/TPTP/mash_eval.ML
author blanchet
Sun Dec 16 14:19:08 2012 +0100 (2012-12-16)
changeset 50563 3a4785d64ecb
parent 50562 0a7c7e121bd8
child 50587 bd6582be1562
permissions -rw-r--r--
escape nicknames
     1 (*  Title:      HOL/TPTP/mash_eval.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2012
     4 
     5 Evaluate proof suggestions from MaSh (Machine-learning for Sledgehammer).
     6 *)
     7 
     8 signature MASH_EVAL =
     9 sig
    10   type params = Sledgehammer_Provers.params
    11 
    12   val evaluate_mash_suggestions :
    13     Proof.context -> params -> int * int option -> string option -> string
    14     -> string -> unit
    15 end;
    16 
    17 structure MaSh_Eval : MASH_EVAL =
    18 struct
    19 
    20 open Sledgehammer_Util
    21 open Sledgehammer_Fact
    22 open Sledgehammer_MePo
    23 open Sledgehammer_MaSh
    24 open Sledgehammer_Provers
    25 open Sledgehammer_Isar
    26 
    27 val MePoN = "MePo"
    28 val MaShN = "MaSh"
    29 val MeShN = "MeSh"
    30 val IsarN = "Isar"
    31 
    32 fun in_range (from, to) j =
    33   j >= from andalso (to = NONE orelse j <= the to)
    34 
    35 fun evaluate_mash_suggestions ctxt params range prob_dir_name sugg_file_name
    36                               report_file_name =
    37   let
    38     val report_path = report_file_name |> Path.explode
    39     val _ = File.write report_path ""
    40     fun print s = (tracing s; File.append report_path (s ^ "\n"))
    41     val {provers, max_facts, slice, type_enc, lam_trans, timeout, ...} =
    42       default_params ctxt []
    43     val prover = hd provers
    44     val slack_max_facts = generous_max_facts (the max_facts)
    45     val sugg_path = sugg_file_name |> Path.explode
    46     val lines = sugg_path |> File.read_lines
    47     val css = clasimpset_rule_table_of ctxt
    48     val facts = all_facts ctxt true false Symtab.empty [] [] css
    49     val all_names = build_all_names nickname_of facts
    50     val mepo_ok = Unsynchronized.ref 0
    51     val mash_ok = Unsynchronized.ref 0
    52     val mesh_ok = Unsynchronized.ref 0
    53     val isar_ok = Unsynchronized.ref 0
    54     fun with_index facts s = (find_index (curry (op =) s) facts + 1, s)
    55     fun index_string (j, s) = s ^ "@" ^ string_of_int j
    56     fun str_of_res label facts ({outcome, run_time, used_facts, ...}
    57                                 : prover_result) =
    58       let val facts = facts |> map (fn ((name, _), _) => name ()) in
    59         "  " ^ label ^ ": " ^
    60         (if is_none outcome then
    61            "Success (" ^ ATP_Util.string_from_time run_time ^ "): " ^
    62            (used_facts |> map (with_index facts o fst)
    63                        |> sort (int_ord o pairself fst)
    64                        |> map index_string
    65                        |> space_implode " ") ^
    66            (if length facts < the max_facts then
    67               " (of " ^ string_of_int (length facts) ^ ")"
    68             else
    69               "")
    70          else
    71            "Failure: " ^
    72            (facts |> take (the max_facts) |> tag_list 1
    73                   |> map index_string
    74                   |> space_implode " "))
    75       end
    76     fun solve_goal (j, line) =
    77       if in_range range j then
    78         let
    79           val (name, suggs) = extract_query line
    80           val th =
    81             case find_first (fn (_, th) => nickname_of th = name) facts of
    82               SOME (_, th) => th
    83             | NONE => error ("No fact called \"" ^ name ^ "\".")
    84           val goal = goal_of_thm (Proof_Context.theory_of ctxt) th
    85           val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
    86           val isar_deps = isar_dependencies_of all_names th |> these
    87           val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
    88           val mepo_facts =
    89             mepo_suggested_facts ctxt params prover slack_max_facts NONE hyp_ts
    90                 concl_t facts
    91             |> weight_mepo_facts
    92           val (mash_facts, mash_unks) =
    93             find_mash_suggestions slack_max_facts suggs facts [] []
    94             |>> weight_mash_facts
    95           val mess = [(0.5, (mepo_facts, [])), (0.5, (mash_facts, mash_unks))]
    96           val mesh_facts = mesh_facts slack_max_facts mess
    97           val isar_facts =
    98             find_suggested_facts (map (rpair 1.0) isar_deps) facts
    99           (* adapted from "mirabelle_sledgehammer.ML" *)
   100           fun set_file_name heading (SOME dir) =
   101               let
   102                 val prob_prefix =
   103                   "goal_" ^ string_of_int j ^ "__" ^ escape_meta name ^ "__" ^
   104                   heading
   105               in
   106                 Config.put dest_dir dir
   107                 #> Config.put problem_prefix (prob_prefix ^ "__")
   108                 #> Config.put SMT_Config.debug_files (dir ^ "/" ^ prob_prefix)
   109               end
   110             | set_file_name _ NONE = I
   111           fun prove ok heading get facts =
   112             let
   113               fun nickify ((_, stature), th) =
   114                 ((K (escape_meta (nickname_of th)), stature), th)
   115               val facts =
   116                 facts
   117                 |> map (get #> nickify)
   118                 |> maybe_instantiate_inducts ctxt hyp_ts concl_t
   119                 |> take (the max_facts)
   120               val ctxt = ctxt |> set_file_name heading prob_dir_name
   121               val res as {outcome, ...} =
   122                 run_prover_for_mash ctxt params prover facts goal
   123               val _ = if is_none outcome then ok := !ok + 1 else ()
   124             in str_of_res heading facts res end
   125           val [mepo_s, mash_s, mesh_s, isar_s] =
   126             [fn () => prove mepo_ok MePoN fst mepo_facts,
   127              fn () => prove mash_ok MaShN fst mash_facts,
   128              fn () => prove mesh_ok MeShN I mesh_facts,
   129              fn () => prove isar_ok IsarN fst isar_facts]
   130             |> (* Par_List. *) map (fn f => f ())
   131         in
   132           ["Goal " ^ string_of_int j ^ ": " ^ name, mepo_s, mash_s, mesh_s,
   133            isar_s]
   134           |> cat_lines |> print
   135         end
   136       else
   137         ()
   138     fun total_of heading ok n =
   139       "  " ^ heading ^ ": " ^ string_of_int (!ok) ^ " (" ^
   140       Real.fmt (StringCvt.FIX (SOME 1))
   141                (100.0 * Real.fromInt (!ok) / Real.fromInt n) ^ "%)"
   142     val inst_inducts = Config.get ctxt instantiate_inducts
   143     val options =
   144       [prover, string_of_int (the max_facts) ^ " facts",
   145        "slice" |> not slice ? prefix "dont_", the_default "smart" type_enc,
   146        the_default "smart" lam_trans,
   147        ATP_Util.string_from_time (timeout |> the_default one_year),
   148        "instantiate_inducts" |> not inst_inducts ? prefix "dont_"]
   149     val n = length lines
   150   in
   151     print " * * *";
   152     print ("Options: " ^ commas options);
   153     Par_List.map solve_goal (tag_list 1 lines);
   154     ["Successes (of " ^ string_of_int n ^ " goals)",
   155      total_of MePoN mepo_ok n,
   156      total_of MaShN mash_ok n,
   157      total_of MeShN mesh_ok n,
   158      total_of IsarN isar_ok n]
   159     |> cat_lines |> print
   160   end
   161 
   162 end;