src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author blanchet
Thu Nov 21 12:29:29 2013 +0100 (2013-11-21 ago)
changeset 54546 8b403a7a8c44
parent 54503 b490e15a5e19
child 54693 dd5874e4553f
permissions -rw-r--r--
fixed spying so that the envirnoment variables are queried at run-time not at build-time
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type raw_fact = Sledgehammer_Fact.raw_fact
    11   type fact = Sledgehammer_Fact.fact
    12   type fact_override = Sledgehammer_Fact.fact_override
    13   type params = Sledgehammer_Provers.params
    14   type prover_result = Sledgehammer_Provers.prover_result
    15 
    16   val trace : bool Config.T
    17   val MePoN : string
    18   val MaShN : string
    19   val MeShN : string
    20   val mepoN : string
    21   val mashN : string
    22   val meshN : string
    23   val unlearnN : string
    24   val learn_isarN : string
    25   val learn_proverN : string
    26   val relearn_isarN : string
    27   val relearn_proverN : string
    28   val fact_filters : string list
    29   val encode_str : string -> string
    30   val encode_strs : string list -> string
    31   val unencode_str : string -> string
    32   val unencode_strs : string -> string list
    33   val encode_features : (string * real) list -> string
    34   val extract_suggestions : string -> string * string list
    35 
    36   structure MaSh:
    37   sig
    38     val unlearn : Proof.context -> bool -> unit
    39     val learn :
    40       Proof.context -> bool -> bool
    41       -> (string * string list * string list * string list) list -> unit
    42     val relearn :
    43       Proof.context -> bool -> bool -> (string * string list) list -> unit
    44     val query :
    45       Proof.context -> bool -> int
    46       -> (string * string list * string list * string list) list
    47          * string list * string list * (string * real) list
    48       -> string list
    49   end
    50 
    51   val mash_unlearn : Proof.context -> params -> unit
    52   val is_mash_enabled : unit -> bool
    53   val nickname_of_thm : thm -> string
    54   val find_suggested_facts :
    55     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    56   val mesh_facts :
    57     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    58     -> 'a list
    59   val crude_thm_ord : thm * thm -> order
    60   val thm_less : thm * thm -> bool
    61   val goal_of_thm : theory -> thm -> thm
    62   val run_prover_for_mash :
    63     Proof.context -> params -> string -> string -> fact list -> thm -> prover_result
    64   val features_of :
    65     Proof.context -> theory -> int -> int Symtab.table -> stature -> term list ->
    66     (string * real) list
    67   val trim_dependencies : string list -> string list option
    68   val isar_dependencies_of :
    69     string Symtab.table * string Symtab.table -> thm -> string list
    70   val prover_dependencies_of :
    71     Proof.context -> params -> string -> int -> raw_fact list
    72     -> string Symtab.table * string Symtab.table -> thm
    73     -> bool * string list
    74   val attach_parents_to_facts :
    75     ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
    76   val num_extra_feature_facts : int
    77   val extra_feature_factor : real
    78   val weight_facts_smoothly : 'a list -> ('a * real) list
    79   val weight_facts_steeply : 'a list -> ('a * real) list
    80   val find_mash_suggestions :
    81     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    82     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    83   val add_const_counts : term -> int Symtab.table -> int Symtab.table
    84   val mash_suggested_facts :
    85     Proof.context -> params -> int -> term list -> term -> raw_fact list -> fact list * fact list
    86   val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
    87   val mash_learn :
    88     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    89 
    90   val mash_can_suggest_facts : Proof.context -> bool -> bool
    91   val generous_max_facts : int -> int
    92   val mepo_weight : real
    93   val mash_weight : real
    94   val relevant_facts :
    95     Proof.context -> params -> string -> int -> fact_override -> term list
    96     -> term -> raw_fact list -> (string * fact list) list
    97   val kill_learners : Proof.context -> params -> unit
    98   val running_learners : unit -> unit
    99 end;
   100 
   101 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
   102 struct
   103 
   104 open ATP_Util
   105 open ATP_Problem_Generate
   106 open Sledgehammer_Util
   107 open Sledgehammer_Fact
   108 open Sledgehammer_Provers
   109 open Sledgehammer_Minimize
   110 open Sledgehammer_MePo
   111 
   112 val trace =
   113   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
   114 
   115 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
   116 
   117 val MePoN = "MePo"
   118 val MaShN = "MaSh"
   119 val MeShN = "MeSh"
   120 
   121 val mepoN = "mepo"
   122 val mashN = "mash"
   123 val meshN = "mesh"
   124 
   125 val fact_filters = [meshN, mepoN, mashN]
   126 
   127 val unlearnN = "unlearn"
   128 val learn_isarN = "learn_isar"
   129 val learn_proverN = "learn_prover"
   130 val relearn_isarN = "relearn_isar"
   131 val relearn_proverN = "relearn_prover"
   132 
   133 fun mash_model_dir () =
   134   Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   135 val mash_state_dir = mash_model_dir
   136 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   137 
   138 
   139 (*** Low-level communication with MaSh ***)
   140 
   141 val save_models_arg = "--saveModels"
   142 val shutdown_server_arg = "--shutdownServer"
   143 
   144 fun wipe_out_file file = (try (File.rm o Path.explode) file; ())
   145 
   146 fun write_file banner (xs, f) path =
   147   (case banner of SOME s => File.write path s | NONE => ();
   148    xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
   149   handle IO.Io _ => ()
   150 
   151 fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
   152   let
   153     val (temp_dir, serial) =
   154       if overlord then (getenv "ISABELLE_HOME_USER", "")
   155       else (getenv "ISABELLE_TMP", serial_string ())
   156     val log_file = temp_dir ^ "/mash_log" ^ serial
   157     val err_file = temp_dir ^ "/mash_err" ^ serial
   158     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   159     val sugg_path = Path.explode sugg_file
   160     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   161     val cmd_path = Path.explode cmd_file
   162     val model_dir = File.shell_path (mash_model_dir ())
   163     val command =
   164       "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
   165       \PYTHONDONTWRITEBYTECODE=y ./mash.py\
   166       \ --quiet\
   167       \ --port=$MASH_PORT\
   168       \ --outputDir " ^ model_dir ^
   169       " --modelFile=" ^ model_dir ^ "/model.pickle\
   170       \ --dictsFile=" ^ model_dir ^ "/dict.pickle\
   171       \ --log " ^ log_file ^
   172       " --inputFile " ^ cmd_file ^
   173       " --predictions " ^ sugg_file ^
   174       (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
   175       " >& " ^ err_file ^
   176       (if background then " &" else "")
   177     fun run_on () =
   178       (Isabelle_System.bash command
   179        |> tap (fn _ =>
   180             case try File.read (Path.explode err_file) |> the_default "" of
   181               "" => trace_msg ctxt (K "Done")
   182             | s => warning ("MaSh error: " ^ elide_string 1000 s));
   183        read_suggs (fn () => try File.read_lines sugg_path |> these))
   184     fun clean_up () =
   185       if overlord then ()
   186       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   187   in
   188     write_file (SOME "") ([], K "") sugg_path;
   189     write_file (SOME "") write_cmds cmd_path;
   190     trace_msg ctxt (fn () => "Running " ^ command);
   191     with_cleanup clean_up run_on ()
   192   end
   193 
   194 fun meta_char c =
   195   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   196      c = #")" orelse c = #"," then
   197     String.str c
   198   else
   199     (* fixed width, in case more digits follow *)
   200     "%" ^ stringN_of_int 3 (Char.ord c)
   201 
   202 fun unmeta_chars accum [] = String.implode (rev accum)
   203   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   204     (case Int.fromString (String.implode [d1, d2, d3]) of
   205        SOME n => unmeta_chars (Char.chr n :: accum) cs
   206      | NONE => "" (* error *))
   207   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   208   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   209 
   210 val encode_str = String.translate meta_char
   211 val encode_strs = map encode_str #> space_implode " "
   212 val unencode_str = String.explode #> unmeta_chars []
   213 val unencode_strs =
   214   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
   215 
   216 (* Avoid scientific notation *)
   217 fun safe_str_of_real r =
   218   if r < 0.00001 then "0.00001"
   219   else if r >= 1000000.0 then "1000000"
   220   else Markup.print_real r
   221 
   222 fun encode_feature (name, weight) =
   223   encode_str name ^
   224   (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
   225 
   226 val encode_features = map encode_feature #> space_implode " "
   227 
   228 fun str_of_learn (name, parents, feats, deps) =
   229   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   230   encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
   231 
   232 fun str_of_relearn (name, deps) =
   233   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
   234 
   235 fun str_of_query max_suggs (learns, hints, parents, feats) =
   236   implode (map str_of_learn learns) ^
   237   "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^
   238   encode_features feats ^
   239   (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
   240 
   241 (* The suggested weights don't make much sense. *)
   242 fun extract_suggestion sugg =
   243   case space_explode "=" sugg of
   244     [name, _ (* weight *)] =>
   245     SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
   246   | [name] => SOME (unencode_str name (* , 1.0 *))
   247   | _ => NONE
   248 
   249 fun extract_suggestions line =
   250   case space_explode ":" line of
   251     [goal, suggs] =>
   252     (unencode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   253   | _ => ("", [])
   254 
   255 structure MaSh =
   256 struct
   257 
   258 fun shutdown ctxt overlord =
   259   (trace_msg ctxt (K "MaSh shutdown");
   260    run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
   261 
   262 fun save ctxt overlord =
   263   (trace_msg ctxt (K "MaSh save");
   264    run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
   265 
   266 fun unlearn ctxt overlord =
   267   let val path = mash_model_dir () in
   268     trace_msg ctxt (K "MaSh unlearn");
   269     shutdown ctxt overlord;
   270     try (File.fold_dir (fn file => fn _ =>
   271                            try File.rm (Path.append path (Path.basic file)))
   272                        path) NONE;
   273     ()
   274   end
   275 
   276 fun learn _ _ _ [] = ()
   277   | learn ctxt overlord save learns =
   278     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
   279       (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
   280        run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
   281                      (learns, str_of_learn) (K ()))
   282     end
   283 
   284 fun relearn _ _ _ [] = ()
   285   | relearn ctxt overlord save relearns =
   286     (trace_msg ctxt (fn () => "MaSh relearn " ^
   287          elide_string 1000 (space_implode " " (map #1 relearns)));
   288      run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
   289                    (relearns, str_of_relearn) (K ()))
   290 
   291 fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
   292   (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
   293    run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs)
   294        (fn suggs =>
   295            case suggs () of
   296              [] => []
   297            | suggs => snd (extract_suggestions (List.last suggs)))
   298    handle List.Empty => [])
   299 
   300 end;
   301 
   302 
   303 (*** Middle-level communication with MaSh ***)
   304 
   305 datatype proof_kind =
   306   Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   307 
   308 fun str_of_proof_kind Isar_Proof = "i"
   309   | str_of_proof_kind Automatic_Proof = "a"
   310   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   311 
   312 fun proof_kind_of_str "i" = Isar_Proof
   313   | proof_kind_of_str "a" = Automatic_Proof
   314   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   315 
   316 (* FIXME: Here a "Graph.update_node" function would be useful *)
   317 fun update_access_graph_node (name, kind) =
   318   Graph.default_node (name, Isar_Proof)
   319   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   320 
   321 fun try_graph ctxt when def f =
   322   f ()
   323   handle Graph.CYCLES (cycle :: _) =>
   324          (trace_msg ctxt (fn () =>
   325               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   326        | Graph.DUP name =>
   327          (trace_msg ctxt (fn () =>
   328               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   329        | Graph.UNDEF name =>
   330          (trace_msg ctxt (fn () =>
   331               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   332        | exn =>
   333          if Exn.is_interrupt exn then
   334            reraise exn
   335          else
   336            (trace_msg ctxt (fn () =>
   337                 "Internal error when " ^ when ^ ":\n" ^
   338                 ML_Compiler.exn_message exn); def)
   339 
   340 fun graph_info G =
   341   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   342   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   343   " edge(s), " ^
   344   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   345   string_of_int (length (Graph.maximals G)) ^ " maximal"
   346 
   347 type mash_state =
   348   {access_G : unit Graph.T,
   349    num_known_facts : int,
   350    dirty : string list option}
   351 
   352 val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []}
   353 
   354 local
   355 
   356 val version = "*** MaSh version 20130820 ***"
   357 
   358 exception FILE_VERSION_TOO_NEW of unit
   359 
   360 fun extract_node line =
   361   case space_explode ":" line of
   362     [head, parents] =>
   363     (case space_explode " " head of
   364        [kind, name] =>
   365        SOME (unencode_str name, unencode_strs parents,
   366              try proof_kind_of_str kind |> the_default Isar_Proof)
   367      | _ => NONE)
   368   | _ => NONE
   369 
   370 fun load_state _ _ (state as (true, _)) = state
   371   | load_state ctxt overlord _ =
   372     let val path = mash_state_file () in
   373       (true,
   374        case try File.read_lines path of
   375          SOME (version' :: node_lines) =>
   376          let
   377            fun add_edge_to name parent =
   378              Graph.default_node (parent, Isar_Proof)
   379              #> Graph.add_edge (parent, name)
   380            fun add_node line =
   381              case extract_node line of
   382                NONE => I (* shouldn't happen *)
   383              | SOME (name, parents, kind) =>
   384                update_access_graph_node (name, kind)
   385                #> fold (add_edge_to name) parents
   386            val (access_G, num_known_facts) =
   387              case string_ord (version', version) of
   388                EQUAL =>
   389                (try_graph ctxt "loading state" Graph.empty (fn () =>
   390                     fold add_node node_lines Graph.empty),
   391                 length node_lines)
   392              | LESS =>
   393                (* can't parse old file *)
   394                (MaSh.unlearn ctxt overlord; (Graph.empty, 0))
   395              | GREATER => raise FILE_VERSION_TOO_NEW ()
   396          in
   397            trace_msg ctxt (fn () =>
   398                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   399            {access_G = access_G, num_known_facts = num_known_facts,
   400             dirty = SOME []}
   401          end
   402        | _ => empty_state)
   403     end
   404 
   405 fun save_state _ (state as {dirty = SOME [], ...}) = state
   406   | save_state ctxt {access_G, num_known_facts, dirty} =
   407     let
   408       fun str_of_entry (name, parents, kind) =
   409         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
   410         encode_strs parents ^ "\n"
   411       fun append_entry (name, (kind, (parents, _))) =
   412         cons (name, Graph.Keys.dest parents, kind)
   413       val (banner, entries) =
   414         case dirty of
   415           SOME names =>
   416           (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   417         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])
   418     in
   419       write_file banner (entries, str_of_entry) (mash_state_file ());
   420       trace_msg ctxt (fn () =>
   421           "Saved fact graph (" ^ graph_info access_G ^
   422           (case dirty of
   423              SOME dirty =>
   424              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   425            | _ => "") ^  ")");
   426       {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
   427     end
   428 
   429 val global_state =
   430   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   431 
   432 in
   433 
   434 fun map_state ctxt overlord f =
   435   Synchronized.change global_state
   436                       (load_state ctxt overlord ##> (f #> save_state ctxt))
   437   handle FILE_VERSION_TOO_NEW () => ()
   438 
   439 fun peek_state ctxt overlord f =
   440   Synchronized.change_result global_state
   441       (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   442 
   443 fun clear_state ctxt overlord =
   444   Synchronized.change global_state (fn _ =>
   445       (MaSh.unlearn ctxt overlord; (* also removes the state file *)
   446        (false, empty_state)))
   447 
   448 end
   449 
   450 fun mash_unlearn ctxt ({overlord, ...} : params) =
   451   (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.")
   452 
   453 
   454 (*** Isabelle helpers ***)
   455 
   456 fun is_mash_enabled () = (getenv "MASH" = "yes")
   457 
   458 val local_prefix = "local" ^ Long_Name.separator
   459 
   460 fun elided_backquote_thm threshold th =
   461   elide_string threshold
   462     (backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th)
   463 
   464 val thy_name_of_thm = Context.theory_name o Thm.theory_of_thm
   465 
   466 fun nickname_of_thm th =
   467   if Thm.has_name_hint th then
   468     let val hint = Thm.get_name_hint th in
   469       (* There must be a better way to detect local facts. *)
   470       case try (unprefix local_prefix) hint of
   471         SOME suf =>
   472         thy_name_of_thm th ^ Long_Name.separator ^ suf ^
   473         Long_Name.separator ^ elided_backquote_thm 50 th
   474       | NONE => hint
   475     end
   476   else
   477     elided_backquote_thm 200 th
   478 
   479 fun find_suggested_facts ctxt facts =
   480   let
   481     fun add (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
   482     val tab = fold add facts Symtab.empty
   483     fun lookup nick =
   484       Symtab.lookup tab nick
   485       |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick)
   486                | _ => ())
   487   in map_filter lookup end
   488 
   489 fun scaled_avg [] = 0
   490   | scaled_avg xs =
   491     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   492 
   493 fun avg [] = 0.0
   494   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   495 
   496 fun normalize_scores _ [] = []
   497   | normalize_scores max_facts xs =
   498     let val avg = avg (map snd (take max_facts xs)) in
   499       map (apsnd (curry Real.* (1.0 / avg))) xs
   500     end
   501 
   502 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   503     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   504   | mesh_facts fact_eq max_facts mess =
   505     let
   506       val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
   507       fun score_in fact (global_weight, (sels, unks)) =
   508         let
   509           val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score)
   510         in
   511           case find_index (curry fact_eq fact o fst) sels of
   512             ~1 => if member fact_eq unks fact then NONE else SOME 0.0
   513           | rank => score_at rank
   514         end
   515       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   516       val facts = fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
   517     in
   518       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   519             |> map snd |> take max_facts
   520     end
   521 
   522 fun free_feature_of s = ("f" ^ s, 40.0 (* FUDGE *))
   523 fun thy_feature_of s = ("y" ^ s, 8.0 (* FUDGE *))
   524 fun type_feature_of s = ("t" ^ s, 4.0 (* FUDGE *))
   525 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
   526 val local_feature = ("local", 16.0 (* FUDGE *))
   527 
   528 fun crude_theory_ord p =
   529   if Theory.subthy p then
   530     if Theory.eq_thy p then EQUAL else LESS
   531   else if Theory.subthy (swap p) then
   532     GREATER
   533   else case int_ord (pairself (length o Theory.ancestors_of) p) of
   534     EQUAL => string_ord (pairself Context.theory_name p)
   535   | order => order
   536 
   537 fun crude_thm_ord p =
   538   case crude_theory_ord (pairself theory_of_thm p) of
   539     EQUAL =>
   540     let val q = pairself nickname_of_thm p in
   541       (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
   542       case bool_ord (pairself (String.isSuffix "_def") (swap q)) of
   543         EQUAL => string_ord q
   544       | ord => ord
   545     end
   546   | ord => ord
   547 
   548 val thm_less_eq = Theory.subthy o pairself theory_of_thm
   549 fun thm_less p = thm_less_eq p andalso not (thm_less_eq (swap p))
   550 
   551 val freezeT = Type.legacy_freeze_type
   552 
   553 fun freeze (t $ u) = freeze t $ freeze u
   554   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   555   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   556   | freeze (Const (s, T)) = Const (s, freezeT T)
   557   | freeze (Free (s, T)) = Free (s, freezeT T)
   558   | freeze t = t
   559 
   560 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   561 
   562 fun run_prover_for_mash ctxt params prover goal_name facts goal =
   563   let
   564     val problem =
   565       {comment = "Goal: " ^ goal_name, state = Proof.init ctxt, goal = goal, subgoal = 1,
   566        subgoal_count = 1, factss = [("", facts)]}
   567   in
   568     get_minimizing_prover ctxt MaSh (K ()) prover params (K (K (K ""))) problem
   569   end
   570 
   571 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   572 
   573 val pat_tvar_prefix = "_"
   574 val pat_var_prefix = "_"
   575 
   576 (* try "Long_Name.base_name" for shorter names *)
   577 fun massage_long_name s = s
   578 
   579 val crude_str_of_sort =
   580   space_implode ":" o map massage_long_name o subtract (op =) @{sort type}
   581 
   582 fun crude_str_of_typ (Type (s, [])) = massage_long_name s
   583   | crude_str_of_typ (Type (s, Ts)) =
   584     massage_long_name s ^ implode (map crude_str_of_typ Ts)
   585   | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
   586   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   587 
   588 fun maybe_singleton_str _ "" = []
   589   | maybe_singleton_str pref s = [pref ^ s]
   590 
   591 val max_pat_breadth = 10 (* FUDGE *)
   592 
   593 fun term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth ts =
   594   let
   595     val thy = Proof_Context.theory_of ctxt
   596 
   597     val fixes = map snd (Variable.dest_fixes ctxt)
   598     val classes = Sign.classes_of thy
   599 
   600     fun add_classes @{sort type} = I
   601       | add_classes S =
   602         fold (`(Sorts.super_classes classes)
   603               #> swap #> op ::
   604               #> subtract (op =) @{sort type} #> map massage_long_name
   605               #> map class_feature_of
   606               #> union (eq_fst (op =))) S
   607 
   608     fun pattify_type 0 _ = []
   609       | pattify_type _ (Type (s, [])) =
   610         if member (op =) bad_types s then [] else [massage_long_name s]
   611       | pattify_type depth (Type (s, U :: Ts)) =
   612         let
   613           val T = Type (s, Ts)
   614           val ps = take max_pat_breadth (pattify_type depth T)
   615           val qs = take max_pat_breadth ("" :: pattify_type (depth - 1) U)
   616         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   617       | pattify_type _ (TFree (_, S)) =
   618         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   619       | pattify_type _ (TVar (_, S)) =
   620         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   621     fun add_type_pat depth T =
   622       union (eq_fst (op =)) (map type_feature_of (pattify_type depth T))
   623     fun add_type_pats 0 _ = I
   624       | add_type_pats depth t =
   625         add_type_pat depth t #> add_type_pats (depth - 1) t
   626     fun add_type T =
   627       add_type_pats type_max_depth T
   628       #> fold_atyps_sorts (add_classes o snd) T
   629     fun add_subtypes (T as Type (_, Ts)) = add_type T #> fold add_subtypes Ts
   630       | add_subtypes T = add_type T
   631 
   632     fun weight_of_const s =
   633       16.0 +
   634       (if num_facts = 0 then
   635          0.0
   636        else
   637          let val count = Symtab.lookup const_tab s |> the_default 1 in
   638            Real.fromInt num_facts / Real.fromInt count (* FUDGE *)
   639          end)
   640     fun pattify_term _ 0 _ = []
   641       | pattify_term _ _ (Const (s, _)) =
   642         if is_widely_irrelevant_const s then [] else [(massage_long_name s, weight_of_const s)]
   643       | pattify_term _ _ (Free (s, T)) =
   644         maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
   645         |> map (rpair 0.0)
   646         |> (if member (op =) fixes s then
   647               cons (free_feature_of (massage_long_name
   648                   (thy_name ^ Long_Name.separator ^ s)))
   649             else
   650               I)
   651       | pattify_term _ _ (Var (_, T)) =
   652         maybe_singleton_str pat_var_prefix (crude_str_of_typ T) |> map (rpair 0.0)
   653       | pattify_term Ts _ (Bound j) =
   654         maybe_singleton_str pat_var_prefix (crude_str_of_typ (nth Ts j)) |> map (rpair 0.0)
   655       | pattify_term Ts depth (t $ u) =
   656         let
   657           val ps = take max_pat_breadth (pattify_term Ts depth t)
   658           val qs = take max_pat_breadth (("", 0.0) :: pattify_term Ts (depth - 1) u)
   659         in
   660           map_product (fn ppw as (p, pw) =>
   661               fn ("", _) => ppw
   662                | (q, qw) => (p ^ "(" ^ q ^ ")", pw + qw)) ps qs
   663         end
   664       | pattify_term _ _ _ = []
   665     fun add_term_pat Ts = union (eq_fst (op =)) oo pattify_term Ts
   666     fun add_term_pats _ 0 _ = I
   667       | add_term_pats Ts depth t =
   668         add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t
   669     fun add_term Ts = add_term_pats Ts term_max_depth
   670     fun add_subterms Ts t =
   671       case strip_comb t of
   672         (Const (s, T), args) =>
   673         (not (is_widely_irrelevant_const s) ? add_term Ts t)
   674         #> add_subtypes T
   675         #> fold (add_subterms Ts) args
   676       | (head, args) =>
   677         (case head of
   678            Free (_, T) => add_term Ts t #> add_subtypes T
   679          | Var (_, T) => add_subtypes T
   680          | Abs (_, T, body) => add_subtypes T #> add_subterms (T :: Ts) body
   681          | _ => I)
   682         #> fold (add_subterms Ts) args
   683   in [] |> fold (add_subterms []) ts end
   684 
   685 val term_max_depth = 2
   686 val type_max_depth = 1
   687 
   688 (* TODO: Generate type classes for types? *)
   689 fun features_of ctxt thy num_facts const_tab (scope, _) ts =
   690   let val thy_name = Context.theory_name thy in
   691     thy_feature_of thy_name ::
   692     term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth ts
   693     |> scope <> Global ? cons local_feature
   694   end
   695 
   696 (* Too many dependencies is a sign that a decision procedure is at work. There
   697    isn't much to learn from such proofs. *)
   698 val max_dependencies = 20
   699 
   700 val prover_default_max_facts = 25
   701 
   702 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   703 val typedef_dep = nickname_of_thm @{thm CollectI}
   704 (* Mysterious parts of the class machinery create lots of proofs that refer
   705    exclusively to "someI_ex" (and to some internal constructions). *)
   706 val class_some_dep = nickname_of_thm @{thm someI_ex}
   707 
   708 val fundef_ths =
   709   @{thms fundef_ex1_existence fundef_ex1_uniqueness fundef_ex1_iff
   710          fundef_default_value}
   711   |> map nickname_of_thm
   712 
   713 (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
   714 val typedef_ths =
   715   @{thms type_definition.Abs_inverse type_definition.Rep_inverse
   716          type_definition.Rep type_definition.Rep_inject
   717          type_definition.Abs_inject type_definition.Rep_cases
   718          type_definition.Abs_cases type_definition.Rep_induct
   719          type_definition.Abs_induct type_definition.Rep_range
   720          type_definition.Abs_image}
   721   |> map nickname_of_thm
   722 
   723 fun is_size_def [dep] th =
   724     (case first_field ".recs" dep of
   725        SOME (pref, _) =>
   726        (case first_field ".size" (nickname_of_thm th) of
   727           SOME (pref', _) => pref = pref'
   728         | NONE => false)
   729      | NONE => false)
   730   | is_size_def _ _ = false
   731 
   732 fun no_dependencies_for_status status =
   733   status = Non_Rec_Def orelse status = Rec_Def
   734 
   735 fun trim_dependencies deps =
   736   if length deps > max_dependencies then NONE else SOME deps
   737 
   738 fun isar_dependencies_of name_tabs th =
   739   let val deps = thms_in_proof (SOME name_tabs) th in
   740     if deps = [typedef_dep] orelse
   741        deps = [class_some_dep] orelse
   742        exists (member (op =) fundef_ths) deps orelse
   743        exists (member (op =) typedef_ths) deps orelse
   744        is_size_def deps th then
   745       []
   746     else
   747       deps
   748   end
   749 
   750 fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
   751                            auto_level facts name_tabs th =
   752   case isar_dependencies_of name_tabs th of
   753     [] => (false, [])
   754   | isar_deps =>
   755     let
   756       val thy = Proof_Context.theory_of ctxt
   757       val goal = goal_of_thm thy th
   758       val name = nickname_of_thm th
   759       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt
   760       val facts = facts |> filter (fn (_, th') => thm_less (th', th))
   761       fun nickify ((_, stature), th) = ((nickname_of_thm th, stature), th)
   762       fun is_dep dep (_, th) = nickname_of_thm th = dep
   763       fun add_isar_dep facts dep accum =
   764         if exists (is_dep dep) accum then
   765           accum
   766         else case find_first (is_dep dep) facts of
   767           SOME ((_, status), th) => accum @ [(("", status), th)]
   768         | NONE => accum (* shouldn't happen *)
   769       val mepo_facts =
   770         facts
   771         |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE
   772              hyp_ts concl_t
   773       val facts =
   774         mepo_facts
   775         |> fold (add_isar_dep facts) isar_deps
   776         |> map nickify
   777       val num_isar_deps = length isar_deps
   778     in
   779       if verbose andalso auto_level = 0 then
   780         "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
   781         " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
   782         |> Output.urgent_message
   783       else
   784         ();
   785       case run_prover_for_mash ctxt params prover name facts goal of
   786         {outcome = NONE, used_facts, ...} =>
   787         (if verbose andalso auto_level = 0 then
   788            let val num_facts = length used_facts in
   789              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   790              plural_s num_facts ^ "."
   791              |> Output.urgent_message
   792            end
   793          else
   794            ();
   795          (true, map fst used_facts))
   796       | _ => (false, isar_deps)
   797     end
   798 
   799 
   800 (*** High-level communication with MaSh ***)
   801 
   802 (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
   803 
   804 fun chunks_and_parents_for chunks th =
   805   let
   806     fun insert_parent new parents =
   807       let val parents = parents |> filter_out (fn p => thm_less_eq (p, new)) in
   808         parents |> forall (fn p => not (thm_less_eq (new, p))) parents
   809                    ? cons new
   810       end
   811     fun rechunk seen (rest as th' :: ths) =
   812       if thm_less_eq (th', th) then (rev seen, rest)
   813       else rechunk (th' :: seen) ths
   814     fun do_chunk [] accum = accum
   815       | do_chunk (chunk as hd_chunk :: _) (chunks, parents) =
   816         if thm_less_eq (hd_chunk, th) then
   817           (chunk :: chunks, insert_parent hd_chunk parents)
   818         else if thm_less_eq (List.last chunk, th) then
   819           let val (front, back as hd_back :: _) = rechunk [] chunk in
   820             (front :: back :: chunks, insert_parent hd_back parents)
   821           end
   822         else
   823           (chunk :: chunks, parents)
   824   in
   825     fold_rev do_chunk chunks ([], [])
   826     |>> cons []
   827     ||> map nickname_of_thm
   828   end
   829 
   830 fun attach_parents_to_facts _ [] = []
   831   | attach_parents_to_facts old_facts (facts as (_, th) :: _) =
   832     let
   833       fun do_facts _ [] = []
   834         | do_facts (_, parents) [fact] = [(parents, fact)]
   835         | do_facts (chunks, parents)
   836                    ((fact as (_, th)) :: (facts as (_, th') :: _)) =
   837           let
   838             val chunks = app_hd (cons th) chunks
   839             val chunks_and_parents' =
   840               if thm_less_eq (th, th') andalso
   841                  thy_name_of_thm th = thy_name_of_thm th' then
   842                 (chunks, [nickname_of_thm th])
   843               else
   844                 chunks_and_parents_for chunks th'
   845           in (parents, fact) :: do_facts chunks_and_parents' facts end
   846     in
   847       old_facts @ facts
   848       |> do_facts (chunks_and_parents_for [[]] th)
   849       |> drop (length old_facts)
   850     end
   851 
   852 fun maximal_wrt_graph G keys =
   853   let
   854     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
   855     fun insert_new seen name =
   856       not (Symtab.defined seen name) ? insert (op =) name
   857     fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   858     fun find_maxes _ (maxs, []) = map snd maxs
   859       | find_maxes seen (maxs, new :: news) =
   860         find_maxes
   861             (seen |> num_keys (Graph.imm_succs G new) > 1
   862                      ? Symtab.default (new, ()))
   863             (if Symtab.defined tab new then
   864                let
   865                  val newp = Graph.all_preds G [new]
   866                  fun is_ancestor x yp = member (op =) yp x
   867                  val maxs =
   868                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   869                in
   870                  if exists (is_ancestor new o fst) maxs then
   871                    (maxs, news)
   872                  else
   873                    ((newp, new)
   874                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   875                     news)
   876                end
   877              else
   878                (maxs, Graph.Keys.fold (insert_new seen)
   879                                       (Graph.imm_preds G new) news))
   880   in find_maxes Symtab.empty ([], Graph.maximals G) end
   881 
   882 fun maximal_wrt_access_graph access_G =
   883   map (nickname_of_thm o snd)
   884   #> maximal_wrt_graph access_G
   885 
   886 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   887 
   888 val chained_feature_factor = 0.5 (* FUDGE *)
   889 val extra_feature_factor = 0.1 (* FUDGE *)
   890 val num_extra_feature_facts = 10 (* FUDGE *)
   891 
   892 (* FUDGE *)
   893 fun weight_of_proximity_fact rank =
   894   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   895 
   896 fun weight_facts_smoothly facts =
   897   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   898 
   899 (* FUDGE *)
   900 fun steep_weight_of_fact rank =
   901   Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
   902 
   903 fun weight_facts_steeply facts =
   904   facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
   905 
   906 val max_proximity_facts = 100
   907 
   908 fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   909   let
   910     val inter_fact = inter (eq_snd Thm.eq_thm_prop)
   911     val raw_mash = find_suggested_facts ctxt facts suggs
   912     val proximate = take max_proximity_facts facts
   913     val unknown_chained = inter_fact raw_unknown chained
   914     val unknown_proximate = inter_fact raw_unknown proximate
   915     val mess =
   916       [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   917        (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
   918        (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
   919     val unknown =
   920       raw_unknown
   921       |> fold (subtract (eq_snd Thm.eq_thm_prop))
   922               [unknown_chained, unknown_proximate]
   923   in (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown) end
   924 
   925 fun add_const_counts t =
   926   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1))
   927        (Term.add_const_names t [])
   928 
   929 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   930   let
   931     val thy = Proof_Context.theory_of ctxt
   932     val thy_name = Context.theory_name thy
   933     val facts = facts |> sort (crude_thm_ord o pairself snd o swap)
   934     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   935     val num_facts = length facts
   936     val const_tab = fold (add_const_counts o prop_of o snd) facts Symtab.empty
   937 
   938     fun fact_has_right_theory (_, th) =
   939       thy_name = Context.theory_name (theory_of_thm th)
   940     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
   941       [prop_of th]
   942       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature
   943       |> map (apsnd (fn r => weight * factor * r))
   944 
   945     val (access_G, suggs) =
   946       peek_state ctxt overlord (fn {access_G, ...} =>
   947           if Graph.is_empty access_G then
   948             (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   949           else
   950             let
   951               val parents = maximal_wrt_access_graph access_G facts
   952               val goal_feats =
   953                 features_of ctxt thy num_facts const_tab (Local, General) (concl_t :: hyp_ts)
   954               val chained_feats =
   955                 chained
   956                 |> map (rpair 1.0)
   957                 |> map (chained_or_extra_features_of chained_feature_factor)
   958                 |> rpair [] |-> fold (union (eq_fst (op =)))
   959               val extra_feats =
   960                 facts
   961                 |> take (Int.max (0, num_extra_feature_facts - length chained))
   962                 |> filter fact_has_right_theory
   963                 |> weight_facts_steeply
   964                 |> map (chained_or_extra_features_of extra_feature_factor)
   965                 |> rpair [] |-> fold (union (eq_fst (op =)))
   966               val feats =
   967                 fold (union (eq_fst (op =))) [chained_feats, extra_feats]
   968                      goal_feats
   969                 |> debug ? sort (Real.compare o swap o pairself snd)
   970               val hints =
   971                 chained |> filter (is_fact_in_graph access_G o snd)
   972                         |> map (nickname_of_thm o snd)
   973             in
   974               (access_G, MaSh.query ctxt overlord max_facts
   975                                     ([], hints, parents, feats))
   976             end)
   977     val unknown = facts |> filter_out (is_fact_in_graph access_G o snd)
   978   in
   979     find_mash_suggestions ctxt max_facts suggs facts chained unknown
   980     |> pairself (map fact_of_raw_fact)
   981   end
   982 
   983 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
   984   let
   985     fun maybe_learn_from from (accum as (parents, graph)) =
   986       try_graph ctxt "updating graph" accum (fn () =>
   987           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   988     val graph = graph |> Graph.default_node (name, Isar_Proof)
   989     val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
   990     val (deps, _) = ([], graph) |> fold maybe_learn_from deps
   991   in ((name, parents, feats, deps) :: learns, graph) end
   992 
   993 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
   994   let
   995     fun maybe_relearn_from from (accum as (parents, graph)) =
   996       try_graph ctxt "updating graph" accum (fn () =>
   997           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   998     val graph = graph |> update_access_graph_node (name, Automatic_Proof)
   999     val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
  1000   in ((name, deps) :: relearns, graph) end
  1001 
  1002 fun flop_wrt_access_graph name =
  1003   update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
  1004 
  1005 val learn_timeout_slack = 2.0
  1006 
  1007 fun launch_thread timeout task =
  1008   let
  1009     val hard_timeout = time_mult learn_timeout_slack timeout
  1010     val birth_time = Time.now ()
  1011     val death_time = Time.+ (birth_time, hard_timeout)
  1012     val desc = ("Machine learner for Sledgehammer", "")
  1013   in Async_Manager.thread MaShN birth_time death_time desc task end
  1014 
  1015 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
  1016   if is_mash_enabled () then
  1017     launch_thread (timeout |> the_default one_day) (fn () =>
  1018         let
  1019           val thy = Proof_Context.theory_of ctxt
  1020           val feats = features_of ctxt thy 0 Symtab.empty (Local, General) [t] |> map fst
  1021         in
  1022           peek_state ctxt overlord (fn {access_G, ...} =>
  1023               let
  1024                 val parents = maximal_wrt_access_graph access_G facts
  1025                 val deps =
  1026                   used_ths |> filter (is_fact_in_graph access_G)
  1027                            |> map nickname_of_thm
  1028               in
  1029                 MaSh.learn ctxt overlord true [("", parents, feats, deps)]
  1030               end);
  1031           (true, "")
  1032         end)
  1033   else
  1034     ()
  1035 
  1036 fun sendback sub =
  1037   Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
  1038 
  1039 val commit_timeout = seconds 30.0
  1040 
  1041 (* The timeout is understood in a very relaxed fashion. *)
  1042 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
  1043                      save auto_level run_prover learn_timeout facts =
  1044   let
  1045     val timer = Timer.startRealTimer ()
  1046     fun next_commit_time () =
  1047       Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1048     val {access_G, ...} = peek_state ctxt overlord I
  1049     val is_in_access_G = is_fact_in_graph access_G o snd
  1050     val no_new_facts = forall is_in_access_G facts
  1051   in
  1052     if no_new_facts andalso not run_prover then
  1053       if auto_level < 2 then
  1054         "No new " ^ (if run_prover then "automatic" else "Isar") ^
  1055         " proofs to learn." ^
  1056         (if auto_level = 0 andalso not run_prover then
  1057            "\n\nHint: Try " ^ sendback learn_proverN ^
  1058            " to learn from an automatic prover."
  1059          else
  1060            "")
  1061       else
  1062         ""
  1063     else
  1064       let
  1065         val name_tabs = build_name_tables nickname_of_thm facts
  1066         fun deps_of status th =
  1067           if no_dependencies_for_status status then
  1068             SOME []
  1069           else if run_prover then
  1070             prover_dependencies_of ctxt params prover auto_level facts name_tabs
  1071                                    th
  1072             |> (fn (false, _) => NONE
  1073                  | (true, deps) => trim_dependencies deps)
  1074           else
  1075             isar_dependencies_of name_tabs th
  1076             |> trim_dependencies
  1077         fun do_commit [] [] [] state = state
  1078           | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
  1079             let
  1080               val was_empty = Graph.is_empty access_G
  1081               val (learns, access_G) =
  1082                 ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
  1083               val (relearns, access_G) =
  1084                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
  1085               val access_G = access_G |> fold flop_wrt_access_graph flops
  1086               val num_known_facts = num_known_facts + length learns
  1087               val dirty =
  1088                 case (was_empty, dirty, relearns) of
  1089                   (false, SOME names, []) => SOME (map #1 learns @ names)
  1090                 | _ => NONE
  1091             in
  1092               MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
  1093               MaSh.relearn ctxt overlord save relearns;
  1094               {access_G = access_G, num_known_facts = num_known_facts,
  1095                dirty = dirty}
  1096             end
  1097         fun commit last learns relearns flops =
  1098           (if debug andalso auto_level = 0 then
  1099              Output.urgent_message "Committing..."
  1100            else
  1101              ();
  1102            map_state ctxt overlord (do_commit (rev learns) relearns flops);
  1103            if not last andalso auto_level = 0 then
  1104              let val num_proofs = length learns + length relearns in
  1105                "Learned " ^ string_of_int num_proofs ^ " " ^
  1106                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1107                plural_s num_proofs ^ " in the last " ^
  1108                string_of_time commit_timeout ^ "."
  1109                |> Output.urgent_message
  1110              end
  1111            else
  1112              ())
  1113         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1114           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1115                            (learns, (n, next_commit, _)) =
  1116             let
  1117               val name = nickname_of_thm th
  1118               val feats =
  1119                 features_of ctxt (theory_of_thm th) 0 Symtab.empty stature [prop_of th] |> map fst
  1120               val deps = deps_of status th |> these
  1121               val n = n |> not (null deps) ? Integer.add 1
  1122               val learns = (name, parents, feats, deps) :: learns
  1123               val (learns, next_commit) =
  1124                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1125                   (commit false learns [] []; ([], next_commit_time ()))
  1126                 else
  1127                   (learns, next_commit)
  1128               val timed_out =
  1129                 case learn_timeout of
  1130                   SOME timeout => Time.> (Timer.checkRealTimer timer, timeout)
  1131                 | NONE => false
  1132             in (learns, (n, next_commit, timed_out)) end
  1133         val n =
  1134           if no_new_facts then
  1135             0
  1136           else
  1137             let
  1138               val new_facts =
  1139                 facts |> sort (crude_thm_ord o pairself snd)
  1140                       |> attach_parents_to_facts []
  1141                       |> filter_out (is_in_access_G o snd)
  1142               val (learns, (n, _, _)) =
  1143                 ([], (0, next_commit_time (), false))
  1144                 |> fold learn_new_fact new_facts
  1145             in commit true learns [] []; n end
  1146         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1147           | relearn_old_fact ((_, (_, status)), th)
  1148                              ((relearns, flops), (n, next_commit, _)) =
  1149             let
  1150               val name = nickname_of_thm th
  1151               val (n, relearns, flops) =
  1152                 case deps_of status th of
  1153                   SOME deps => (n + 1, (name, deps) :: relearns, flops)
  1154                 | NONE => (n, relearns, name :: flops)
  1155               val (relearns, flops, next_commit) =
  1156                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1157                   (commit false [] relearns flops;
  1158                    ([], [], next_commit_time ()))
  1159                 else
  1160                   (relearns, flops, next_commit)
  1161               val timed_out =
  1162                 case learn_timeout of
  1163                   SOME timeout => Time.> (Timer.checkRealTimer timer, timeout)
  1164                 | NONE => false
  1165             in ((relearns, flops), (n, next_commit, timed_out)) end
  1166         val n =
  1167           if not run_prover then
  1168             n
  1169           else
  1170             let
  1171               val max_isar = 1000 * max_dependencies
  1172               fun kind_of_proof th =
  1173                 try (Graph.get_node access_G) (nickname_of_thm th)
  1174                 |> the_default Isar_Proof
  1175               fun priority_of (_, th) =
  1176                 random_range 0 max_isar
  1177                 + (case kind_of_proof th of
  1178                      Isar_Proof => 0
  1179                    | Automatic_Proof => 2 * max_isar
  1180                    | Isar_Proof_wegen_Prover_Flop => max_isar)
  1181                 - 100 * length (isar_dependencies_of name_tabs th)
  1182               val old_facts =
  1183                 facts |> filter is_in_access_G
  1184                       |> map (`priority_of)
  1185                       |> sort (int_ord o pairself fst)
  1186                       |> map snd
  1187               val ((relearns, flops), (n, _, _)) =
  1188                 (([], []), (n, next_commit_time (), false))
  1189                 |> fold relearn_old_fact old_facts
  1190             in commit true [] relearns flops; n end
  1191       in
  1192         if verbose orelse auto_level < 2 then
  1193           "Learned " ^ string_of_int n ^ " nontrivial " ^
  1194           (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s n ^
  1195           (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer)
  1196            else "") ^ "."
  1197         else
  1198           ""
  1199       end
  1200   end
  1201 
  1202 fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained run_prover =
  1203   let
  1204     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
  1205     val ctxt = ctxt |> Config.put instantiate_inducts false
  1206     val facts =
  1207       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
  1208                        @{prop True}
  1209       |> sort (crude_thm_ord o pairself snd o swap)
  1210     val num_facts = length facts
  1211     val prover = hd provers
  1212     fun learn auto_level run_prover =
  1213       mash_learn_facts ctxt params prover true auto_level run_prover NONE facts
  1214       |> Output.urgent_message
  1215   in
  1216     if run_prover then
  1217       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1218        plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^
  1219        (case timeout of
  1220           SOME timeout => " timeout: " ^ string_of_time timeout
  1221         | NONE => "") ^ ").\n\nCollecting Isar proofs first..."
  1222        |> Output.urgent_message;
  1223        learn 1 false;
  1224        "Now collecting automatic proofs. This may take several hours. You can \
  1225        \safely stop the learning process at any point."
  1226        |> Output.urgent_message;
  1227        learn 0 true)
  1228     else
  1229       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1230        plural_s num_facts ^ " for Isar proofs..."
  1231        |> Output.urgent_message;
  1232        learn 0 false)
  1233   end
  1234 
  1235 fun mash_can_suggest_facts ctxt overlord =
  1236   not (Graph.is_empty (#access_G (peek_state ctxt overlord I)))
  1237 
  1238 (* Generate more suggestions than requested, because some might be thrown out
  1239    later for various reasons. *)
  1240 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1241 
  1242 val mepo_weight = 0.5
  1243 val mash_weight = 0.5
  1244 
  1245 val max_facts_to_learn_before_query = 100
  1246 
  1247 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1248    Sledgehammer and Try. *)
  1249 val min_secs_for_learning = 15
  1250 
  1251 fun relevant_facts ctxt
  1252         (params as {overlord, blocking, learn, fact_filter, timeout, ...})
  1253         prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t
  1254         facts =
  1255   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1256     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1257   else if only then
  1258     let val facts = facts |> map fact_of_raw_fact in
  1259       [("", facts)]
  1260     end
  1261   else if max_facts <= 0 orelse null facts then
  1262     [("", [])]
  1263   else
  1264     let
  1265       fun maybe_launch_thread () =
  1266         if not blocking andalso
  1267            not (Async_Manager.has_running_threads MaShN) andalso
  1268            (timeout = NONE orelse
  1269             Time.toSeconds (the timeout) >= min_secs_for_learning) then
  1270           let
  1271             val timeout = Option.map (time_mult learn_timeout_slack) timeout
  1272           in
  1273             launch_thread (timeout |> the_default one_day)
  1274                 (fn () => (true, mash_learn_facts ctxt params prover true 2
  1275                                                   false timeout facts))
  1276           end
  1277         else
  1278           ()
  1279       fun maybe_learn () =
  1280         if is_mash_enabled () andalso learn then
  1281           let
  1282             val {access_G, num_known_facts, ...} = peek_state ctxt overlord I
  1283             val is_in_access_G = is_fact_in_graph access_G o snd
  1284           in
  1285             if length facts - num_known_facts
  1286                <= max_facts_to_learn_before_query then
  1287               case length (filter_out is_in_access_G facts) of
  1288                 0 => false
  1289               | num_facts_to_learn =>
  1290                 if num_facts_to_learn <= max_facts_to_learn_before_query then
  1291                   (mash_learn_facts ctxt params prover false 2 false timeout facts
  1292                    |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s));
  1293                    true)
  1294                 else
  1295                   (maybe_launch_thread (); false)
  1296             else
  1297               (maybe_launch_thread (); false)
  1298           end
  1299         else
  1300           false
  1301       val (save, effective_fact_filter) =
  1302         case fact_filter of
  1303           SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
  1304         | NONE =>
  1305           if is_mash_enabled () then
  1306             (maybe_learn (),
  1307              if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
  1308           else
  1309             (false, mepoN)
  1310 
  1311       val unique_facts = drop_duplicate_facts facts
  1312       val add_ths = Attrib.eval_thms ctxt add
  1313 
  1314       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
  1315       fun add_and_take accepts =
  1316         (case add_ths of
  1317            [] => accepts
  1318          | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
  1319                 (accepts |> filter_out in_add))
  1320         |> take max_facts
  1321       fun mepo () =
  1322         (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
  1323          |> weight_facts_steeply, [])
  1324       fun mash () =
  1325         mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
  1326         |>> weight_facts_steeply
  1327       val mess =
  1328         (* the order is important for the "case" expression below *)
  1329         [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
  1330            |> effective_fact_filter <> mashN ? cons (mepo_weight, mepo)
  1331            |> Par_List.map (apsnd (fn f => f ()))
  1332       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
  1333     in
  1334       if save then MaSh.save ctxt overlord else ();
  1335       case (fact_filter, mess) of
  1336         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
  1337         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
  1338          (mashN, mash |> map fst |> add_and_take)]
  1339       | _ => [(effective_fact_filter, mesh)]
  1340     end
  1341 
  1342 fun kill_learners ctxt ({overlord, ...} : params) =
  1343   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
  1344 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1345 
  1346 end;