src/HOL/TPTP/mash_import.ML
changeset 48285 902ab51dd12a
parent 48284 a3cb8901d60c
child 48286 788c66a40b32
equal deleted inserted replaced
48284:a3cb8901d60c 48285:902ab51dd12a
     1 (*  Title:      HOL/TPTP/mash_import.ML
       
     2     Author:     Jasmin Blanchette, TU Muenchen
       
     3     Copyright   2012
       
     4 
       
     5 Import proof suggestions from MaSh (Machine-learning for Sledgehammer) and
       
     6 evaluate them.
       
     7 *)
       
     8 
       
     9 signature MASH_IMPORT =
       
    10 sig
       
    11   type params = Sledgehammer_Provers.params
       
    12 
       
    13   val import_and_evaluate_mash_suggestions :
       
    14     Proof.context -> params -> theory -> string -> unit
       
    15 end;
       
    16 
       
    17 structure MaSh_Import : MASH_IMPORT =
       
    18 struct
       
    19 
       
    20 open Sledgehammer_Filter_MaSh
       
    21 
       
    22 val unescape_meta =
       
    23   let
       
    24     fun un accum [] = String.implode (rev accum)
       
    25       | un accum (#"\\" :: d1 :: d2 :: d3 :: cs) =
       
    26         (case Int.fromString (String.implode [d1, d2, d3]) of
       
    27            SOME n => un (Char.chr n :: accum) cs
       
    28          | NONE => "" (* error *))
       
    29       | un _ (#"\\" :: _) = "" (* error *)
       
    30       | un accum (c :: cs) = un (c :: accum) cs
       
    31   in un [] o String.explode end
       
    32 
       
    33 val of_fact_name = unescape_meta
       
    34 
       
    35 val isaN = "Isabelle"
       
    36 val iterN = "Iterative"
       
    37 val mashN = "MaSh"
       
    38 val iter_mashN = "Iter+MaSh"
       
    39 
       
    40 val max_relevant_slack = 2
       
    41 
       
    42 fun import_and_evaluate_mash_suggestions ctxt params thy file_name =
       
    43   let
       
    44     val {provers, max_relevant, slice, type_enc, lam_trans, timeout, ...} =
       
    45       Sledgehammer_Isar.default_params ctxt []
       
    46     val prover_name = hd provers
       
    47     val path = file_name |> Path.explode
       
    48     val lines = path |> File.read_lines
       
    49     val facts = all_non_tautological_facts_of thy
       
    50     val all_names = facts |> map (Thm.get_name_hint o snd)
       
    51     val iter_ok = Unsynchronized.ref 0
       
    52     val mash_ok = Unsynchronized.ref 0
       
    53     val iter_mash_ok = Unsynchronized.ref 0
       
    54     val isa_ok = Unsynchronized.ref 0
       
    55     fun find_sugg facts sugg =
       
    56       find_first (fn (_, th) => Thm.get_name_hint th = sugg) facts
       
    57     fun sugg_facts hyp_ts concl_t facts =
       
    58       map_filter (find_sugg facts o of_fact_name)
       
    59       #> take (max_relevant_slack * the max_relevant)
       
    60       #> Sledgehammer_Fact.maybe_instantiate_inducts ctxt hyp_ts concl_t
       
    61       #> map (apfst (apfst (fn name => name ())))
       
    62     fun iter_mash_facts fs1 fs2 =
       
    63       let
       
    64         val fact_eq = (op =) o pairself fst
       
    65         fun score_in f fs =
       
    66           case find_index (curry fact_eq f) fs of
       
    67             ~1 => length fs
       
    68           | j => j
       
    69         fun score_of f = score_in f fs1 + score_in f fs2
       
    70       in
       
    71         union fact_eq fs1 fs2
       
    72         |> map (`score_of) |> sort (int_ord o pairself fst) |> map snd
       
    73         |> take (the max_relevant)
       
    74       end
       
    75     fun with_index facts s =
       
    76       (find_index (fn ((s', _), _) => s = s') facts + 1, s)
       
    77     fun index_string (j, s) = s ^ "@" ^ string_of_int j
       
    78     fun str_of_res label facts {outcome, run_time, used_facts, ...} =
       
    79       "  " ^ label ^ ": " ^
       
    80       (if is_none outcome then
       
    81          "Success (" ^ ATP_Util.string_from_time run_time ^ "): " ^
       
    82          (used_facts |> map (with_index facts o fst)
       
    83                      |> sort (int_ord o pairself fst)
       
    84                      |> map index_string
       
    85                      |> space_implode " ") ^
       
    86          (if length facts < the max_relevant then
       
    87             " (of " ^ string_of_int (length facts) ^ ")"
       
    88           else
       
    89             "")
       
    90        else
       
    91          "Failure: " ^
       
    92          (facts |> map (fst o fst)
       
    93                 |> take (the max_relevant)
       
    94                 |> tag_list 1
       
    95                 |> map index_string
       
    96                 |> space_implode " "))
       
    97     fun prove ok heading facts goal =
       
    98       let
       
    99         val facts = facts |> take (the max_relevant)
       
   100         val res as {outcome, ...} = run_prover ctxt params facts goal
       
   101         val _ = if is_none outcome then ok := !ok + 1 else ()
       
   102       in str_of_res heading facts res end
       
   103     fun solve_goal j name suggs =
       
   104       let
       
   105         val name = of_fact_name name
       
   106         val th =
       
   107           case find_first (fn (_, th) => Thm.get_name_hint th = name) facts of
       
   108             SOME (_, th) => th
       
   109           | NONE => error ("No fact called \"" ^ name ^ "\"")
       
   110         val isa_deps = isabelle_dependencies_of all_names th
       
   111         val goal = goal_of_thm thy th
       
   112         val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
       
   113         val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
       
   114         val isa_facts = sugg_facts hyp_ts concl_t facts isa_deps
       
   115         val iter_facts =
       
   116           iter_facts ctxt params (max_relevant_slack * the max_relevant) goal
       
   117                      facts
       
   118         val mash_facts = sugg_facts hyp_ts concl_t facts suggs
       
   119         val iter_mash_facts = iter_mash_facts iter_facts mash_facts
       
   120         val iter_s = prove iter_ok iterN iter_facts goal
       
   121         val mash_s = prove mash_ok mashN mash_facts goal
       
   122         val iter_mash_s = prove iter_mash_ok iter_mashN iter_mash_facts goal
       
   123         val isa_s = prove isa_ok isaN isa_facts goal
       
   124       in
       
   125         ["Goal " ^ string_of_int j ^ ": " ^ name, iter_s, mash_s, iter_mash_s,
       
   126          isa_s]
       
   127         |> cat_lines |> tracing
       
   128       end
       
   129     val explode_suggs = space_explode " " #> filter_out (curry (op =) "")
       
   130     fun do_line (j, line) =
       
   131       case space_explode ":" line of
       
   132         [goal_name, suggs] => solve_goal j goal_name (explode_suggs suggs)
       
   133       | _ => ()
       
   134     fun total_of heading ok n =
       
   135       " " ^ heading ^ ": " ^ string_of_int (!ok) ^ " (" ^
       
   136       Real.fmt (StringCvt.FIX (SOME 1))
       
   137                (100.0 * Real.fromInt (!ok) / Real.fromInt n) ^ "%)"
       
   138     val inst_inducts = Config.get ctxt Sledgehammer_Fact.instantiate_inducts
       
   139     val options =
       
   140       [prover_name, string_of_int (the max_relevant) ^ " facts",
       
   141        "slice" |> not slice ? prefix "dont_", the_default "smart" type_enc,
       
   142        the_default "smart" lam_trans, ATP_Util.string_from_time timeout,
       
   143        "instantiate_inducts" |> not inst_inducts ? prefix "dont_"]
       
   144     val n = length lines
       
   145   in
       
   146     tracing " * * *";
       
   147     tracing ("Options: " ^ commas options);
       
   148     List.app do_line (tag_list 1 lines);
       
   149     ["Successes (of " ^ string_of_int n ^ " goals)",
       
   150      total_of iterN iter_ok n,
       
   151      total_of mashN mash_ok n,
       
   152      total_of iter_mashN iter_mash_ok n,
       
   153      total_of isaN isa_ok n]
       
   154     |> cat_lines |> tracing
       
   155   end
       
   156 
       
   157 end;