src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57005 33f3d2ea803d
parent 56995 61855ade6c7e
child 57006 20e5b110d19b
equal deleted inserted replaced
57004:c8288ce9676a 57005:33f3d2ea803d
    26   val relearn_isarN : string
    26   val relearn_isarN : string
    27   val relearn_proverN : string
    27   val relearn_proverN : string
    28   val fact_filters : string list
    28   val fact_filters : string list
    29   val encode_str : string -> string
    29   val encode_str : string -> string
    30   val encode_strs : string list -> string
    30   val encode_strs : string list -> string
    31   val unencode_str : string -> string
    31   val decode_str : string -> string
    32   val unencode_strs : string -> string list
    32   val decode_strs : string -> string list
    33   val encode_plain_features : string list list -> string
    33   val encode_unweighted_features : string list list -> string
    34   val encode_features : (string list * real) list -> string
    34   val encode_features : (string list * real) list -> string
    35   val extract_suggestions : string -> string * string list
    35   val extract_suggestions : string -> string * string list
    36 
    36 
    37   structure MaSh:
    37   structure MaSh:
    38   sig
    38   sig
   108 open Sledgehammer_Fact
   108 open Sledgehammer_Fact
   109 open Sledgehammer_Prover
   109 open Sledgehammer_Prover
   110 open Sledgehammer_Prover_Minimize
   110 open Sledgehammer_Prover_Minimize
   111 open Sledgehammer_MePo
   111 open Sledgehammer_MePo
   112 
   112 
   113 val trace =
   113 val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
   114   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
       
   115 
   114 
   116 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
   115 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
   117 
   116 
   118 val MePoN = "MePo"
   117 val MePoN = "MePo"
   119 val MaShN = "MaSh"
   118 val MaShN = "MaSh"
   170       " --modelFile=" ^ model_dir ^ "/model.pickle\
   169       " --modelFile=" ^ model_dir ^ "/model.pickle\
   171       \ --dictsFile=" ^ model_dir ^ "/dict.pickle\
   170       \ --dictsFile=" ^ model_dir ^ "/dict.pickle\
   172       \ --log " ^ log_file ^
   171       \ --log " ^ log_file ^
   173       " --inputFile " ^ cmd_file ^
   172       " --inputFile " ^ cmd_file ^
   174       " --predictions " ^ sugg_file ^
   173       " --predictions " ^ sugg_file ^
   175       (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
   174       (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^ " >& " ^ err_file ^
   176       " >& " ^ err_file ^
       
   177       (if background then " &" else "")
   175       (if background then " &" else "")
   178     fun run_on () =
   176     fun run_on () =
   179       (Isabelle_System.bash command
   177       (Isabelle_System.bash command
   180        |> tap (fn _ =>
   178        |> tap (fn _ =>
   181             (case try File.read (Path.explode err_file) |> the_default "" of
   179             (case try File.read (Path.explode err_file) |> the_default "" of
   182               "" => trace_msg ctxt (K "Done")
   180               "" => trace_msg ctxt (K "Done")
   183             | s => warning ("MaSh error: " ^ elide_string 1000 s)));
   181             | s => warning ("MaSh error: " ^ elide_string 1000 s)));
   184        read_suggs (fn () => try File.read_lines sugg_path |> these))
   182        read_suggs (fn () => try File.read_lines sugg_path |> these))
   185     fun clean_up () =
   183     fun clean_up () =
   186       if overlord then ()
   184       if overlord then () else List.app wipe_out_file [err_file, sugg_file, cmd_file]
   187       else List.app wipe_out_file [err_file, sugg_file, cmd_file]
       
   188   in
   185   in
   189     write_file (SOME "") ([], K "") sugg_path;
   186     write_file (SOME "") ([], K "") sugg_path;
   190     write_file (SOME "") write_cmds cmd_path;
   187     write_file (SOME "") write_cmds cmd_path;
   191     trace_msg ctxt (fn () => "Running " ^ command);
   188     trace_msg ctxt (fn () => "Running " ^ command);
   192     with_cleanup clean_up run_on ()
   189     with_cleanup clean_up run_on ()
   201     "%" ^ stringN_of_int 3 (Char.ord c)
   198     "%" ^ stringN_of_int 3 (Char.ord c)
   202 
   199 
   203 fun unmeta_chars accum [] = String.implode (rev accum)
   200 fun unmeta_chars accum [] = String.implode (rev accum)
   204   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   201   | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
   205     (case Int.fromString (String.implode [d1, d2, d3]) of
   202     (case Int.fromString (String.implode [d1, d2, d3]) of
   206        SOME n => unmeta_chars (Char.chr n :: accum) cs
   203       SOME n => unmeta_chars (Char.chr n :: accum) cs
   207      | NONE => "" (* error *))
   204     | NONE => "" (* error *))
   208   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   205   | unmeta_chars _ (#"%" :: _) = "" (* error *)
   209   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   206   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   210 
   207 
   211 val encode_str = String.translate meta_char
   208 val encode_str = String.translate meta_char
       
   209 val decode_str = String.explode #> unmeta_chars []
       
   210 
   212 val encode_strs = map encode_str #> space_implode " "
   211 val encode_strs = map encode_str #> space_implode " "
   213 val unencode_str = String.explode #> unmeta_chars []
   212 val decode_strs = space_explode " " #> filter_out (curry (op =) "") #> map decode_str
   214 val unencode_strs =
       
   215   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
       
   216 
   213 
   217 (* Avoid scientific notation *)
   214 (* Avoid scientific notation *)
   218 fun safe_str_of_real r =
   215 fun safe_str_of_real r =
   219   if r < 0.00001 then "0.00001"
   216   if r < 0.00001 then "0.00001"
   220   else if r >= 1000000.0 then "1000000"
   217   else if r >= 1000000.0 then "1000000"
   221   else Markup.print_real r
   218   else Markup.print_real r
   222 
   219 
   223 val encode_plain_feature = space_implode "|" o map encode_str
   220 val encode_unweighted_feature = map encode_str #> space_implode "|"
       
   221 val decode_unweighted_feature = space_explode "|" #> map decode_str
   224 
   222 
   225 fun encode_feature (names, weight) =
   223 fun encode_feature (names, weight) =
   226   encode_plain_feature names ^ (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
   224   encode_unweighted_feature names ^ (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
   227 
   225 
   228 val encode_plain_features = map encode_plain_feature #> space_implode " "
   226 val encode_unweighted_features = map encode_unweighted_feature #> space_implode " "
       
   227 val decode_unweighted_features = space_explode " " #> map decode_unweighted_feature
       
   228 
   229 val encode_features = map encode_feature #> space_implode " "
   229 val encode_features = map encode_feature #> space_implode " "
   230 
   230 
   231 fun str_of_learn (name, parents, feats : string list list, deps) =
   231 fun str_of_learn (name, parents, feats, deps) =
   232   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   232   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   233   encode_plain_features feats ^ "; " ^ encode_strs deps ^ "\n"
   233   encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n"
   234 
   234 
   235 fun str_of_relearn (name, deps) =
   235 fun str_of_relearn (name, deps) = "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
   236   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
       
   237 
   236 
   238 fun str_of_query max_suggs (learns, hints, parents, feats) =
   237 fun str_of_query max_suggs (learns, hints, parents, feats) =
   239   implode (map str_of_learn learns) ^
   238   implode (map str_of_learn learns) ^
   240   "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^
   239   "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^ encode_features feats ^
   241   encode_features feats ^
       
   242   (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
   240   (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
   243 
   241 
   244 (* The suggested weights don't make much sense. *)
   242 (* The suggested weights do not make much sense. *)
   245 fun extract_suggestion sugg =
   243 fun extract_suggestion sugg =
   246   (case space_explode "=" sugg of
   244   (case space_explode "=" sugg of
   247     [name, _ (* weight *)] =>
   245     [name, _ (* weight *)] => SOME (decode_str name)
   248     SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
   246   | [name] => SOME (decode_str name)
   249   | [name] => SOME (unencode_str name (* , 1.0 *))
       
   250   | _ => NONE)
   247   | _ => NONE)
   251 
   248 
   252 fun extract_suggestions line =
   249 fun extract_suggestions line =
   253   (case space_explode ":" line of
   250   (case space_explode ":" line of
   254     [goal, suggs] => (unencode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   251     [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   255   | _ => ("", []))
   252   | _ => ("", []))
   256 
   253 
   257 structure MaSh =
   254 structure MaSh =
   258 struct
   255 struct
   259 
   256 
   302 end;
   299 end;
   303 
   300 
   304 
   301 
   305 (*** Middle-level communication with MaSh ***)
   302 (*** Middle-level communication with MaSh ***)
   306 
   303 
   307 datatype proof_kind =
   304 datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   308   Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
       
   309 
   305 
   310 fun str_of_proof_kind Isar_Proof = "i"
   306 fun str_of_proof_kind Isar_Proof = "i"
   311   | str_of_proof_kind Automatic_Proof = "a"
   307   | str_of_proof_kind Automatic_Proof = "a"
   312   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   308   | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
   313 
   309 
   314 fun proof_kind_of_str "i" = Isar_Proof
   310 fun proof_kind_of_str "a" = Automatic_Proof
   315   | proof_kind_of_str "a" = Automatic_Proof
       
   316   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   311   | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
   317 
   312   | proof_kind_of_str _ (* "i" *) = Isar_Proof
   318 (* FIXME: Here a "Graph.update_node" function would be useful *)
       
   319 fun update_access_graph_node (name, kind) =
       
   320   Graph.default_node (name, Isar_Proof)
       
   321   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
       
   322 
   313 
   323 fun try_graph ctxt when def f =
   314 fun try_graph ctxt when def f =
   324   f ()
   315   f ()
   325   handle Graph.CYCLES (cycle :: _) =>
   316   handle
   326          (trace_msg ctxt (fn () =>
   317     Graph.CYCLES (cycle :: _) =>
   327               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   318     (trace_msg ctxt (fn () => "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   328        | Graph.DUP name =>
   319   | Graph.DUP name =>
   329          (trace_msg ctxt (fn () =>
   320     (trace_msg ctxt (fn () => "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   330               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   321   | Graph.UNDEF name =>
   331        | Graph.UNDEF name =>
   322     (trace_msg ctxt (fn () => "Unknown fact " ^ quote name ^ " when " ^ when); def)
   332          (trace_msg ctxt (fn () =>
   323   | exn =>
   333               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   324     if Exn.is_interrupt exn then
   334        | exn =>
   325       reraise exn
   335          if Exn.is_interrupt exn then
   326     else
   336            reraise exn
   327       (trace_msg ctxt (fn () => "Internal error when " ^ when ^ ":\n" ^ Runtime.exn_message exn);
   337          else
   328        def)
   338            (trace_msg ctxt (fn () =>
       
   339                 "Internal error when " ^ when ^ ":\n" ^
       
   340                 Runtime.exn_message exn); def)
       
   341 
   329 
   342 fun graph_info G =
   330 fun graph_info G =
   343   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   331   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   344   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   332   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^
   345   " edge(s), " ^
   333   " edge(s), " ^
   346   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   334   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   347   string_of_int (length (Graph.maximals G)) ^ " maximal"
   335   string_of_int (length (Graph.maximals G)) ^ " maximal"
   348 
   336 
   349 type mash_state =
   337 type mash_state =
   350   {access_G : unit Graph.T,
   338   {access_G : (proof_kind * string list list * string list) Graph.T,
   351    num_known_facts : int,
   339    num_known_facts : int,
   352    dirty : string list option}
   340    dirty : string list option}
   353 
   341 
   354 val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []}
   342 val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []} : mash_state
   355 
   343 
   356 local
   344 local
   357 
   345 
   358 val version = "*** MaSh version 20131206 ***"
   346 val version = "*** MaSh version 20140516 ***"
   359 
   347 
   360 exception FILE_VERSION_TOO_NEW of unit
   348 exception FILE_VERSION_TOO_NEW of unit
   361 
   349 
   362 fun extract_node line =
   350 fun extract_node line =
   363   (case space_explode ":" line of
   351   (case space_explode ":" line of
   364     [head, parents] =>
   352     [head, tail] =>
   365     (case space_explode " " head of
   353     (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of
   366       [kind, name] =>
   354       ([kind, name], [parents, feats, deps]) =>
   367       SOME (unencode_str name, unencode_strs parents,
   355       SOME (proof_kind_of_str kind, decode_str name, decode_strs parents,
   368         try proof_kind_of_str kind |> the_default Isar_Proof)
   356         decode_unweighted_features feats, decode_strs deps)
   369     | _ => NONE)
   357     | _ => NONE)
   370   | _ => NONE)
   358   | _ => NONE)
   371 
   359 
   372 fun load_state _ _ (state as (true, _)) = state
   360 fun load_state _ _ (state as (true, _)) = state
   373   | load_state ctxt overlord _ =
   361   | load_state ctxt overlord _ =
   375       (true,
   363       (true,
   376        (case try File.read_lines path of
   364        (case try File.read_lines path of
   377          SOME (version' :: node_lines) =>
   365          SOME (version' :: node_lines) =>
   378          let
   366          let
   379            fun add_edge_to name parent =
   367            fun add_edge_to name parent =
   380              Graph.default_node (parent, Isar_Proof)
   368              Graph.default_node (parent, (Isar_Proof, [], []))
   381              #> Graph.add_edge (parent, name)
   369              #> Graph.add_edge (parent, name)
   382            fun add_node line =
   370            fun add_node line =
   383              (case extract_node line of
   371              (case extract_node line of
   384                NONE => I (* shouldn't happen *)
   372                NONE => I (* should not happen *)
   385              | SOME (name, parents, kind) =>
   373              | SOME (kind, name, parents, feats, deps) =>
   386                update_access_graph_node (name, kind) #> fold (add_edge_to name) parents)
   374                Graph.default_node (name, (kind, feats, deps))
       
   375                #> Graph.map_node name (K (kind, feats, deps))
       
   376                #> fold (add_edge_to name) parents)
   387            val (access_G, num_known_facts) =
   377            val (access_G, num_known_facts) =
   388              (case string_ord (version', version) of
   378              (case string_ord (version', version) of
   389                EQUAL =>
   379                EQUAL =>
   390                (try_graph ctxt "loading state" Graph.empty (fn () =>
   380                (try_graph ctxt "loading state" Graph.empty (fn () =>
   391                   fold add_node node_lines Graph.empty),
   381                   fold add_node node_lines Graph.empty),
   392                 length node_lines)
   382                 length node_lines)
   393              | LESS =>
   383              | LESS => (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
   394                (* can't parse old file *)
       
   395                (MaSh.unlearn ctxt overlord; (Graph.empty, 0))
       
   396              | GREATER => raise FILE_VERSION_TOO_NEW ())
   384              | GREATER => raise FILE_VERSION_TOO_NEW ())
   397          in
   385          in
   398            trace_msg ctxt (fn () =>
   386            trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
   399                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   387            {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
   400            {access_G = access_G, num_known_facts = num_known_facts,
       
   401             dirty = SOME []}
       
   402          end
   388          end
   403        | _ => empty_state))
   389        | _ => empty_state))
   404     end
   390     end
   405 
   391 
       
   392 fun str_of_entry (kind, name, parents, feats, deps) =
       
   393   str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
       
   394   encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n"
       
   395 
   406 fun save_state _ (state as {dirty = SOME [], ...}) = state
   396 fun save_state _ (state as {dirty = SOME [], ...}) = state
   407   | save_state ctxt {access_G, num_known_facts, dirty} =
   397   | save_state ctxt {access_G, num_known_facts, dirty} =
   408     let
   398     let
   409       fun str_of_entry (name, parents, kind) =
   399       fun append_entry (name, ((kind, feats, deps), (parents, _))) =
   410         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
   400         cons (kind, name, Graph.Keys.dest parents, feats, deps)
   411         encode_strs parents ^ "\n"
   401 
   412       fun append_entry (name, (kind, (parents, _))) =
       
   413         cons (name, Graph.Keys.dest parents, kind)
       
   414       val (banner, entries) =
   402       val (banner, entries) =
   415         (case dirty of
   403         (case dirty of
   416           SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   404           SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   417         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
   405         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
   418     in
   406     in
   419       write_file banner (entries, str_of_entry) (mash_state_file ());
   407       write_file banner (entries, str_of_entry) (mash_state_file ());
   420       trace_msg ctxt (fn () =>
   408       trace_msg ctxt (fn () =>
   421           "Saved fact graph (" ^ graph_info access_G ^
   409         "Saved fact graph (" ^ graph_info access_G ^
   422           (case dirty of
   410         (case dirty of
   423              SOME dirty =>
   411           SOME dirty => "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   424              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
   412         | _ => "") ^  ")");
   425            | _ => "") ^  ")");
       
   426       {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
   413       {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
   427     end
   414     end
   428 
   415 
   429 val global_state =
   416 val global_state =
   430   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   417   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   431 
   418 
   432 in
   419 in
   433 
   420 
   434 fun map_state ctxt overlord f =
   421 fun map_state ctxt overlord f =
   435   Synchronized.change global_state
   422   Synchronized.change global_state (load_state ctxt overlord ##> (f #> save_state ctxt))
   436                       (load_state ctxt overlord ##> (f #> save_state ctxt))
       
   437   handle FILE_VERSION_TOO_NEW () => ()
   423   handle FILE_VERSION_TOO_NEW () => ()
   438 
   424 
   439 fun peek_state ctxt overlord f =
   425 fun peek_state ctxt overlord f =
   440   Synchronized.change_result global_state
   426   Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   441       (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
       
   442 
   427 
   443 fun clear_state ctxt overlord =
   428 fun clear_state ctxt overlord =
   444   Synchronized.change global_state (fn _ =>
   429   (* "unlearn" also removes the state file *)
   445       (MaSh.unlearn ctxt overlord; (* also removes the state file *)
   430   Synchronized.change global_state (fn _ => (MaSh.unlearn ctxt overlord; (false, empty_state)))
   446        (false, empty_state)))
       
   447 
   431 
   448 end
   432 end
   449 
   433 
   450 fun mash_unlearn ctxt ({overlord, ...} : params) =
   434 fun mash_unlearn ctxt ({overlord, ...} : params) =
   451   (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.")
   435   (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.")
   745     term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts
   729     term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts
   746     |> scope <> Global ? cons local_feature
   730     |> scope <> Global ? cons local_feature
   747   end
   731   end
   748 
   732 
   749 (* Too many dependencies is a sign that a decision procedure is at work. There
   733 (* Too many dependencies is a sign that a decision procedure is at work. There
   750    isn't much to learn from such proofs. *)
   734    is not much to learn from such proofs. *)
   751 val max_dependencies = 20
   735 val max_dependencies = 20
   752 
   736 
   753 val prover_default_max_facts = 25
   737 val prover_default_max_facts = 25
   754 
   738 
   755 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   739 (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   817         if exists (is_dep dep) accum then
   801         if exists (is_dep dep) accum then
   818           accum
   802           accum
   819         else
   803         else
   820           (case find_first (is_dep dep) facts of
   804           (case find_first (is_dep dep) facts of
   821             SOME ((_, status), th) => accum @ [(("", status), th)]
   805             SOME ((_, status), th) => accum @ [(("", status), th)]
   822           | NONE => accum (* shouldn't happen *))
   806           | NONE => accum (* should not happen *))
   823       val mepo_facts =
   807       val mepo_facts =
   824         facts
   808         facts
   825         |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE
   809         |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE
   826              hyp_ts concl_t
   810              hyp_ts concl_t
   827       val facts =
   811       val facts =
  1004           else
   988           else
  1005             let
   989             let
  1006               val parents = maximal_wrt_access_graph access_G facts
   990               val parents = maximal_wrt_access_graph access_G facts
  1007               val goal_feats =
   991               val goal_feats =
  1008                 features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
   992                 features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
  1009               val chained_feats =
   993               val chained_feats = chained
  1010                 chained
       
  1011                 |> map (rpair 1.0)
   994                 |> map (rpair 1.0)
  1012                 |> map (chained_or_extra_features_of chained_feature_factor)
   995                 |> map (chained_or_extra_features_of chained_feature_factor)
  1013                 |> rpair [] |-> fold (union (eq_fst (op =)))
   996                 |> rpair [] |-> fold (union (eq_fst (op =)))
  1014               val extra_feats =
   997               val extra_feats = facts
  1015                 facts
       
  1016                 |> take (Int.max (0, num_extra_feature_facts - length chained))
   998                 |> take (Int.max (0, num_extra_feature_facts - length chained))
  1017                 |> filter fact_has_right_theory
   999                 |> filter fact_has_right_theory
  1018                 |> weight_facts_steeply
  1000                 |> weight_facts_steeply
  1019                 |> map (chained_or_extra_features_of extra_feature_factor)
  1001                 |> map (chained_or_extra_features_of extra_feature_factor)
  1020                 |> rpair [] |-> fold (union (eq_fst (op =)))
  1002                 |> rpair [] |-> fold (union (eq_fst (op =)))
  1021               val feats =
  1003               val feats =
  1022                 fold (union (eq_fst (op =))) [chained_feats, extra_feats]
  1004                 fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
  1023                      goal_feats
       
  1024                 |> debug ? sort (Real.compare o swap o pairself snd)
  1005                 |> debug ? sort (Real.compare o swap o pairself snd)
  1025               val hints =
  1006               val hints = chained
  1026                 chained |> filter (is_fact_in_graph access_G o snd)
  1007                 |> filter (is_fact_in_graph access_G o snd)
  1027                         |> map (nickname_of_thm o snd)
  1008                 |> map (nickname_of_thm o snd)
  1028             in
  1009             in
  1029               (access_G, MaSh.query ctxt overlord max_facts
  1010               (access_G, MaSh.query ctxt overlord max_facts ([], hints, parents, feats))
  1030                                     ([], hints, parents, feats))
       
  1031             end)
  1011             end)
  1032     val unknown = facts |> filter_out (is_fact_in_graph access_G o snd)
  1012     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
  1033   in
  1013   in
  1034     find_mash_suggestions ctxt max_facts suggs facts chained unknown
  1014     find_mash_suggestions ctxt max_facts suggs facts chained unknown
  1035     |> pairself (map fact_of_raw_fact)
  1015     |> pairself (map fact_of_raw_fact)
  1036   end
  1016   end
  1037 
  1017 
  1038 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
  1018 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
  1039   let
  1019   let
  1040     fun maybe_learn_from from (accum as (parents, graph)) =
  1020     fun maybe_learn_from from (accum as (parents, graph)) =
  1041       try_graph ctxt "updating graph" accum (fn () =>
  1021       try_graph ctxt "updating graph" accum (fn () =>
  1042           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1022         (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1043     val graph = graph |> Graph.default_node (name, Isar_Proof)
  1023     val graph = graph |> Graph.default_node (name, (Isar_Proof, feats, deps))
  1044     val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
  1024     val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
  1045     val (deps, _) = ([], graph) |> fold maybe_learn_from deps
  1025     val (deps, _) = ([], graph) |> fold maybe_learn_from deps
  1046   in ((name, parents, feats, deps) :: learns, graph) end
  1026   in ((name, parents, feats, deps) :: learns, graph) end
  1047 
  1027 
  1048 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
  1028 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
  1049   let
  1029   let
  1050     fun maybe_relearn_from from (accum as (parents, graph)) =
  1030     fun maybe_relearn_from from (accum as (parents, graph)) =
  1051       try_graph ctxt "updating graph" accum (fn () =>
  1031       try_graph ctxt "updating graph" accum (fn () =>
  1052           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1032         (from :: parents, Graph.add_edge_acyclic (from, name) graph))
  1053     val graph = graph |> update_access_graph_node (name, Automatic_Proof)
  1033     val graph = graph |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
  1054     val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
  1034     val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
  1055   in ((name, deps) :: relearns, graph) end
  1035   in ((name, deps) :: relearns, graph) end
  1056 
  1036 
  1057 fun flop_wrt_access_graph name =
  1037 fun flop_wrt_access_graph name =
  1058   update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
  1038   Graph.map_node name (fn (_, feats, deps) => (Isar_Proof_wegen_Prover_Flop, feats, deps))
  1059 
  1039 
  1060 val learn_timeout_slack = 2.0
  1040 val learn_timeout_slack = 2.0
  1061 
  1041 
  1062 fun launch_thread timeout task =
  1042 fun launch_thread timeout task =
  1063   let
  1043   let
  1097 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level
  1077 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level
  1098     run_prover learn_timeout facts =
  1078     run_prover learn_timeout facts =
  1099   let
  1079   let
  1100     val timer = Timer.startRealTimer ()
  1080     val timer = Timer.startRealTimer ()
  1101     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1081     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
       
  1082 
  1102     val {access_G, ...} = peek_state ctxt overlord I
  1083     val {access_G, ...} = peek_state ctxt overlord I
  1103     val is_in_access_G = is_fact_in_graph access_G o snd
  1084     val is_in_access_G = is_fact_in_graph access_G o snd
  1104     val no_new_facts = forall is_in_access_G facts
  1085     val no_new_facts = forall is_in_access_G facts
  1105   in
  1086   in
  1106     if no_new_facts andalso not run_prover then
  1087     if no_new_facts andalso not run_prover then
  1107       if auto_level < 2 then
  1088       if auto_level < 2 then
  1108         "No new " ^ (if run_prover then "automatic" else "Isar") ^
  1089         "No new " ^ (if run_prover then "automatic" else "Isar") ^ " proofs to learn." ^
  1109         " proofs to learn." ^
       
  1110         (if auto_level = 0 andalso not run_prover then
  1090         (if auto_level = 0 andalso not run_prover then
  1111            "\n\nHint: Try " ^ sendback learn_proverN ^
  1091            "\n\nHint: Try " ^ sendback learn_proverN ^ " to learn from an automatic prover."
  1112            " to learn from an automatic prover."
       
  1113          else
  1092          else
  1114            "")
  1093            "")
  1115       else
  1094       else
  1116         ""
  1095         ""
  1117     else
  1096     else
  1118       let
  1097       let
  1119         val name_tabs = build_name_tables nickname_of_thm facts
  1098         val name_tabs = build_name_tables nickname_of_thm facts
       
  1099 
  1120         fun deps_of status th =
  1100         fun deps_of status th =
  1121           if no_dependencies_for_status status then
  1101           if no_dependencies_for_status status then
  1122             SOME []
  1102             SOME []
  1123           else if run_prover then
  1103           else if run_prover then
  1124             prover_dependencies_of ctxt params prover auto_level facts name_tabs
  1104             prover_dependencies_of ctxt params prover auto_level facts name_tabs th
  1125                                    th
  1105             |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
  1126             |> (fn (false, _) => NONE
       
  1127                  | (true, deps) => trim_dependencies deps)
       
  1128           else
  1106           else
  1129             isar_dependencies_of name_tabs th
  1107             isar_dependencies_of name_tabs th
  1130             |> trim_dependencies
  1108             |> trim_dependencies
       
  1109 
  1131         fun do_commit [] [] [] state = state
  1110         fun do_commit [] [] [] state = state
  1132           | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
  1111           | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
  1133             let
  1112             let
  1134               val was_empty = Graph.is_empty access_G
  1113               val was_empty = Graph.is_empty access_G
  1135               val (learns, access_G) =
  1114               val (learns, access_G) = ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
  1136                 ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
       
  1137               val (relearns, access_G) =
  1115               val (relearns, access_G) =
  1138                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
  1116                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
  1139               val access_G = access_G |> fold flop_wrt_access_graph flops
  1117               val access_G = access_G |> fold flop_wrt_access_graph flops
  1140               val num_known_facts = num_known_facts + length learns
  1118               val num_known_facts = num_known_facts + length learns
  1141               val dirty =
  1119               val dirty =
  1143                   (false, SOME names, []) => SOME (map #1 learns @ names)
  1121                   (false, SOME names, []) => SOME (map #1 learns @ names)
  1144                 | _ => NONE)
  1122                 | _ => NONE)
  1145             in
  1123             in
  1146               MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
  1124               MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
  1147               MaSh.relearn ctxt overlord save relearns;
  1125               MaSh.relearn ctxt overlord save relearns;
  1148               {access_G = access_G, num_known_facts = num_known_facts,
  1126               {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
  1149                dirty = dirty}
       
  1150             end
  1127             end
       
  1128 
  1151         fun commit last learns relearns flops =
  1129         fun commit last learns relearns flops =
  1152           (if debug andalso auto_level = 0 then
  1130           (if debug andalso auto_level = 0 then
  1153              Output.urgent_message "Committing..."
  1131              Output.urgent_message "Committing..."
  1154            else
  1132            else
  1155              ();
  1133              ();
  1162                string_of_time commit_timeout ^ "."
  1140                string_of_time commit_timeout ^ "."
  1163                |> Output.urgent_message
  1141                |> Output.urgent_message
  1164              end
  1142              end
  1165            else
  1143            else
  1166              ())
  1144              ())
       
  1145 
  1167         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1146         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1168           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1147           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1169                            (learns, (n, next_commit, _)) =
  1148               (learns, (n, next_commit, _)) =
  1170             let
  1149             let
  1171               val name = nickname_of_thm th
  1150               val name = nickname_of_thm th
  1172               val feats =
  1151               val feats =
  1173                 features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
  1152                 features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
  1174                 |> map fst
  1153                 |> map fst
  1179                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1158                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1180                   (commit false learns [] []; ([], next_commit_time ()))
  1159                   (commit false learns [] []; ([], next_commit_time ()))
  1181                 else
  1160                 else
  1182                   (learns, next_commit)
  1161                   (learns, next_commit)
  1183               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1162               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1184             in (learns, (n, next_commit, timed_out)) end
  1163             in
       
  1164               (learns, (n, next_commit, timed_out))
       
  1165             end
       
  1166 
  1185         val n =
  1167         val n =
  1186           if no_new_facts then
  1168           if no_new_facts then
  1187             0
  1169             0
  1188           else
  1170           else
  1189             let
  1171             let
  1190               val new_facts =
  1172               val new_facts = facts
  1191                 facts |> sort (crude_thm_ord o pairself snd)
  1173                 |> sort (crude_thm_ord o pairself snd)
  1192                       |> attach_parents_to_facts []
  1174                 |> attach_parents_to_facts []
  1193                       |> filter_out (is_in_access_G o snd)
  1175                 |> filter_out (is_in_access_G o snd)
  1194               val (learns, (n, _, _)) =
  1176               val (learns, (n, _, _)) =
  1195                 ([], (0, next_commit_time (), false))
  1177                 ([], (0, next_commit_time (), false))
  1196                 |> fold learn_new_fact new_facts
  1178                 |> fold learn_new_fact new_facts
  1197             in commit true learns [] []; n end
  1179             in
       
  1180               commit true learns [] []; n
       
  1181             end
       
  1182 
  1198         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1183         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1199           | relearn_old_fact ((_, (_, status)), th)
  1184           | relearn_old_fact ((_, (_, status)), th) ((relearns, flops), (n, next_commit, _)) =
  1200                              ((relearns, flops), (n, next_commit, _)) =
       
  1201             let
  1185             let
  1202               val name = nickname_of_thm th
  1186               val name = nickname_of_thm th
  1203               val (n, relearns, flops) =
  1187               val (n, relearns, flops) =
  1204                 (case deps_of status th of
  1188                 (case deps_of status th of
  1205                   SOME deps => (n + 1, (name, deps) :: relearns, flops)
  1189                   SOME deps => (n + 1, (name, deps) :: relearns, flops)
  1206                 | NONE => (n, relearns, name :: flops))
  1190                 | NONE => (n, relearns, name :: flops))
  1207               val (relearns, flops, next_commit) =
  1191               val (relearns, flops, next_commit) =
  1208                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1192                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1209                   (commit false [] relearns flops;
  1193                   (commit false [] relearns flops; ([], [], next_commit_time ()))
  1210                    ([], [], next_commit_time ()))
       
  1211                 else
  1194                 else
  1212                   (relearns, flops, next_commit)
  1195                   (relearns, flops, next_commit)
  1213               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1196               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1214             in ((relearns, flops), (n, next_commit, timed_out)) end
  1197             in
       
  1198               ((relearns, flops), (n, next_commit, timed_out))
       
  1199             end
       
  1200 
  1215         val n =
  1201         val n =
  1216           if not run_prover then
  1202           if not run_prover then
  1217             n
  1203             n
  1218           else
  1204           else
  1219             let
  1205             let
  1220               val max_isar = 1000 * max_dependencies
  1206               val max_isar = 1000 * max_dependencies
  1221               fun kind_of_proof th =
  1207 
  1222                 try (Graph.get_node access_G) (nickname_of_thm th)
  1208               val kind_of_proof =
  1223                 |> the_default Isar_Proof
  1209                 nickname_of_thm #> try (#1 o Graph.get_node access_G) #> the_default Isar_Proof
  1224               fun priority_of (_, th) =
  1210 
       
  1211               fun priority_of th =
  1225                 random_range 0 max_isar
  1212                 random_range 0 max_isar
  1226                 + (case kind_of_proof th of
  1213                 + (case kind_of_proof th of
  1227                      Isar_Proof => 0
  1214                      Isar_Proof => 0
  1228                    | Automatic_Proof => 2 * max_isar
  1215                    | Automatic_Proof => 2 * max_isar
  1229                    | Isar_Proof_wegen_Prover_Flop => max_isar)
  1216                    | Isar_Proof_wegen_Prover_Flop => max_isar)
  1230                 - 100 * length (isar_dependencies_of name_tabs th)
  1217                 - 100 * length (isar_dependencies_of name_tabs th)
  1231               val old_facts =
  1218 
  1232                 facts |> filter is_in_access_G
  1219               val old_facts = facts
  1233                       |> map (`priority_of)
  1220                 |> filter is_in_access_G
  1234                       |> sort (int_ord o pairself fst)
  1221                 |> map (`(priority_of o snd))
  1235                       |> map snd
  1222                 |> sort (int_ord o pairself fst)
       
  1223                 |> map snd
  1236               val ((relearns, flops), (n, _, _)) =
  1224               val ((relearns, flops), (n, _, _)) =
  1237                 (([], []), (n, next_commit_time (), false))
  1225                 (([], []), (n, next_commit_time (), false))
  1238                 |> fold relearn_old_fact old_facts
  1226                 |> fold relearn_old_fact old_facts
  1239             in commit true [] relearns flops; n end
  1227             in
       
  1228               commit true [] relearns flops; n
       
  1229             end
  1240       in
  1230       in
  1241         if verbose orelse auto_level < 2 then
  1231         if verbose orelse auto_level < 2 then
  1242           "Learned " ^ string_of_int n ^ " nontrivial " ^
  1232           "Learned " ^ string_of_int n ^ " nontrivial " ^
  1243           (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s n ^
  1233           (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s n ^
  1244           (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer)
  1234           (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer)
  1289 val mepo_weight = 0.5
  1279 val mepo_weight = 0.5
  1290 val mash_weight = 0.5
  1280 val mash_weight = 0.5
  1291 
  1281 
  1292 val max_facts_to_learn_before_query = 100
  1282 val max_facts_to_learn_before_query = 100
  1293 
  1283 
  1294 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1284 (* The threshold should be large enough so that MaSh does not get activated for Auto Sledgehammer
  1295    Sledgehammer and Try. *)
  1285    and Try. *)
  1296 val min_secs_for_learning = 15
  1286 val min_secs_for_learning = 15
  1297 
  1287 
  1298 fun relevant_facts ctxt (params as {overlord, blocking, learn, fact_filter, timeout, ...}) prover
  1288 fun relevant_facts ctxt (params as {overlord, blocking, learn, fact_filter, timeout, ...}) prover
  1299     max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1289     max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1300   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1290   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1379       | _ => [(effective_fact_filter, mesh)])
  1369       | _ => [(effective_fact_filter, mesh)])
  1380     end
  1370     end
  1381 
  1371 
  1382 fun kill_learners ctxt ({overlord, ...} : params) =
  1372 fun kill_learners ctxt ({overlord, ...} : params) =
  1383   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
  1373   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
       
  1374 
  1384 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1375 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1385 
  1376 
  1386 end;
  1377 end;