# HG changeset patch # User blanchet # Date 1400535833 -7200 # Node ID d3eed0518882b6de98449a579a295acd19367c8b # Parent 20e5b110d19b1d3c641a70ce5346eca18ef6bdca started work on MaSh/SML diff -r 20e5b110d19b -r d3eed0518882 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 @@ -14,6 +14,7 @@ type prover_result = Sledgehammer_Prover.prover_result val trace : bool Config.T + val sml : bool Config.T val MePoN : string val MaShN : string val MeShN : string @@ -34,18 +35,6 @@ val encode_features : (string list * real) list -> string val extract_suggestions : string -> string * string list - structure MaSh: - sig - val unlearn : Proof.context -> bool -> unit - val learn : Proof.context -> bool -> bool -> - (string * string list * string list list * string list) list -> unit - val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit - val query : Proof.context -> bool -> int -> - (string * string list * string list list * string list) list * string list * string list * - (string list * real) list -> - string list - end - val mash_unlearn : Proof.context -> params -> unit val is_mash_enabled : unit -> bool val nickname_of_thm : thm -> string @@ -98,6 +87,7 @@ open Sledgehammer_MePo val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false) +val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false) fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else () @@ -117,12 +107,18 @@ val relearn_isarN = "relearn_isar" val relearn_proverN = "relearn_prover" -fun mash_model_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir -val mash_state_dir = mash_model_dir +fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state") +fun wipe_out_mash_state_dir () = + let val path = mash_state_dir () in + try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path) + NONE; + () + end -(*** Low-level communication with MaSh ***) + +(*** Low-level communication with Python version of MaSh ***) val save_models_arg = "--saveModels" val shutdown_server_arg = "--shutdownServer" @@ -145,7 +141,7 @@ val sugg_path = Path.explode sugg_file val cmd_file = temp_dir ^ "/mash_commands" ^ serial val cmd_path = Path.explode cmd_file - val model_dir = File.shell_path (mash_model_dir ()) + val model_dir = File.shell_path (mash_state_dir ()) val command = "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \ \PYTHONDONTWRITEBYTECODE=y ./mash.py\ @@ -238,52 +234,65 @@ [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs)) | _ => ("", [])) -structure MaSh = +structure MaSh_Py = struct fun shutdown ctxt overlord = - (trace_msg ctxt (K "MaSh shutdown"); + (trace_msg ctxt (K "MaSh_Py shutdown"); run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ())) fun save ctxt overlord = - (trace_msg ctxt (K "MaSh save"); + (trace_msg ctxt (K "MaSh_Py save"); run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())) fun unlearn ctxt overlord = - let val path = mash_model_dir () in - trace_msg ctxt (K "MaSh unlearn"); - shutdown ctxt overlord; - try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path) - NONE; - () - end + (trace_msg ctxt (K "MaSh_Py unlearn"); + shutdown ctxt overlord; + wipe_out_mash_state_dir ()) fun learn _ _ _ [] = () | learn ctxt overlord save learns = - let val names = elide_string 1000 (space_implode " " (map #1 learns)) in - (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names)); - run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn) - (K ())) - end + (trace_msg ctxt (fn () => + let val names = elide_string 1000 (space_implode " " (map #1 learns)) in + "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names) + end); + run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn) + (K ())) fun relearn _ _ _ [] = () | relearn ctxt overlord save relearns = - (trace_msg ctxt (fn () => "MaSh relearn " ^ + (trace_msg ctxt (fn () => "MaSh_Py relearn " ^ elide_string 1000 (space_implode " " (map #1 relearns))); run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (relearns, str_of_relearn) (K ())) fun query ctxt overlord max_suggs (query as (_, _, _, feats)) = - (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats); + (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats); run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs => - (case suggs () of - [] => [] - | suggs => snd (extract_suggestions (List.last suggs)))) + (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs)))) handle List.Empty => []) end; +(*** Standard ML version of MaSh ***) + +structure MaSh_SML = +struct + +fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list) + max_suggs (query as (_, _, _, feats)) = + (trace_msg ctxt (fn () => + let val names = elide_string 1000 (space_implode " " (map #1 learns)) in + "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^ + "MaSh_SML query " ^ encode_features feats + end); + (* Implementation missing *) + []) + +end; + + (*** Middle-level communication with MaSh ***) datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop @@ -364,7 +373,9 @@ (try_graph ctxt "loading state" Graph.empty (fn () => fold add_node node_lines Graph.empty), length node_lines) - | LESS => (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *) + | LESS => + (if Config.get ctxt sml then wipe_out_mash_state_dir () + else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *) | GREATER => raise FILE_VERSION_TOO_NEW ()) in trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")"); @@ -411,7 +422,9 @@ fun clear_state ctxt overlord = (* "unlearn" also removes the state file *) - Synchronized.change global_state (fn _ => (MaSh.unlearn ctxt overlord; (false, empty_state))) + Synchronized.change global_state (fn _ => + (if Config.get ctxt sml then wipe_out_mash_state_dir () + else MaSh_Py.unlearn ctxt overlord; (false, empty_state))) end @@ -953,10 +966,8 @@ [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])), (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])), (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))] - val unknown = - raw_unknown - |> fold (subtract (eq_snd Thm.eq_thm_prop)) - [unknown_chained, unknown_proximate] + val unknown = raw_unknown + |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate] in (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown) end @@ -964,6 +975,13 @@ fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t []) +fun learn_of_graph graph = + let + fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps) + in + Graph.schedule sched graph + end + fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt @@ -983,31 +1001,34 @@ val (access_G, suggs) = peek_state ctxt overlord (fn {access_G, ...} => - if Graph.is_empty access_G then - (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, [])) - else - let - val parents = maximal_wrt_access_graph access_G facts - val goal_feats = - features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts) - val chained_feats = chained - |> map (rpair 1.0) - |> map (chained_or_extra_features_of chained_feature_factor) - |> rpair [] |-> fold (union (eq_fst (op =))) - val extra_feats = facts - |> take (Int.max (0, num_extra_feature_facts - length chained)) - |> filter fact_has_right_theory - |> weight_facts_steeply - |> map (chained_or_extra_features_of extra_feature_factor) - |> rpair [] |-> fold (union (eq_fst (op =))) - val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats - |> debug ? sort (Real.compare o swap o pairself snd) - val hints = chained - |> filter (is_fact_in_graph access_G o snd) - |> map (nickname_of_thm o snd) - in - (access_G, MaSh.query ctxt overlord max_facts ([], hints, parents, feats)) - end) + if Graph.is_empty access_G then + (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, [])) + else + let + val parents = maximal_wrt_access_graph access_G facts + val goal_feats = + features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts) + val chained_feats = chained + |> map (rpair 1.0) + |> map (chained_or_extra_features_of chained_feature_factor) + |> rpair [] |-> fold (union (eq_fst (op =))) + val extra_feats = facts + |> take (Int.max (0, num_extra_feature_facts - length chained)) + |> filter fact_has_right_theory + |> weight_facts_steeply + |> map (chained_or_extra_features_of extra_feature_factor) + |> rpair [] |-> fold (union (eq_fst (op =))) + val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats + |> debug ? sort (Real.compare o swap o pairself snd) + val hints = chained + |> filter (is_fact_in_graph access_G o snd) + |> map (nickname_of_thm o snd) + in + (access_G, + (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G) + else MaSh_Py.query ctxt overlord) + max_facts ([], hints, parents, feats)) + end) val unknown = filter_out (is_fact_in_graph access_G o snd) facts in find_mash_suggestions ctxt max_facts suggs facts chained unknown @@ -1060,14 +1081,15 @@ val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst in peek_state ctxt overlord (fn {access_G, ...} => - let - val parents = maximal_wrt_access_graph access_G facts - val deps = - used_ths |> filter (is_fact_in_graph access_G) - |> map nickname_of_thm - in - MaSh.learn ctxt overlord true [("", parents, feats, deps)] - end); + let + val parents = maximal_wrt_access_graph access_G facts + val deps = + used_ths |> filter (is_fact_in_graph access_G) + |> map nickname_of_thm + in + if Config.get ctxt sml then () (* TODO: implement *) + else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)] + end); (true, "") end) else @@ -1126,8 +1148,11 @@ (false, SOME names, []) => SOME (map #1 learns @ names) | _ => NONE) in - MaSh.learn ctxt overlord (save andalso null relearns) (rev learns); - MaSh.relearn ctxt overlord save relearns; + if Config.get ctxt sml then + () + else + (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns); + MaSh_Py.relearn ctxt overlord save relearns); {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} end @@ -1363,7 +1388,7 @@ |> Par_List.map (apsnd (fn f => f ())) val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take in - if save then MaSh.save ctxt overlord else (); + if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord; (case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), @@ -1372,7 +1397,8 @@ end fun kill_learners ctxt ({overlord, ...} : params) = - (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord) + (Async_Manager.kill_threads MaShN "learner"; + if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord) fun running_learners () = Async_Manager.running_threads MaShN "learner"