--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
@@ -616,10 +616,10 @@
MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
| MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
-fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
- visible_facts max_suggs goal_feats int_goal_feats =
+fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
+ (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
(trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
- elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
+ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
(case engine of
MaSh_SML_kNN =>
let
@@ -632,7 +632,7 @@
k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
end
| MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
- |> map (curry Vector.sub facts o fst))
+ |> map (curry Vector.sub fact_names o fst))
end;
@@ -684,14 +684,47 @@
type mash_state =
{access_G : (proof_kind * string list * string list) Graph.T,
xtabs : xtab * xtab,
+ ffds : string vector * int list vector * int list vector,
+ freqs : int vector * int Inttab.table vector * int vector,
dirty_facts : string list option}
+val empty_xtabs = (empty_xtab, empty_xtab)
+val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
+val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
+val empty_graphxx = (Graph.empty, empty_xtabs)
+
val empty_state =
{access_G = Graph.empty,
- xtabs = (empty_xtab, empty_xtab),
+ xtabs = empty_xtabs,
+ ffds = empty_ffds,
+ freqs = empty_freqs,
dirty_facts = SOME []} : mash_state
-val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
+fun reorder_learns (num_facts, fact_tab) learns =
+ let val ary = Array.array (num_facts, ("", [], [])) in
+ List.app (fn learn as (fact, _, _) =>
+ Array.update (ary, the (Symtab.lookup fact_tab fact), learn))
+ learns;
+ Array.foldr (op ::) [] ary
+ end
+
+fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) =
+ let
+ val learns =
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ |> reorder_learns fact_xtab
+
+ val fact_names = Vector.fromList (map #1 learns)
+ val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+ val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+
+ val tfreq = Vector.tabulate (num_facts, K 0)
+ val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+ val dffreq = Vector.tabulate (num_feats, K 0)
+ in
+ ((fact_names, featss, depss),
+ MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss)
+ end
local
@@ -737,9 +770,11 @@
else wipe_out_mash_state_dir ();
empty_graphxx)
| GREATER => raise FILE_VERSION_TOO_NEW ())
+
+ val (ffds, freqs) = recompute_ffd_freqs access_G xtabs
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
- {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []}
+ {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
end
| _ => empty_state)))
end
@@ -749,7 +784,7 @@
encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
- | save_state ctxt (memory_time, {access_G, xtabs, dirty_facts}) =
+ | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) =
let
fun append_entry (name, ((kind, feats, deps), (parents, _))) =
cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -770,7 +805,8 @@
(case dirty_facts of
SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
| _ => "") ^ ")");
- (Time.now (), {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []})
+ (Time.now (),
+ {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []})
end
val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1275,16 +1311,6 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
-fun reorder_learns (num_facts, fact_tab) learns0 =
- let
- val learns = Array.array (num_facts, ("", [], []))
- in
- List.app (fn learn as (fact, _, _) =>
- Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
- learns0;
- Array.foldr (op ::) [] learns
- end
-
fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
@@ -1333,9 +1359,9 @@
(parents, hints, feats)
end
- val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
- peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
- ((access_G, xtabs),
+ val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) =
+ peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} =>
+ ((access_G, xtabs, ffds, freqs),
if Graph.is_empty access_G then
(trace_msg ctxt (K "Nothing has been learned yet"); [])
else if engine = MaSh_Py then
@@ -1364,25 +1390,10 @@
end
else
let
- val learns0 =
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
- val learns = reorder_learns fact_xtab learns0
-
- val facts = Vector.fromList (map #1 learns)
- val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
- val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
-
- val tfreq = Vector.tabulate (num_facts, K 0)
- val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
- val dffreq = Vector.tabulate (num_feats, K 0)
-
- val freqs' =
- MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
-
val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
in
- MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
- visible_facts max_suggs goal_feats int_goal_feats
+ MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
+ max_suggs goal_feats int_goal_feats
end
end
@@ -1447,7 +1458,7 @@
val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
in
map_state ctxt overlord
- (fn state as {access_G, xtabs, dirty_facts} =>
+ (fn state as {access_G, xtabs, ffds, freqs, dirty_facts} =>
let
val parents = maximal_wrt_access_graph access_G facts
val deps = used_ths
@@ -1459,10 +1470,12 @@
else
let
val name = learned_proof_name ()
- val (access_G, xtabs) =
+ val (access_G', xtabs') =
add_node Automatic_Proof name parents feats deps (access_G, xtabs)
+
+ val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs'
in
- {access_G = access_G, xtabs = xtabs,
+ {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
dirty_facts = Option.map (cons name) dirty_facts}
end
end);
@@ -1510,26 +1523,31 @@
isar_dependencies_of name_tabs th
fun do_commit [] [] [] state = state
- | do_commit learns relearns flops {access_G, xtabs, dirty_facts} =
+ | do_commit learns relearns flops {access_G, xtabs, ffds, freqs, dirty_facts} =
let
+ val was_empty = Graph.is_empty access_G
+
+ (* TODO: use "fold_map" *)
val (learns, (access_G, xtabs)) =
fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
val (relearns, access_G) =
fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
- val was_empty = Graph.is_empty access_G
val access_G = access_G |> fold flop_wrt_access_graph flops
val dirty_facts =
(case (was_empty, dirty_facts) of
(false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
| _ => NONE)
+
+ val (ffds', freqs') = recompute_ffd_freqs access_G xtabs
in
if engine = MaSh_Py then
(MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
MaSh_Py.relearn ctxt overlord save relearns)
else
();
- {access_G = access_G, xtabs = xtabs, dirty_facts = dirty_facts}
+ {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs',
+ dirty_facts = dirty_facts}
end
fun commit last learns relearns flops =