src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author wenzelm
Sat Mar 22 18:19:57 2014 +0100 (2014-03-22)
changeset 56254 a2dd9200854d
parent 55642 63beb38e9258
child 56303 4cc3f4db3447
permissions -rw-r--r--
more antiquotations;
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type raw_fact = Sledgehammer_Fact.raw_fact
    11   type fact = Sledgehammer_Fact.fact
    12   type fact_override = Sledgehammer_Fact.fact_override
    13   type params = Sledgehammer_Prover.params
    14   type prover_result = Sledgehammer_Prover.prover_result
    15 
    16   val trace : bool Config.T
    17   val MePoN : string
    18   val MaShN : string
    19   val MeShN : string
    20   val mepoN : string
    21   val mashN : string
    22   val meshN : string
    23   val unlearnN : string
    24   val learn_isarN : string
    25   val learn_proverN : string
    26   val relearn_isarN : string
    27   val relearn_proverN : string
    28   val fact_filters : string list
    29   val encode_str : string -> string
    30   val encode_strs : string list -> string
    31   val unencode_str : string -> string
    32   val unencode_strs : string -> string list
    33   val encode_plain_features : string list list -> string
    34   val encode_features : (string list * real) list -> string
    35   val extract_suggestions : string -> string * string list
    36 
    37   structure MaSh:
    38   sig
    39     val unlearn : Proof.context -> bool -> unit
    40     val learn :
    41       Proof.context -> bool -> bool
    42       -> (string * string list * string list list * string list) list -> unit
    43     val relearn :
    44       Proof.context -> bool -> bool -> (string * string list) list -> unit
    45     val query :
    46       Proof.context -> bool -> int
    47       -> (string * string list * string list list * string list) list
    48          * string list * string list * (string list * real) list
    49       -> string list
    50   end
    51 
    52   val mash_unlearn : Proof.context -> params -> unit
    53   val is_mash_enabled : unit -> bool
    54   val nickname_of_thm : thm -> string
    55   val find_suggested_facts :
    56     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    57   val mesh_facts :
    58     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    59     -> 'a list
    60   val crude_thm_ord : thm * thm -> order
    61   val thm_less : thm * thm -> bool
    62   val goal_of_thm : theory -> thm -> thm
    63   val run_prover_for_mash :
    64     Proof.context -> params -> string -> string -> fact list -> thm -> prover_result
    65   val features_of :
    66     Proof.context -> theory -> int -> int Symtab.table -> stature -> bool -> term list ->
    67     (string list * real) list
    68   val trim_dependencies : string list -> string list option
    69   val isar_dependencies_of :
    70     string Symtab.table * string Symtab.table -> thm -> string list
    71   val prover_dependencies_of :
    72     Proof.context -> params -> string -> int -> raw_fact list
    73     -> string Symtab.table * string Symtab.table -> thm
    74     -> bool * string list
    75   val attach_parents_to_facts :
    76     ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
    77   val num_extra_feature_facts : int
    78   val extra_feature_factor : real
    79   val weight_facts_smoothly : 'a list -> ('a * real) list
    80   val weight_facts_steeply : 'a list -> ('a * real) list
    81   val find_mash_suggestions :
    82     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    83     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    84   val add_const_counts : term -> int Symtab.table -> int Symtab.table
    85   val mash_suggested_facts :
    86     Proof.context -> params -> int -> term list -> term -> raw_fact list -> fact list * fact list
    87   val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
    88   val mash_learn :
    89     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    90 
    91   val mash_can_suggest_facts : Proof.context -> bool -> bool
    92   val generous_max_facts : int -> int
    93   val mepo_weight : real
    94   val mash_weight : real
    95   val relevant_facts :
    96     Proof.context -> params -> string -> int -> fact_override -> term list
    97     -> term -> raw_fact list -> (string * fact list) list
    98   val kill_learners : Proof.context -> params -> unit
    99   val running_learners : unit -> unit
   100 end;
   101 
   102 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
   103 struct
   104 
   105 open ATP_Util
   106 open ATP_Problem_Generate
   107 open Sledgehammer_Util
   108 open Sledgehammer_Fact
   109 open Sledgehammer_Prover
   110 open Sledgehammer_Prover_Minimize
   111 open Sledgehammer_MePo
   112 
   113 val trace =
   114   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
   115 
   116 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
   117 
   118 val MePoN = "MePo"
   119 val MaShN = "MaSh"
   120 val MeShN = "MeSh"
   121 
   122 val mepoN = "mepo"
   123 val mashN = "mash"
   124 val meshN = "mesh"
   125 
   126 val fact_filters = [meshN, mepoN, mashN]
   127 
   128 val unlearnN = "unlearn"
   129 val learn_isarN = "learn_isar"
   130 val learn_proverN = "learn_prover"
   131 val relearn_isarN = "relearn_isar"
   132 val relearn_proverN = "relearn_prover"
   133 
   134 fun mash_model_dir () =
   135   Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   136 val mash_state_dir = mash_model_dir
   137 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   138 
   139 
   140 (*** Low-level communication with MaSh ***)
   141 
   142 val save_models_arg = "--saveModels"
   143 val shutdown_server_arg = "--shutdownServer"
   144 
   145 fun wipe_out_file file = (try (File.rm o Path.explode) file; ())
   146 
   147 fun write_file banner (xs, f) path =
   148   (case banner of SOME s => File.write path s | NONE => ();
   149    xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
   150   handle IO.Io _ => ()
   151 
   152 fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
   153   let
   154     val (temp_dir, serial) =
   155       if overlord then (getenv "ISABELLE_HOME_USER", "")
   156       else (getenv "ISABELLE_TMP", serial_string ())
   157     val log_file = temp_dir ^ "/mash_log" ^ serial
   158     val err_file = temp_dir ^ "/mash_err" ^ serial
   159     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   160     val sugg_path = Path.explode sugg_file
   161     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   162     val cmd_path = Path.explode cmd_file
   163     val model_dir = File.shell_path (mash_model_dir ())
   164     val command =
   165       "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
   166       \PYTHONDONTWRITEBYTECODE=y ./mash.py\
   167       \ --quiet\
   168       \ --port=$MASH_PORT\
   169       \ --outputDir " ^ model_dir ^
   170       " --modelFile=" ^ model_dir ^ "/model.pickle\
   171       \ --dictsFile=" ^ model_dir ^ "/dict.pickle\
   172       \ --log " ^ log_file ^
   173       " --inputFile " ^ cmd_file ^
   174       " --predictions " ^ sugg_file ^
   175       (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
   176       " >& " ^ err_file ^
   177       (if background then " &" else "")
   178     fun run_on () =
   179       (Isabelle_System.bash command
   180        |> tap (fn _ =>
   181             (case try File.read (Path.explode err_file) |> the_default "" of
   182               "" => trace_msg ctxt (K "Done")
   183             | s => warning ("MaSh error: " ^ elide_string 1000 s)));
   184        read_suggs (fn () => try File.read_lines sugg_path |> these))
   185     fun clean_up () =
   186       if overlord then ()
   187       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   188   in
   189     write_file (SOME "") ([], K "") sugg_path;
   190     write_file (SOME "") write_cmds cmd_path;
   191     trace_msg ctxt (fn () => "Running " ^ command);
   192     with_cleanup clean_up run_on ()
   193   end
   194 
   195 fun meta_char c =
   196   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   197      c = #")" orelse c = #"," then
   198     String.str c
   199   else
   200     (* fixed width, in case more digits follow *)
   201     "%" ^ stringN_of_int 3 (Char.ord c)
   202 
   203 fun unmeta_chars accum [] = String.implode (rev accum)
   204   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   205     (case Int.fromString (String.implode [d1, d2, d3]) of
   206        SOME n => unmeta_chars (Char.chr n :: accum) cs
   207      | NONE => "" (* error *))
   208   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   209   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   210 
   211 val encode_str = String.translate meta_char
   212 val encode_strs = map encode_str #> space_implode " "
   213 val unencode_str = String.explode #> unmeta_chars []
   214 val unencode_strs =
   215   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
   216 
   217 (* Avoid scientific notation *)
   218 fun safe_str_of_real r =
   219   if r < 0.00001 then "0.00001"
   220   else if r >= 1000000.0 then "1000000"
   221   else Markup.print_real r
   222 
   223 val encode_plain_feature = space_implode "|" o map encode_str
   224 
   225 fun encode_feature (names, weight) =
   226   encode_plain_feature names ^ (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
   227 
   228 val encode_plain_features = map encode_plain_feature #> space_implode " "
   229 val encode_features = map encode_feature #> space_implode " "
   230 
   231 fun str_of_learn (name, parents, feats : string list list, deps) =
   232   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   233   encode_plain_features feats ^ "; " ^ encode_strs deps ^ "\n"
   234 
   235 fun str_of_relearn (name, deps) =
   236   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
   237 
   238 fun str_of_query max_suggs (learns, hints, parents, feats) =
   239   implode (map str_of_learn learns) ^
   240   "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^
   241   encode_features feats ^
   242   (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
   243 
   244 (* The suggested weights don't make much sense. *)
   245 fun extract_suggestion sugg =
   246   (case space_explode "=" sugg of
   247     [name, _ (* weight *)] =>
   248     SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
   249   | [name] => SOME (unencode_str name (* , 1.0 *))
   250   | _ => NONE)
   251 
   252 fun extract_suggestions line =
   253   (case space_explode ":" line of
   254     [goal, suggs] => (unencode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   255   | _ => ("", []))
   256 
   257 structure MaSh =
   258 struct
   259 
   260 fun shutdown ctxt overlord =
   261   (trace_msg ctxt (K "MaSh shutdown");
   262    run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
   263 
   264 fun save ctxt overlord =
   265   (trace_msg ctxt (K "MaSh save");
   266    run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
   267 
   268 fun unlearn ctxt overlord =
   269   let val path = mash_model_dir () in
   270     trace_msg ctxt (K "MaSh unlearn");
   271     shutdown ctxt overlord;
   272     try (File.fold_dir (fn file => fn _ =>
   273                            try File.rm (Path.append path (Path.basic file)))
   274                        path) NONE;
   275     ()
   276   end
   277 
   278 fun learn _ _ _ [] = ()
   279   | learn ctxt overlord save (learns : (string * string list * string list list * string list) list) (*##*)
   280    =
   281     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
   282       (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
   283        run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
   284                      (learns, str_of_learn) (K ()))
   285     end
   286 
   287 fun relearn _ _ _ [] = ()
   288   | relearn ctxt overlord save relearns =
   289     (trace_msg ctxt (fn () => "MaSh relearn " ^
   290          elide_string 1000 (space_implode " " (map #1 relearns)));
   291      run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
   292                    (relearns, str_of_relearn) (K ()))
   293 
   294 fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
   295   (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
   296    run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs =>
   297      (case suggs () of
   298        [] => []
   299      | suggs => snd (extract_suggestions (List.last suggs))))
   300    handle List.Empty => [])
   301 
   302 end;
   303 
   304 
   305 (*** Middle-level communication with MaSh ***)
   306 
   307 datatype proof_kind =
   308   Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   309 
   310 fun str_of_proof_kind Isar_Proof = "i"
   311   | str_of_proof_kind Automatic_Proof = "a"
   312   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   313 
   314 fun proof_kind_of_str "i" = Isar_Proof
   315   | proof_kind_of_str "a" = Automatic_Proof
   316   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   317 
   318 (* FIXME: Here a "Graph.update_node" function would be useful *)
   319 fun update_access_graph_node (name, kind) =
   320   Graph.default_node (name, Isar_Proof)
   321   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   322 
   323 fun try_graph ctxt when def f =
   324   f ()
   325   handle Graph.CYCLES (cycle :: _) =>
   326          (trace_msg ctxt (fn () =>
   327               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   328        | Graph.DUP name =>
   329          (trace_msg ctxt (fn () =>
   330               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   331        | Graph.UNDEF name =>
   332          (trace_msg ctxt (fn () =>
   333               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   334        | exn =>
   335          if Exn.is_interrupt exn then
   336            reraise exn
   337          else
   338            (trace_msg ctxt (fn () =>
   339                 "Internal error when " ^ when ^ ":\n" ^
   340                 ML_Compiler.exn_message exn); def)
   341 
   342 fun graph_info G =
   343   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   344   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   345   " edge(s), " ^
   346   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   347   string_of_int (length (Graph.maximals G)) ^ " maximal"
   348 
   349 type mash_state =
   350   {access_G : unit Graph.T,
   351    num_known_facts : int,
   352    dirty : string list option}
   353 
   354 val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []}
   355 
   356 local
   357 
   358 val version = "*** MaSh version 20131206 ***"
   359 
   360 exception FILE_VERSION_TOO_NEW of unit
   361 
   362 fun extract_node line =
   363   (case space_explode ":" line of
   364     [head, parents] =>
   365     (case space_explode " " head of
   366       [kind, name] =>
   367       SOME (unencode_str name, unencode_strs parents,
   368         try proof_kind_of_str kind |> the_default Isar_Proof)
   369     | _ => NONE)
   370   | _ => NONE)
   371 
   372 fun load_state _ _ (state as (true, _)) = state
   373   | load_state ctxt overlord _ =
   374     let val path = mash_state_file () in
   375       (true,
   376        (case try File.read_lines path of
   377          SOME (version' :: node_lines) =>
   378          let
   379            fun add_edge_to name parent =
   380              Graph.default_node (parent, Isar_Proof)
   381              #> Graph.add_edge (parent, name)
   382            fun add_node line =
   383              (case extract_node line of
   384                NONE => I (* shouldn't happen *)
   385              | SOME (name, parents, kind) =>
   386                update_access_graph_node (name, kind) #> fold (add_edge_to name) parents)
   387            val (access_G, num_known_facts) =
   388              (case string_ord (version', version) of
   389                EQUAL =>
   390                (try_graph ctxt "loading state" Graph.empty (fn () =>
   391                   fold add_node node_lines Graph.empty),
   392                 length node_lines)
   393              | LESS =>
   394                (* can't parse old file *)
   395                (MaSh.unlearn ctxt overlord; (Graph.empty, 0))
   396              | GREATER => raise FILE_VERSION_TOO_NEW ())
   397          in
   398            trace_msg ctxt (fn () =>
   399                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   400            {access_G = access_G, num_known_facts = num_known_facts,
   401             dirty = SOME []}
   402          end
   403        | _ => empty_state))
   404     end
   405 
   406 fun save_state _ (state as {dirty = SOME [], ...}) = state
   407   | save_state ctxt {access_G, num_known_facts, dirty} =
   408     let
   409       fun str_of_entry (name, parents, kind) =
   410         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
   411         encode_strs parents ^ "\n"
   412       fun append_entry (name, (kind, (parents, _))) =
   413         cons (name, Graph.Keys.dest parents, kind)
   414       val (banner, entries) =
   415         (case dirty of
   416           SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   417         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
   418     in
   419       write_file banner (entries, str_of_entry) (mash_state_file ());
   420       trace_msg ctxt (fn () =>
   421           "Saved fact graph (" ^ graph_info access_G ^
   422           (case dirty of
   423              SOME dirty =>
   424              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   425            | _ => "") ^  ")");
   426       {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
   427     end
   428 
   429 val global_state =
   430   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   431 
   432 in
   433 
   434 fun map_state ctxt overlord f =
   435   Synchronized.change global_state
   436                       (load_state ctxt overlord ##> (f #> save_state ctxt))
   437   handle FILE_VERSION_TOO_NEW () => ()
   438 
   439 fun peek_state ctxt overlord f =
   440   Synchronized.change_result global_state
   441       (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   442 
   443 fun clear_state ctxt overlord =
   444   Synchronized.change global_state (fn _ =>
   445       (MaSh.unlearn ctxt overlord; (* also removes the state file *)
   446        (false, empty_state)))
   447 
   448 end
   449 
   450 fun mash_unlearn ctxt ({overlord, ...} : params) =
   451   (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.")
   452 
   453 
   454 (*** Isabelle helpers ***)
   455 
   456 fun is_mash_enabled () = (getenv "MASH" = "yes")
   457 
   458 val local_prefix = "local" ^ Long_Name.separator
   459 
   460 fun elided_backquote_thm threshold th =
   461   elide_string threshold
   462     (backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th)
   463 
   464 val thy_name_of_thm = Context.theory_name o Thm.theory_of_thm
   465 
   466 fun nickname_of_thm th =
   467   if Thm.has_name_hint th then
   468     let val hint = Thm.get_name_hint th in
   469       (* There must be a better way to detect local facts. *)
   470       (case try (unprefix local_prefix) hint of
   471         SOME suf =>
   472         thy_name_of_thm th ^ Long_Name.separator ^ suf ^ Long_Name.separator ^
   473         elided_backquote_thm 50 th
   474       | NONE => hint)
   475     end
   476   else
   477     elided_backquote_thm 200 th
   478 
   479 fun find_suggested_facts ctxt facts =
   480   let
   481     fun add (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
   482     val tab = fold add facts Symtab.empty
   483     fun lookup nick =
   484       Symtab.lookup tab nick
   485       |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick)
   486                | _ => ())
   487   in map_filter lookup end
   488 
   489 fun scaled_avg [] = 0
   490   | scaled_avg xs =
   491     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   492 
   493 fun avg [] = 0.0
   494   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   495 
   496 fun normalize_scores _ [] = []
   497   | normalize_scores max_facts xs =
   498     let val avg = avg (map snd (take max_facts xs)) in
   499       map (apsnd (curry Real.* (1.0 / avg))) xs
   500     end
   501 
   502 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   503     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   504   | mesh_facts fact_eq max_facts mess =
   505     let
   506       val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
   507       fun score_in fact (global_weight, (sels, unks)) =
   508         let
   509           val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score)
   510         in
   511           (case find_index (curry fact_eq fact o fst) sels of
   512             ~1 => if member fact_eq unks fact then NONE else SOME 0.0
   513           | rank => score_at rank)
   514         end
   515       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   516       val facts = fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
   517     in
   518       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   519             |> map snd |> take max_facts
   520     end
   521 
   522 val default_weight = 1.0
   523 fun free_feature_of s = (["f" ^ s], 40.0 (* FUDGE *))
   524 fun thy_feature_of s = (["y" ^ s], 8.0 (* FUDGE *))
   525 fun type_feature_of s = (["t" ^ s], 4.0 (* FUDGE *))
   526 fun var_feature_of s = ([s], 1.0 (* FUDGE *))
   527 fun class_feature_of s = (["s" ^ s], 1.0 (* FUDGE *))
   528 val local_feature = (["local"], 16.0 (* FUDGE *))
   529 
   530 fun crude_theory_ord p =
   531   if Theory.subthy p then
   532     if Theory.eq_thy p then EQUAL else LESS
   533   else if Theory.subthy (swap p) then
   534     GREATER
   535   else
   536     (case int_ord (pairself (length o Theory.ancestors_of) p) of
   537       EQUAL => string_ord (pairself Context.theory_name p)
   538     | order => order)
   539 
   540 fun crude_thm_ord p =
   541   (case crude_theory_ord (pairself theory_of_thm p) of
   542     EQUAL =>
   543     let val q = pairself nickname_of_thm p in
   544       (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
   545       (case bool_ord (pairself (String.isSuffix "_def") (swap q)) of
   546         EQUAL => string_ord q
   547       | ord => ord)
   548     end
   549   | ord => ord)
   550 
   551 val thm_less_eq = Theory.subthy o pairself theory_of_thm
   552 fun thm_less p = thm_less_eq p andalso not (thm_less_eq (swap p))
   553 
   554 val freezeT = Type.legacy_freeze_type
   555 
   556 fun freeze (t $ u) = freeze t $ freeze u
   557   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   558   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   559   | freeze (Const (s, T)) = Const (s, freezeT T)
   560   | freeze (Free (s, T)) = Free (s, freezeT T)
   561   | freeze t = t
   562 
   563 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   564 
   565 fun run_prover_for_mash ctxt params prover goal_name facts goal =
   566   let
   567     val problem =
   568       {comment = "Goal: " ^ goal_name, state = Proof.init ctxt, goal = goal, subgoal = 1,
   569        subgoal_count = 1, factss = [("", facts)]}
   570   in
   571     get_minimizing_prover ctxt MaSh (K ()) prover params (K (K (K ""))) problem
   572   end
   573 
   574 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   575 
   576 val pat_tvar_prefix = "_"
   577 val pat_var_prefix = "_"
   578 
   579 (* try "Long_Name.base_name" for shorter names *)
   580 fun massage_long_name s = if s = @{class type} then "T" else s
   581 
   582 val crude_str_of_sort =
   583   space_implode ":" o map massage_long_name o subtract (op =) @{sort type}
   584 
   585 fun crude_str_of_typ (Type (s, [])) = massage_long_name s
   586   | crude_str_of_typ (Type (s, Ts)) =
   587     massage_long_name s ^ implode (map crude_str_of_typ Ts)
   588   | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
   589   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   590 
   591 fun maybe_singleton_str _ "" = []
   592   | maybe_singleton_str pref s = [pref ^ s]
   593 
   594 val max_pat_breadth = 10 (* FUDGE *)
   595 
   596 fun keep m xs =
   597   let val n = length xs in
   598     if n <= m then xs else take (m div 2) xs @ drop (n - (m + 1) div 2) xs
   599   end
   600 
   601 fun sort_of_type alg T =
   602   let
   603     val graph = Sorts.classes_of alg
   604     fun cls_of S [] = S
   605       | cls_of S (cl :: cls) =
   606         if Sorts.of_sort alg (T, [cl]) then
   607           cls_of (insert (op =) cl S) cls
   608         else
   609           let val cls' = Sorts.minimize_sort alg (Sorts.super_classes alg cl) in
   610             cls_of S (union (op =) cls' cls)
   611           end
   612   in
   613     cls_of [] (Graph.maximals graph)
   614   end
   615 
   616 val generalize_goal = false
   617 
   618 fun term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts =
   619   let
   620     val thy = Proof_Context.theory_of ctxt
   621     val alg = Sign.classes_of thy
   622 
   623     val fixes = map snd (Variable.dest_fixes ctxt)
   624     val classes = Sign.classes_of thy
   625 
   626     fun add_classes @{sort type} = I
   627       | add_classes S =
   628         fold (`(Sorts.super_classes classes)
   629               #> swap #> op ::
   630               #> subtract (op =) @{sort type} #> map massage_long_name
   631               #> map class_feature_of
   632               #> union (eq_fst (op =))) S
   633 
   634     fun pattify_type 0 _ = []
   635       | pattify_type _ (Type (s, [])) =
   636         if member (op =) bad_types s then [] else [massage_long_name s]
   637       | pattify_type depth (Type (s, U :: Ts)) =
   638         let
   639           val T = Type (s, Ts)
   640           val ps = keep max_pat_breadth (pattify_type depth T)
   641           val qs = keep max_pat_breadth ("" :: pattify_type (depth - 1) U)
   642         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   643       | pattify_type _ (TFree (_, S)) =
   644         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   645       | pattify_type _ (TVar (_, S)) =
   646         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   647     fun add_type_pat depth T =
   648       union (eq_fst (op =)) (map type_feature_of (pattify_type depth T))
   649     fun add_type_pats 0 _ = I
   650       | add_type_pats depth t =
   651         add_type_pat depth t #> add_type_pats (depth - 1) t
   652     fun add_type T =
   653       add_type_pats type_max_depth T
   654       #> fold_atyps_sorts (add_classes o snd) T
   655     fun add_subtypes (T as Type (_, Ts)) = add_type T #> fold add_subtypes Ts
   656       | add_subtypes T = add_type T
   657 
   658     fun weight_of_const s =
   659       16.0 +
   660       (if num_facts = 0 then
   661          0.0
   662        else
   663          let val count = Symtab.lookup const_tab s |> the_default 1 in
   664            Real.fromInt num_facts / Real.fromInt count (* FUDGE *)
   665          end)
   666     fun pattify_term _ 0 _ = ([] : (string list * real) list)
   667       | pattify_term _ _ (Const (x as (s, _))) =
   668         if is_widely_irrelevant_const s then
   669           []
   670         else
   671           let
   672             val strs_of_sort =
   673               (if generalize_goal andalso in_goal then Sorts.complete_sort alg
   674                else single o hd)
   675               #> map massage_long_name
   676             fun strs_of_type_arg (T as Type (s, _)) =
   677                 massage_long_name s ::
   678                 (if generalize_goal andalso in_goal then strs_of_sort (sort_of_type alg T) else [])
   679               | strs_of_type_arg (TFree (_, S)) = strs_of_sort S
   680               | strs_of_type_arg (TVar (_, S)) = strs_of_sort S
   681 
   682             val typargss =
   683               these (try (Sign.const_typargs thy) x)
   684               |> map strs_of_type_arg
   685               |> n_fold_cartesian_product
   686               |> keep max_pat_breadth
   687             val s' = massage_long_name s
   688             val w = weight_of_const s
   689 
   690             fun str_of_type_args [] = ""
   691               | str_of_type_args ss = "(" ^ space_implode "," ss ^ ")"
   692           in
   693             [(map (curry (op ^) s' o str_of_type_args) typargss, w)]
   694           end
   695       | pattify_term _ _ (Free (s, T)) =
   696         maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
   697         |> map var_feature_of
   698         |> (if member (op =) fixes s then
   699               cons (free_feature_of (massage_long_name (thy_name ^ Long_Name.separator ^ s)))
   700             else
   701               I)
   702       | pattify_term _ _ (Var (_, T)) =
   703         maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
   704         |> map var_feature_of
   705       | pattify_term Ts _ (Bound j) =
   706         maybe_singleton_str pat_var_prefix (crude_str_of_typ (nth Ts j))
   707         |> map var_feature_of
   708       | pattify_term Ts depth (t $ u) =
   709         let
   710           val ps = keep max_pat_breadth (pattify_term Ts depth t)
   711           val qs = keep max_pat_breadth (([], default_weight) :: pattify_term Ts (depth - 1) u)
   712         in
   713           map_product (fn ppw as (p :: _, pw) =>
   714               fn ([], _) => ppw
   715                | (q :: _, qw) => ([p ^ "(" ^ q ^ ")"], pw + qw)) ps qs
   716         end
   717       | pattify_term _ _ _ = []
   718     fun add_term_pat Ts = union (eq_fst (op =)) oo pattify_term Ts
   719     fun add_term_pats _ 0 _ = I
   720       | add_term_pats Ts depth t =
   721         add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t
   722     fun add_term Ts = add_term_pats Ts term_max_depth
   723     fun add_subterms Ts t =
   724       (case strip_comb t of
   725         (Const (s, T), args) =>
   726         (not (is_widely_irrelevant_const s) ? add_term Ts t)
   727         #> add_subtypes T
   728         #> fold (add_subterms Ts) args
   729       | (head, args) =>
   730         (case head of
   731            Free (_, T) => add_term Ts t #> add_subtypes T
   732          | Var (_, T) => add_subtypes T
   733          | Abs (_, T, body) => add_subtypes T #> add_subterms (T :: Ts) body
   734          | _ => I)
   735         #> fold (add_subterms Ts) args)
   736   in [] |> fold (add_subterms []) ts end
   737 
   738 val term_max_depth = 2
   739 val type_max_depth = 1
   740 
   741 (* TODO: Generate type classes for types? *)
   742 fun features_of ctxt thy num_facts const_tab (scope, _) in_goal ts =
   743   let val thy_name = Context.theory_name thy in
   744     thy_feature_of thy_name ::
   745     term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts
   746     |> scope <> Global ? cons local_feature
   747   end
   748 
   749 (* Too many dependencies is a sign that a decision procedure is at work. There
   750    isn't much to learn from such proofs. *)
   751 val max_dependencies = 20
   752 
   753 val prover_default_max_facts = 25
   754 
   755 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   756 val typedef_dep = nickname_of_thm @{thm CollectI}
   757 (* Mysterious parts of the class machinery create lots of proofs that refer
   758    exclusively to "someI_ex" (and to some internal constructions). *)
   759 val class_some_dep = nickname_of_thm @{thm someI_ex}
   760 
   761 val fundef_ths =
   762   @{thms fundef_ex1_existence fundef_ex1_uniqueness fundef_ex1_iff
   763          fundef_default_value}
   764   |> map nickname_of_thm
   765 
   766 (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
   767 val typedef_ths =
   768   @{thms type_definition.Abs_inverse type_definition.Rep_inverse
   769          type_definition.Rep type_definition.Rep_inject
   770          type_definition.Abs_inject type_definition.Rep_cases
   771          type_definition.Abs_cases type_definition.Rep_induct
   772          type_definition.Abs_induct type_definition.Rep_range
   773          type_definition.Abs_image}
   774   |> map nickname_of_thm
   775 
   776 fun is_size_def [dep] th =
   777     (case first_field ".rec" dep of
   778        SOME (pref, _) =>
   779        (case first_field ".size" (nickname_of_thm th) of
   780           SOME (pref', _) => pref = pref'
   781         | NONE => false)
   782      | NONE => false)
   783   | is_size_def _ _ = false
   784 
   785 fun no_dependencies_for_status status =
   786   status = Non_Rec_Def orelse status = Rec_Def
   787 
   788 fun trim_dependencies deps =
   789   if length deps > max_dependencies then NONE else SOME deps
   790 
   791 fun isar_dependencies_of name_tabs th =
   792   let val deps = thms_in_proof (SOME name_tabs) th in
   793     if deps = [typedef_dep] orelse
   794        deps = [class_some_dep] orelse
   795        exists (member (op =) fundef_ths) deps orelse
   796        exists (member (op =) typedef_ths) deps orelse
   797        is_size_def deps th then
   798       []
   799     else
   800       deps
   801   end
   802 
   803 fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
   804                            auto_level facts name_tabs th =
   805   (case isar_dependencies_of name_tabs th of
   806     [] => (false, [])
   807   | isar_deps =>
   808     let
   809       val thy = Proof_Context.theory_of ctxt
   810       val goal = goal_of_thm thy th
   811       val name = nickname_of_thm th
   812       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt
   813       val facts = facts |> filter (fn (_, th') => thm_less (th', th))
   814       fun nickify ((_, stature), th) = ((nickname_of_thm th, stature), th)
   815       fun is_dep dep (_, th) = nickname_of_thm th = dep
   816       fun add_isar_dep facts dep accum =
   817         if exists (is_dep dep) accum then
   818           accum
   819         else
   820           (case find_first (is_dep dep) facts of
   821             SOME ((_, status), th) => accum @ [(("", status), th)]
   822           | NONE => accum (* shouldn't happen *))
   823       val mepo_facts =
   824         facts
   825         |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE
   826              hyp_ts concl_t
   827       val facts =
   828         mepo_facts
   829         |> fold (add_isar_dep facts) isar_deps
   830         |> map nickify
   831       val num_isar_deps = length isar_deps
   832     in
   833       if verbose andalso auto_level = 0 then
   834         "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
   835         " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
   836         |> Output.urgent_message
   837       else
   838         ();
   839       (case run_prover_for_mash ctxt params prover name facts goal of
   840         {outcome = NONE, used_facts, ...} =>
   841         (if verbose andalso auto_level = 0 then
   842            let val num_facts = length used_facts in
   843              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   844              plural_s num_facts ^ "."
   845              |> Output.urgent_message
   846            end
   847          else
   848            ();
   849          (true, map fst used_facts))
   850       | _ => (false, isar_deps))
   851     end)
   852 
   853 
   854 (*** High-level communication with MaSh ***)
   855 
   856 (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
   857 
   858 fun chunks_and_parents_for chunks th =
   859   let
   860     fun insert_parent new parents =
   861       let val parents = parents |> filter_out (fn p => thm_less_eq (p, new)) in
   862         parents |> forall (fn p => not (thm_less_eq (new, p))) parents
   863                    ? cons new
   864       end
   865     fun rechunk seen (rest as th' :: ths) =
   866       if thm_less_eq (th', th) then (rev seen, rest)
   867       else rechunk (th' :: seen) ths
   868     fun do_chunk [] accum = accum
   869       | do_chunk (chunk as hd_chunk :: _) (chunks, parents) =
   870         if thm_less_eq (hd_chunk, th) then
   871           (chunk :: chunks, insert_parent hd_chunk parents)
   872         else if thm_less_eq (List.last chunk, th) then
   873           let val (front, back as hd_back :: _) = rechunk [] chunk in
   874             (front :: back :: chunks, insert_parent hd_back parents)
   875           end
   876         else
   877           (chunk :: chunks, parents)
   878   in
   879     fold_rev do_chunk chunks ([], [])
   880     |>> cons []
   881     ||> map nickname_of_thm
   882   end
   883 
   884 fun attach_parents_to_facts _ [] = []
   885   | attach_parents_to_facts old_facts (facts as (_, th) :: _) =
   886     let
   887       fun do_facts _ [] = []
   888         | do_facts (_, parents) [fact] = [(parents, fact)]
   889         | do_facts (chunks, parents)
   890                    ((fact as (_, th)) :: (facts as (_, th') :: _)) =
   891           let
   892             val chunks = app_hd (cons th) chunks
   893             val chunks_and_parents' =
   894               if thm_less_eq (th, th') andalso
   895                  thy_name_of_thm th = thy_name_of_thm th' then
   896                 (chunks, [nickname_of_thm th])
   897               else
   898                 chunks_and_parents_for chunks th'
   899           in (parents, fact) :: do_facts chunks_and_parents' facts end
   900     in
   901       old_facts @ facts
   902       |> do_facts (chunks_and_parents_for [[]] th)
   903       |> drop (length old_facts)
   904     end
   905 
   906 fun maximal_wrt_graph G keys =
   907   let
   908     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
   909     fun insert_new seen name =
   910       not (Symtab.defined seen name) ? insert (op =) name
   911     fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   912     fun find_maxes _ (maxs, []) = map snd maxs
   913       | find_maxes seen (maxs, new :: news) =
   914         find_maxes
   915             (seen |> num_keys (Graph.imm_succs G new) > 1
   916                      ? Symtab.default (new, ()))
   917             (if Symtab.defined tab new then
   918                let
   919                  val newp = Graph.all_preds G [new]
   920                  fun is_ancestor x yp = member (op =) yp x
   921                  val maxs =
   922                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   923                in
   924                  if exists (is_ancestor new o fst) maxs then
   925                    (maxs, news)
   926                  else
   927                    ((newp, new)
   928                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   929                     news)
   930                end
   931              else
   932                (maxs, Graph.Keys.fold (insert_new seen)
   933                                       (Graph.imm_preds G new) news))
   934   in find_maxes Symtab.empty ([], Graph.maximals G) end
   935 
   936 fun maximal_wrt_access_graph access_G =
   937   map (nickname_of_thm o snd)
   938   #> maximal_wrt_graph access_G
   939 
   940 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   941 
   942 val chained_feature_factor = 0.5 (* FUDGE *)
   943 val extra_feature_factor = 0.1 (* FUDGE *)
   944 val num_extra_feature_facts = 10 (* FUDGE *)
   945 
   946 (* FUDGE *)
   947 fun weight_of_proximity_fact rank =
   948   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   949 
   950 fun weight_facts_smoothly facts =
   951   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   952 
   953 (* FUDGE *)
   954 fun steep_weight_of_fact rank =
   955   Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
   956 
   957 fun weight_facts_steeply facts =
   958   facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
   959 
   960 val max_proximity_facts = 100
   961 
   962 fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   963   let
   964     val inter_fact = inter (eq_snd Thm.eq_thm_prop)
   965     val raw_mash = find_suggested_facts ctxt facts suggs
   966     val proximate = take max_proximity_facts facts
   967     val unknown_chained = inter_fact raw_unknown chained
   968     val unknown_proximate = inter_fact raw_unknown proximate
   969     val mess =
   970       [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   971        (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
   972        (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
   973     val unknown =
   974       raw_unknown
   975       |> fold (subtract (eq_snd Thm.eq_thm_prop))
   976               [unknown_chained, unknown_proximate]
   977   in (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown) end
   978 
   979 fun add_const_counts t =
   980   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1))
   981        (Term.add_const_names t [])
   982 
   983 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   984   let
   985     val thy = Proof_Context.theory_of ctxt
   986     val thy_name = Context.theory_name thy
   987     val facts = facts |> sort (crude_thm_ord o pairself snd o swap)
   988     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   989     val num_facts = length facts
   990     val const_tab = fold (add_const_counts o prop_of o snd) facts Symtab.empty
   991 
   992     fun fact_has_right_theory (_, th) =
   993       thy_name = Context.theory_name (theory_of_thm th)
   994     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
   995       [prop_of th]
   996       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
   997       |> map (apsnd (fn r => weight * factor * r))
   998 
   999     val (access_G, suggs) =
  1000       peek_state ctxt overlord (fn {access_G, ...} =>
  1001           if Graph.is_empty access_G then
  1002             (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
  1003           else
  1004             let
  1005               val parents = maximal_wrt_access_graph access_G facts
  1006               val goal_feats =
  1007                 features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
  1008               val chained_feats =
  1009                 chained
  1010                 |> map (rpair 1.0)
  1011                 |> map (chained_or_extra_features_of chained_feature_factor)
  1012                 |> rpair [] |-> fold (union (eq_fst (op =)))
  1013               val extra_feats =
  1014                 facts
  1015                 |> take (Int.max (0, num_extra_feature_facts - length chained))
  1016                 |> filter fact_has_right_theory
  1017                 |> weight_facts_steeply
  1018                 |> map (chained_or_extra_features_of extra_feature_factor)
  1019                 |> rpair [] |-> fold (union (eq_fst (op =)))
  1020               val feats =
  1021                 fold (union (eq_fst (op =))) [chained_feats, extra_feats]
  1022                      goal_feats
  1023                 |> debug ? sort (Real.compare o swap o pairself snd)
  1024               val hints =
  1025                 chained |> filter (is_fact_in_graph access_G o snd)
  1026                         |> map (nickname_of_thm o snd)
  1027             in
  1028               (access_G, MaSh.query ctxt overlord max_facts
  1029                                     ([], hints, parents, feats))
  1030             end)
  1031     val unknown = facts |> filter_out (is_fact_in_graph access_G o snd)
  1032   in
  1033     find_mash_suggestions ctxt max_facts suggs facts chained unknown
  1034     |> pairself (map fact_of_raw_fact)
  1035   end
  1036 
  1037 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
  1038   let
  1039     fun maybe_learn_from from (accum as (parents, graph)) =
  1040       try_graph ctxt "updating graph" accum (fn () =>
  1041           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1042     val graph = graph |> Graph.default_node (name, Isar_Proof)
  1043     val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
  1044     val (deps, _) = ([], graph) |> fold maybe_learn_from deps
  1045   in ((name, parents, feats, deps) :: learns, graph) end
  1046 
  1047 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
  1048   let
  1049     fun maybe_relearn_from from (accum as (parents, graph)) =
  1050       try_graph ctxt "updating graph" accum (fn () =>
  1051           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1052     val graph = graph |> update_access_graph_node (name, Automatic_Proof)
  1053     val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
  1054   in ((name, deps) :: relearns, graph) end
  1055 
  1056 fun flop_wrt_access_graph name =
  1057   update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
  1058 
  1059 val learn_timeout_slack = 2.0
  1060 
  1061 fun launch_thread timeout task =
  1062   let
  1063     val hard_timeout = time_mult learn_timeout_slack timeout
  1064     val birth_time = Time.now ()
  1065     val death_time = Time.+ (birth_time, hard_timeout)
  1066     val desc = ("Machine learner for Sledgehammer", "")
  1067   in Async_Manager.thread MaShN birth_time death_time desc task end
  1068 
  1069 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
  1070   if is_mash_enabled () then
  1071     launch_thread timeout (fn () =>
  1072         let
  1073           val thy = Proof_Context.theory_of ctxt
  1074           val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
  1075         in
  1076           peek_state ctxt overlord (fn {access_G, ...} =>
  1077               let
  1078                 val parents = maximal_wrt_access_graph access_G facts
  1079                 val deps =
  1080                   used_ths |> filter (is_fact_in_graph access_G)
  1081                            |> map nickname_of_thm
  1082               in
  1083                 MaSh.learn ctxt overlord true [("", parents, feats, deps)]
  1084               end);
  1085           (true, "")
  1086         end)
  1087   else
  1088     ()
  1089 
  1090 fun sendback sub =
  1091   Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
  1092 
  1093 val commit_timeout = seconds 30.0
  1094 
  1095 (* The timeout is understood in a very relaxed fashion. *)
  1096 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level
  1097     run_prover learn_timeout facts =
  1098   let
  1099     val timer = Timer.startRealTimer ()
  1100     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1101     val {access_G, ...} = peek_state ctxt overlord I
  1102     val is_in_access_G = is_fact_in_graph access_G o snd
  1103     val no_new_facts = forall is_in_access_G facts
  1104   in
  1105     if no_new_facts andalso not run_prover then
  1106       if auto_level < 2 then
  1107         "No new " ^ (if run_prover then "automatic" else "Isar") ^
  1108         " proofs to learn." ^
  1109         (if auto_level = 0 andalso not run_prover then
  1110            "\n\nHint: Try " ^ sendback learn_proverN ^
  1111            " to learn from an automatic prover."
  1112          else
  1113            "")
  1114       else
  1115         ""
  1116     else
  1117       let
  1118         val name_tabs = build_name_tables nickname_of_thm facts
  1119         fun deps_of status th =
  1120           if no_dependencies_for_status status then
  1121             SOME []
  1122           else if run_prover then
  1123             prover_dependencies_of ctxt params prover auto_level facts name_tabs
  1124                                    th
  1125             |> (fn (false, _) => NONE
  1126                  | (true, deps) => trim_dependencies deps)
  1127           else
  1128             isar_dependencies_of name_tabs th
  1129             |> trim_dependencies
  1130         fun do_commit [] [] [] state = state
  1131           | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
  1132             let
  1133               val was_empty = Graph.is_empty access_G
  1134               val (learns, access_G) =
  1135                 ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
  1136               val (relearns, access_G) =
  1137                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
  1138               val access_G = access_G |> fold flop_wrt_access_graph flops
  1139               val num_known_facts = num_known_facts + length learns
  1140               val dirty =
  1141                 (case (was_empty, dirty, relearns) of
  1142                   (false, SOME names, []) => SOME (map #1 learns @ names)
  1143                 | _ => NONE)
  1144             in
  1145               MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
  1146               MaSh.relearn ctxt overlord save relearns;
  1147               {access_G = access_G, num_known_facts = num_known_facts,
  1148                dirty = dirty}
  1149             end
  1150         fun commit last learns relearns flops =
  1151           (if debug andalso auto_level = 0 then
  1152              Output.urgent_message "Committing..."
  1153            else
  1154              ();
  1155            map_state ctxt overlord (do_commit (rev learns) relearns flops);
  1156            if not last andalso auto_level = 0 then
  1157              let val num_proofs = length learns + length relearns in
  1158                "Learned " ^ string_of_int num_proofs ^ " " ^
  1159                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1160                plural_s num_proofs ^ " in the last " ^
  1161                string_of_time commit_timeout ^ "."
  1162                |> Output.urgent_message
  1163              end
  1164            else
  1165              ())
  1166         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1167           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1168                            (learns, (n, next_commit, _)) =
  1169             let
  1170               val name = nickname_of_thm th
  1171               val feats =
  1172                 features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
  1173                 |> map fst
  1174               val deps = deps_of status th |> these
  1175               val n = n |> not (null deps) ? Integer.add 1
  1176               val learns = (name, parents, feats, deps) :: learns
  1177               val (learns, next_commit) =
  1178                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1179                   (commit false learns [] []; ([], next_commit_time ()))
  1180                 else
  1181                   (learns, next_commit)
  1182               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1183             in (learns, (n, next_commit, timed_out)) end
  1184         val n =
  1185           if no_new_facts then
  1186             0
  1187           else
  1188             let
  1189               val new_facts =
  1190                 facts |> sort (crude_thm_ord o pairself snd)
  1191                       |> attach_parents_to_facts []
  1192                       |> filter_out (is_in_access_G o snd)
  1193               val (learns, (n, _, _)) =
  1194                 ([], (0, next_commit_time (), false))
  1195                 |> fold learn_new_fact new_facts
  1196             in commit true learns [] []; n end
  1197         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1198           | relearn_old_fact ((_, (_, status)), th)
  1199                              ((relearns, flops), (n, next_commit, _)) =
  1200             let
  1201               val name = nickname_of_thm th
  1202               val (n, relearns, flops) =
  1203                 (case deps_of status th of
  1204                   SOME deps => (n + 1, (name, deps) :: relearns, flops)
  1205                 | NONE => (n, relearns, name :: flops))
  1206               val (relearns, flops, next_commit) =
  1207                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1208                   (commit false [] relearns flops;
  1209                    ([], [], next_commit_time ()))
  1210                 else
  1211                   (relearns, flops, next_commit)
  1212               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1213             in ((relearns, flops), (n, next_commit, timed_out)) end
  1214         val n =
  1215           if not run_prover then
  1216             n
  1217           else
  1218             let
  1219               val max_isar = 1000 * max_dependencies
  1220               fun kind_of_proof th =
  1221                 try (Graph.get_node access_G) (nickname_of_thm th)
  1222                 |> the_default Isar_Proof
  1223               fun priority_of (_, th) =
  1224                 random_range 0 max_isar
  1225                 + (case kind_of_proof th of
  1226                      Isar_Proof => 0
  1227                    | Automatic_Proof => 2 * max_isar
  1228                    | Isar_Proof_wegen_Prover_Flop => max_isar)
  1229                 - 100 * length (isar_dependencies_of name_tabs th)
  1230               val old_facts =
  1231                 facts |> filter is_in_access_G
  1232                       |> map (`priority_of)
  1233                       |> sort (int_ord o pairself fst)
  1234                       |> map snd
  1235               val ((relearns, flops), (n, _, _)) =
  1236                 (([], []), (n, next_commit_time (), false))
  1237                 |> fold relearn_old_fact old_facts
  1238             in commit true [] relearns flops; n end
  1239       in
  1240         if verbose orelse auto_level < 2 then
  1241           "Learned " ^ string_of_int n ^ " nontrivial " ^
  1242           (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s n ^
  1243           (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer)
  1244            else "") ^ "."
  1245         else
  1246           ""
  1247       end
  1248   end
  1249 
  1250 fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained run_prover =
  1251   let
  1252     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
  1253     val ctxt = ctxt |> Config.put instantiate_inducts false
  1254     val facts =
  1255       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
  1256                        @{prop True}
  1257       |> sort (crude_thm_ord o pairself snd o swap)
  1258     val num_facts = length facts
  1259     val prover = hd provers
  1260     fun learn auto_level run_prover =
  1261       mash_learn_facts ctxt params prover true auto_level run_prover one_year facts
  1262       |> Output.urgent_message
  1263   in
  1264     if run_prover then
  1265       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
  1266        " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
  1267        ").\n\nCollecting Isar proofs first..."
  1268        |> Output.urgent_message;
  1269        learn 1 false;
  1270        "Now collecting automatic proofs. This may take several hours. You can \
  1271        \safely stop the learning process at any point."
  1272        |> Output.urgent_message;
  1273        learn 0 true)
  1274     else
  1275       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1276        plural_s num_facts ^ " for Isar proofs..."
  1277        |> Output.urgent_message;
  1278        learn 0 false)
  1279   end
  1280 
  1281 fun mash_can_suggest_facts ctxt overlord =
  1282   not (Graph.is_empty (#access_G (peek_state ctxt overlord I)))
  1283 
  1284 (* Generate more suggestions than requested, because some might be thrown out
  1285    later for various reasons. *)
  1286 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1287 
  1288 val mepo_weight = 0.5
  1289 val mash_weight = 0.5
  1290 
  1291 val max_facts_to_learn_before_query = 100
  1292 
  1293 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1294    Sledgehammer and Try. *)
  1295 val min_secs_for_learning = 15
  1296 
  1297 fun relevant_facts ctxt (params as {overlord, blocking, learn, fact_filter, timeout, ...}) prover
  1298     max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1299   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1300     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1301   else if only then
  1302     let val facts = facts |> map fact_of_raw_fact in
  1303       [("", facts)]
  1304     end
  1305   else if max_facts <= 0 orelse null facts then
  1306     [("", [])]
  1307   else
  1308     let
  1309       fun maybe_launch_thread () =
  1310         if not blocking andalso not (Async_Manager.has_running_threads MaShN) andalso
  1311            Time.toSeconds timeout >= min_secs_for_learning then
  1312           let val timeout = time_mult learn_timeout_slack timeout in
  1313             launch_thread timeout
  1314               (fn () => (true, mash_learn_facts ctxt params prover true 2 false timeout facts))
  1315           end
  1316         else
  1317           ()
  1318       fun maybe_learn () =
  1319         if is_mash_enabled () andalso learn then
  1320           let
  1321             val {access_G, num_known_facts, ...} = peek_state ctxt overlord I
  1322             val is_in_access_G = is_fact_in_graph access_G o snd
  1323           in
  1324             if length facts - num_known_facts
  1325                <= max_facts_to_learn_before_query then
  1326               (case length (filter_out is_in_access_G facts) of
  1327                 0 => false
  1328               | num_facts_to_learn =>
  1329                 if num_facts_to_learn <= max_facts_to_learn_before_query then
  1330                   (mash_learn_facts ctxt params prover false 2 false timeout facts
  1331                    |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s));
  1332                    true)
  1333                 else
  1334                   (maybe_launch_thread (); false))
  1335             else
  1336               (maybe_launch_thread (); false)
  1337           end
  1338         else
  1339           false
  1340       val (save, effective_fact_filter) =
  1341         (case fact_filter of
  1342           SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
  1343         | NONE =>
  1344           if is_mash_enabled () then
  1345             (maybe_learn (),
  1346              if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
  1347           else
  1348             (false, mepoN))
  1349 
  1350       val unique_facts = drop_duplicate_facts facts
  1351       val add_ths = Attrib.eval_thms ctxt add
  1352 
  1353       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
  1354       fun add_and_take accepts =
  1355         (case add_ths of
  1356            [] => accepts
  1357          | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
  1358                 (accepts |> filter_out in_add))
  1359         |> take max_facts
  1360       fun mepo () =
  1361         (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
  1362          |> weight_facts_steeply, [])
  1363       fun mash () =
  1364         mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
  1365         |>> weight_facts_steeply
  1366       val mess =
  1367         (* the order is important for the "case" expression below *)
  1368         [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
  1369            |> effective_fact_filter <> mashN ? cons (mepo_weight, mepo)
  1370            |> Par_List.map (apsnd (fn f => f ()))
  1371       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
  1372     in
  1373       if save then MaSh.save ctxt overlord else ();
  1374       (case (fact_filter, mess) of
  1375         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
  1376         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
  1377          (mashN, mash |> map fst |> add_and_take)]
  1378       | _ => [(effective_fact_filter, mesh)])
  1379     end
  1380 
  1381 fun kill_learners ctxt ({overlord, ...} : params) =
  1382   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
  1383 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1384 
  1385 end;