src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author wenzelm
Thu Jul 18 20:53:22 2013 +0200 (2013-07-18 ago)
changeset 52697 6fb98a20c349
parent 52196 2281f33e8da6
child 53082 369e39511555
permissions -rw-r--r--
explicit padding on command boundary for "auto" generated sendback -- do not replace the corresponding goal command, but append to it;
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type raw_fact = Sledgehammer_Fact.raw_fact
    11   type fact = Sledgehammer_Fact.fact
    12   type fact_override = Sledgehammer_Fact.fact_override
    13   type params = Sledgehammer_Provers.params
    14   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    15   type prover_result = Sledgehammer_Provers.prover_result
    16 
    17   val trace : bool Config.T
    18   val snow : bool Config.T
    19   val MePoN : string
    20   val MaShN : string
    21   val MeShN : string
    22   val mepoN : string
    23   val mashN : string
    24   val meshN : string
    25   val unlearnN : string
    26   val learn_isarN : string
    27   val learn_proverN : string
    28   val relearn_isarN : string
    29   val relearn_proverN : string
    30   val fact_filters : string list
    31   val encode_str : string -> string
    32   val encode_strs : string list -> string
    33   val unencode_str : string -> string
    34   val unencode_strs : string -> string list
    35   val encode_features : (string * real) list -> string
    36   val extract_suggestions : string -> string * string list
    37 
    38   structure MaSh:
    39   sig
    40     val unlearn : Proof.context -> unit
    41     val learn :
    42       Proof.context -> bool
    43       -> (string * string list * (string * real) list * string list) list -> unit
    44     val relearn :
    45       Proof.context -> bool -> (string * string list) list -> unit
    46     val query :
    47       Proof.context -> bool -> bool -> int
    48       -> string list * (string * real) list * string list -> string list
    49   end
    50 
    51   val mash_unlearn : Proof.context -> unit
    52   val nickname_of_thm : thm -> string
    53   val find_suggested_facts :
    54     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    55   val mesh_facts :
    56     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    57     -> 'a list
    58   val crude_thm_ord : thm * thm -> order
    59   val thm_less : thm * thm -> bool
    60   val goal_of_thm : theory -> thm -> thm
    61   val run_prover_for_mash :
    62     Proof.context -> params -> string -> fact list -> thm -> prover_result
    63   val features_of :
    64     Proof.context -> string -> theory -> stature -> term list
    65     -> (string * real) list
    66   val trim_dependencies : string list -> string list option
    67   val isar_dependencies_of :
    68     string Symtab.table * string Symtab.table -> thm -> string list
    69   val prover_dependencies_of :
    70     Proof.context -> params -> string -> int -> raw_fact list
    71     -> string Symtab.table * string Symtab.table -> thm
    72     -> bool * string list
    73   val weight_mepo_facts : 'a list -> ('a * real) list
    74   val weight_mash_facts : 'a list -> ('a * real) list
    75   val find_mash_suggestions :
    76     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    77     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    78   val mash_suggested_facts :
    79     Proof.context -> params -> string -> int -> term list -> term
    80     -> raw_fact list -> fact list * fact list
    81   val mash_learn_proof :
    82     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    83     -> unit
    84   val attach_parents_to_facts :
    85     ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
    86   val mash_learn :
    87     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    88   val is_mash_enabled : unit -> bool
    89   val mash_can_suggest_facts : Proof.context -> bool
    90   val generous_max_facts : int -> int
    91   val mepo_weight : real
    92   val mash_weight : real
    93   val relevant_facts :
    94     Proof.context -> params -> string -> int -> fact_override -> term list
    95     -> term -> raw_fact list -> (string * fact list) list
    96   val kill_learners : unit -> unit
    97   val running_learners : unit -> unit
    98 end;
    99 
   100 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
   101 struct
   102 
   103 open ATP_Util
   104 open ATP_Problem_Generate
   105 open Sledgehammer_Util
   106 open Sledgehammer_Fact
   107 open Sledgehammer_Provers
   108 open Sledgehammer_Minimize
   109 open Sledgehammer_MePo
   110 
   111 val trace =
   112   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
   113 val snow =
   114   Attrib.setup_config_bool @{binding sledgehammer_mash_snow} (K false)
   115 
   116 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
   117 
   118 val MePoN = "MePo"
   119 val MaShN = "MaSh"
   120 val MeShN = "MeSh"
   121 
   122 val mepoN = "mepo"
   123 val mashN = "mash"
   124 val meshN = "mesh"
   125 
   126 val fact_filters = [meshN, mepoN, mashN]
   127 
   128 val unlearnN = "unlearn"
   129 val learn_isarN = "learn_isar"
   130 val learn_proverN = "learn_prover"
   131 val relearn_isarN = "relearn_isar"
   132 val relearn_proverN = "relearn_prover"
   133 
   134 fun mash_model_dir () =
   135   Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   136 val mash_state_dir = mash_model_dir
   137 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   138 
   139 
   140 (*** Low-level communication with MaSh ***)
   141 
   142 fun wipe_out_file file = (try (File.rm o Path.explode) file; ())
   143 
   144 fun write_file banner (xs, f) path =
   145   (case banner of SOME s => File.write path s | NONE => ();
   146    xs |> chunk_list 500
   147       |> List.app (File.append path o space_implode "" o map f))
   148   handle IO.Io _ => ()
   149 
   150 fun run_mash_tool ctxt overlord save max_suggs write_cmds read_suggs =
   151   let
   152     val (temp_dir, serial) =
   153       if overlord then (getenv "ISABELLE_HOME_USER", "")
   154       else (getenv "ISABELLE_TMP", serial_string ())
   155     val log_file = if overlord then temp_dir ^ "/mash_log" else "/dev/null"
   156     val err_file = temp_dir ^ "/mash_err" ^ serial
   157     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   158     val sugg_path = Path.explode sugg_file
   159     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   160     val cmd_path = Path.explode cmd_file
   161     val model_dir = File.shell_path (mash_model_dir ())
   162     val core =
   163       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   164       " --numberOfPredictions " ^ string_of_int max_suggs ^
   165       (* " --learnTheories" ^ *) (if save then " --saveModel" else "")
   166     val command =
   167       "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; " ^
   168       "./mash.py --quiet" ^
   169       (if Config.get ctxt snow then " --snow" else "") ^
   170       " --outputDir " ^ model_dir ^
   171       " --modelFile=" ^ model_dir ^ "/model.pickle" ^
   172       " --dictsFile=" ^ model_dir ^ "/dict.pickle" ^
   173       " --theoryFile=" ^ model_dir ^ "/theory.pickle" ^
   174       " --log " ^ log_file ^ " " ^ core ^
   175       " >& " ^ err_file
   176     fun run_on () =
   177       (Isabelle_System.bash command
   178        |> tap (fn _ => trace_msg ctxt (fn () =>
   179               case try File.read (Path.explode err_file) of
   180                 NONE => "Done"
   181               | SOME "" => "Done"
   182               | SOME s => "Error: " ^ elide_string 1000 s));
   183        read_suggs (fn () => try File.read_lines sugg_path |> these))
   184     fun clean_up () =
   185       if overlord then ()
   186       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   187   in
   188     write_file (SOME "") ([], K "") sugg_path;
   189     write_file (SOME "") write_cmds cmd_path;
   190     trace_msg ctxt (fn () => "Running " ^ command);
   191     with_cleanup clean_up run_on ()
   192   end
   193 
   194 fun meta_char c =
   195   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   196      c = #")" orelse c = #"," then
   197     String.str c
   198   else
   199     (* fixed width, in case more digits follow *)
   200     "%" ^ stringN_of_int 3 (Char.ord c)
   201 
   202 fun unmeta_chars accum [] = String.implode (rev accum)
   203   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   204     (case Int.fromString (String.implode [d1, d2, d3]) of
   205        SOME n => unmeta_chars (Char.chr n :: accum) cs
   206      | NONE => "" (* error *))
   207   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   208   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   209 
   210 val encode_str = String.translate meta_char
   211 val encode_strs = map encode_str #> space_implode " "
   212 val unencode_str = String.explode #> unmeta_chars []
   213 val unencode_strs =
   214   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
   215 
   216 fun freshish_name () =
   217   Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
   218   serial_string ()
   219 
   220 fun encode_feature (name, weight) =
   221   encode_str name ^
   222   (if Real.== (weight, 1.0) then "" else "=" ^ Markup.print_real weight)
   223 
   224 val encode_features = map encode_feature #> space_implode " "
   225 
   226 fun str_of_learn (name, parents, feats, deps) =
   227   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   228   encode_features feats ^ "; " ^ encode_strs deps ^ "\n"
   229 
   230 fun str_of_relearn (name, deps) =
   231   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
   232 
   233 fun str_of_query learn_hints (parents, feats, hints) =
   234   (if not learn_hints orelse null hints then ""
   235    else str_of_learn (freshish_name (), parents, feats, hints)) ^
   236   "? " ^ encode_strs parents ^ "; " ^ encode_features feats ^
   237   (if learn_hints orelse null hints then "" else "; " ^ encode_strs hints) ^
   238   "\n"
   239 
   240 (* The weights currently returned by "mash.py" are too spaced out to make any
   241    sense. *)
   242 fun extract_suggestion sugg =
   243   case space_explode "=" sugg of
   244     [name, _ (* weight *)] =>
   245     SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
   246   | [name] => SOME (unencode_str name (* , 1.0 *))
   247   | _ => NONE
   248 
   249 fun extract_suggestions line =
   250   case space_explode ":" line of
   251     [goal, suggs] =>
   252     (unencode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   253   | _ => ("", [])
   254 
   255 structure MaSh =
   256 struct
   257 
   258 fun unlearn ctxt =
   259   let val path = mash_model_dir () in
   260     trace_msg ctxt (K "MaSh unlearn");
   261     try (File.fold_dir (fn file => fn _ =>
   262                            try File.rm (Path.append path (Path.basic file)))
   263                        path) NONE;
   264     ()
   265   end
   266 
   267 fun learn _ _ [] = ()
   268   | learn ctxt overlord learns =
   269     (trace_msg ctxt (fn () => "MaSh learn " ^
   270          elide_string 1000 (space_implode " " (map #1 learns)));
   271      run_mash_tool ctxt overlord true 0 (learns, str_of_learn) (K ()))
   272 
   273 fun relearn _ _ [] = ()
   274   | relearn ctxt overlord relearns =
   275     (trace_msg ctxt (fn () => "MaSh relearn " ^
   276          elide_string 1000 (space_implode " " (map #1 relearns)));
   277      run_mash_tool ctxt overlord true 0 (relearns, str_of_relearn) (K ()))
   278 
   279 fun query ctxt overlord learn_hints max_suggs (query as (_, feats, hints)) =
   280   (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
   281    run_mash_tool ctxt overlord (learn_hints andalso not (null hints))
   282        max_suggs ([query], str_of_query learn_hints)
   283        (fn suggs =>
   284            case suggs () of
   285              [] => []
   286            | suggs => snd (extract_suggestions (List.last suggs)))
   287    handle List.Empty => [])
   288 
   289 end;
   290 
   291 
   292 (*** Middle-level communication with MaSh ***)
   293 
   294 datatype proof_kind =
   295   Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   296 
   297 fun str_of_proof_kind Isar_Proof = "i"
   298   | str_of_proof_kind Automatic_Proof = "a"
   299   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   300 
   301 fun proof_kind_of_str "i" = Isar_Proof
   302   | proof_kind_of_str "a" = Automatic_Proof
   303   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   304 
   305 (* FIXME: Here a "Graph.update_node" function would be useful *)
   306 fun update_access_graph_node (name, kind) =
   307   Graph.default_node (name, Isar_Proof)
   308   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   309 
   310 fun try_graph ctxt when def f =
   311   f ()
   312   handle Graph.CYCLES (cycle :: _) =>
   313          (trace_msg ctxt (fn () =>
   314               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   315        | Graph.DUP name =>
   316          (trace_msg ctxt (fn () =>
   317               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   318        | Graph.UNDEF name =>
   319          (trace_msg ctxt (fn () =>
   320               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   321        | exn =>
   322          if Exn.is_interrupt exn then
   323            reraise exn
   324          else
   325            (trace_msg ctxt (fn () =>
   326                 "Internal error when " ^ when ^ ":\n" ^
   327                 ML_Compiler.exn_message exn); def)
   328 
   329 fun graph_info G =
   330   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   331   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   332   " edge(s), " ^
   333   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   334   string_of_int (length (Graph.maximals G)) ^ " maximal"
   335 
   336 type mash_state = {access_G : unit Graph.T, dirty : string list option}
   337 
   338 val empty_state = {access_G = Graph.empty, dirty = SOME []}
   339 
   340 local
   341 
   342 val version = "*** MaSh version 20130207a ***"
   343 
   344 exception Too_New of unit
   345 
   346 fun extract_node line =
   347   case space_explode ":" line of
   348     [head, parents] =>
   349     (case space_explode " " head of
   350        [kind, name] =>
   351        SOME (unencode_str name, unencode_strs parents,
   352              try proof_kind_of_str kind |> the_default Isar_Proof)
   353      | _ => NONE)
   354   | _ => NONE
   355 
   356 fun load _ (state as (true, _)) = state
   357   | load ctxt _ =
   358     let val path = mash_state_file () in
   359       (true,
   360        case try File.read_lines path of
   361          SOME (version' :: node_lines) =>
   362          let
   363            fun add_edge_to name parent =
   364              Graph.default_node (parent, Isar_Proof)
   365              #> Graph.add_edge (parent, name)
   366            fun add_node line =
   367              case extract_node line of
   368                NONE => I (* shouldn't happen *)
   369              | SOME (name, parents, kind) =>
   370                update_access_graph_node (name, kind)
   371                #> fold (add_edge_to name) parents
   372            val access_G =
   373              case string_ord (version', version) of
   374                EQUAL =>
   375                try_graph ctxt "loading state" Graph.empty (fn () =>
   376                    fold add_node node_lines Graph.empty)
   377              | LESS =>
   378                (MaSh.unlearn ctxt; Graph.empty) (* can't parse old file *)
   379              | GREATER => raise Too_New ()
   380          in
   381            trace_msg ctxt (fn () =>
   382                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   383            {access_G = access_G, dirty = SOME []}
   384          end
   385        | _ => empty_state)
   386     end
   387 
   388 fun save _ (state as {dirty = SOME [], ...}) = state
   389   | save ctxt {access_G, dirty} =
   390     let
   391       fun str_of_entry (name, parents, kind) =
   392         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
   393         encode_strs parents ^ "\n"
   394       fun append_entry (name, (kind, (parents, _))) =
   395         cons (name, Graph.Keys.dest parents, kind)
   396       val (banner, entries) =
   397         case dirty of
   398           SOME names =>
   399           (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   400         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])
   401     in
   402       write_file banner (entries, str_of_entry) (mash_state_file ());
   403       trace_msg ctxt (fn () =>
   404           "Saved fact graph (" ^ graph_info access_G ^
   405           (case dirty of
   406              SOME dirty =>
   407              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   408            | _ => "") ^  ")");
   409       {access_G = access_G, dirty = SOME []}
   410     end
   411 
   412 val global_state =
   413   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   414 
   415 in
   416 
   417 fun map_state ctxt f =
   418   Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
   419   handle Too_New () => ()
   420 
   421 fun peek_state ctxt f =
   422   Synchronized.change_result global_state
   423       (perhaps (try (load ctxt)) #> `snd #>> f)
   424 
   425 fun clear_state ctxt =
   426   Synchronized.change global_state (fn _ =>
   427       (MaSh.unlearn ctxt; (* also removes the state file *)
   428        (true, empty_state)))
   429 
   430 end
   431 
   432 val mash_unlearn = clear_state
   433 
   434 
   435 (*** Isabelle helpers ***)
   436 
   437 val local_prefix = "local" ^ Long_Name.separator
   438 
   439 fun elided_backquote_thm threshold th =
   440   elide_string threshold
   441     (backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th)
   442 
   443 val thy_name_of_thm = Context.theory_name o Thm.theory_of_thm
   444 
   445 fun nickname_of_thm th =
   446   if Thm.has_name_hint th then
   447     let val hint = Thm.get_name_hint th in
   448       (* There must be a better way to detect local facts. *)
   449       case try (unprefix local_prefix) hint of
   450         SOME suf =>
   451         thy_name_of_thm th ^ Long_Name.separator ^ suf ^
   452         Long_Name.separator ^ elided_backquote_thm 50 th
   453       | NONE => hint
   454     end
   455   else
   456     elided_backquote_thm 200 th
   457 
   458 fun find_suggested_facts ctxt facts =
   459   let
   460     fun add (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
   461     val tab = fold add facts Symtab.empty
   462     fun lookup nick =
   463       Symtab.lookup tab nick
   464       |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick)
   465                | _ => ())
   466   in map_filter lookup end
   467 
   468 fun scaled_avg [] = 0
   469   | scaled_avg xs =
   470     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   471 
   472 fun avg [] = 0.0
   473   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   474 
   475 fun normalize_scores _ [] = []
   476   | normalize_scores max_facts xs =
   477     let val avg = avg (map snd (take max_facts xs)) in
   478       map (apsnd (curry Real.* (1.0 / avg))) xs
   479     end
   480 
   481 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   482     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   483   | mesh_facts fact_eq max_facts mess =
   484     let
   485       val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
   486       fun score_in fact (global_weight, (sels, unks)) =
   487         let
   488           fun score_at j =
   489             case try (nth sels) j of
   490               SOME (_, score) => SOME (global_weight * score)
   491             | NONE => NONE
   492         in
   493           case find_index (curry fact_eq fact o fst) sels of
   494             ~1 => (case find_index (curry fact_eq fact) unks of
   495                      ~1 => SOME 0.0
   496                    | _ => NONE)
   497           | rank => score_at rank
   498         end
   499       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   500       val facts =
   501         fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
   502     in
   503       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   504             |> map snd |> take max_facts
   505     end
   506 
   507 fun thy_feature_of s = ("y" ^ s, 8.0 (* FUDGE *))
   508 fun const_feature_of s = ("c" ^ s, 32.0 (* FUDGE *))
   509 fun free_feature_of s = ("f" ^ s, 40.0 (* FUDGE *))
   510 fun type_feature_of s = ("t" ^ s, 4.0 (* FUDGE *))
   511 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
   512 fun status_feature_of status = (string_of_status status, 2.0 (* FUDGE *))
   513 val local_feature = ("local", 8.0 (* FUDGE *))
   514 val lams_feature = ("lams", 2.0 (* FUDGE *))
   515 val skos_feature = ("skos", 2.0 (* FUDGE *))
   516 
   517 (* The following "crude" functions should be progressively phased out, since
   518    they create visibility edges that do not exist in Isabelle, resulting in
   519    failed lookups later on. *)
   520 
   521 fun crude_theory_ord p =
   522   if Theory.subthy p then
   523     if Theory.eq_thy p then EQUAL else LESS
   524   else if Theory.subthy (swap p) then
   525     GREATER
   526   else case int_ord (pairself (length o Theory.ancestors_of) p) of
   527     EQUAL => string_ord (pairself Context.theory_name p)
   528   | order => order
   529 
   530 fun crude_thm_ord p =
   531   case crude_theory_ord (pairself theory_of_thm p) of
   532     EQUAL =>
   533     let val q = pairself nickname_of_thm p in
   534       (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
   535       case bool_ord (pairself (String.isSuffix "_def") (swap q)) of
   536         EQUAL => string_ord q
   537       | ord => ord
   538     end
   539   | ord => ord
   540 
   541 val thm_less_eq = Theory.subthy o pairself theory_of_thm
   542 fun thm_less p = thm_less_eq p andalso not (thm_less_eq (swap p))
   543 
   544 val freezeT = Type.legacy_freeze_type
   545 
   546 fun freeze (t $ u) = freeze t $ freeze u
   547   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   548   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   549   | freeze (Const (s, T)) = Const (s, freezeT T)
   550   | freeze (Free (s, T)) = Free (s, freezeT T)
   551   | freeze t = t
   552 
   553 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   554 
   555 fun run_prover_for_mash ctxt params prover facts goal =
   556   let
   557     val problem =
   558       {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
   559        factss = [("", facts)]}
   560   in
   561     get_minimizing_prover ctxt MaSh (K (K ())) prover params (K (K (K "")))
   562                           problem
   563   end
   564 
   565 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   566 
   567 val logical_consts =
   568   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
   569 
   570 val max_pattern_breadth = 10
   571 
   572 fun term_features_of ctxt prover thy_name term_max_depth type_max_depth ts =
   573   let
   574     val thy = Proof_Context.theory_of ctxt
   575     fun is_built_in (x as (s, _)) args =
   576       if member (op =) logical_consts s then (true, args)
   577       else is_built_in_const_of_prover ctxt prover x args
   578     val fixes = map snd (Variable.dest_fixes ctxt)
   579     val classes = Sign.classes_of thy
   580     fun add_classes @{sort type} = I
   581       | add_classes S =
   582         fold (`(Sorts.super_classes classes)
   583               #> swap #> op ::
   584               #> subtract (op =) @{sort type}
   585               #> map class_feature_of
   586               #> union (op = o pairself fst)) S
   587     fun do_add_type (Type (s, Ts)) =
   588         (not (member (op =) bad_types s)
   589          ? insert (op = o pairself fst) (type_feature_of s))
   590         #> fold do_add_type Ts
   591       | do_add_type (TFree (_, S)) = add_classes S
   592       | do_add_type (TVar (_, S)) = add_classes S
   593     fun add_type T = type_max_depth >= 0 ? do_add_type T
   594     fun patternify_term _ 0 _ = []
   595       | patternify_term args _ (Const (x as (s, _))) =
   596         if fst (is_built_in x args) then [] else [s]
   597       | patternify_term _ depth (Free (s, _)) =
   598         if depth = term_max_depth andalso member (op =) fixes s then
   599           [thy_name ^ Long_Name.separator ^ s]
   600         else
   601           []
   602       | patternify_term args depth (t $ u) =
   603         let
   604           val ps =
   605             take max_pattern_breadth (patternify_term (u :: args) depth t)
   606           val qs =
   607             take max_pattern_breadth ("" :: patternify_term [] (depth - 1) u)
   608         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   609       | patternify_term _ _ _ = []
   610     fun add_term_pattern feature_of =
   611       union (op = o pairself fst) o map feature_of oo patternify_term []
   612     fun add_term_patterns _ 0 _ = I
   613       | add_term_patterns feature_of depth t =
   614         add_term_pattern feature_of depth t
   615         #> add_term_patterns feature_of (depth - 1) t
   616     fun add_term feature_of = add_term_patterns feature_of term_max_depth
   617     fun add_patterns t =
   618       case strip_comb t of
   619         (Const (x as (_, T)), args) =>
   620         let val (built_in, args) = is_built_in x args in
   621           (not built_in ? add_term const_feature_of t)
   622           #> add_type T
   623           #> fold add_patterns args
   624         end
   625       | (head, args) =>
   626         (case head of
   627            Const (_, T) => add_term const_feature_of t #> add_type T
   628          | Free (_, T) => add_term free_feature_of t #> add_type T
   629          | Var (_, T) => add_type T
   630          | Abs (_, T, body) => add_type T #> add_patterns body
   631          | _ => I)
   632         #> fold add_patterns args
   633   in [] |> fold add_patterns ts end
   634 
   635 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
   636 
   637 val term_max_depth = 2
   638 val type_max_depth = 2
   639 
   640 (* TODO: Generate type classes for types? *)
   641 fun features_of ctxt prover thy (scope, status) ts =
   642   let val thy_name = Context.theory_name thy in
   643     thy_feature_of thy_name ::
   644     term_features_of ctxt prover thy_name term_max_depth type_max_depth ts
   645     |> status <> General ? cons (status_feature_of status)
   646     |> scope <> Global ? cons local_feature
   647     |> exists (not o is_lambda_free) ts ? cons lams_feature
   648     |> exists (exists_Const is_exists) ts ? cons skos_feature
   649   end
   650 
   651 (* Too many dependencies is a sign that a decision procedure is at work. There
   652    isn't much to learn from such proofs. *)
   653 val max_dependencies = 20
   654 
   655 val prover_default_max_facts = 50
   656 
   657 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   658 val typedef_dep = nickname_of_thm @{thm CollectI}
   659 (* Mysterious parts of the class machinery create lots of proofs that refer
   660    exclusively to "someI_ex" (and to some internal constructions). *)
   661 val class_some_dep = nickname_of_thm @{thm someI_ex}
   662 
   663 val fundef_ths =
   664   @{thms fundef_ex1_existence fundef_ex1_uniqueness fundef_ex1_iff
   665          fundef_default_value}
   666   |> map nickname_of_thm
   667 
   668 (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
   669 val typedef_ths =
   670   @{thms type_definition.Abs_inverse type_definition.Rep_inverse
   671          type_definition.Rep type_definition.Rep_inject
   672          type_definition.Abs_inject type_definition.Rep_cases
   673          type_definition.Abs_cases type_definition.Rep_induct
   674          type_definition.Abs_induct type_definition.Rep_range
   675          type_definition.Abs_image}
   676   |> map nickname_of_thm
   677 
   678 fun is_size_def [dep] th =
   679     (case first_field ".recs" dep of
   680        SOME (pref, _) =>
   681        (case first_field ".size" (nickname_of_thm th) of
   682           SOME (pref', _) => pref = pref'
   683         | NONE => false)
   684      | NONE => false)
   685   | is_size_def _ _ = false
   686 
   687 fun trim_dependencies deps =
   688   if length deps > max_dependencies then NONE else SOME deps
   689 
   690 fun isar_dependencies_of name_tabs th =
   691   let val deps = thms_in_proof (SOME name_tabs) th in
   692     if deps = [typedef_dep] orelse
   693        deps = [class_some_dep] orelse
   694        exists (member (op =) fundef_ths) deps orelse
   695        exists (member (op =) typedef_ths) deps orelse
   696        is_size_def deps th then
   697       []
   698     else
   699       deps
   700   end
   701 
   702 fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
   703                            auto_level facts name_tabs th =
   704   case isar_dependencies_of name_tabs th of
   705     [] => (false, [])
   706   | isar_deps =>
   707     let
   708       val thy = Proof_Context.theory_of ctxt
   709       val goal = goal_of_thm thy th
   710       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt
   711       val facts = facts |> filter (fn (_, th') => thm_less (th', th))
   712       fun nickify ((_, stature), th) = ((nickname_of_thm th, stature), th)
   713       fun is_dep dep (_, th) = nickname_of_thm th = dep
   714       fun add_isar_dep facts dep accum =
   715         if exists (is_dep dep) accum then
   716           accum
   717         else case find_first (is_dep dep) facts of
   718           SOME ((_, status), th) => accum @ [(("", status), th)]
   719         | NONE => accum (* shouldn't happen *)
   720       val facts =
   721         facts
   722         |> mepo_suggested_facts ctxt params prover
   723                (max_facts |> the_default prover_default_max_facts) NONE hyp_ts
   724                concl_t
   725         |> fold (add_isar_dep facts) isar_deps
   726         |> map nickify
   727     in
   728       if verbose andalso auto_level = 0 then
   729         let val num_facts = length facts in
   730           "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of_thm th) ^
   731           " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
   732           "."
   733           |> Output.urgent_message
   734         end
   735       else
   736         ();
   737       case run_prover_for_mash ctxt params prover facts goal of
   738         {outcome = NONE, used_facts, ...} =>
   739         (if verbose andalso auto_level = 0 then
   740            let val num_facts = length used_facts in
   741              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   742              plural_s num_facts ^ "."
   743              |> Output.urgent_message
   744            end
   745          else
   746            ();
   747          (true, map fst used_facts))
   748       | _ => (false, isar_deps)
   749     end
   750 
   751 
   752 (*** High-level communication with MaSh ***)
   753 
   754 fun maximal_wrt_graph G keys =
   755   let
   756     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
   757     fun insert_new seen name =
   758       not (Symtab.defined seen name) ? insert (op =) name
   759     fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   760     fun find_maxes _ (maxs, []) = map snd maxs
   761       | find_maxes seen (maxs, new :: news) =
   762         find_maxes
   763             (seen |> num_keys (Graph.imm_succs G new) > 1
   764                      ? Symtab.default (new, ()))
   765             (if Symtab.defined tab new then
   766                let
   767                  val newp = Graph.all_preds G [new]
   768                  fun is_ancestor x yp = member (op =) yp x
   769                  val maxs =
   770                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   771                in
   772                  if exists (is_ancestor new o fst) maxs then
   773                    (maxs, news)
   774                  else
   775                    ((newp, new)
   776                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   777                     news)
   778                end
   779              else
   780                (maxs, Graph.Keys.fold (insert_new seen)
   781                                       (Graph.imm_preds G new) news))
   782   in find_maxes Symtab.empty ([], Graph.maximals G) end
   783 
   784 fun maximal_wrt_access_graph access_G =
   785   map (nickname_of_thm o snd)
   786   #> maximal_wrt_graph access_G
   787 
   788 fun is_fact_in_graph access_G get_th fact =
   789   can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
   790 
   791 (* FUDGE *)
   792 fun weight_of_mepo_fact rank =
   793   Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
   794 
   795 fun weight_mepo_facts facts =
   796   facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
   797 
   798 val weight_raw_mash_facts = weight_mepo_facts
   799 val weight_mash_facts = weight_raw_mash_facts
   800 
   801 (* FUDGE *)
   802 fun weight_of_proximity_fact rank =
   803   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   804 
   805 fun weight_proximity_facts facts =
   806   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   807 
   808 val max_proximity_facts = 100
   809 
   810 fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
   811   | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   812     let
   813       val raw_mash = find_suggested_facts ctxt facts suggs
   814       val unknown_chained =
   815         inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
   816       val proximity =
   817         facts |> sort (crude_thm_ord o pairself snd o swap)
   818               |> take max_proximity_facts
   819       val mess =
   820         [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   821          (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   822          (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
   823       val unknown =
   824         raw_unknown
   825         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   826                 [unknown_chained, proximity]
   827     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   828 
   829 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
   830                          hyp_ts concl_t facts =
   831   let
   832     val thy = Proof_Context.theory_of ctxt
   833     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   834     val (access_G, suggs) =
   835       peek_state ctxt (fn {access_G, ...} =>
   836           if Graph.is_empty access_G then
   837             (access_G, [])
   838           else
   839             let
   840               val parents = maximal_wrt_access_graph access_G facts
   841               val feats =
   842                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   843               val hints =
   844                 chained |> filter (is_fact_in_graph access_G snd)
   845                         |> map (nickname_of_thm o snd)
   846             in
   847               (access_G, MaSh.query ctxt overlord learn max_facts
   848                                     (parents, feats, hints))
   849             end)
   850     val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
   851   in
   852     find_mash_suggestions ctxt max_facts suggs facts chained unknown
   853     |> pairself (map fact_of_raw_fact)
   854   end
   855 
   856 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
   857   let
   858     fun maybe_learn_from from (accum as (parents, graph)) =
   859       try_graph ctxt "updating graph" accum (fn () =>
   860           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   861     val graph = graph |> Graph.default_node (name, Isar_Proof)
   862     val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
   863     val (deps, _) = ([], graph) |> fold maybe_learn_from deps
   864   in ((name, parents, feats, deps) :: learns, graph) end
   865 
   866 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
   867   let
   868     fun maybe_relearn_from from (accum as (parents, graph)) =
   869       try_graph ctxt "updating graph" accum (fn () =>
   870           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   871     val graph = graph |> update_access_graph_node (name, Automatic_Proof)
   872     val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
   873   in ((name, deps) :: relearns, graph) end
   874 
   875 fun flop_wrt_access_graph name =
   876   update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
   877 
   878 val learn_timeout_slack = 2.0
   879 
   880 fun launch_thread timeout task =
   881   let
   882     val hard_timeout = time_mult learn_timeout_slack timeout
   883     val birth_time = Time.now ()
   884     val death_time = Time.+ (birth_time, hard_timeout)
   885     val desc = ("Machine learner for Sledgehammer", "")
   886   in Async_Manager.thread MaShN birth_time death_time desc task end
   887 
   888 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   889                      used_ths =
   890   launch_thread (timeout |> the_default one_day) (fn () =>
   891       let
   892         val thy = Proof_Context.theory_of ctxt
   893         val name = freshish_name ()
   894         val feats = features_of ctxt prover thy (Local, General) [t]
   895       in
   896         peek_state ctxt (fn {access_G, ...} =>
   897             let
   898               val parents = maximal_wrt_access_graph access_G facts
   899               val deps =
   900                 used_ths |> filter (is_fact_in_graph access_G I)
   901                          |> map nickname_of_thm
   902             in
   903               MaSh.learn ctxt overlord [(name, parents, feats, deps)]
   904             end);
   905         (true, "")
   906       end)
   907 
   908 (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
   909 
   910 fun chunks_and_parents_for chunks th =
   911   let
   912     fun insert_parent new parents =
   913       let val parents = parents |> filter_out (fn p => thm_less_eq (p, new)) in
   914         parents |> forall (fn p => not (thm_less_eq (new, p))) parents
   915                    ? cons new
   916       end
   917     fun rechunk seen (rest as th' :: ths) =
   918       if thm_less_eq (th', th) then (rev seen, rest)
   919       else rechunk (th' :: seen) ths
   920     fun do_chunk [] accum = accum
   921       | do_chunk (chunk as hd_chunk :: _) (chunks, parents) =
   922         if thm_less_eq (hd_chunk, th) then
   923           (chunk :: chunks, insert_parent hd_chunk parents)
   924         else if thm_less_eq (List.last chunk, th) then
   925           let val (front, back as hd_back :: _) = rechunk [] chunk in
   926             (front :: back :: chunks, insert_parent hd_back parents)
   927           end
   928         else
   929           (chunk :: chunks, parents)
   930   in
   931     fold_rev do_chunk chunks ([], [])
   932     |>> cons []
   933     ||> map nickname_of_thm
   934   end
   935 
   936 fun attach_parents_to_facts _ [] = []
   937   | attach_parents_to_facts old_facts (facts as (_, th) :: _) =
   938     let
   939       fun do_facts _ [] = []
   940         | do_facts (_, parents) [fact] = [(parents, fact)]
   941         | do_facts (chunks, parents)
   942                    ((fact as (_, th)) :: (facts as (_, th') :: _)) =
   943           let
   944             val chunks = app_hd (cons th) chunks
   945             val chunks_and_parents' =
   946               if thm_less_eq (th, th') andalso
   947                  thy_name_of_thm th = thy_name_of_thm th' then
   948                 (chunks, [nickname_of_thm th])
   949               else
   950                 chunks_and_parents_for chunks th'
   951           in (parents, fact) :: do_facts chunks_and_parents' facts end
   952     in
   953       old_facts @ facts
   954       |> do_facts (chunks_and_parents_for [[]] th)
   955       |> drop (length old_facts)
   956     end
   957 
   958 fun sendback sub =
   959   Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
   960 
   961 val commit_timeout = seconds 30.0
   962 
   963 (* The timeout is understood in a very relaxed fashion. *)
   964 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
   965                      auto_level run_prover learn_timeout facts =
   966   let
   967     val timer = Timer.startRealTimer ()
   968     fun next_commit_time () =
   969       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   970     val {access_G, ...} = peek_state ctxt I
   971     val is_in_access_G = is_fact_in_graph access_G snd
   972     val no_new_facts = forall is_in_access_G facts
   973   in
   974     if no_new_facts andalso not run_prover then
   975       if auto_level < 2 then
   976         "No new " ^ (if run_prover then "automatic" else "Isar") ^
   977         " proofs to learn." ^
   978         (if auto_level = 0 andalso not run_prover then
   979            "\n\nHint: Try " ^ sendback learn_proverN ^
   980            " to learn from an automatic prover."
   981          else
   982            "")
   983       else
   984         ""
   985     else
   986       let
   987         val name_tabs = build_name_tables nickname_of_thm facts
   988         fun deps_of status th =
   989           if status = Non_Rec_Def orelse status = Rec_Def then
   990             SOME []
   991           else if run_prover then
   992             prover_dependencies_of ctxt params prover auto_level facts name_tabs
   993                                    th
   994             |> (fn (false, _) => NONE
   995                  | (true, deps) => trim_dependencies deps)
   996           else
   997             isar_dependencies_of name_tabs th
   998             |> trim_dependencies
   999         fun do_commit [] [] [] state = state
  1000           | do_commit learns relearns flops {access_G, dirty} =
  1001             let
  1002               val was_empty = Graph.is_empty access_G
  1003               val (learns, access_G) =
  1004                 ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
  1005               val (relearns, access_G) =
  1006                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
  1007               val access_G = access_G |> fold flop_wrt_access_graph flops
  1008               val dirty =
  1009                 case (was_empty, dirty, relearns) of
  1010                   (false, SOME names, []) => SOME (map #1 learns @ names)
  1011                 | _ => NONE
  1012             in
  1013               MaSh.learn ctxt overlord (rev learns);
  1014               MaSh.relearn ctxt overlord relearns;
  1015               {access_G = access_G, dirty = dirty}
  1016             end
  1017         fun commit last learns relearns flops =
  1018           (if debug andalso auto_level = 0 then
  1019              Output.urgent_message "Committing..."
  1020            else
  1021              ();
  1022            map_state ctxt (do_commit (rev learns) relearns flops);
  1023            if not last andalso auto_level = 0 then
  1024              let val num_proofs = length learns + length relearns in
  1025                "Learned " ^ string_of_int num_proofs ^ " " ^
  1026                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1027                plural_s num_proofs ^ " in the last " ^
  1028                string_of_time commit_timeout ^ "."
  1029                |> Output.urgent_message
  1030              end
  1031            else
  1032              ())
  1033         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1034           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1035                            (learns, (n, next_commit, _)) =
  1036             let
  1037               val name = nickname_of_thm th
  1038               val feats =
  1039                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
  1040               val deps = deps_of status th |> these
  1041               val n = n |> not (null deps) ? Integer.add 1
  1042               val learns = (name, parents, feats, deps) :: learns
  1043               val (learns, next_commit) =
  1044                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1045                   (commit false learns [] []; ([], next_commit_time ()))
  1046                 else
  1047                   (learns, next_commit)
  1048               val timed_out =
  1049                 case learn_timeout of
  1050                   SOME timeout => Time.> (Timer.checkRealTimer timer, timeout)
  1051                 | NONE => false
  1052             in (learns, (n, next_commit, timed_out)) end
  1053         val n =
  1054           if no_new_facts then
  1055             0
  1056           else
  1057             let
  1058               val new_facts =
  1059                 facts |> sort (crude_thm_ord o pairself snd)
  1060                       |> attach_parents_to_facts []
  1061                       |> filter_out (is_in_access_G o snd)
  1062               val (learns, (n, _, _)) =
  1063                 ([], (0, next_commit_time (), false))
  1064                 |> fold learn_new_fact new_facts
  1065             in commit true learns [] []; n end
  1066         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1067           | relearn_old_fact ((_, (_, status)), th)
  1068                              ((relearns, flops), (n, next_commit, _)) =
  1069             let
  1070               val name = nickname_of_thm th
  1071               val (n, relearns, flops) =
  1072                 case deps_of status th of
  1073                   SOME deps => (n + 1, (name, deps) :: relearns, flops)
  1074                 | NONE => (n, relearns, name :: flops)
  1075               val (relearns, flops, next_commit) =
  1076                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1077                   (commit false [] relearns flops;
  1078                    ([], [], next_commit_time ()))
  1079                 else
  1080                   (relearns, flops, next_commit)
  1081               val timed_out =
  1082                 case learn_timeout of
  1083                   SOME timeout => Time.> (Timer.checkRealTimer timer, timeout)
  1084                 | NONE => false
  1085             in ((relearns, flops), (n, next_commit, timed_out)) end
  1086         val n =
  1087           if not run_prover then
  1088             n
  1089           else
  1090             let
  1091               val max_isar = 1000 * max_dependencies
  1092               fun kind_of_proof th =
  1093                 try (Graph.get_node access_G) (nickname_of_thm th)
  1094                 |> the_default Isar_Proof
  1095               fun priority_of (_, th) =
  1096                 random_range 0 max_isar
  1097                 + (case kind_of_proof th of
  1098                      Isar_Proof => 0
  1099                    | Automatic_Proof => 2 * max_isar
  1100                    | Isar_Proof_wegen_Prover_Flop => max_isar)
  1101                 - 500 * length (isar_dependencies_of name_tabs th)
  1102               val old_facts =
  1103                 facts |> filter is_in_access_G
  1104                       |> map (`priority_of)
  1105                       |> sort (int_ord o pairself fst)
  1106                       |> map snd
  1107               val ((relearns, flops), (n, _, _)) =
  1108                 (([], []), (n, next_commit_time (), false))
  1109                 |> fold relearn_old_fact old_facts
  1110             in commit true [] relearns flops; n end
  1111       in
  1112         if verbose orelse auto_level < 2 then
  1113           "Learned " ^ string_of_int n ^ " nontrivial " ^
  1114           (if run_prover then "automatic" else "Isar") ^ " proof" ^ plural_s n ^
  1115           (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer)
  1116            else "") ^ "."
  1117         else
  1118           ""
  1119       end
  1120   end
  1121 
  1122 fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained
  1123                run_prover =
  1124   let
  1125     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
  1126     val ctxt = ctxt |> Config.put instantiate_inducts false
  1127     val facts =
  1128       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
  1129                        @{prop True}
  1130     val num_facts = length facts
  1131     val prover = hd provers
  1132     fun learn auto_level run_prover =
  1133       mash_learn_facts ctxt params prover auto_level run_prover NONE facts
  1134       |> Output.urgent_message
  1135   in
  1136     if run_prover then
  1137       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1138        plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^
  1139        (case timeout of
  1140           SOME timeout => " timeout: " ^ string_of_time timeout
  1141         | NONE => "") ^ ").\n\nCollecting Isar proofs first..."
  1142        |> Output.urgent_message;
  1143        learn 1 false;
  1144        "Now collecting automatic proofs. This may take several hours. You can \
  1145        \safely stop the learning process at any point."
  1146        |> Output.urgent_message;
  1147        learn 0 true)
  1148     else
  1149       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1150        plural_s num_facts ^ " for Isar proofs..."
  1151        |> Output.urgent_message;
  1152        learn 0 false)
  1153   end
  1154 
  1155 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1156 fun mash_can_suggest_facts ctxt =
  1157   not (Graph.is_empty (#access_G (peek_state ctxt I)))
  1158 
  1159 (* Generate more suggestions than requested, because some might be thrown out
  1160    later for various reasons. *)
  1161 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1162 
  1163 val mepo_weight = 0.5
  1164 val mash_weight = 0.5
  1165 
  1166 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1167    Sledgehammer and Try. *)
  1168 val min_secs_for_learning = 15
  1169 
  1170 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1171         max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1172   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1173     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1174   else if only then
  1175     let val facts = facts |> map fact_of_raw_fact in
  1176       [("", facts)]
  1177     end
  1178   else if max_facts <= 0 orelse null facts then
  1179     [("", [])]
  1180   else
  1181     let
  1182       fun maybe_learn () =
  1183         if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
  1184            (timeout = NONE orelse
  1185             Time.toSeconds (the timeout) >= min_secs_for_learning) then
  1186           let
  1187             val timeout = Option.map (time_mult learn_timeout_slack) timeout
  1188           in
  1189             launch_thread (timeout |> the_default one_day)
  1190                 (fn () => (true, mash_learn_facts ctxt params prover 2 false
  1191                                                   timeout facts))
  1192           end
  1193         else
  1194           ()
  1195       val effective_fact_filter =
  1196         case fact_filter of
  1197           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
  1198         | NONE =>
  1199           if is_mash_enabled () then
  1200             (maybe_learn ();
  1201              if mash_can_suggest_facts ctxt then meshN else mepoN)
  1202           else
  1203             mepoN
  1204       val add_ths = Attrib.eval_thms ctxt add
  1205       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
  1206       fun add_and_take accepts =
  1207         (case add_ths of
  1208            [] => accepts
  1209          | _ => (facts |> filter in_add |> map fact_of_raw_fact) @
  1210                 (accepts |> filter_out in_add))
  1211         |> take max_facts
  1212       fun mepo () =
  1213         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1214                              facts
  1215         |> weight_mepo_facts
  1216       fun mash () =
  1217         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1218             hyp_ts concl_t facts
  1219         |>> weight_mash_facts
  1220       val mess =
  1221         (* the order is important for the "case" expression below *)
  1222         [] |> (if effective_fact_filter <> mepoN then
  1223                  cons (mash_weight, (mash ()))
  1224                else
  1225                  I)
  1226            |> (if effective_fact_filter <> mashN then
  1227                  cons (mepo_weight, (mepo (), []))
  1228                else
  1229                  I)
  1230       val mesh =
  1231         mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess
  1232         |> add_and_take
  1233     in
  1234       case (fact_filter, mess) of
  1235         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
  1236         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
  1237          (mashN, mash |> map fst |> add_and_take)]
  1238       | _ => [("", mesh)]
  1239     end
  1240 
  1241 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
  1242 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1243 
  1244 end;