# HG changeset patch # User blanchet # Date 1403793690 -7200 # Node ID fe96689f393bb6c1e134cdacd05d6219b95044d1 # Parent 73e9b858ec8de8f562aedde0db5519690bce6bee recompute learning data at learning time, not query time diff -r 73e9b858ec8d -r fe96689f393b src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200 @@ -616,10 +616,10 @@ MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats)) -fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq)) - visible_facts max_suggs goal_feats int_goal_feats = +fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss) + (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats = (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^ - elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}"); + elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); (case engine of MaSh_SML_kNN => let @@ -632,7 +632,7 @@ k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats end | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats) - |> map (curry Vector.sub facts o fst)) + |> map (curry Vector.sub fact_names o fst)) end; @@ -684,14 +684,47 @@ type mash_state = {access_G : (proof_kind * string list * string list) Graph.T, xtabs : xtab * xtab, + ffds : string vector * int list vector * int list vector, + freqs : int vector * int Inttab.table vector * int vector, dirty_facts : string list option} +val empty_xtabs = (empty_xtab, empty_xtab) +val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList []) +val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList []) +val empty_graphxx = (Graph.empty, empty_xtabs) + val empty_state = {access_G = Graph.empty, - xtabs = (empty_xtab, empty_xtab), + xtabs = empty_xtabs, + ffds = empty_ffds, + freqs = empty_freqs, dirty_facts = SOME []} : mash_state -val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab)) +fun reorder_learns (num_facts, fact_tab) learns = + let val ary = Array.array (num_facts, ("", [], [])) in + List.app (fn learn as (fact, _, _) => + Array.update (ary, the (Symtab.lookup fact_tab fact), learn)) + learns; + Array.foldr (op ::) [] ary + end + +fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) = + let + val learns = + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + |> reorder_learns fact_xtab + + val fact_names = Vector.fromList (map #1 learns) + val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns) + val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) + + val tfreq = Vector.tabulate (num_facts, K 0) + val sfreq = Vector.tabulate (num_facts, K Inttab.empty) + val dffreq = Vector.tabulate (num_feats, K 0) + in + ((fact_names, featss, depss), + MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss) + end local @@ -737,9 +770,11 @@ else wipe_out_mash_state_dir (); empty_graphxx) | GREATER => raise FILE_VERSION_TOO_NEW ()) + + val (ffds, freqs) = recompute_ffd_freqs access_G xtabs in trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")"); - {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []} + {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []} end | _ => empty_state))) end @@ -749,7 +784,7 @@ encode_strs feats ^ "; " ^ encode_strs deps ^ "\n" fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state - | save_state ctxt (memory_time, {access_G, xtabs, dirty_facts}) = + | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) = let fun append_entry (name, ((kind, feats, deps), (parents, _))) = cons (kind, name, Graph.Keys.dest parents, feats, deps) @@ -770,7 +805,8 @@ (case dirty_facts of SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)" | _ => "") ^ ")"); - (Time.now (), {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []}) + (Time.now (), + {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}) end val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state) @@ -1275,16 +1311,6 @@ fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t []) -fun reorder_learns (num_facts, fact_tab) learns0 = - let - val learns = Array.array (num_facts, ("", [], [])) - in - List.app (fn learn as (fact, _, _) => - Array.update (learns, the (Symtab.lookup fact_tab fact), learn)) - learns0; - Array.foldr (op ::) [] learns - end - fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt @@ -1333,9 +1359,9 @@ (parents, hints, feats) end - val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) = - peek_state ctxt overlord (fn {access_G, xtabs, ...} => - ((access_G, xtabs), + val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) = + peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} => + ((access_G, xtabs, ffds, freqs), if Graph.is_empty access_G then (trace_msg ctxt (K "Nothing has been learned yet"); []) else if engine = MaSh_Py then @@ -1364,25 +1390,10 @@ end else let - val learns0 = - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G - val learns = reorder_learns fact_xtab learns0 - - val facts = Vector.fromList (map #1 learns) - val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns) - val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) - - val tfreq = Vector.tabulate (num_facts, K 0) - val sfreq = Vector.tabulate (num_facts, K Inttab.empty) - val dffreq = Vector.tabulate (num_feats, K 0) - - val freqs' = - MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss - val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats in - MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs' - visible_facts max_suggs goal_feats int_goal_feats + MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts + max_suggs goal_feats int_goal_feats end end @@ -1447,7 +1458,7 @@ val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t]) in map_state ctxt overlord - (fn state as {access_G, xtabs, dirty_facts} => + (fn state as {access_G, xtabs, ffds, freqs, dirty_facts} => let val parents = maximal_wrt_access_graph access_G facts val deps = used_ths @@ -1459,10 +1470,12 @@ else let val name = learned_proof_name () - val (access_G, xtabs) = + val (access_G', xtabs') = add_node Automatic_Proof name parents feats deps (access_G, xtabs) + + val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs' in - {access_G = access_G, xtabs = xtabs, + {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs', dirty_facts = Option.map (cons name) dirty_facts} end end); @@ -1510,26 +1523,31 @@ isar_dependencies_of name_tabs th fun do_commit [] [] [] state = state - | do_commit learns relearns flops {access_G, xtabs, dirty_facts} = + | do_commit learns relearns flops {access_G, xtabs, ffds, freqs, dirty_facts} = let + val was_empty = Graph.is_empty access_G + + (* TODO: use "fold_map" *) val (learns, (access_G, xtabs)) = fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs)) val (relearns, access_G) = fold (relearn_wrt_access_graph ctxt) relearns ([], access_G) - val was_empty = Graph.is_empty access_G val access_G = access_G |> fold flop_wrt_access_graph flops val dirty_facts = (case (was_empty, dirty_facts) of (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names) | _ => NONE) + + val (ffds', freqs') = recompute_ffd_freqs access_G xtabs in if engine = MaSh_Py then (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns); MaSh_Py.relearn ctxt overlord save relearns) else (); - {access_G = access_G, xtabs = xtabs, dirty_facts = dirty_facts} + {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs', + dirty_facts = dirty_facts} end fun commit last learns relearns flops =