--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:00:00 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:11:37 2014 +0200
@@ -284,22 +284,16 @@
structure MaSh_SML =
struct
-fun max a b = if a > b then a else b
-
exception BOTTOM of int
fun heap cmp bnd a =
let
fun maxson l i =
- let
- val i31 = i + i + i + 1
- in
+ let val i31 = i + i + i + 1 in
if i31 + 2 < l then
- let
- val x = Unsynchronized.ref i31;
- val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
- val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
- in
+ let val x = Unsynchronized.ref i31 in
+ if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
+ if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
!x
end
else
@@ -354,7 +348,7 @@
val _ = for (((l + 1) div 3) - 1)
fun for2 i =
- if i < max 2 (l - bnd) then () else
+ if i < Integer.max 2 (l - bnd) then () else
let
val e = Array.sub (a, i)
val _ = Array.update (a, i, Array.sub (a, 0))
@@ -387,51 +381,57 @@
fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
let
(* Can be later used for TFIDF *)
- fun sym_wght _ = 1.0;
- val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
+ fun sym_wght _ = 1.0
+
+ val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+
fun inc_overlap j v =
let
- val ov = snd (Array.sub (overlaps_sqr,j))
+ val ov = snd (Array.sub (overlaps_sqr, j))
in
Array.update (overlaps_sqr, j, (j, v + ov))
- end;
+ end
+
fun do_sym (s, con_wght) =
let
- val sw = sym_wght s;
- val w2 = sw * sw * con_wght;
+ val sw = sym_wght s
+ val w2 = sw * sw * con_wght
+
fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
in
- ignore (map do_th (get_sym_ths s))
- end;
- val () = ignore (map do_sym syms);
- val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
- val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
+ List.app do_th (get_sym_ths s)
+ end
+
+ val _ = List.app do_sym syms
+ val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+ val recommends = Array.tabulate (adv_max, rpair 0.0)
+
fun inc_recommend j v =
- if j >= adv_max then () else
- let
- val ov = snd (Array.sub (recommends,j))
- in
- Array.update (recommends, j, (j, v + ov))
- end;
+ if j >= adv_max then ()
+ else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
+
fun for k =
- if k = knns then () else
- if k >= adv_max then () else
- let
- val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
- val o1 = Math.sqrt o2;
- val () = inc_recommend j o1;
- val ds = get_deps j;
- val l = Real.fromInt (length ds);
- val _ = map (fn d => inc_recommend d (o1 / l)) ds
- in
- for (k + 1)
- end;
- val () = for 0;
- val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+ if k = knns orelse k >= adv_max then
+ ()
+ else
+ let
+ val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+ val o1 = Math.sqrt o2
+ val _ = inc_recommend j o1
+ val ds = get_deps j
+ val l = Real.fromInt (length ds)
+ val _ = map (fn d => inc_recommend d (o1 / l)) ds
+ in
+ for (k + 1)
+ end
+
+ val _ = for 0
+ val _ = heap (Real.compare o pairself snd) advno recommends
+
fun ret acc at =
- if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+ if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
- ret [] (max 0 (adv_max - advno))
+ ret [] (Integer.max 0 (adv_max - advno))
end
val knns = 40 (* FUDGE *)
@@ -440,7 +440,7 @@
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-fun learn_and_query ctxt parents access_G max_suggs hints feats =
+fun query ctxt parents access_G max_suggs hints feats =
let
val str_of_feat = space_implode "|"
@@ -469,9 +469,9 @@
all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
val facts = rev rev_facts
- val fact_ary = Array.fromList facts
+ val fact_vec = Vector.fromList facts
- val deps_ary = Array.fromList (rev rev_depss)
+ val deps_vec = Vector.fromList (rev rev_depss)
val facts_ary = Array.array (num_feats, [])
val _ =
fold (fn feats => fn fact =>
@@ -484,11 +484,11 @@
in
trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
elide_string 1000 (space_implode " " facts) ^ "}");
- knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
+ knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
(curry Array.sub facts_ary) knns max_suggs
(map_filter (fn (feat, weight) =>
Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
- |> map ((fn i => Array.sub (fact_ary, i)) o fst)
+ |> map (curry Vector.sub fact_vec o fst)
end
end;
@@ -625,7 +625,7 @@
Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
fun clear_state ctxt overlord =
- (* "unlearn" also removes the state file *)
+ (* "MaSh_Py.unlearn" also removes the state file *)
Synchronized.change global_state (fn _ =>
(if Config.get ctxt sml then wipe_out_mash_state_dir ()
else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
@@ -971,9 +971,6 @@
| NONE => false)
| is_size_def _ _ = false
-fun no_dependencies_for_status status =
- status = Non_Rec_Def orelse status = Rec_Def
-
fun trim_dependencies deps =
if length deps > max_dependencies then NONE else SOME deps
@@ -1022,18 +1019,17 @@
val num_isar_deps = length isar_deps
in
if verbose andalso auto_level = 0 then
- "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
- " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
- |> Output.urgent_message
+ Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
+ string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
+ " facts.")
else
();
(case run_prover_for_mash ctxt params prover name facts goal of
{outcome = NONE, used_facts, ...} =>
(if verbose andalso auto_level = 0 then
let val num_facts = length used_facts in
- "Found proof with " ^ string_of_int num_facts ^ " fact" ^
- plural_s num_facts ^ "."
- |> Output.urgent_message
+ Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
+ plural_s num_facts ^ ".")
end
else
();
@@ -1187,40 +1183,57 @@
|> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
|> map (apsnd (fn r => weight * factor * r))
- val (access_G, suggs) =
+ fun query_args access_G =
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val hints = chained
+ |> filter (is_fact_in_graph access_G o snd)
+ |> map (nickname_of_thm o snd)
+
+ val goal_feats =
+ features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+ val chained_feats = chained
+ |> map (rpair 1.0)
+ |> map (chained_or_extra_features_of chained_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val extra_feats = facts
+ |> take (Int.max (0, num_extra_feature_facts - length chained))
+ |> filter fact_has_right_theory
+ |> weight_facts_steeply
+ |> map (chained_or_extra_features_of extra_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
+ val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+ |> debug ? sort (Real.compare o swap o pairself snd)
+ in
+ (parents, hints, feats)
+ end
+
+ val sml = Config.get ctxt sml
+
+ val (access_G, py_suggs) =
peek_state ctxt overlord (fn {access_G, ...} =>
if Graph.is_empty access_G then
(trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
else
- let
- val parents = maximal_wrt_access_graph access_G facts
- val goal_feats =
- features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
- val chained_feats = chained
- |> map (rpair 1.0)
- |> map (chained_or_extra_features_of chained_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val extra_feats = facts
- |> take (Int.max (0, num_extra_feature_facts - length chained))
- |> filter fact_has_right_theory
- |> weight_facts_steeply
- |> map (chained_or_extra_features_of extra_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
- val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
- |> debug ? sort (Real.compare o swap o pairself snd)
- val hints = chained
- |> filter (is_fact_in_graph access_G o snd)
- |> map (nickname_of_thm o snd)
- in
- (access_G,
- if Config.get ctxt sml then
- MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
- else
- MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
- end)
+ (access_G,
+ if sml then
+ []
+ else
+ let val (parents, hints, feats) = query_args access_G in
+ MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+ end))
+
+ val sml_suggs =
+ if sml then
+ let val (parents, hints, feats) = query_args access_G in
+ MaSh_SML.query ctxt parents access_G max_facts hints feats
+ end
+ else
+ []
+
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
- find_mash_suggestions ctxt max_facts suggs facts chained unknown
+ find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
|> pairself (map fact_of_raw_fact)
end
@@ -1323,7 +1336,7 @@
val name_tabs = build_name_tables nickname_of_thm facts
fun deps_of status th =
- if no_dependencies_for_status status then
+ if status = Non_Rec_Def orelse status = Rec_Def then
SOME []
else if run_prover then
prover_dependencies_of ctxt params prover auto_level facts name_tabs th
@@ -1355,18 +1368,13 @@
end
fun commit last learns relearns flops =
- (if debug andalso auto_level = 0 then
- Output.urgent_message "Committing..."
- else
- ();
+ (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
map_state ctxt overlord (do_commit (rev learns) relearns flops);
if not last andalso auto_level = 0 then
let val num_proofs = length learns + length relearns in
- "Learned " ^ string_of_int num_proofs ^ " " ^
- (if run_prover then "automatic" else "Isar") ^ " proof" ^
- plural_s num_proofs ^ " in the last " ^
- string_of_time commit_timeout ^ "."
- |> Output.urgent_message
+ Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
+ (if run_prover then "automatic" else "Isar") ^ " proof" ^
+ plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
end
else
())
@@ -1478,14 +1486,12 @@
|> Output.urgent_message
in
if run_prover then
- ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
- " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
- ").\n\nCollecting Isar proofs first..."
- |> Output.urgent_message;
+ (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+ plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
+ string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
learn 1 false;
- "Now collecting automatic proofs. This may take several hours. You can safely stop the \
- \learning process at any point."
- |> Output.urgent_message;
+ Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
+ \can safely stop the learning process at any point.";
learn 0 true)
else
(Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^