--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:52 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:56 2014 +0200
@@ -121,8 +121,17 @@
val relearn_isarN = "relearn_isar"
val relearn_proverN = "relearn_prover"
+val hintsN = ".hints"
+
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
+type xtab = int * int Symtab.table
+
+val empty_xtab = (0, Symtab.empty)
+
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
+fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+
fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
@@ -384,10 +393,10 @@
val number_of_nearest_neighbors = 10 (* FUDGE *)
-fun select_visible_facts recommends =
+fun select_visible_facts big_number recommends =
List.app (fn at =>
let val (j, ov) = Array.sub (recommends, at) in
- Array.update (recommends, at, (j, 1000000000.0 + ov))
+ Array.update (recommends, at, (j, big_number + ov))
end)
exception EXIT of unit
@@ -461,7 +470,7 @@
in
while1 ();
while2 ();
- select_visible_facts recommends visible_facts;
+ select_visible_facts 1000000000.0 recommends visible_facts;
heap (Real.compare o pairself snd) max_suggs num_facts recommends;
ret [] (Integer.max 0 (num_facts - max_suggs))
end
@@ -540,7 +549,7 @@
fun ret at acc =
if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
in
- select_visible_facts posterior visible_facts;
+ select_visible_facts 100000.0 posterior visible_facts;
heap (Real.compare o pairself snd) max_suggs num_facts posterior;
ret (Integer.max 0 (num_facts - max_suggs)) []
end
@@ -608,13 +617,20 @@
val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
- learns conj_feats =
+ learns0 conj_feats =
if engine = MaSh_SML_kNN_Cpp then
- k_nearest_neighbors_cpp max_suggs learns conj_feats
+ k_nearest_neighbors_cpp max_suggs learns0 conj_feats
else if engine = MaSh_SML_NB_Cpp then
- naive_bayes_cpp max_suggs learns conj_feats
+ naive_bayes_cpp max_suggs learns0 conj_feats
else
let
+ val learn_ary = Array.array (num_facts, ("", [], []))
+ val _ =
+ List.app (fn entry as (fact, _, _) =>
+ Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
+ learns0
+ val learns = Array.foldr (op ::) [] learn_ary
+
val facts = map #1 learns
val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
@@ -675,10 +691,12 @@
Graph.default_node (parent, (Isar_Proof, [], []))
#> Graph.add_edge (parent, name)
-fun add_node kind name parents feats deps G =
- (Graph.new_node (name, (kind, feats, deps)) G
- handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) G)
- |> fold (add_edge_to name) parents
+fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
+ ((Graph.new_node (name, (kind, feats, deps)) access_G
+ handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
+ |> fold (add_edge_to name) parents,
+ maybe_add_to_xtab name fact_xtab,
+ fold maybe_add_to_xtab feats feat_xtab)
fun try_graph ctxt when def f =
f ()
@@ -703,10 +721,19 @@
type mash_state =
{access_G : (proof_kind * string list * string list) Graph.T,
- num_known_facts : int,
+ fact_xtab : xtab,
+ feat_xtab : xtab,
+ num_known_facts : int, (* ### FIXME: kill *)
dirty_facts : string list option}
-val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty_facts = SOME []} : mash_state
+val empty_state =
+ {access_G = Graph.empty,
+ fact_xtab = empty_xtab,
+ feat_xtab = empty_xtab,
+ num_known_facts = 0,
+ dirty_facts = SOME []} : mash_state
+
+val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
local
@@ -741,21 +768,22 @@
NONE => I (* should not happen *)
| SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
- val (access_G, num_known_facts) =
+ val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
(case string_ord (version', version) of
EQUAL =>
- (try_graph ctxt "loading state" Graph.empty (fn () =>
- fold extract_line_and_add_node node_lines Graph.empty),
+ (try_graph ctxt "loading state" empty_graphxx (fn () =>
+ fold extract_line_and_add_node node_lines empty_graphxx),
length node_lines)
| LESS =>
(* cannot parse old file *)
(if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
else wipe_out_mash_state_dir ();
- (Graph.empty, 0))
+ (empty_graphxx, 0))
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
- {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []}
+ {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+ num_known_facts = num_known_facts, dirty_facts = SOME []}
end
| _ => empty_state)))
end
@@ -765,7 +793,7 @@
encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
- | save_state ctxt (memory_time, {access_G, num_known_facts, dirty_facts}) =
+ | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
let
fun append_entry (name, ((kind, feats, deps), (parents, _))) =
cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -786,7 +814,9 @@
(case dirty_facts of
SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
| _ => "") ^ ")");
- (Time.now (), {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []})
+ (Time.now (),
+ {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+ num_known_facts = num_known_facts, dirty_facts = SOME []})
end
val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1291,11 +1321,6 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
-val empty_xtab = (0, Symtab.empty)
-
-fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
-fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
-
fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
@@ -1344,19 +1369,18 @@
(parents, hints, feats)
end
- val (access_G, py_suggs) =
- peek_state ctxt overlord (fn {access_G, ...} =>
- if Graph.is_empty access_G then
- (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
- else
- (access_G,
- if engine = MaSh_Py then
- let val (parents, hints, feats) = query_args access_G in
- MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
- |> map fst
- end
- else
- []))
+ val ((access_G, fact_xtab, feat_xtab), py_suggs) =
+ peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
+ ((access_G, fact_xtab, feat_xtab),
+ if Graph.is_empty access_G then
+ (trace_msg ctxt (K "Nothing has been learned yet"); [])
+ else if engine = MaSh_Py then
+ let val (parents, hints, feats) = query_args access_G in
+ MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+ |> map fst
+ end
+ else
+ []))
val sml_suggs =
if engine = MaSh_Py then
@@ -1367,11 +1391,8 @@
val feats = map fst feats0
val visible_facts = Graph.all_preds access_G parents
val learns =
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
- (if null hints then [] else [(".hints", feats, hints)])
-
- val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
- val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
+ (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
in
MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
feats
@@ -1383,27 +1404,33 @@
|> pairself (map fact_of_raw_fact)
end
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
+ (learns, (access_G, fact_xtab, feat_xtab)) =
let
- fun maybe_learn_from from (accum as (parents, G)) =
+ fun maybe_learn_from from (accum as (parents, access_G)) =
try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) G))
- val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
- val (parents, G) = ([], G) |> fold maybe_learn_from parents
- val (deps, _) = ([], G) |> fold maybe_learn_from deps
+ (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
+
+ val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps))
+ val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents
+ val (deps, _) = ([], access_G) |> fold maybe_learn_from deps
+
+ val fact_xtab = maybe_add_to_xtab name fact_xtab
+ val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
in
- ((name, parents, feats, deps) :: learns, G)
+ ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
end
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
let
- fun maybe_relearn_from from (accum as (parents, G)) =
+ fun maybe_relearn_from from (accum as (parents, access_G)) =
try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) G))
- val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
- val (deps, _) = ([], G) |> fold maybe_relearn_from deps
+ (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
+ val access_G =
+ access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
+ val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps
in
- ((name, deps) :: relearns, G)
+ ((name, deps) :: relearns, access_G)
end
fun flop_wrt_access_graph name =
@@ -1431,24 +1458,28 @@
val thy = Proof_Context.theory_of ctxt
val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
in
- map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty_facts} =>
- let
- val parents = maximal_wrt_access_graph access_G facts
- val deps = used_ths
- |> filter (is_fact_in_graph access_G)
- |> map nickname_of_thm
- in
- if the_mash_engine () = MaSh_Py then
- (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
- else
- let
- val name = learned_proof_name ()
- val access_G = access_G |> add_node Automatic_Proof name parents feats deps
- in
- {access_G = access_G, num_known_facts = num_known_facts + 1,
- dirty_facts = Option.map (cons name) dirty_facts}
- end
- end);
+ map_state ctxt overlord
+ (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val deps = used_ths
+ |> filter (is_fact_in_graph access_G)
+ |> map nickname_of_thm
+ in
+ if the_mash_engine () = MaSh_Py then
+ (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
+ else
+ let
+ val name = learned_proof_name ()
+ val (access_G, fact_xtab, feat_xtab) =
+ add_node Automatic_Proof name parents feats deps
+ (access_G, fact_xtab, feat_xtab)
+ in
+ {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+ num_known_facts = num_known_facts + 1,
+ dirty_facts = Option.map (cons name) dirty_facts}
+ end
+ end);
(true, "")
end)
else
@@ -1466,7 +1497,7 @@
fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
val engine = the_mash_engine ()
- val {access_G, ...} = peek_state ctxt overlord I
+ val {access_G, fact_xtab, feat_xtab, ...} = peek_state ctxt overlord I
val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
in
@@ -1493,12 +1524,15 @@
isar_dependencies_of name_tabs th
fun do_commit [] [] [] state = state
- | do_commit learns relearns flops {access_G, num_known_facts, dirty_facts} =
+ | do_commit learns relearns flops
+ {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
let
+ val (learns, (access_G, fact_xtab, feat_xtab)) =
+ fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
+ val (relearns, access_G) =
+ fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
+
val was_empty = Graph.is_empty access_G
- val (learns, access_G) = ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
- val (relearns, access_G) =
- ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
val access_G = access_G |> fold flop_wrt_access_graph flops
val num_known_facts = num_known_facts + length learns
val dirty_facts =
@@ -1511,7 +1545,8 @@
MaSh_Py.relearn ctxt overlord save relearns)
else
();
- {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = dirty_facts}
+ {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+ num_known_facts = num_known_facts, dirty_facts = dirty_facts}
end
fun commit last learns relearns flops =