--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
@@ -14,6 +14,7 @@
type prover_result = Sledgehammer_Prover.prover_result
val trace : bool Config.T
+ val sml : bool Config.T
val MePoN : string
val MaShN : string
val MeShN : string
@@ -34,18 +35,6 @@
val encode_features : (string list * real) list -> string
val extract_suggestions : string -> string * string list
- structure MaSh:
- sig
- val unlearn : Proof.context -> bool -> unit
- val learn : Proof.context -> bool -> bool ->
- (string * string list * string list list * string list) list -> unit
- val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
- val query : Proof.context -> bool -> int ->
- (string * string list * string list list * string list) list * string list * string list *
- (string list * real) list ->
- string list
- end
-
val mash_unlearn : Proof.context -> params -> unit
val is_mash_enabled : unit -> bool
val nickname_of_thm : thm -> string
@@ -98,6 +87,7 @@
open Sledgehammer_MePo
val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
+val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
@@ -117,12 +107,18 @@
val relearn_isarN = "relearn_isar"
val relearn_proverN = "relearn_prover"
-fun mash_model_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
-val mash_state_dir = mash_model_dir
+fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
+fun wipe_out_mash_state_dir () =
+ let val path = mash_state_dir () in
+ try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
+ NONE;
+ ()
+ end
-(*** Low-level communication with MaSh ***)
+
+(*** Low-level communication with Python version of MaSh ***)
val save_models_arg = "--saveModels"
val shutdown_server_arg = "--shutdownServer"
@@ -145,7 +141,7 @@
val sugg_path = Path.explode sugg_file
val cmd_file = temp_dir ^ "/mash_commands" ^ serial
val cmd_path = Path.explode cmd_file
- val model_dir = File.shell_path (mash_model_dir ())
+ val model_dir = File.shell_path (mash_state_dir ())
val command =
"cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
\PYTHONDONTWRITEBYTECODE=y ./mash.py\
@@ -238,52 +234,65 @@
[goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
| _ => ("", []))
-structure MaSh =
+structure MaSh_Py =
struct
fun shutdown ctxt overlord =
- (trace_msg ctxt (K "MaSh shutdown");
+ (trace_msg ctxt (K "MaSh_Py shutdown");
run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
fun save ctxt overlord =
- (trace_msg ctxt (K "MaSh save");
+ (trace_msg ctxt (K "MaSh_Py save");
run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
fun unlearn ctxt overlord =
- let val path = mash_model_dir () in
- trace_msg ctxt (K "MaSh unlearn");
- shutdown ctxt overlord;
- try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
- NONE;
- ()
- end
+ (trace_msg ctxt (K "MaSh_Py unlearn");
+ shutdown ctxt overlord;
+ wipe_out_mash_state_dir ())
fun learn _ _ _ [] = ()
| learn ctxt overlord save learns =
- let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
- (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
- run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
- (K ()))
- end
+ (trace_msg ctxt (fn () =>
+ let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
+ "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
+ end);
+ run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
+ (K ()))
fun relearn _ _ _ [] = ()
| relearn ctxt overlord save relearns =
- (trace_msg ctxt (fn () => "MaSh relearn " ^
+ (trace_msg ctxt (fn () => "MaSh_Py relearn " ^
elide_string 1000 (space_implode " " (map #1 relearns)));
run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
(relearns, str_of_relearn) (K ()))
fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
- (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
+ (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats);
run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs =>
- (case suggs () of
- [] => []
- | suggs => snd (extract_suggestions (List.last suggs))))
+ (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs))))
handle List.Empty => [])
end;
+(*** Standard ML version of MaSh ***)
+
+structure MaSh_SML =
+struct
+
+fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
+ max_suggs (query as (_, _, _, feats)) =
+ (trace_msg ctxt (fn () =>
+ let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
+ "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
+ "MaSh_SML query " ^ encode_features feats
+ end);
+ (* Implementation missing *)
+ [])
+
+end;
+
+
(*** Middle-level communication with MaSh ***)
datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
@@ -364,7 +373,9 @@
(try_graph ctxt "loading state" Graph.empty (fn () =>
fold add_node node_lines Graph.empty),
length node_lines)
- | LESS => (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
+ | LESS =>
+ (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+ else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
@@ -411,7 +422,9 @@
fun clear_state ctxt overlord =
(* "unlearn" also removes the state file *)
- Synchronized.change global_state (fn _ => (MaSh.unlearn ctxt overlord; (false, empty_state)))
+ Synchronized.change global_state (fn _ =>
+ (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+ else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
end
@@ -953,10 +966,8 @@
[(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
(0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
(0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
- val unknown =
- raw_unknown
- |> fold (subtract (eq_snd Thm.eq_thm_prop))
- [unknown_chained, unknown_proximate]
+ val unknown = raw_unknown
+ |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate]
in
(mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown)
end
@@ -964,6 +975,13 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
+fun learn_of_graph graph =
+ let
+ fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
+ in
+ Graph.schedule sched graph
+ end
+
fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
@@ -983,31 +1001,34 @@
val (access_G, suggs) =
peek_state ctxt overlord (fn {access_G, ...} =>
- if Graph.is_empty access_G then
- (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
- else
- let
- val parents = maximal_wrt_access_graph access_G facts
- val goal_feats =
- features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
- val chained_feats = chained
- |> map (rpair 1.0)
- |> map (chained_or_extra_features_of chained_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val extra_feats = facts
- |> take (Int.max (0, num_extra_feature_facts - length chained))
- |> filter fact_has_right_theory
- |> weight_facts_steeply
- |> map (chained_or_extra_features_of extra_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
- |> debug ? sort (Real.compare o swap o pairself snd)
- val hints = chained
- |> filter (is_fact_in_graph access_G o snd)
- |> map (nickname_of_thm o snd)
- in
- (access_G, MaSh.query ctxt overlord max_facts ([], hints, parents, feats))
- end)
+ if Graph.is_empty access_G then
+ (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
+ else
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val goal_feats =
+ features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+ val chained_feats = chained
+ |> map (rpair 1.0)
+ |> map (chained_or_extra_features_of chained_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val extra_feats = facts
+ |> take (Int.max (0, num_extra_feature_facts - length chained))
+ |> filter fact_has_right_theory
+ |> weight_facts_steeply
+ |> map (chained_or_extra_features_of extra_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+ |> debug ? sort (Real.compare o swap o pairself snd)
+ val hints = chained
+ |> filter (is_fact_in_graph access_G o snd)
+ |> map (nickname_of_thm o snd)
+ in
+ (access_G,
+ (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
+ else MaSh_Py.query ctxt overlord)
+ max_facts ([], hints, parents, feats))
+ end)
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
find_mash_suggestions ctxt max_facts suggs facts chained unknown
@@ -1060,14 +1081,15 @@
val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
in
peek_state ctxt overlord (fn {access_G, ...} =>
- let
- val parents = maximal_wrt_access_graph access_G facts
- val deps =
- used_ths |> filter (is_fact_in_graph access_G)
- |> map nickname_of_thm
- in
- MaSh.learn ctxt overlord true [("", parents, feats, deps)]
- end);
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val deps =
+ used_ths |> filter (is_fact_in_graph access_G)
+ |> map nickname_of_thm
+ in
+ if Config.get ctxt sml then () (* TODO: implement *)
+ else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
+ end);
(true, "")
end)
else
@@ -1126,8 +1148,11 @@
(false, SOME names, []) => SOME (map #1 learns @ names)
| _ => NONE)
in
- MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
- MaSh.relearn ctxt overlord save relearns;
+ if Config.get ctxt sml then
+ ()
+ else
+ (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
+ MaSh_Py.relearn ctxt overlord save relearns);
{access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
end
@@ -1363,7 +1388,7 @@
|> Par_List.map (apsnd (fn f => f ()))
val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
in
- if save then MaSh.save ctxt overlord else ();
+ if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
(case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1372,7 +1397,8 @@
end
fun kill_learners ctxt ({overlord, ...} : params) =
- (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
+ (Async_Manager.kill_threads MaShN "learner";
+ if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
fun running_learners () = Async_Manager.running_threads MaShN "learner"