src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author blanchet
Wed Dec 12 21:48:29 2012 +0100 (2012-12-12 ago)
changeset 50510 7e4f2f8d9b50
parent 50485 3c6ac2da2f45
child 50557 31313171deb5
permissions -rw-r--r--
export a pair of ML functions
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type fact = Sledgehammer_Fact.fact
    11   type fact_override = Sledgehammer_Fact.fact_override
    12   type params = Sledgehammer_Provers.params
    13   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    14   type prover_result = Sledgehammer_Provers.prover_result
    15 
    16   val trace : bool Config.T
    17   val MaShN : string
    18   val mepoN : string
    19   val mashN : string
    20   val meshN : string
    21   val unlearnN : string
    22   val learn_isarN : string
    23   val learn_proverN : string
    24   val relearn_isarN : string
    25   val relearn_proverN : string
    26   val fact_filters : string list
    27   val escape_meta : string -> string
    28   val escape_metas : string list -> string
    29   val unescape_meta : string -> string
    30   val unescape_metas : string -> string list
    31   val encode_features : (string * real) list -> string
    32   val extract_query : string -> string * (string * real) list
    33   val mash_CLEAR : Proof.context -> unit
    34   val mash_ADD :
    35     Proof.context -> bool
    36     -> (string * string list * (string * real) list * string list) list -> unit
    37   val mash_REPROVE :
    38     Proof.context -> bool -> (string * string list) list -> unit
    39   val mash_QUERY :
    40     Proof.context -> bool -> int -> string list * (string * real) list
    41     -> (string * real) list
    42   val mash_unlearn : Proof.context -> unit
    43   val nickname_of : thm -> string
    44   val find_suggested_facts :
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    46   val mesh_facts :
    47     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
    48     -> ('a * thm) list
    49   val theory_ord : theory * theory -> order
    50   val thm_ord : thm * thm -> order
    51   val goal_of_thm : theory -> thm -> thm
    52   val run_prover_for_mash :
    53     Proof.context -> params -> string -> fact list -> thm -> prover_result
    54   val features_of :
    55     Proof.context -> string -> theory -> stature -> term list
    56     -> (string * real) list
    57   val isar_dependencies_of : string Symtab.table -> thm -> string list option
    58   val prover_dependencies_of :
    59     Proof.context -> params -> string -> int -> fact list -> string Symtab.table
    60     -> thm -> bool * string list option
    61   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    62   val find_mash_suggestions :
    63     int -> (Symtab.key * 'a) list -> ('b * thm) list -> ('b * thm) list
    64     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    65   val mash_suggested_facts :
    66     Proof.context -> params -> string -> int -> term list -> term -> fact list
    67     -> fact list * fact list
    68   val mash_learn_proof :
    69     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    70     -> unit
    71   val mash_learn :
    72     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    73   val is_mash_enabled : unit -> bool
    74   val mash_can_suggest_facts : Proof.context -> bool
    75   val generous_max_facts : int -> int
    76   val relevant_facts :
    77     Proof.context -> params -> string -> int -> fact_override -> term list
    78     -> term -> fact list -> fact list
    79   val kill_learners : unit -> unit
    80   val running_learners : unit -> unit
    81 end;
    82 
    83 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
    84 struct
    85 
    86 open ATP_Util
    87 open ATP_Problem_Generate
    88 open Sledgehammer_Util
    89 open Sledgehammer_Fact
    90 open Sledgehammer_Provers
    91 open Sledgehammer_Minimize
    92 open Sledgehammer_MePo
    93 
    94 val trace =
    95   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    96 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    97 
    98 val MaShN = "MaSh"
    99 
   100 val mepoN = "mepo"
   101 val mashN = "mash"
   102 val meshN = "mesh"
   103 
   104 val fact_filters = [meshN, mepoN, mashN]
   105 
   106 val unlearnN = "unlearn"
   107 val learn_isarN = "learn_isar"
   108 val learn_proverN = "learn_prover"
   109 val relearn_isarN = "relearn_isar"
   110 val relearn_proverN = "relearn_prover"
   111 
   112 fun mash_model_dir () =
   113   Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   114 val mash_state_dir = mash_model_dir
   115 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   116 
   117 
   118 (*** Low-level communication with MaSh ***)
   119 
   120 fun wipe_out_file file = (try (File.rm o Path.explode) file; ())
   121 
   122 fun write_file banner (xs, f) path =
   123   (case banner of SOME s => File.write path s | NONE => ();
   124    xs |> chunk_list 500
   125       |> List.app (File.append path o space_implode "" o map f))
   126   handle IO.Io _ => ()
   127 
   128 fun run_mash_tool ctxt overlord save max_suggs write_cmds read_suggs =
   129   let
   130     val (temp_dir, serial) =
   131       if overlord then (getenv "ISABELLE_HOME_USER", "")
   132       else (getenv "ISABELLE_TMP", serial_string ())
   133     val log_file = if overlord then temp_dir ^ "/mash_log" else "/dev/null"
   134     val err_file = temp_dir ^ "/mash_err" ^ serial
   135     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   136     val sugg_path = Path.explode sugg_file
   137     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   138     val cmd_path = Path.explode cmd_file
   139     val core =
   140       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   141       " --numberOfPredictions " ^ string_of_int max_suggs ^
   142       (if save then " --saveModel" else "")
   143     val command =
   144       "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^
   145       File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^
   146       " >& " ^ err_file
   147       |> tap (fn _ => trace_msg ctxt (fn () =>
   148              case try File.read (Path.explode err_file) of
   149                NONE => "Done"
   150              | SOME "" => "Done"
   151              | SOME s => "Error: " ^ elide_string 1000 s))
   152     fun run_on () =
   153       (Isabelle_System.bash command;
   154        read_suggs (fn () => try File.read_lines sugg_path |> these))
   155     fun clean_up () =
   156       if overlord then ()
   157       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   158   in
   159     write_file (SOME "") ([], K "") sugg_path;
   160     write_file (SOME "") write_cmds cmd_path;
   161     trace_msg ctxt (fn () => "Running " ^ command);
   162     with_cleanup clean_up run_on ()
   163   end
   164 
   165 fun meta_char c =
   166   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   167      c = #")" orelse c = #"," then
   168     String.str c
   169   else
   170     (* fixed width, in case more digits follow *)
   171     "%" ^ stringN_of_int 3 (Char.ord c)
   172 
   173 fun unmeta_chars accum [] = String.implode (rev accum)
   174   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   175     (case Int.fromString (String.implode [d1, d2, d3]) of
   176        SOME n => unmeta_chars (Char.chr n :: accum) cs
   177      | NONE => "" (* error *))
   178   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   179   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   180 
   181 val escape_meta = String.translate meta_char
   182 val escape_metas = map escape_meta #> space_implode " "
   183 val unescape_meta = String.explode #> unmeta_chars []
   184 val unescape_metas =
   185   space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   186 
   187 fun encode_feature (name, weight) =
   188   escape_meta name ^
   189   (if Real.== (weight, 1.0) then "" else "=" ^ Real.toString weight)
   190 
   191 val encode_features = map encode_feature #> space_implode " "
   192 
   193 fun str_of_add (name, parents, feats, deps) =
   194   "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   195   encode_features feats ^ "; " ^ escape_metas deps ^ "\n"
   196 
   197 fun str_of_reprove (name, deps) =
   198   "p " ^ escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
   199 
   200 fun str_of_query (parents, feats) =
   201   "? " ^ escape_metas parents ^ "; " ^ encode_features feats ^ "\n"
   202 
   203 fun extract_suggestion sugg =
   204   case space_explode "=" sugg of
   205     [name, weight] =>
   206     SOME (unescape_meta name, Real.fromString weight |> the_default 1.0)
   207   | [name] => SOME (unescape_meta name, 1.0)
   208   | _ => NONE
   209 
   210 fun extract_query line =
   211   case space_explode ":" line of
   212     [goal, suggs] =>
   213     (unescape_meta goal,
   214      map_filter extract_suggestion (space_explode " " suggs))
   215   | _ => ("", [])
   216 
   217 fun mash_CLEAR ctxt =
   218   let val path = mash_model_dir () in
   219     trace_msg ctxt (K "MaSh CLEAR");
   220     try (File.fold_dir (fn file => fn _ =>
   221                            try File.rm (Path.append path (Path.basic file)))
   222                        path) NONE;
   223     ()
   224   end
   225 
   226 fun mash_ADD _ _ [] = ()
   227   | mash_ADD ctxt overlord adds =
   228     (trace_msg ctxt (fn () => "MaSh ADD " ^
   229          elide_string 1000 (space_implode " " (map #1 adds)));
   230      run_mash_tool ctxt overlord true 0 (adds, str_of_add) (K ()))
   231 
   232 fun mash_REPROVE _ _ [] = ()
   233   | mash_REPROVE ctxt overlord reps =
   234     (trace_msg ctxt (fn () => "MaSh REPROVE " ^
   235          elide_string 1000 (space_implode " " (map #1 reps)));
   236      run_mash_tool ctxt overlord true 0 (reps, str_of_reprove) (K ()))
   237 
   238 fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   239   (trace_msg ctxt (fn () => "MaSh QUERY " ^ encode_features feats);
   240    run_mash_tool ctxt overlord false max_suggs
   241        ([query], str_of_query)
   242        (fn suggs =>
   243            case suggs () of
   244              [] => []
   245            | suggs => snd (extract_query (List.last suggs)))
   246    handle List.Empty => [])
   247 
   248 
   249 (*** Middle-level communication with MaSh ***)
   250 
   251 datatype proof_kind =
   252   Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   253 
   254 fun str_of_proof_kind Isar_Proof = "i"
   255   | str_of_proof_kind Automatic_Proof = "a"
   256   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   257 
   258 fun proof_kind_of_str "i" = Isar_Proof
   259   | proof_kind_of_str "a" = Automatic_Proof
   260   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   261 
   262 (* FIXME: Here a "Graph.update_node" function would be useful *)
   263 fun update_fact_graph_node (name, kind) =
   264   Graph.default_node (name, Isar_Proof)
   265   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   266 
   267 fun try_graph ctxt when def f =
   268   f ()
   269   handle Graph.CYCLES (cycle :: _) =>
   270          (trace_msg ctxt (fn () =>
   271               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   272        | Graph.DUP name =>
   273          (trace_msg ctxt (fn () =>
   274               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   275        | Graph.UNDEF name =>
   276          (trace_msg ctxt (fn () =>
   277               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   278        | exn =>
   279          if Exn.is_interrupt exn then
   280            reraise exn
   281          else
   282            (trace_msg ctxt (fn () =>
   283                 "Internal error when " ^ when ^ ":\n" ^
   284                 ML_Compiler.exn_message exn); def)
   285 
   286 fun graph_info G =
   287   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   288   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   289   " edge(s), " ^
   290   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   291   string_of_int (length (Graph.maximals G)) ^ " maximal"
   292 
   293 type mash_state = {fact_G : unit Graph.T, dirty : string list option}
   294 
   295 val empty_state = {fact_G = Graph.empty, dirty = SOME []}
   296 
   297 local
   298 
   299 val version = "*** MaSh version 20121212a ***"
   300 
   301 exception Too_New of unit
   302 
   303 fun extract_node line =
   304   case space_explode ":" line of
   305     [head, parents] =>
   306     (case space_explode " " head of
   307        [kind, name] =>
   308        SOME (unescape_meta name, unescape_metas parents,
   309              try proof_kind_of_str kind |> the_default Isar_Proof)
   310      | _ => NONE)
   311   | _ => NONE
   312 
   313 fun load _ (state as (true, _)) = state
   314   | load ctxt _ =
   315     let val path = mash_state_file () in
   316       (true,
   317        case try File.read_lines path of
   318          SOME (version' :: node_lines) =>
   319          let
   320            fun add_edge_to name parent =
   321              Graph.default_node (parent, Isar_Proof)
   322              #> Graph.add_edge (parent, name)
   323            fun add_node line =
   324              case extract_node line of
   325                NONE => I (* shouldn't happen *)
   326              | SOME (name, parents, kind) =>
   327                update_fact_graph_node (name, kind)
   328                #> fold (add_edge_to name) parents
   329            val fact_G =
   330              case string_ord (version', version) of
   331                EQUAL =>
   332                try_graph ctxt "loading state" Graph.empty (fn () =>
   333                    fold add_node node_lines Graph.empty)
   334              | LESS => Graph.empty (* can't parse old file *)
   335              | GREATER => raise Too_New ()
   336          in
   337            trace_msg ctxt (fn () =>
   338                "Loaded fact graph (" ^ graph_info fact_G ^ ")");
   339            {fact_G = fact_G, dirty = SOME []}
   340          end
   341        | _ => empty_state)
   342     end
   343 
   344 fun save _ (state as {dirty = SOME [], ...}) = state
   345   | save ctxt {fact_G, dirty} =
   346     let
   347       fun str_of_entry (name, parents, kind) =
   348         str_of_proof_kind kind ^ " " ^ escape_meta name ^ ": " ^
   349         escape_metas parents ^ "\n"
   350       fun append_entry (name, (kind, (parents, _))) =
   351         cons (name, Graph.Keys.dest parents, kind)
   352       val (banner, entries) =
   353         case dirty of
   354           SOME names =>
   355           (NONE, fold (append_entry o Graph.get_entry fact_G) names [])
   356         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry fact_G [])
   357     in
   358       write_file banner (entries, str_of_entry) (mash_state_file ());
   359       trace_msg ctxt (fn () =>
   360           "Saved fact graph (" ^ graph_info fact_G ^
   361           (case dirty of
   362              SOME dirty =>
   363              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   364            | _ => "") ^  ")");
   365       {fact_G = fact_G, dirty = SOME []}
   366     end
   367 
   368 val global_state =
   369   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   370 
   371 in
   372 
   373 fun mash_map ctxt f =
   374   Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
   375   handle Too_New () => ()
   376 
   377 fun mash_peek ctxt f =
   378   Synchronized.change_result global_state
   379       (perhaps (try (load ctxt)) #> `snd #>> f)
   380 
   381 fun mash_get ctxt =
   382   Synchronized.change_result global_state (perhaps (try (load ctxt)) #> `snd)
   383 
   384 fun mash_unlearn ctxt =
   385   Synchronized.change global_state (fn _ =>
   386       (mash_CLEAR ctxt; (* also removes the state file *)
   387        (true, empty_state)))
   388 
   389 end
   390 
   391 
   392 (*** Isabelle helpers ***)
   393 
   394 fun parent_of_local_thm th =
   395   let
   396     val thy = th |> Thm.theory_of_thm
   397     val facts = thy |> Global_Theory.facts_of
   398     val space = facts |> Facts.space_of
   399     fun id_of s = #id (Name_Space.the_entry space s)
   400     fun max_id (s', _) (s, id) =
   401       let val id' = id_of s' in if id > id' then (s, id) else (s', id') end
   402   in ("", ~1) |> Facts.fold_static max_id facts |> fst end
   403 
   404 val local_prefix = "local" ^ Long_Name.separator
   405 
   406 fun nickname_of th =
   407   if Thm.has_name_hint th then
   408     let val hint = Thm.get_name_hint th in
   409       (* FIXME: There must be a better way to detect local facts. *)
   410       case try (unprefix local_prefix) hint of
   411         SOME suf =>
   412         parent_of_local_thm th ^ Long_Name.separator ^ Long_Name.separator ^ suf
   413       | NONE => hint
   414     end
   415   else
   416     backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th
   417 
   418 fun find_suggested_facts suggs facts =
   419   let
   420     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   421     val tab = Symtab.empty |> fold add_fact facts
   422     fun find_sugg (name, weight) =
   423       Symtab.lookup tab name |> Option.map (rpair weight)
   424   in map_filter find_sugg suggs end
   425 
   426 fun scaled_avg [] = 0
   427   | scaled_avg xs =
   428     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   429 
   430 fun avg [] = 0.0
   431   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   432 
   433 fun normalize_scores _ [] = []
   434   | normalize_scores max_facts xs =
   435     let val avg = avg (map snd (take max_facts xs)) in
   436       map (apsnd (curry Real.* (1.0 / avg))) xs
   437     end
   438 
   439 fun mesh_facts max_facts [(_, (sels, unks))] =
   440     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   441   | mesh_facts max_facts mess =
   442     let
   443       val mess =
   444         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   445       val fact_eq = Thm.eq_thm o pairself snd
   446       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   447         let
   448           fun score_at j =
   449             case try (nth sels) j of
   450               SOME (_, score) => SOME (global_weight * score)
   451             | NONE => NONE
   452         in
   453           case find_index (curry fact_eq fact o fst) sels of
   454             ~1 => (case find_index (curry fact_eq fact) unks of
   455                      ~1 => score_at (sel_len - 1)
   456                    | _ => NONE)
   457           | rank => score_at rank
   458         end
   459       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   460       val facts =
   461         fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess
   462              []
   463     in
   464       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   465             |> map snd |> take max_facts
   466     end
   467 
   468 fun thy_feature_of s = ("y" ^ s, 0.5 (* FUDGE *))
   469 fun const_feature_of s = ("c" ^ s, 4.0 (* FUDGE *))
   470 fun free_feature_of s = ("f" ^ s, 5.0 (* FUDGE *))
   471 fun type_feature_of s = ("t" ^ s, 0.5 (* FUDGE *))
   472 fun class_feature_of s = ("s" ^ s, 0.25 (* FUDGE *))
   473 fun status_feature_of status = (string_of_status status, 0.5 (* FUDGE *))
   474 val local_feature = ("local", 2.0 (* FUDGE *))
   475 val lams_feature = ("lams", 0.5 (* FUDGE *))
   476 val skos_feature = ("skos", 0.5 (* FUDGE *))
   477 
   478 fun theory_ord p =
   479   if Theory.eq_thy p then
   480     EQUAL
   481   else if Theory.subthy p then
   482     LESS
   483   else if Theory.subthy (swap p) then
   484     GREATER
   485   else case int_ord (pairself (length o Theory.ancestors_of) p) of
   486     EQUAL => string_ord (pairself Context.theory_name p)
   487   | order => order
   488 
   489 fun thm_ord p =
   490   case theory_ord (pairself theory_of_thm p) of
   491     EQUAL =>
   492     (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
   493     string_ord (pairself nickname_of (swap p))
   494   | ord => ord
   495 
   496 val freezeT = Type.legacy_freeze_type
   497 
   498 fun freeze (t $ u) = freeze t $ freeze u
   499   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   500   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   501   | freeze (Const (s, T)) = Const (s, freezeT T)
   502   | freeze (Free (s, T)) = Free (s, freezeT T)
   503   | freeze t = t
   504 
   505 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   506 
   507 fun run_prover_for_mash ctxt params prover facts goal =
   508   let
   509     val problem =
   510       {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
   511        facts = facts |> map (apfst (apfst (fn name => name ())))
   512                      |> map Untranslated_Fact}
   513   in
   514     get_minimizing_prover ctxt MaSh (K (K ())) prover params (K (K (K "")))
   515                           problem
   516   end
   517 
   518 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   519 
   520 val logical_consts =
   521   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
   522 
   523 fun interesting_terms_types_and_classes ctxt thy_name prover term_max_depth
   524                                         type_max_depth ts =
   525   let
   526     val thy = Proof_Context.theory_of ctxt
   527     val fixes = map snd (Variable.dest_fixes ctxt)
   528     val classes = Sign.classes_of thy
   529     fun is_bad_const (x as (s, _)) args =
   530       member (op =) logical_consts s orelse
   531       fst (is_built_in_const_for_prover ctxt prover x args)
   532     fun add_classes @{sort type} = I
   533       | add_classes S =
   534         fold (`(Sorts.super_classes classes)
   535               #> swap #> op ::
   536               #> subtract (op =) @{sort type}
   537               #> map class_feature_of
   538               #> union (op = o pairself fst)) S
   539     fun do_add_type (Type (s, Ts)) =
   540         (not (member (op =) bad_types s)
   541          ? insert (op = o pairself fst) (type_feature_of s))
   542         #> fold do_add_type Ts
   543       | do_add_type (TFree (_, S)) = add_classes S
   544       | do_add_type (TVar (_, S)) = add_classes S
   545     fun add_type T = type_max_depth >= 0 ? do_add_type T
   546     fun patternify_term _ ~1 _ = []
   547       | patternify_term args _ (Const (x as (s, _))) =
   548         if is_bad_const x args then [] else [s]
   549       | patternify_term _ depth (Free (s, _)) =
   550         if depth = term_max_depth andalso member (op =) fixes s then
   551           [thy_name ^ Long_Name.separator ^ s]
   552         else
   553           []
   554       | patternify_term _ 0 _ = []
   555       | patternify_term args depth (t $ u) =
   556         let
   557           val ps = patternify_term (u :: args) depth t
   558           val qs = "" :: patternify_term [] (depth - 1) u
   559         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   560       | patternify_term _ _ _ = []
   561     fun add_term_pattern feature_of =
   562       union (op = o pairself fst) o map feature_of oo patternify_term []
   563     fun add_term_patterns _ ~1 _ = I
   564       | add_term_patterns feature_of depth t =
   565         add_term_pattern feature_of depth t
   566         #> add_term_patterns feature_of (depth - 1) t
   567     fun add_term feature_of = add_term_patterns feature_of term_max_depth
   568     fun add_patterns t =
   569       let val (head, args) = strip_comb t in
   570         (case head of
   571            Const (_, T) => add_term const_feature_of t #> add_type T
   572          | Free (_, T) => add_term free_feature_of t #> add_type T
   573          | Var (_, T) => add_type T
   574          | Abs (_, T, body) => add_type T #> add_patterns body
   575          | _ => I)
   576         #> fold add_patterns args
   577       end
   578   in [] |> fold add_patterns ts end
   579 
   580 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
   581 
   582 val term_max_depth = 1
   583 val type_max_depth = 1
   584 
   585 (* TODO: Generate type classes for types? *)
   586 fun features_of ctxt prover thy (scope, status) ts =
   587   let val thy_name = Context.theory_name thy in
   588     thy_feature_of thy_name ::
   589     interesting_terms_types_and_classes ctxt thy_name prover term_max_depth
   590         type_max_depth ts
   591     |> status <> General ? cons (status_feature_of status)
   592     |> scope <> Global ? cons local_feature
   593     |> exists (not o is_lambda_free) ts ? cons lams_feature
   594     |> exists (exists_Const is_exists) ts ? cons skos_feature
   595   end
   596 
   597 (* Too many dependencies is a sign that a decision procedure is at work. There
   598    isn't much to learn from such proofs. *)
   599 val max_dependencies = 20
   600 
   601 val prover_dependency_default_max_facts = 50
   602 
   603 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   604 val typedef_deps = [@{thm CollectI} |> nickname_of]
   605 
   606 (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
   607 val typedef_ths =
   608   @{thms type_definition.Abs_inverse type_definition.Rep_inverse
   609          type_definition.Rep type_definition.Rep_inject
   610          type_definition.Abs_inject type_definition.Rep_cases
   611          type_definition.Abs_cases type_definition.Rep_induct
   612          type_definition.Abs_induct type_definition.Rep_range
   613          type_definition.Abs_image}
   614   |> map nickname_of
   615 
   616 fun is_size_def [dep] th =
   617     (case first_field ".recs" dep of
   618        SOME (pref, _) =>
   619        (case first_field ".size" (nickname_of th) of
   620           SOME (pref', _) => pref = pref'
   621         | NONE => false)
   622      | NONE => false)
   623   | is_size_def _ _ = false
   624 
   625 fun trim_dependencies th deps =
   626   if length deps > max_dependencies then
   627     NONE
   628   else
   629     SOME (if deps = typedef_deps orelse
   630              exists (member (op =) typedef_ths) deps orelse
   631              is_size_def deps th then
   632             []
   633           else
   634             deps)
   635 
   636 fun isar_dependencies_of all_names th =
   637   th |> thms_in_proof (SOME all_names) |> trim_dependencies th
   638 
   639 fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
   640                            auto_level facts all_names th =
   641   case isar_dependencies_of all_names th of
   642     SOME [] => (false, SOME [])
   643   | isar_deps =>
   644     let
   645       val thy = Proof_Context.theory_of ctxt
   646       val goal = goal_of_thm thy th
   647       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
   648       val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
   649       fun fix_name ((_, stature), th) = ((fn () => nickname_of th, stature), th)
   650       fun is_dep dep (_, th) = nickname_of th = dep
   651       fun add_isar_dep facts dep accum =
   652         if exists (is_dep dep) accum then
   653           accum
   654         else case find_first (is_dep dep) facts of
   655           SOME ((name, status), th) => accum @ [((name, status), th)]
   656         | NONE => accum (* shouldn't happen *)
   657       val facts =
   658         facts
   659         |> mepo_suggested_facts ctxt params prover
   660                (max_facts |> the_default prover_dependency_default_max_facts)
   661                NONE hyp_ts concl_t
   662         |> fold (add_isar_dep facts) (these isar_deps)
   663         |> map fix_name
   664     in
   665       if verbose andalso auto_level = 0 then
   666         let val num_facts = length facts in
   667           "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of th) ^
   668           " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
   669           "."
   670           |> Output.urgent_message
   671         end
   672       else
   673         ();
   674       case run_prover_for_mash ctxt params prover facts goal of
   675         {outcome = NONE, used_facts, ...} =>
   676         (if verbose andalso auto_level = 0 then
   677            let val num_facts = length used_facts in
   678              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   679              plural_s num_facts ^ "."
   680              |> Output.urgent_message
   681            end
   682          else
   683            ();
   684          case used_facts |> map fst |> trim_dependencies th of
   685            NONE => (false, isar_deps)
   686          | prover_deps => (true, prover_deps))
   687       | _ => (false, isar_deps)
   688     end
   689 
   690 
   691 (*** High-level communication with MaSh ***)
   692 
   693 fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   694 
   695 fun maximal_in_graph fact_G facts =
   696   let
   697     val facts = [] |> fold (cons o nickname_of o snd) facts
   698     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) facts
   699     fun insert_new seen name =
   700       not (Symtab.defined seen name) ? insert (op =) name
   701     fun find_maxes _ (maxs, []) = map snd maxs
   702       | find_maxes seen (maxs, new :: news) =
   703         find_maxes
   704             (seen |> num_keys (Graph.imm_succs fact_G new) > 1
   705                      ? Symtab.default (new, ()))
   706             (if Symtab.defined tab new then
   707                let
   708                  val newp = Graph.all_preds fact_G [new]
   709                  fun is_ancestor x yp = member (op =) yp x
   710                  val maxs =
   711                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   712                in
   713                  if exists (is_ancestor new o fst) maxs then
   714                    (maxs, news)
   715                  else
   716                    ((newp, new)
   717                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   718                     news)
   719                end
   720              else
   721                (maxs, Graph.Keys.fold (insert_new seen)
   722                                       (Graph.imm_preds fact_G new) news))
   723   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   724 
   725 fun is_fact_in_graph fact_G (_, th) =
   726   can (Graph.get_node fact_G) (nickname_of th)
   727 
   728 (* use MePo weights for now *)
   729 val weight_raw_mash_facts = weight_mepo_facts
   730 val weight_mash_facts = weight_raw_mash_facts
   731 
   732 (* FUDGE *)
   733 fun weight_of_proximity_fact rank =
   734   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   735 
   736 fun weight_proximity_facts facts =
   737   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   738 
   739 val max_proximity_facts = 100
   740 
   741 fun find_mash_suggestions max_facts suggs facts chained raw_unknown =
   742   let
   743     val raw_mash =
   744       facts |> find_suggested_facts suggs
   745             (* The weights currently returned by "mash.py" are too spaced out to
   746                make any sense. *)
   747             |> map fst
   748     val proximity =
   749       facts |> sort (thm_ord o pairself snd o swap)
   750             |> take max_proximity_facts
   751     val mess =
   752       [(0.80 (* FUDGE *), (map (rpair 1.0) chained, [])),
   753        (0.16 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   754        (0.04 (* FUDGE *), (weight_proximity_facts proximity, []))]
   755     val unknown =
   756       raw_unknown
   757       |> fold (subtract (Thm.eq_thm_prop o pairself snd)) [chained, proximity]
   758   in (mesh_facts max_facts mess, unknown) end
   759 
   760 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   761                          concl_t facts =
   762   let
   763     val thy = Proof_Context.theory_of ctxt
   764     val (fact_G, suggs) =
   765       mash_peek ctxt (fn {fact_G, ...} =>
   766           if Graph.is_empty fact_G then
   767             (fact_G, [])
   768           else
   769             let
   770               val parents = maximal_in_graph fact_G facts
   771               val feats =
   772                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   773             in
   774               (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   775             end)
   776     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   777     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   778   in find_mash_suggestions max_facts suggs facts chained unknown end
   779 
   780 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   781   let
   782     fun maybe_add_from from (accum as (parents, graph)) =
   783       try_graph ctxt "updating graph" accum (fn () =>
   784           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   785     val graph = graph |> Graph.default_node (name, Isar_Proof)
   786     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   787     val (deps, _) = ([], graph) |> fold maybe_add_from deps
   788   in ((name, parents, feats, deps) :: adds, graph) end
   789 
   790 fun reprove_wrt_fact_graph ctxt (name, deps) (reps, graph) =
   791   let
   792     fun maybe_rep_from from (accum as (parents, graph)) =
   793       try_graph ctxt "updating graph" accum (fn () =>
   794           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   795     val graph = graph |> update_fact_graph_node (name, Automatic_Proof)
   796     val (deps, _) = ([], graph) |> fold maybe_rep_from deps
   797   in ((name, deps) :: reps, graph) end
   798 
   799 fun flop_wrt_fact_graph name =
   800   update_fact_graph_node (name, Isar_Proof_wegen_Prover_Flop)
   801 
   802 val learn_timeout_slack = 2.0
   803 
   804 fun launch_thread timeout task =
   805   let
   806     val hard_timeout = time_mult learn_timeout_slack timeout
   807     val birth_time = Time.now ()
   808     val death_time = Time.+ (birth_time, hard_timeout)
   809     val desc = ("Machine learner for Sledgehammer", "")
   810   in Async_Manager.launch MaShN birth_time death_time desc task end
   811 
   812 fun freshish_name () =
   813   Date.fmt ".%Y_%m_%d_%H_%M_%S__" (Date.fromTimeLocal (Time.now ())) ^
   814   serial_string ()
   815 
   816 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   817                      used_ths =
   818   if is_smt_prover ctxt prover then
   819     ()
   820   else
   821     launch_thread timeout (fn () =>
   822         let
   823           val thy = Proof_Context.theory_of ctxt
   824           val name = freshish_name ()
   825           val feats = features_of ctxt prover thy (Local, General) [t]
   826           val deps = used_ths |> map nickname_of
   827         in
   828           mash_peek ctxt (fn {fact_G, ...} =>
   829               let val parents = maximal_in_graph fact_G facts in
   830                 mash_ADD ctxt overlord [(name, parents, feats, deps)]
   831               end);
   832           (true, "")
   833         end)
   834 
   835 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub)
   836 
   837 val commit_timeout = seconds 30.0
   838 
   839 (* The timeout is understood in a very relaxed fashion. *)
   840 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
   841                      auto_level run_prover learn_timeout facts =
   842   let
   843     val timer = Timer.startRealTimer ()
   844     fun next_commit_time () =
   845       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   846     val {fact_G, ...} = mash_get ctxt
   847     val facts = facts |> sort (thm_ord o pairself snd)
   848     val (old_facts, new_facts) =
   849       facts |> List.partition (is_fact_in_graph fact_G)
   850   in
   851     if null new_facts andalso (not run_prover orelse null old_facts) then
   852       if auto_level < 2 then
   853         "No new " ^ (if run_prover then "automatic" else "Isar") ^
   854         " proofs to learn." ^
   855         (if auto_level = 0 andalso not run_prover then
   856            "\n\nHint: Try " ^ sendback learn_proverN ^
   857            " to learn from automatic provers."
   858          else
   859            "")
   860       else
   861         ""
   862     else
   863       let
   864         val all_names = build_all_names nickname_of facts
   865         fun deps_of status th =
   866           if status = Non_Rec_Def orelse status = Rec_Def then
   867             SOME []
   868           else if run_prover then
   869             prover_dependencies_of ctxt params prover auto_level facts all_names
   870                                    th
   871             |> (fn (false, _) => NONE | (true, deps) => deps)
   872           else
   873             isar_dependencies_of all_names th
   874         fun do_commit [] [] [] state = state
   875           | do_commit adds reps flops {fact_G, dirty} =
   876             let
   877               val was_empty = Graph.is_empty fact_G
   878               val (adds, fact_G) =
   879                 ([], fact_G) |> fold (add_wrt_fact_graph ctxt) adds
   880               val (reps, fact_G) =
   881                 ([], fact_G) |> fold (reprove_wrt_fact_graph ctxt) reps
   882               val fact_G = fact_G |> fold flop_wrt_fact_graph flops
   883               val dirty =
   884                 case (was_empty, dirty, reps) of
   885                   (false, SOME names, []) => SOME (map #1 adds @ names)
   886                 | _ => NONE
   887             in
   888               mash_ADD ctxt overlord (rev adds);
   889               mash_REPROVE ctxt overlord reps;
   890               {fact_G = fact_G, dirty = dirty}
   891             end
   892         fun commit last adds reps flops =
   893           (if debug andalso auto_level = 0 then
   894              Output.urgent_message "Committing..."
   895            else
   896              ();
   897            mash_map ctxt (do_commit (rev adds) reps flops);
   898            if not last andalso auto_level = 0 then
   899              let val num_proofs = length adds + length reps in
   900                "Learned " ^ string_of_int num_proofs ^ " " ^
   901                (if run_prover then "automatic" else "Isar") ^ " proof" ^
   902                plural_s num_proofs ^ " in the last " ^
   903                string_from_time commit_timeout ^ "."
   904                |> Output.urgent_message
   905              end
   906            else
   907              ())
   908         fun learn_new_fact _ (accum as (_, (_, _, _, true))) = accum
   909           | learn_new_fact ((_, stature as (_, status)), th)
   910                            (adds, (parents, n, next_commit, _)) =
   911             let
   912               val name = nickname_of th
   913               val feats =
   914                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
   915               val deps = deps_of status th |> these
   916               val n = n |> not (null deps) ? Integer.add 1
   917               val adds = (name, parents, feats, deps) :: adds
   918               val (adds, next_commit) =
   919                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   920                   (commit false adds [] []; ([], next_commit_time ()))
   921                 else
   922                   (adds, next_commit)
   923               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   924             in (adds, ([name], n, next_commit, timed_out)) end
   925         val n =
   926           if null new_facts then
   927             0
   928           else
   929             let
   930               val last_th = new_facts |> List.last |> snd
   931               (* crude approximation *)
   932               val ancestors =
   933                 old_facts
   934                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   935               val parents = maximal_in_graph fact_G ancestors
   936               val (adds, (_, n, _, _)) =
   937                 ([], (parents, 0, next_commit_time (), false))
   938                 |> fold learn_new_fact new_facts
   939             in commit true adds [] []; n end
   940         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   941           | relearn_old_fact ((_, (_, status)), th)
   942                              ((reps, flops), (n, next_commit, _)) =
   943             let
   944               val name = nickname_of th
   945               val (n, reps, flops) =
   946                 case deps_of status th of
   947                   SOME deps => (n + 1, (name, deps) :: reps, flops)
   948                 | NONE => (n, reps, name :: flops)
   949               val (reps, flops, next_commit) =
   950                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   951                   (commit false [] reps flops; ([], [], next_commit_time ()))
   952                 else
   953                   (reps, flops, next_commit)
   954               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   955             in ((reps, flops), (n, next_commit, timed_out)) end
   956         val n =
   957           if not run_prover orelse null old_facts then
   958             n
   959           else
   960             let
   961               val max_isar = 1000 * max_dependencies
   962               fun kind_of_proof th =
   963                 try (Graph.get_node fact_G) (nickname_of th)
   964                 |> the_default Isar_Proof
   965               fun priority_of (_, th) =
   966                 random_range 0 max_isar
   967                 + (case kind_of_proof th of
   968                      Isar_Proof => 0
   969                    | Automatic_Proof => 2 * max_isar
   970                    | Isar_Proof_wegen_Prover_Flop => max_isar)
   971                 - 500 * (th |> isar_dependencies_of all_names
   972                             |> Option.map length
   973                             |> the_default max_dependencies)
   974               val old_facts =
   975                 old_facts |> map (`priority_of)
   976                           |> sort (int_ord o pairself fst)
   977                           |> map snd
   978               val ((reps, flops), (n, _, _)) =
   979                 (([], []), (n, next_commit_time (), false))
   980                 |> fold relearn_old_fact old_facts
   981             in commit true [] reps flops; n end
   982       in
   983         if verbose orelse auto_level < 2 then
   984           "Learned " ^ string_of_int n ^ " nontrivial " ^
   985           (if run_prover then "automatic" else "Isar") ^ " proof" ^ plural_s n ^
   986           (if verbose then
   987              " in " ^ string_from_time (Timer.checkRealTimer timer)
   988            else
   989              "") ^ "."
   990         else
   991           ""
   992       end
   993   end
   994 
   995 fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained
   996                run_prover =
   997   let
   998     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
   999     val ctxt = ctxt |> Config.put instantiate_inducts false
  1000     val facts =
  1001       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
  1002                        @{prop True}
  1003     val num_facts = length facts
  1004     val prover = hd provers
  1005     fun learn auto_level run_prover =
  1006       mash_learn_facts ctxt params prover auto_level run_prover infinite_timeout
  1007                        facts
  1008       |> Output.urgent_message
  1009   in
  1010     if run_prover then
  1011       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1012        plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^
  1013        " timeout: " ^ string_from_time timeout ^
  1014        ").\n\nCollecting Isar proofs first..."
  1015        |> Output.urgent_message;
  1016        learn 1 false;
  1017        "Now collecting automatic proofs. This may take several hours. You can \
  1018        \safely stop the learning process at any point."
  1019        |> Output.urgent_message;
  1020        learn 0 true)
  1021     else
  1022       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1023        plural_s num_facts ^ " for Isar proofs..."
  1024        |> Output.urgent_message;
  1025        learn 0 false)
  1026   end
  1027 
  1028 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1029 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
  1030 
  1031 (* Generate more suggestions than requested, because some might be thrown out
  1032    later for various reasons. *)
  1033 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1034 
  1035 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1036    Sledgehammer and Try. *)
  1037 val min_secs_for_learning = 15
  1038 
  1039 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1040         max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1041   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1042     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1043   else if only then
  1044     facts
  1045   else if max_facts <= 0 orelse null facts then
  1046     []
  1047   else
  1048     let
  1049       fun maybe_learn () =
  1050         if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
  1051            Time.toSeconds timeout >= min_secs_for_learning then
  1052           let val timeout = time_mult learn_timeout_slack timeout in
  1053             launch_thread timeout
  1054                 (fn () => (true, mash_learn_facts ctxt params prover 2 false
  1055                                                   timeout facts))
  1056           end
  1057         else
  1058           ()
  1059       val fact_filter =
  1060         case fact_filter of
  1061           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
  1062         | NONE =>
  1063           if is_smt_prover ctxt prover then
  1064             mepoN
  1065           else if is_mash_enabled () then
  1066             (maybe_learn ();
  1067              if mash_can_suggest_facts ctxt then meshN else mepoN)
  1068           else
  1069             mepoN
  1070       val add_ths = Attrib.eval_thms ctxt add
  1071       fun prepend_facts ths accepts =
  1072         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
  1073          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
  1074         |> take max_facts
  1075       fun mepo () =
  1076         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1077                              facts
  1078         |> weight_mepo_facts
  1079       fun mash () =
  1080         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1081             hyp_ts concl_t facts
  1082         |>> weight_mash_facts
  1083       val mess =
  1084         [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
  1085            |> (if fact_filter <> mepoN then cons (0.5, (mash ())) else I)
  1086     in
  1087       mesh_facts max_facts mess
  1088       |> not (null add_ths) ? prepend_facts add_ths
  1089     end
  1090 
  1091 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
  1092 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1093 
  1094 end;