--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:39 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:46 2014 +0200
@@ -46,30 +46,6 @@
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
- structure MaSh_Py :
- sig
- val unlearn : Proof.context -> bool -> unit
- val learn : Proof.context -> bool -> bool ->
- (string * string list * string list * string list) list -> unit
- val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
- val query : Proof.context -> bool -> int ->
- (string * string list * string list * string list) list * string list * string list
- * (string * real) list ->
- (string * real) list
- end
-
- structure MaSh_SML :
- sig
- val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
- int -> int list -> (int * real) list
- val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list ->
- int list -> (int * real) list
- val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
- int -> int -> int list -> (int * real) list
- val query : Proof.context -> bool -> mash_engine -> string list -> int ->
- (string * string list * string list) list * string list * string list -> string list
- end
-
val mash_unlearn : Proof.context -> params -> unit
val nickname_of_thm : thm -> string
val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
@@ -492,7 +468,7 @@
val nb_def_prior_weight = 21 (* FUDGE *)
-fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
+fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
let
fun learn_fact th feats deps =
let
@@ -525,7 +501,7 @@
val sfreq = Array.array (num_facts, Inttab.empty)
val dffreq = Array.array (num_feats, 0)
in
- learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
+ learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
end
fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
@@ -574,7 +550,7 @@
|> naive_bayes_query num_facts max_suggs visible_facts feats
(* experimental *)
-fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
let
fun name_of_fact j = "f" ^ string_of_int j
fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -631,66 +607,54 @@
c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
-val empty_xtab = (0, Symtab.empty)
-
-fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
-fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
+ learns conj_feats =
+ if engine = MaSh_SML_kNN_Cpp then
+ k_nearest_neighbors_cpp max_suggs learns conj_feats
+ else if engine = MaSh_SML_NB_Cpp then
+ naive_bayes_cpp max_suggs learns conj_feats
+ else
+ let
+ val facts = map #1 learns
+ val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
+ val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
-fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
- let
- val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
- in
- if engine = MaSh_SML_kNN_Cpp then
- k_nearest_neighbors_cpp max_suggs learns feats
- else if engine = MaSh_SML_NB_Cpp then
- naive_bayes_cpp max_suggs learns feats
- else
- let
- val facts = map #1 learns
- val fact_vec = Vector.fromList facts
+ val fact_vec = Vector.fromList facts
+ val deps_vec = Vector.fromList depss
- val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
- val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
-
- val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
-
- val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
-
- val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
-
- val get_deps = curry Vector.sub deps_vec
+ val get_deps = curry Vector.sub deps_vec
- val int_feats = map_filter (Symtab.lookup feat_tab) feats
- in
- trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
- elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
- (if engine = MaSh_SML_kNN then
- let
- val facts_ary = Array.array (num_feats, [])
- val _ =
- fold (fn feats => fn fact =>
- (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
- featss 0
- val get_facts = curry Array.sub facts_ary
- in
- k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
- int_feats
- end
- else
- let
- val unweighted_feats_ary = Vector.fromList featss
- val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
- in
- (case engine of
- MaSh_SML_NB =>
- naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
- int_visible_facts int_feats
- | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
- get_unweighted_feats num_feats max_suggs int_feats)
- end)
- |> map (curry Vector.sub fact_vec o fst)
- end
- end
+ val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
+ val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
+ in
+ trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
+ elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
+ (if engine = MaSh_SML_kNN then
+ let
+ val facts_ary = Array.array (num_feats, [])
+ val _ =
+ fold (fn feats => fn fact =>
+ (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
+ featss 0
+ val get_facts = curry Array.sub facts_ary
+ in
+ k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
+ int_conj_feats
+ end
+ else
+ let
+ val feats_ary = Vector.fromList featss
+ val get_feats = curry Vector.sub feats_ary
+ in
+ (case engine of
+ MaSh_SML_NB =>
+ naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
+ int_conj_feats
+ | MaSh_SML_NB_Py =>
+ naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs int_conj_feats)
+ end)
+ |> map (curry Vector.sub fact_vec o fst)
+ end
end;
@@ -1328,6 +1292,11 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
+val empty_xtab = (0, Symtab.empty)
+
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
+fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+
fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
@@ -1395,12 +1364,18 @@
[]
else
let
- val (parents, hints, feats) = query_args access_G
+ val (parents, hints, feats0) = query_args access_G
+ val feats = map fst feats0
val visible_facts = Graph.all_preds access_G parents
val learns =
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
+ (if null hints then [] else [(".hints", feats, hints)])
+
+ val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
+ val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
in
- MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
+ MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
+ feats
end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts