--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
@@ -653,12 +653,13 @@
Graph.default_node (parent, (Isar_Proof, [], []))
#> Graph.add_edge (parent, name)
-fun add_node kind name parents feats deps (access_G, (fact_xtab, feat_xtab)) =
+fun add_node kind name parents feats deps (access_G, (fact_xtab, feat_xtab), learns) =
((Graph.new_node (name, (kind, feats, deps)) access_G
handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
|> fold (add_edge_to name) parents,
(maybe_add_to_xtab name fact_xtab,
- fold maybe_add_to_xtab feats feat_xtab))
+ fold maybe_add_to_xtab feats feat_xtab),
+ (name, feats, deps) :: learns)
fun try_graph ctxt when def f =
f ()
@@ -691,7 +692,6 @@
val empty_xtabs = (empty_xtab, empty_xtab)
val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
-val empty_graphxx = (Graph.empty, empty_xtabs)
val empty_state =
{access_G = Graph.empty,
@@ -700,6 +700,16 @@
freqs = empty_freqs,
dirty_facts = SOME []} : mash_state
+fun recompute_ffd_freqs_from_learns learns ((num_facts, fact_tab), (num_feats, feat_tab)) freqs =
+ let
+ val fact_names = Vector.fromList (map #1 learns)
+ val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+ val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+ in
+ ((fact_names, featss, depss),
+ MaSh_SML.learn_facts freqs 0 num_facts num_feats depss featss)
+ end
+
fun reorder_learns (num_facts, fact_tab) learns =
let val ary = Array.array (num_facts, ("", [], [])) in
List.app (fn learn as (fact, _, _) =>
@@ -708,22 +718,13 @@
Array.foldr (op ::) [] ary
end
-fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) =
+fun recompute_ffd_freqs_from_access_G access_G (xtabs as (fact_xtab, _)) =
let
val learns =
Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
|> reorder_learns fact_xtab
-
- val fact_names = Vector.fromList (map #1 learns)
- val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
- val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
-
- val tfreq = Vector.tabulate (num_facts, K 0)
- val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
- val dffreq = Vector.tabulate (num_feats, K 0)
in
- ((fact_names, featss, depss),
- MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss)
+ recompute_ffd_freqs_from_learns learns xtabs empty_freqs
end
local
@@ -759,19 +760,21 @@
NONE => I (* should not happen *)
| SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
- val (access_G, xtabs) =
+ val empty_G_etc = (Graph.empty, empty_xtabs, [])
+
+ val (access_G, xtabs, rev_learns) =
(case string_ord (version', version) of
EQUAL =>
- try_graph ctxt "loading state" empty_graphxx
- (fn () => fold extract_line_and_add_node node_lines empty_graphxx)
+ try_graph ctxt "loading state" empty_G_etc
+ (fn () => fold extract_line_and_add_node node_lines empty_G_etc)
| LESS =>
(* cannot parse old file *)
(if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
else wipe_out_mash_state_dir ();
- empty_graphxx)
+ empty_G_etc)
| GREATER => raise FILE_VERSION_TOO_NEW ())
- val (ffds, freqs) = recompute_ffd_freqs access_G xtabs
+ val (ffds, freqs) = recompute_ffd_freqs_from_learns (rev rev_learns) xtabs empty_freqs
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
{access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
@@ -1470,10 +1473,10 @@
else
let
val name = learned_proof_name ()
- val (access_G', xtabs') =
- add_node Automatic_Proof name parents feats deps (access_G, xtabs)
+ val (access_G', xtabs', learns) =
+ add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])
- val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs'
+ val (ffds', freqs') = recompute_ffd_freqs_from_access_G access_G' xtabs'
in
{access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
dirty_facts = Option.map (cons name) dirty_facts}
@@ -1539,7 +1542,7 @@
(false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
| _ => NONE)
- val (ffds', freqs') = recompute_ffd_freqs access_G xtabs
+ val (ffds', freqs') = recompute_ffd_freqs_from_access_G access_G xtabs
in
if engine = MaSh_Py then
(MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);