--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon Jan 21 21:28:16 2019 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon Jan 21 22:29:41 2019 +0100
@@ -70,7 +70,7 @@
raw_fact list -> fact list * fact list
val mash_unlearn : Proof.context -> unit
- val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
+ val mash_learn_proof : Proof.context -> params -> term -> thm list -> unit
val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
raw_fact list -> string
val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
@@ -291,7 +291,7 @@
structure MaSh =
struct
-fun select_visible_facts (big_number : real) recommends =
+fun select_fact_idxs (big_number : real) recommends =
List.app (fn at =>
let val (j, ov) = Array.sub (recommends, at) in
Array.update (recommends, at, (j, big_number + ov))
@@ -337,7 +337,7 @@
(Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
end
-fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs visible_facts goal_feats =
+fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs fact_idxs goal_feats =
let
val tau = 0.2 (* FUDGE *)
val pos_weight = 5.0 (* FUDGE *)
@@ -375,14 +375,14 @@
fun ret at acc =
if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
in
- select_visible_facts 100000.0 posterior visible_facts;
+ select_fact_idxs 100000.0 posterior fact_idxs;
sort_array_suffix (Real.compare o apply2 snd) max_suggs posterior;
ret (Integer.max 0 (num_facts - max_suggs)) []
end
val initial_k = 0
-fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts goal_feats =
+fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs goal_feats =
let
exception EXIT of unit
@@ -451,7 +451,7 @@
in
while1 ();
while2 ();
- select_visible_facts 1000000000.0 recommends visible_facts;
+ select_fact_idxs 1000000000.0 recommends fact_idxs;
sort_array_suffix (Real.compare o apply2 snd) max_suggs recommends;
ret [] (Integer.max 0 (num_facts - max_suggs))
end
@@ -502,14 +502,13 @@
| MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats))
fun query_internal ctxt algorithm num_facts num_feats (fact_names, featss, depss)
- (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
+ (freqs as (_, _, dffreq)) fact_idxs max_suggs goal_feats int_goal_feats =
let
fun nb () =
- naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+ naive_bayes freqs num_facts max_suggs fact_idxs int_goal_feats
|> map fst
fun knn () =
- k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts
- int_goal_feats
+ k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs int_goal_feats
|> map fst
in
(trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
@@ -652,7 +651,7 @@
local
-val version = "*** MaSh version 20161123 ***"
+val version = "*** MaSh version 20190121 ***"
exception FILE_VERSION_TOO_NEW of unit
@@ -904,7 +903,7 @@
#> union (op =)) S
fun pattify_type 0 _ = []
- | pattify_type depth (Type (s, [])) = if member (op =) bad_types s then [] else [s]
+ | pattify_type _ (Type (s, [])) = if member (op =) bad_types s then [] else [s]
| pattify_type depth (Type (s, U :: Ts)) =
let
val T = Type (s, Ts)
@@ -930,7 +929,7 @@
| add_subtypes T = add_type T
fun pattify_term _ 0 _ = []
- | pattify_term _ depth (Const (s, _)) =
+ | pattify_term _ _ (Const (s, _)) =
if is_widely_irrelevant_const s then [] else [s]
| pattify_term _ _ (Free (s, T)) =
maybe_singleton_str (crude_str_of_typ T)
@@ -1136,64 +1135,6 @@
|> drop (length old_facts)
end
-fun maximal_wrt_graph _ [] = []
- | maximal_wrt_graph G keys =
- if can (Graph.get_node G o the_single) keys then
- keys
- else
- let
- val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
-
- fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
-
- fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
-
- fun find_maxes _ (maxs, []) = map snd maxs
- | find_maxes seen (maxs, new :: news) =
- find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
- (if Symtab.defined tab new then
- let
- val newp = Graph.all_preds G [new]
- fun is_ancestor x yp = member (op =) yp x
- val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
- in
- if exists (is_ancestor new o fst) maxs then (maxs, news)
- else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
- end
- else
- (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
- in
- find_maxes Symtab.empty ([],
- G |> Graph.restrict (not o String.isPrefix anonymous_proof_prefix) |> Graph.maximals)
- end
-
-val max_facts_for_shuffle_cleanup = 20
-
-fun maximal_wrt_access_graph _ [] = []
- | maximal_wrt_access_graph access_G (fact :: facts) =
- let
- fun cleanup_wrt (_, th) =
- let val thy_id = Thm.theory_id th in
- filter_out (fn (_, th') =>
- Context.proper_subthy_id (Thm.theory_id th', thy_id))
- end
-
- fun shuffle_cleanup accum [] = accum
- | shuffle_cleanup accum (fact :: facts) =
- let
- val accum' = accum |> cleanup_wrt fact
- val facts' = facts |> cleanup_wrt fact
- in
- shuffle_cleanup accum' facts'
- end
- in
- fact :: cleanup_wrt fact facts
- |> (fn facts => facts
- |> length facts <= max_facts_for_shuffle_cleanup ? shuffle_cleanup [])
- |> map (nickname_of_thm o snd)
- |> maximal_wrt_graph access_G
- end
-
fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
val chained_feature_factor = 0.5 (* FUDGE *)
@@ -1257,8 +1198,7 @@
fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
|> debug ? sort (Real.compare o swap o apply2 snd)
- val parents = maximal_wrt_access_graph access_G facts
- val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
+ val fact_idxs = map_filter (Symtab.lookup fact_tab o nickname_of_thm o snd) facts
val suggs =
if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then
@@ -1274,8 +1214,8 @@
val int_goal_feats =
map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
in
- MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs visible_facts
- max_suggs goal_feats int_goal_feats
+ MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs fact_idxs max_suggs
+ goal_feats int_goal_feats
end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
@@ -1336,25 +1276,24 @@
Date.fmt (anonymous_proof_prefix ^ "%Y%m%d.%H%M%S.") (Date.fromTimeLocal (Time.now ())) ^
serial_string ()
-fun mash_learn_proof ctxt ({timeout, ...} : params) t facts used_ths =
+fun mash_learn_proof ctxt ({timeout, ...} : params) t used_ths =
if not (null used_ths) andalso is_mash_enabled () then
launch_thread timeout (fn () =>
let
val thy = Proof_Context.theory_of ctxt
val feats = features_of ctxt (Context.theory_name thy) (Local, General) [t]
- val facts = rev_sort_list_prefix (crude_thm_ord ctxt o apply2 snd) 1 facts
in
map_state ctxt
(fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
let
- val parents = maximal_wrt_access_graph access_G facts
val deps = used_ths
|> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
val name = anonymous_proof_name ()
val (access_G', xtabs', rev_learns) =
- add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])
+ add_node Automatic_Proof name [] (* ignore parents *) feats deps
+ (access_G, xtabs, [])
val (ffds', freqs') =
recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs
@@ -1475,7 +1414,7 @@
let
val new_facts = facts
|> sort (crude_thm_ord ctxt o apply2 snd)
- |> attach_parents_to_facts []
+ |> map (pair []) (* ignore parents *)
|> filter_out (is_in_access_G o snd)
val (learns, (num_nontrivial, _, _)) =
([], (0, next_commit_time (), false))
@@ -1582,7 +1521,7 @@
(* Generate more suggestions than requested, because some might be thrown out later for various
reasons (e.g., duplicates). *)
-fun generous_max_suggestions max_facts = 3 * max_facts div 2 + 25
+fun generous_max_suggestions max_facts = 2 * max_facts + 25 (* FUDGE *)
val mepo_weight = 0.5 (* FUDGE *)
val mash_weight = 0.5 (* FUDGE *)