src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author blanchet
Wed Dec 05 13:25:06 2012 +0100 (2012-12-05 ago)
changeset 50383 4274b25ff4e7
parent 50382 cb564ff43c28
child 50389 ad0ac9112d2c
permissions -rw-r--r--
take proximity into account for MaSh + fix a debilitating bug in feature generation
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type fact = Sledgehammer_Fact.fact
    11   type fact_override = Sledgehammer_Fact.fact_override
    12   type params = Sledgehammer_Provers.params
    13   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    14   type prover_result = Sledgehammer_Provers.prover_result
    15 
    16   val trace : bool Config.T
    17   val MaShN : string
    18   val mepoN : string
    19   val mashN : string
    20   val meshN : string
    21   val unlearnN : string
    22   val learn_isarN : string
    23   val learn_atpN : string
    24   val relearn_isarN : string
    25   val relearn_atpN : string
    26   val fact_filters : string list
    27   val escape_meta : string -> string
    28   val escape_metas : string list -> string
    29   val unescape_meta : string -> string
    30   val unescape_metas : string -> string list
    31   val encode_features : (string * real) list -> string
    32   val extract_query : string -> string * (string * real) list
    33   val mash_CLEAR : Proof.context -> unit
    34   val mash_ADD :
    35     Proof.context -> bool
    36     -> (string * string list * (string * real) list * string list) list -> unit
    37   val mash_REPROVE :
    38     Proof.context -> bool -> (string * string list) list -> unit
    39   val mash_QUERY :
    40     Proof.context -> bool -> int -> string list * (string * real) list
    41     -> (string * real) list
    42   val mash_unlearn : Proof.context -> unit
    43   val nickname_of : thm -> string
    44   val suggested_facts :
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    46   val mesh_facts :
    47     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
    48     -> ('a * thm) list
    49   val theory_ord : theory * theory -> order
    50   val thm_ord : thm * thm -> order
    51   val goal_of_thm : theory -> thm -> thm
    52   val run_prover_for_mash :
    53     Proof.context -> params -> string -> fact list -> thm -> prover_result
    54   val features_of :
    55     Proof.context -> string -> theory -> stature -> term list
    56     -> (string * real) list
    57   val isar_dependencies_of : unit Symtab.table -> thm -> string list option
    58   val atp_dependencies_of :
    59     Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    60     -> thm -> bool * string list option
    61   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    62   val mash_suggested_facts :
    63     Proof.context -> params -> string -> int -> term list -> term -> fact list
    64     -> fact list
    65   val mash_learn_proof :
    66     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    67     -> unit
    68   val mash_learn :
    69     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    70   val is_mash_enabled : unit -> bool
    71   val mash_can_suggest_facts : Proof.context -> bool
    72   val relevant_facts :
    73     Proof.context -> params -> string -> int -> fact_override -> term list
    74     -> term -> fact list -> fact list
    75   val kill_learners : unit -> unit
    76   val running_learners : unit -> unit
    77 end;
    78 
    79 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
    80 struct
    81 
    82 open ATP_Util
    83 open ATP_Problem_Generate
    84 open Sledgehammer_Util
    85 open Sledgehammer_Fact
    86 open Sledgehammer_Provers
    87 open Sledgehammer_Minimize
    88 open Sledgehammer_MePo
    89 
    90 val trace =
    91   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    92 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    93 
    94 val MaShN = "MaSh"
    95 
    96 val mepoN = "mepo"
    97 val mashN = "mash"
    98 val meshN = "mesh"
    99 
   100 val fact_filters = [meshN, mepoN, mashN]
   101 
   102 val unlearnN = "unlearn"
   103 val learn_isarN = "learn_isar"
   104 val learn_atpN = "learn_atp"
   105 val relearn_isarN = "relearn_isar"
   106 val relearn_atpN = "relearn_atp"
   107 
   108 fun mash_model_dir () =
   109   Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   110 val mash_state_dir = mash_model_dir
   111 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   112 
   113 
   114 (*** Low-level communication with MaSh ***)
   115 
   116 fun wipe_out_file file = (try (File.rm o Path.explode) file; ())
   117 
   118 fun write_file banner (xs, f) path =
   119   (case banner of SOME s => File.write path s | NONE => ();
   120    xs |> chunk_list 500
   121       |> List.app (File.append path o space_implode "" o map f))
   122   handle IO.Io _ => ()
   123 
   124 fun run_mash_tool ctxt overlord save max_suggs write_cmds read_suggs =
   125   let
   126     val (temp_dir, serial) =
   127       if overlord then (getenv "ISABELLE_HOME_USER", "")
   128       else (getenv "ISABELLE_TMP", serial_string ())
   129     val log_file = if overlord then temp_dir ^ "/mash_log" else "/dev/null"
   130     val err_file = temp_dir ^ "/mash_err" ^ serial
   131     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   132     val sugg_path = Path.explode sugg_file
   133     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   134     val cmd_path = Path.explode cmd_file
   135     val core =
   136       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   137       " --numberOfPredictions " ^ string_of_int max_suggs ^
   138       (if save then " --saveModel" else "")
   139     val command =
   140       "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^
   141       File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^
   142       " >& " ^ err_file
   143       |> tap (fn _ => trace_msg ctxt (fn () =>
   144              case try File.read (Path.explode err_file) of
   145                NONE => "Done"
   146              | SOME "" => "Done"
   147              | SOME s => "Error: " ^ elide_string 1000 s))
   148     fun run_on () =
   149       (Isabelle_System.bash command;
   150        read_suggs (fn () => try File.read_lines sugg_path |> these))
   151     fun clean_up () =
   152       if overlord then ()
   153       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   154   in
   155     write_file (SOME "") ([], K "") sugg_path;
   156     write_file (SOME "") write_cmds cmd_path;
   157     trace_msg ctxt (fn () => "Running " ^ command);
   158     with_cleanup clean_up run_on ()
   159   end
   160 
   161 fun meta_char c =
   162   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   163      c = #")" orelse c = #"," then
   164     String.str c
   165   else
   166     (* fixed width, in case more digits follow *)
   167     "%" ^ stringN_of_int 3 (Char.ord c)
   168 
   169 fun unmeta_chars accum [] = String.implode (rev accum)
   170   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   171     (case Int.fromString (String.implode [d1, d2, d3]) of
   172        SOME n => unmeta_chars (Char.chr n :: accum) cs
   173      | NONE => "" (* error *))
   174   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   175   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   176 
   177 val escape_meta = String.translate meta_char
   178 val escape_metas = map escape_meta #> space_implode " "
   179 val unescape_meta = String.explode #> unmeta_chars []
   180 val unescape_metas =
   181   space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   182 
   183 val supports_weights = false
   184 
   185 fun encode_feature (name, weight) =
   186   let val s = escape_meta name in
   187     if Real.== (weight, 1.0) then
   188       s
   189     else if supports_weights then
   190       s ^ "=" ^ Real.toString weight
   191     else
   192       map (prefix (s ^ ".") o string_of_int) (1 upto Real.ceil weight)
   193       |> space_implode " "
   194   end
   195 
   196 val encode_features = map encode_feature #> space_implode " "
   197 
   198 fun str_of_add (name, parents, feats, deps) =
   199   "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   200   encode_features feats ^ "; " ^ escape_metas deps ^ "\n"
   201 
   202 fun str_of_reprove (name, deps) =
   203   "p " ^ escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
   204 
   205 fun str_of_query (parents, feats) =
   206   "? " ^ escape_metas parents ^ "; " ^ encode_features feats ^ "\n"
   207 
   208 fun extract_suggestion sugg =
   209   case space_explode "=" sugg of
   210     [name, weight] =>
   211     SOME (unescape_meta name, Real.fromString weight |> the_default 0.0)
   212   | _ => NONE
   213 
   214 fun extract_query line =
   215   case space_explode ":" line of
   216     [goal, suggs] =>
   217     (unescape_meta goal,
   218      map_filter extract_suggestion (space_explode " " suggs))
   219   | _ => ("", [])
   220 
   221 fun mash_CLEAR ctxt =
   222   let val path = mash_model_dir () in
   223     trace_msg ctxt (K "MaSh CLEAR");
   224     try (File.fold_dir (fn file => fn _ =>
   225                            try File.rm (Path.append path (Path.basic file)))
   226                        path) NONE;
   227     ()
   228   end
   229 
   230 fun mash_ADD _ _ [] = ()
   231   | mash_ADD ctxt overlord adds =
   232     (trace_msg ctxt (fn () => "MaSh ADD " ^
   233          elide_string 1000 (space_implode " " (map #1 adds)));
   234      run_mash_tool ctxt overlord true 0 (adds, str_of_add) (K ()))
   235 
   236 fun mash_REPROVE _ _ [] = ()
   237   | mash_REPROVE ctxt overlord reps =
   238     (trace_msg ctxt (fn () => "MaSh REPROVE " ^
   239          elide_string 1000 (space_implode " " (map #1 reps)));
   240      run_mash_tool ctxt overlord true 0 (reps, str_of_reprove) (K ()))
   241 
   242 fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   243   (trace_msg ctxt (fn () => "MaSh QUERY " ^ encode_features feats);
   244    run_mash_tool ctxt overlord false max_suggs
   245        ([query], str_of_query)
   246        (fn suggs =>
   247            case suggs () of
   248              [] => []
   249            | suggs => snd (extract_query (List.last suggs)))
   250    handle List.Empty => [])
   251 
   252 
   253 (*** Middle-level communication with MaSh ***)
   254 
   255 datatype proof_kind = Isar_Proof | ATP_Proof | Isar_Proof_wegen_ATP_Flop
   256 
   257 fun str_of_proof_kind Isar_Proof = "i"
   258   | str_of_proof_kind ATP_Proof = "a"
   259   | str_of_proof_kind Isar_Proof_wegen_ATP_Flop = "x"
   260 
   261 fun proof_kind_of_str "i" = Isar_Proof
   262   | proof_kind_of_str "a" = ATP_Proof
   263   | proof_kind_of_str "x" = Isar_Proof_wegen_ATP_Flop
   264 
   265 (* FIXME: Here a "Graph.update_node" function would be useful *)
   266 fun update_fact_graph_node (name, kind) =
   267   Graph.default_node (name, Isar_Proof)
   268   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   269 
   270 fun try_graph ctxt when def f =
   271   f ()
   272   handle Graph.CYCLES (cycle :: _) =>
   273          (trace_msg ctxt (fn () =>
   274               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   275        | Graph.DUP name =>
   276          (trace_msg ctxt (fn () =>
   277               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   278        | Graph.UNDEF name =>
   279          (trace_msg ctxt (fn () =>
   280               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   281        | exn =>
   282          if Exn.is_interrupt exn then
   283            reraise exn
   284          else
   285            (trace_msg ctxt (fn () =>
   286                 "Internal error when " ^ when ^ ":\n" ^
   287                 ML_Compiler.exn_message exn); def)
   288 
   289 fun graph_info G =
   290   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   291   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   292   " edge(s), " ^
   293   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   294   string_of_int (length (Graph.maximals G)) ^ " maximal"
   295 
   296 type mash_state = {fact_G : unit Graph.T, dirty : string list option}
   297 
   298 val empty_state = {fact_G = Graph.empty, dirty = SOME []}
   299 
   300 local
   301 
   302 val version = "*** MaSh version 20121205a ***"
   303 
   304 exception Too_New of unit
   305 
   306 fun extract_node line =
   307   case space_explode ":" line of
   308     [head, parents] =>
   309     (case space_explode " " head of
   310        [kind, name] =>
   311        SOME (unescape_meta name, unescape_metas parents,
   312              try proof_kind_of_str kind |> the_default Isar_Proof)
   313      | _ => NONE)
   314   | _ => NONE
   315 
   316 fun load _ (state as (true, _)) = state
   317   | load ctxt _ =
   318     let val path = mash_state_file () in
   319       (true,
   320        case try File.read_lines path of
   321          SOME (version' :: node_lines) =>
   322          let
   323            fun add_edge_to name parent =
   324              Graph.default_node (parent, Isar_Proof)
   325              #> Graph.add_edge (parent, name)
   326            fun add_node line =
   327              case extract_node line of
   328                NONE => I (* shouldn't happen *)
   329              | SOME (name, parents, kind) =>
   330                update_fact_graph_node (name, kind)
   331                #> fold (add_edge_to name) parents
   332            val fact_G =
   333              case string_ord (version', version) of
   334                EQUAL =>
   335                try_graph ctxt "loading state" Graph.empty (fn () =>
   336                    fold add_node node_lines Graph.empty)
   337              | LESS => Graph.empty (* can't parse old file *)
   338              | GREATER => raise Too_New ()
   339          in
   340            trace_msg ctxt (fn () =>
   341                "Loaded fact graph (" ^ graph_info fact_G ^ ")");
   342            {fact_G = fact_G, dirty = SOME []}
   343          end
   344        | _ => empty_state)
   345     end
   346 
   347 fun save _ (state as {dirty = SOME [], ...}) = state
   348   | save ctxt {fact_G, dirty} =
   349     let
   350       fun str_of_entry (name, parents, kind) =
   351         str_of_proof_kind kind ^ " " ^ escape_meta name ^ ": " ^
   352         escape_metas parents ^ "\n"
   353       fun append_entry (name, (kind, (parents, _))) =
   354         cons (name, Graph.Keys.dest parents, kind)
   355       val (banner, entries) =
   356         case dirty of
   357           SOME names =>
   358           (NONE, fold (append_entry o Graph.get_entry fact_G) names [])
   359         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry fact_G [])
   360     in
   361       write_file banner (entries, str_of_entry) (mash_state_file ());
   362       trace_msg ctxt (fn () =>
   363           "Saved fact graph (" ^ graph_info fact_G ^
   364           (case dirty of
   365              SOME dirty =>
   366              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   367            | _ => "") ^  ")");
   368       {fact_G = fact_G, dirty = SOME []}
   369     end
   370 
   371 val global_state =
   372   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   373 
   374 in
   375 
   376 fun mash_map ctxt f =
   377   Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
   378   handle Too_New () => ()
   379 
   380 fun mash_peek ctxt f =
   381   Synchronized.change_result global_state
   382       (perhaps (try (load ctxt)) #> `snd #>> f)
   383 
   384 fun mash_get ctxt =
   385   Synchronized.change_result global_state (perhaps (try (load ctxt)) #> `snd)
   386 
   387 fun mash_unlearn ctxt =
   388   Synchronized.change global_state (fn _ =>
   389       (mash_CLEAR ctxt; (* also removes the state file *)
   390        (true, empty_state)))
   391 
   392 end
   393 
   394 
   395 (*** Isabelle helpers ***)
   396 
   397 fun parent_of_local_thm th =
   398   let
   399     val thy = th |> Thm.theory_of_thm
   400     val facts = thy |> Global_Theory.facts_of
   401     val space = facts |> Facts.space_of
   402     fun id_of s = #id (Name_Space.the_entry space s)
   403     fun max_id (s', _) (s, id) =
   404       let val id' = id_of s' in if id > id' then (s, id) else (s', id') end
   405   in ("", ~1) |> Facts.fold_static max_id facts |> fst end
   406 
   407 val local_prefix = "local" ^ Long_Name.separator
   408 
   409 fun nickname_of th =
   410   if Thm.has_name_hint th then
   411     let val hint = Thm.get_name_hint th in
   412       (* FIXME: There must be a better way to detect local facts. *)
   413       case try (unprefix local_prefix) hint of
   414         SOME suf =>
   415         parent_of_local_thm th ^ Long_Name.separator ^ Long_Name.separator ^ suf
   416       | NONE => hint
   417     end
   418   else
   419     backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th
   420 
   421 fun suggested_facts suggs facts =
   422   let
   423     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   424     val tab = Symtab.empty |> fold add_fact facts
   425     fun find_sugg (name, weight) =
   426       Symtab.lookup tab name |> Option.map (rpair weight)
   427   in map_filter find_sugg suggs end
   428 
   429 fun scaled_avg [] = 0
   430   | scaled_avg xs =
   431     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   432 
   433 fun avg [] = 0.0
   434   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   435 
   436 fun normalize_scores _ [] = []
   437   | normalize_scores max_facts xs =
   438     let val avg = avg (map snd (take max_facts xs)) in
   439       map (apsnd (curry Real.* (1.0 / avg))) xs
   440     end
   441 
   442 fun mesh_facts max_facts [(_, (sels, unks))] =
   443     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   444   | mesh_facts max_facts mess =
   445     let
   446       val mess =
   447         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   448       val fact_eq = Thm.eq_thm o pairself snd
   449       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   450         let
   451           fun score_at j =
   452             case try (nth sels) j of
   453               SOME (_, score) => SOME (global_weight * score)
   454             | NONE => NONE
   455         in
   456           case find_index (curry fact_eq fact o fst) sels of
   457             ~1 => (case find_index (curry fact_eq fact) unks of
   458                      ~1 => score_at sel_len
   459                    | _ => NONE)
   460           | rank => score_at rank
   461         end
   462       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   463       val facts =
   464         fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess
   465              []
   466     in
   467       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   468             |> map snd |> take max_facts
   469     end
   470 
   471 fun thy_feature_of s = ("y" ^ s, 1.0 (* FUDGE *))
   472 fun term_feature_of s = ("c" ^ s, 1.0 (* FUDGE *))
   473 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *))
   474 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
   475 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *))
   476 val local_feature = ("local", 1.0 (* FUDGE *))
   477 val lams_feature = ("lams", 1.0 (* FUDGE *))
   478 val skos_feature = ("skos", 1.0 (* FUDGE *))
   479 
   480 fun theory_ord p =
   481   if Theory.eq_thy p then
   482     EQUAL
   483   else if Theory.subthy p then
   484     LESS
   485   else if Theory.subthy (swap p) then
   486     GREATER
   487   else case int_ord (pairself (length o Theory.ancestors_of) p) of
   488     EQUAL => string_ord (pairself Context.theory_name p)
   489   | order => order
   490 
   491 fun thm_ord p =
   492   case theory_ord (pairself theory_of_thm p) of
   493     EQUAL =>
   494     (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
   495     string_ord (pairself nickname_of (swap p))
   496   | ord => ord
   497 
   498 val freezeT = Type.legacy_freeze_type
   499 
   500 fun freeze (t $ u) = freeze t $ freeze u
   501   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   502   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   503   | freeze (Const (s, T)) = Const (s, freezeT T)
   504   | freeze (Free (s, T)) = Free (s, freezeT T)
   505   | freeze t = t
   506 
   507 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   508 
   509 fun run_prover_for_mash ctxt params prover facts goal =
   510   let
   511     val problem =
   512       {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
   513        facts = facts |> map (apfst (apfst (fn name => name ())))
   514                      |> map Untranslated_Fact}
   515   in
   516     get_minimizing_prover ctxt MaSh (K (K ())) prover params (K (K (K "")))
   517                           problem
   518   end
   519 
   520 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   521 
   522 val logical_consts =
   523   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
   524 
   525 fun interesting_terms_types_and_classes ctxt prover term_max_depth
   526                                         type_max_depth ts =
   527   let
   528     fun is_bad_const (x as (s, _)) args =
   529       member (op =) logical_consts s orelse
   530       fst (is_built_in_const_for_prover ctxt prover x args)
   531     fun add_classes @{sort type} = I
   532       | add_classes S = union (op = o pairself fst) (map class_feature_of S)
   533     fun do_add_type (Type (s, Ts)) =
   534         (not (member (op =) bad_types s)
   535          ? insert (op = o pairself fst) (type_feature_of s))
   536         #> fold do_add_type Ts
   537       | do_add_type (TFree (_, S)) = add_classes S
   538       | do_add_type (TVar (_, S)) = add_classes S
   539     fun add_type T = type_max_depth >= 0 ? do_add_type T
   540     fun patternify_term _ ~1 _ = []
   541       | patternify_term args _ (Const (x as (s, _))) =
   542         if is_bad_const x args then [] else [s]
   543       | patternify_term _ 0 _ = []
   544       | patternify_term args depth (t $ u) =
   545         let
   546           val ps = patternify_term (u :: args) depth t
   547           val qs = "" :: patternify_term [] (depth - 1) u
   548         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   549       | patternify_term _ _ _ = []
   550     val add_term_pattern =
   551       union (op = o pairself fst) o map term_feature_of oo patternify_term []
   552     fun add_term_patterns ~1 _ = I
   553       | add_term_patterns depth t =
   554         add_term_pattern depth t #> add_term_patterns (depth - 1) t
   555     val add_term = add_term_patterns term_max_depth
   556     fun add_patterns t =
   557       let val (head, args) = strip_comb t in
   558         (case head of
   559            Const (_, T) => add_term t #> add_type T
   560          | Free (_, T) => add_type T
   561          | Var (_, T) => add_type T
   562          | Abs (_, T, body) => add_type T #> add_patterns body
   563          | _ => I)
   564         #> fold add_patterns args
   565       end
   566   in [] |> fold add_patterns ts end
   567 
   568 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
   569 
   570 val term_max_depth = 2
   571 val type_max_depth = 2
   572 
   573 (* TODO: Generate type classes for types? *)
   574 fun features_of ctxt prover thy (scope, status) ts =
   575   thy_feature_of (Context.theory_name thy) ::
   576   interesting_terms_types_and_classes ctxt prover term_max_depth type_max_depth
   577                                       ts
   578   |> status <> General ? cons (status_feature_of status)
   579   |> scope <> Global ? cons local_feature
   580   |> exists (not o is_lambda_free) ts ? cons lams_feature
   581   |> exists (exists_Const is_exists) ts ? cons skos_feature
   582 
   583 (* Too many dependencies is a sign that a crazy decision procedure is at work.
   584    There isn't much to learn from such proofs. *)
   585 val max_dependencies = 100
   586 val atp_dependency_default_max_facts = 50
   587 
   588 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   589 val typedef_deps = [@{thm CollectI} |> nickname_of]
   590 
   591 (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
   592 val typedef_ths =
   593   @{thms type_definition.Abs_inverse type_definition.Rep_inverse
   594          type_definition.Rep type_definition.Rep_inject
   595          type_definition.Abs_inject type_definition.Rep_cases
   596          type_definition.Abs_cases type_definition.Rep_induct
   597          type_definition.Abs_induct type_definition.Rep_range
   598          type_definition.Abs_image}
   599   |> map nickname_of
   600 
   601 fun is_size_def [dep] th =
   602     (case first_field ".recs" dep of
   603        SOME (pref, _) =>
   604        (case first_field ".size" (nickname_of th) of
   605           SOME (pref', _) => pref = pref'
   606         | NONE => false)
   607      | NONE => false)
   608   | is_size_def _ _ = false
   609 
   610 fun trim_dependencies th deps =
   611   if length deps > max_dependencies then
   612     NONE
   613   else
   614     SOME (if deps = typedef_deps orelse
   615              exists (member (op =) typedef_ths) deps orelse
   616              is_size_def deps th then
   617             []
   618           else
   619             deps)
   620 
   621 fun isar_dependencies_of all_names th =
   622   th |> thms_in_proof (SOME all_names) |> trim_dependencies th
   623 
   624 fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
   625                         auto_level facts all_names th =
   626   case isar_dependencies_of all_names th of
   627     SOME [] => (false, SOME [])
   628   | isar_deps =>
   629     let
   630       val thy = Proof_Context.theory_of ctxt
   631       val goal = goal_of_thm thy th
   632       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
   633       val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
   634       fun fix_name ((_, stature), th) = ((fn () => nickname_of th, stature), th)
   635       fun is_dep dep (_, th) = nickname_of th = dep
   636       fun add_isar_dep facts dep accum =
   637         if exists (is_dep dep) accum then
   638           accum
   639         else case find_first (is_dep dep) facts of
   640           SOME ((name, status), th) => accum @ [((name, status), th)]
   641         | NONE => accum (* shouldn't happen *)
   642       val facts =
   643         facts |> mepo_suggested_facts ctxt params prover
   644                      (max_facts |> the_default atp_dependency_default_max_facts)
   645                      NONE hyp_ts concl_t
   646               |> fold (add_isar_dep facts) (these isar_deps)
   647               |> map fix_name
   648     in
   649       if verbose andalso auto_level = 0 then
   650         let val num_facts = length facts in
   651           "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of th) ^
   652           " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
   653           "."
   654           |> Output.urgent_message
   655         end
   656       else
   657         ();
   658       case run_prover_for_mash ctxt params prover facts goal of
   659         {outcome = NONE, used_facts, ...} =>
   660         (if verbose andalso auto_level = 0 then
   661            let val num_facts = length used_facts in
   662              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   663              plural_s num_facts ^ "."
   664              |> Output.urgent_message
   665            end
   666          else
   667            ();
   668          case used_facts |> map fst |> trim_dependencies th of
   669            NONE => (false, isar_deps)
   670          | atp_deps => (true, atp_deps))
   671       | _ => (false, isar_deps)
   672     end
   673 
   674 
   675 (*** High-level communication with MaSh ***)
   676 
   677 fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   678 
   679 fun maximal_in_graph fact_G facts =
   680   let
   681     val facts = [] |> fold (cons o nickname_of o snd) facts
   682     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) facts
   683     fun insert_new seen name =
   684       not (Symtab.defined seen name) ? insert (op =) name
   685     fun find_maxes _ (maxs, []) = map snd maxs
   686       | find_maxes seen (maxs, new :: news) =
   687         find_maxes
   688             (seen |> num_keys (Graph.imm_succs fact_G new) > 1
   689                      ? Symtab.default (new, ()))
   690             (if Symtab.defined tab new then
   691                let
   692                  val newp = Graph.all_preds fact_G [new]
   693                  fun is_ancestor x yp = member (op =) yp x
   694                  val maxs =
   695                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   696                in
   697                  if exists (is_ancestor new o fst) maxs then
   698                    (maxs, news)
   699                  else
   700                    ((newp, new)
   701                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   702                     news)
   703                end
   704              else
   705                (maxs, Graph.Keys.fold (insert_new seen)
   706                                       (Graph.imm_preds fact_G new) news))
   707   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   708 
   709 fun is_fact_in_graph fact_G (_, th) =
   710   can (Graph.get_node fact_G) (nickname_of th)
   711 
   712 (* factor that controls whether unknown global facts should be included *)
   713 val include_unk_global_factor = 15
   714 
   715 (* use MePo weights for now *)
   716 val weight_raw_mash_facts = weight_mepo_facts
   717 val weight_mash_facts = weight_raw_mash_facts
   718 
   719 (* FUDGE *)
   720 fun weight_of_proximity_fact rank =
   721   Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0
   722 
   723 fun weight_proximity_facts facts =
   724   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   725 
   726 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   727                          concl_t facts =
   728   let
   729     val thy = Proof_Context.theory_of ctxt
   730     val (fact_G, suggs) =
   731       mash_peek ctxt (fn {fact_G, ...} =>
   732           if Graph.is_empty fact_G then
   733             (fact_G, [])
   734           else
   735             let
   736               val parents = maximal_in_graph fact_G facts
   737               val feats =
   738                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   739             in
   740               (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   741             end)
   742     val (chained, unchained) =
   743       List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
   744     val raw_mash =
   745       facts |> suggested_facts suggs
   746             (* The weights currently returned by "mash.py" are too spaced out to
   747                make any sense. *)
   748             |> map fst
   749     val proximity =
   750       chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained
   751                        |> sort (thm_ord o pairself snd o swap))
   752     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   753     val mess =
   754       [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
   755        (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))]
   756   in mesh_facts max_facts mess end
   757 
   758 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   759   let
   760     fun maybe_add_from from (accum as (parents, graph)) =
   761       try_graph ctxt "updating graph" accum (fn () =>
   762           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   763     val graph = graph |> Graph.default_node (name, Isar_Proof)
   764     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   765     val (deps, _) = ([], graph) |> fold maybe_add_from deps
   766   in ((name, parents, feats, deps) :: adds, graph) end
   767 
   768 fun reprove_wrt_fact_graph ctxt (name, deps) (reps, graph) =
   769   let
   770     fun maybe_rep_from from (accum as (parents, graph)) =
   771       try_graph ctxt "updating graph" accum (fn () =>
   772           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   773     val graph = graph |> update_fact_graph_node (name, ATP_Proof)
   774     val (deps, _) = ([], graph) |> fold maybe_rep_from deps
   775   in ((name, deps) :: reps, graph) end
   776 
   777 fun flop_wrt_fact_graph name =
   778   update_fact_graph_node (name, Isar_Proof_wegen_ATP_Flop)
   779 
   780 val learn_timeout_slack = 2.0
   781 
   782 fun launch_thread timeout task =
   783   let
   784     val hard_timeout = time_mult learn_timeout_slack timeout
   785     val birth_time = Time.now ()
   786     val death_time = Time.+ (birth_time, hard_timeout)
   787     val desc = ("Machine learner for Sledgehammer", "")
   788   in Async_Manager.launch MaShN birth_time death_time desc task end
   789 
   790 fun freshish_name () =
   791   Date.fmt ".%Y_%m_%d_%H_%M_%S__" (Date.fromTimeLocal (Time.now ())) ^
   792   serial_string ()
   793 
   794 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   795                      used_ths =
   796   if is_smt_prover ctxt prover then
   797     ()
   798   else
   799     launch_thread timeout (fn () =>
   800         let
   801           val thy = Proof_Context.theory_of ctxt
   802           val name = freshish_name ()
   803           val feats = features_of ctxt prover thy (Local, General) [t]
   804           val deps = used_ths |> map nickname_of
   805         in
   806           mash_peek ctxt (fn {fact_G, ...} =>
   807               let val parents = maximal_in_graph fact_G facts in
   808                 mash_ADD ctxt overlord [(name, parents, feats, deps)]
   809               end);
   810           (true, "")
   811         end)
   812 
   813 fun sendback sub = Sendback.markup (sledgehammerN ^ " " ^ sub)
   814 
   815 val commit_timeout = seconds 30.0
   816 
   817 (* The timeout is understood in a very slack fashion. *)
   818 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
   819                      auto_level atp learn_timeout facts =
   820   let
   821     val timer = Timer.startRealTimer ()
   822     fun next_commit_time () =
   823       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   824     val {fact_G, ...} = mash_get ctxt
   825     val (old_facts, new_facts) =
   826       facts |> List.partition (is_fact_in_graph fact_G)
   827             ||> sort (thm_ord o pairself snd)
   828   in
   829     if null new_facts andalso (not atp orelse null old_facts) then
   830       if auto_level < 2 then
   831         "No new " ^ (if atp then "ATP" else "Isar") ^ " proofs to learn." ^
   832         (if auto_level = 0 andalso not atp then
   833            "\n\nHint: Try " ^ sendback learn_atpN ^ " to learn from ATP proofs."
   834          else
   835            "")
   836       else
   837         ""
   838     else
   839       let
   840         val all_names =
   841           Symtab.empty
   842           |> fold Symtab.update (facts |> map (rpair () o nickname_of o snd))
   843         fun deps_of status th =
   844           if status = Non_Rec_Def orelse status = Rec_Def then
   845             SOME []
   846           else if atp then
   847             atp_dependencies_of ctxt params prover auto_level facts all_names th
   848             |> (fn (false, _) => NONE | (true, deps) => deps)
   849           else
   850             isar_dependencies_of all_names th
   851         fun do_commit [] [] [] state = state
   852           | do_commit adds reps flops {fact_G, dirty} =
   853             let
   854               val was_empty = Graph.is_empty fact_G
   855               val (adds, fact_G) =
   856                 ([], fact_G) |> fold (add_wrt_fact_graph ctxt) adds
   857               val (reps, fact_G) =
   858                 ([], fact_G) |> fold (reprove_wrt_fact_graph ctxt) reps
   859               val fact_G = fact_G |> fold flop_wrt_fact_graph flops
   860               val dirty =
   861                 case (was_empty, dirty, reps) of
   862                   (false, SOME names, []) => SOME (map #1 adds @ names)
   863                 | _ => NONE
   864             in
   865               mash_ADD ctxt overlord (rev adds);
   866               mash_REPROVE ctxt overlord reps;
   867               {fact_G = fact_G, dirty = dirty}
   868             end
   869         fun commit last adds reps flops =
   870           (if debug andalso auto_level = 0 then
   871              Output.urgent_message "Committing..."
   872            else
   873              ();
   874            mash_map ctxt (do_commit (rev adds) reps flops);
   875            if not last andalso auto_level = 0 then
   876              let val num_proofs = length adds + length reps in
   877                "Learned " ^ string_of_int num_proofs ^ " " ^
   878                (if atp then "ATP" else "Isar") ^ " proof" ^
   879                plural_s num_proofs ^ " in the last " ^
   880                string_from_time commit_timeout ^ "."
   881                |> Output.urgent_message
   882              end
   883            else
   884              ())
   885         fun learn_new_fact _ (accum as (_, (_, _, _, true))) = accum
   886           | learn_new_fact ((_, stature as (_, status)), th)
   887                            (adds, (parents, n, next_commit, _)) =
   888             let
   889               val name = nickname_of th
   890               val feats =
   891                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
   892               val deps = deps_of status th |> these
   893               val n = n |> not (null deps) ? Integer.add 1
   894               val adds = (name, parents, feats, deps) :: adds
   895               val (adds, next_commit) =
   896                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   897                   (commit false adds [] []; ([], next_commit_time ()))
   898                 else
   899                   (adds, next_commit)
   900               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   901             in (adds, ([name], n, next_commit, timed_out)) end
   902         val n =
   903           if null new_facts then
   904             0
   905           else
   906             let
   907               val last_th = new_facts |> List.last |> snd
   908               (* crude approximation *)
   909               val ancestors =
   910                 old_facts
   911                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   912               val parents = maximal_in_graph fact_G ancestors
   913               val (adds, (_, n, _, _)) =
   914                 ([], (parents, 0, next_commit_time (), false))
   915                 |> fold learn_new_fact new_facts
   916             in commit true adds [] []; n end
   917         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   918           | relearn_old_fact ((_, (_, status)), th)
   919                              ((reps, flops), (n, next_commit, _)) =
   920             let
   921               val name = nickname_of th
   922               val (n, reps, flops) =
   923                 case deps_of status th of
   924                   SOME deps => (n + 1, (name, deps) :: reps, flops)
   925                 | NONE => (n, reps, name :: flops)
   926               val (reps, flops, next_commit) =
   927                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   928                   (commit false [] reps flops; ([], [], next_commit_time ()))
   929                 else
   930                   (reps, flops, next_commit)
   931               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   932             in ((reps, flops), (n, next_commit, timed_out)) end
   933         val n =
   934           if not atp orelse null old_facts then
   935             n
   936           else
   937             let
   938               val max_isar = 1000 * max_dependencies
   939               fun kind_of_proof th =
   940                 try (Graph.get_node fact_G) (nickname_of th)
   941                 |> the_default Isar_Proof
   942               fun priority_of (_, th) =
   943                 random_range 0 max_isar
   944                 + (case kind_of_proof th of
   945                      Isar_Proof => 0
   946                    | ATP_Proof => 2 * max_isar
   947                    | Isar_Proof_wegen_ATP_Flop => max_isar)
   948                 - 500 * (th |> isar_dependencies_of all_names
   949                             |> Option.map length
   950                             |> the_default max_dependencies)
   951               val old_facts =
   952                 old_facts |> map (`priority_of)
   953                           |> sort (int_ord o pairself fst)
   954                           |> map snd
   955               val ((reps, flops), (n, _, _)) =
   956                 (([], []), (n, next_commit_time (), false))
   957                 |> fold relearn_old_fact old_facts
   958             in commit true [] reps flops; n end
   959       in
   960         if verbose orelse auto_level < 2 then
   961           "Learned " ^ string_of_int n ^ " nontrivial " ^
   962           (if atp then "ATP" else "Isar") ^ " proof" ^ plural_s n ^
   963           (if verbose then
   964              " in " ^ string_from_time (Timer.checkRealTimer timer)
   965            else
   966              "") ^ "."
   967         else
   968           ""
   969       end
   970   end
   971 
   972 fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained
   973                atp =
   974   let
   975     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
   976     val ctxt = ctxt |> Config.put instantiate_inducts false
   977     val facts =
   978       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
   979                        @{prop True}
   980     val num_facts = length facts
   981     val prover = hd provers
   982     fun learn auto_level atp =
   983       mash_learn_facts ctxt params prover auto_level atp infinite_timeout facts
   984       |> Output.urgent_message
   985   in
   986     if atp then
   987       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
   988        plural_s num_facts ^ " for ATP proofs (" ^ quote prover ^ " timeout: " ^
   989        string_from_time timeout ^ ").\n\nCollecting Isar proofs first..."
   990        |> Output.urgent_message;
   991        learn 1 false;
   992        "Now collecting ATP proofs. This may take several hours. You can \
   993        \safely stop the learning process at any point."
   994        |> Output.urgent_message;
   995        learn 0 true)
   996     else
   997       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
   998        plural_s num_facts ^ " for Isar proofs..."
   999        |> Output.urgent_message;
  1000        learn 0 false)
  1001   end
  1002 
  1003 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1004 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
  1005 
  1006 (* Generate more suggestions than requested, because some might be thrown out
  1007    later for various reasons. *)
  1008 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1009 
  1010 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1011    Sledgehammer and Try. *)
  1012 val min_secs_for_learning = 15
  1013 
  1014 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1015         max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1016   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1017     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1018   else if only then
  1019     facts
  1020   else if max_facts <= 0 orelse null facts then
  1021     []
  1022   else
  1023     let
  1024       fun maybe_learn () =
  1025         if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
  1026            Time.toSeconds timeout >= min_secs_for_learning then
  1027           let val timeout = time_mult learn_timeout_slack timeout in
  1028             launch_thread timeout
  1029                 (fn () => (true, mash_learn_facts ctxt params prover 2 false
  1030                                                   timeout facts))
  1031           end
  1032         else
  1033           ()
  1034       val fact_filter =
  1035         case fact_filter of
  1036           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
  1037         | NONE =>
  1038           if is_smt_prover ctxt prover then
  1039             mepoN
  1040           else if is_mash_enabled () then
  1041             (maybe_learn ();
  1042              if mash_can_suggest_facts ctxt then meshN else mepoN)
  1043           else
  1044             mepoN
  1045       val add_ths = Attrib.eval_thms ctxt add
  1046       fun prepend_facts ths accepts =
  1047         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
  1048          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
  1049         |> take max_facts
  1050       fun mepo () =
  1051         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1052                              facts
  1053         |> weight_mepo_facts
  1054       fun mash () =
  1055         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1056             hyp_ts concl_t facts
  1057         |> weight_mash_facts
  1058       val mess =
  1059         [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
  1060            |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I)
  1061     in
  1062       mesh_facts max_facts mess
  1063       |> not (null add_ths) ? prepend_facts add_ths
  1064     end
  1065 
  1066 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
  1067 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1068 
  1069 end;