--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:28:05 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:52:59 2014 +0200
@@ -15,7 +15,6 @@
type prover_result = Sledgehammer_Prover.prover_result
val trace : bool Config.T
- val sml : bool Config.T
val MePoN : string
val MaShN : string
val MeShN : string
@@ -37,7 +36,6 @@
val extract_suggestions : string -> string * string list
val mash_unlearn : Proof.context -> params -> unit
- val is_mash_enabled : unit -> bool
val nickname_of_thm : thm -> string
val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
@@ -88,7 +86,6 @@
open Sledgehammer_MePo
val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
-val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
@@ -118,6 +115,25 @@
()
end
+datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+
+fun mash_flavor () =
+ (case getenv "MASH" of
+ "yes" => SOME MaSh_Py
+ | "py" => SOME MaSh_Py
+ | "sml" => SOME MaSh_SML_KNN
+ | "sml_knn" => SOME MaSh_SML_KNN
+ | "sml_nb" => SOME MaSh_SML_NB
+ | _ => NONE)
+
+val is_mash_enabled = is_some o mash_flavor
+
+fun is_mash_sml_enabled () =
+ (case mash_flavor () of
+ SOME MaSh_SML_KNN => true
+ | SOME MaSh_SML_NB => true
+ | _ => false)
+
(*** Low-level communication with Python version of MaSh ***)
@@ -284,22 +300,16 @@
structure MaSh_SML =
struct
-fun max a b = if a > b then a else b
-
exception BOTTOM of int
fun heap cmp bnd a =
let
fun maxson l i =
- let
- val i31 = i + i + i + 1
- in
+ let val i31 = i + i + i + 1 in
if i31 + 2 < l then
- let
- val x = Unsynchronized.ref i31;
- val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
- val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
- in
+ let val x = Unsynchronized.ref i31 in
+ if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
+ if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
!x
end
else
@@ -354,7 +364,7 @@
val _ = for (((l + 1) div 3) - 1)
fun for2 i =
- if i < max 2 (l - bnd) then () else
+ if i < Integer.max 2 (l - bnd) then () else
let
val e = Array.sub (a, i)
val _ = Array.update (a, i, Array.sub (a, 0))
@@ -387,51 +397,57 @@
fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
let
(* Can be later used for TFIDF *)
- fun sym_wght _ = 1.0;
- val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
+ fun sym_wght _ = 1.0
+
+ val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+
fun inc_overlap j v =
let
- val ov = snd (Array.sub (overlaps_sqr,j))
+ val ov = snd (Array.sub (overlaps_sqr, j))
in
Array.update (overlaps_sqr, j, (j, v + ov))
- end;
+ end
+
fun do_sym (s, con_wght) =
let
- val sw = sym_wght s;
- val w2 = sw * sw * con_wght;
+ val sw = sym_wght s
+ val w2 = sw * sw * con_wght
+
fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
in
- ignore (map do_th (get_sym_ths s))
- end;
- val () = ignore (map do_sym syms);
- val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
- val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
+ List.app do_th (get_sym_ths s)
+ end
+
+ val _ = List.app do_sym syms
+ val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+ val recommends = Array.tabulate (adv_max, rpair 0.0)
+
fun inc_recommend j v =
- if j >= adv_max then () else
- let
- val ov = snd (Array.sub (recommends,j))
- in
- Array.update (recommends, j, (j, v + ov))
- end;
+ if j >= adv_max then ()
+ else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
+
fun for k =
- if k = knns then () else
- if k >= adv_max then () else
- let
- val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
- val o1 = Math.sqrt o2;
- val () = inc_recommend j o1;
- val ds = get_deps j;
- val l = Real.fromInt (length ds);
- val _ = map (fn d => inc_recommend d (o1 / l)) ds
- in
- for (k + 1)
- end;
- val () = for 0;
- val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+ if k = knns orelse k >= adv_max then
+ ()
+ else
+ let
+ val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+ val o1 = Math.sqrt o2
+ val _ = inc_recommend j o1
+ val ds = get_deps j
+ val l = Real.fromInt (length ds)
+ val _ = map (fn d => inc_recommend d (o1 / l)) ds
+ in
+ for (k + 1)
+ end
+
+ val _ = for 0
+ val _ = heap (Real.compare o pairself snd) advno recommends
+
fun ret acc at =
- if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+ if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
- ret [] (max 0 (adv_max - advno))
+ ret [] (Integer.max 0 (adv_max - advno))
end
val knns = 40 (* FUDGE *)
@@ -440,7 +456,7 @@
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-fun learn_and_query ctxt parents access_G max_suggs hints feats =
+fun query ctxt parents access_G max_suggs hints feats =
let
val str_of_feat = space_implode "|"
@@ -469,9 +485,9 @@
all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
val facts = rev rev_facts
- val fact_ary = Array.fromList facts
+ val fact_vec = Vector.fromList facts
- val deps_ary = Array.fromList (rev rev_depss)
+ val deps_vec = Vector.fromList (rev rev_depss)
val facts_ary = Array.array (num_feats, [])
val _ =
fold (fn feats => fn fact =>
@@ -484,11 +500,11 @@
in
trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
elide_string 1000 (space_implode " " facts) ^ "}");
- knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
+ knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
(curry Array.sub facts_ary) knns max_suggs
(map_filter (fn (feat, weight) =>
Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
- |> map ((fn i => Array.sub (fact_ary, i)) o fst)
+ |> map (curry Vector.sub fact_vec o fst)
end
end;
@@ -578,7 +594,7 @@
fold extract_line_and_add_node node_lines Graph.empty),
length node_lines)
| LESS =>
- (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+ (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
@@ -625,10 +641,10 @@
Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
fun clear_state ctxt overlord =
- (* "unlearn" also removes the state file *)
+ (* "MaSh_Py.unlearn" also removes the state file *)
Synchronized.change global_state (fn _ =>
- (if Config.get ctxt sml then wipe_out_mash_state_dir ()
- else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
+ (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+ (false, empty_state)))
end
@@ -638,8 +654,6 @@
(*** Isabelle helpers ***)
-fun is_mash_enabled () = (getenv "MASH" = "yes")
-
val local_prefix = "local" ^ Long_Name.separator
fun elided_backquote_thm threshold th =
@@ -971,9 +985,6 @@
| NONE => false)
| is_size_def _ _ = false
-fun no_dependencies_for_status status =
- status = Non_Rec_Def orelse status = Rec_Def
-
fun trim_dependencies deps =
if length deps > max_dependencies then NONE else SOME deps
@@ -1022,18 +1033,17 @@
val num_isar_deps = length isar_deps
in
if verbose andalso auto_level = 0 then
- "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
- " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
- |> Output.urgent_message
+ Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
+ string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
+ " facts.")
else
();
(case run_prover_for_mash ctxt params prover name facts goal of
{outcome = NONE, used_facts, ...} =>
(if verbose andalso auto_level = 0 then
let val num_facts = length used_facts in
- "Found proof with " ^ string_of_int num_facts ^ " fact" ^
- plural_s num_facts ^ "."
- |> Output.urgent_message
+ Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
+ plural_s num_facts ^ ".")
end
else
();
@@ -1187,40 +1197,57 @@
|> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
|> map (apsnd (fn r => weight * factor * r))
- val (access_G, suggs) =
+ fun query_args access_G =
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val hints = chained
+ |> filter (is_fact_in_graph access_G o snd)
+ |> map (nickname_of_thm o snd)
+
+ val goal_feats =
+ features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+ val chained_feats = chained
+ |> map (rpair 1.0)
+ |> map (chained_or_extra_features_of chained_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val extra_feats = facts
+ |> take (Int.max (0, num_extra_feature_facts - length chained))
+ |> filter fact_has_right_theory
+ |> weight_facts_steeply
+ |> map (chained_or_extra_features_of extra_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+ |> debug ? sort (Real.compare o swap o pairself snd)
+ in
+ (parents, hints, feats)
+ end
+
+ val sml = is_mash_sml_enabled ()
+
+ val (access_G, py_suggs) =
peek_state ctxt overlord (fn {access_G, ...} =>
if Graph.is_empty access_G then
(trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
else
- let
- val parents = maximal_wrt_access_graph access_G facts
- val goal_feats =
- features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
- val chained_feats = chained
- |> map (rpair 1.0)
- |> map (chained_or_extra_features_of chained_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val extra_feats = facts
- |> take (Int.max (0, num_extra_feature_facts - length chained))
- |> filter fact_has_right_theory
- |> weight_facts_steeply
- |> map (chained_or_extra_features_of extra_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
- |> debug ? sort (Real.compare o swap o pairself snd)
- val hints = chained
- |> filter (is_fact_in_graph access_G o snd)
- |> map (nickname_of_thm o snd)
- in
- (access_G,
- if Config.get ctxt sml then
- MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
- else
- MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
- end)
+ (access_G,
+ if sml then
+ []
+ else
+ let val (parents, hints, feats) = query_args access_G in
+ MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+ end))
+
+ val sml_suggs =
+ if sml then
+ let val (parents, hints, feats) = query_args access_G in
+ MaSh_SML.query ctxt parents access_G max_facts hints feats
+ end
+ else
+ []
+
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
- find_mash_suggestions ctxt max_facts suggs facts chained unknown
+ find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
|> pairself (map fact_of_raw_fact)
end
@@ -1280,7 +1307,7 @@
|> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
in
- if Config.get ctxt sml then
+ if is_mash_sml_enabled () then
let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
{access_G = access_G, num_known_facts = num_known_facts + 1,
dirty = Option.map (cons name) dirty}
@@ -1305,6 +1332,7 @@
val timer = Timer.startRealTimer ()
fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
+ val sml = is_mash_sml_enabled ()
val {access_G, ...} = peek_state ctxt overlord I
val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
@@ -1323,7 +1351,7 @@
val name_tabs = build_name_tables nickname_of_thm facts
fun deps_of status th =
- if no_dependencies_for_status status then
+ if status = Non_Rec_Def orelse status = Rec_Def then
SOME []
else if run_prover then
prover_dependencies_of ctxt params prover auto_level facts name_tabs th
@@ -1346,7 +1374,7 @@
(false, SOME names, []) => SOME (map #1 learns @ names)
| _ => NONE)
in
- if Config.get ctxt sml then
+ if sml then
()
else
(MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
@@ -1355,18 +1383,13 @@
end
fun commit last learns relearns flops =
- (if debug andalso auto_level = 0 then
- Output.urgent_message "Committing..."
- else
- ();
+ (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
map_state ctxt overlord (do_commit (rev learns) relearns flops);
if not last andalso auto_level = 0 then
let val num_proofs = length learns + length relearns in
- "Learned " ^ string_of_int num_proofs ^ " " ^
- (if run_prover then "automatic" else "Isar") ^ " proof" ^
- plural_s num_proofs ^ " in the last " ^
- string_of_time commit_timeout ^ "."
- |> Output.urgent_message
+ Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
+ (if run_prover then "automatic" else "Isar") ^ " proof" ^
+ plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
end
else
())
@@ -1478,14 +1501,12 @@
|> Output.urgent_message
in
if run_prover then
- ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
- " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
- ").\n\nCollecting Isar proofs first..."
- |> Output.urgent_message;
+ (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+ plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
+ string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
learn 1 false;
- "Now collecting automatic proofs. This may take several hours. You can safely stop the \
- \learning process at any point."
- |> Output.urgent_message;
+ Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
+ \can safely stop the learning process at any point.";
learn 0 true)
else
(Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
@@ -1530,6 +1551,7 @@
end
else
()
+
fun maybe_learn () =
if is_mash_enabled () andalso learn then
let
@@ -1551,6 +1573,7 @@
end
else
false
+
val (save, effective_fact_filter) =
(case fact_filter of
SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
@@ -1565,18 +1588,22 @@
val add_ths = Attrib.eval_thms ctxt add
fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
+
fun add_and_take accepts =
(case add_ths of
[] => accepts
| _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
(accepts |> filter_out in_add))
|> take max_facts
+
fun mepo () =
(mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
|> weight_facts_steeply, [])
+
fun mash () =
mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
|>> weight_facts_steeply
+
val mess =
(* the order is important for the "case" expression below *)
[] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
@@ -1584,7 +1611,7 @@
|> Par_List.map (apsnd (fn f => f ()))
val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
in
- if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
+ if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
(case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1594,7 +1621,7 @@
fun kill_learners ctxt ({overlord, ...} : params) =
(Async_Manager.kill_threads MaShN "learner";
- if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
+ if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
fun running_learners () = Async_Manager.running_threads MaShN "learner"