--- a/src/HOL/TPTP/mash_eval.ML Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/TPTP/mash_eval.ML Fri May 30 12:27:51 2014 +0200
@@ -97,7 +97,7 @@
mesh_isar_line), mesh_prover_line)) =
if in_range range j then
let
- val get_suggs = extract_suggestions ##> take max_suggs
+ val get_suggs = extract_suggestions ##> (take max_suggs #> map fst)
val (name1, mepo_suggs) = get_suggs mepo_line
val (name2, mash_isar_suggs) = get_suggs mash_isar_line
val (name3, mash_prover_suggs) = get_suggs mash_prover_line
--- a/src/HOL/TPTP/mash_export.ML Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/TPTP/mash_export.ML Fri May 30 12:27:51 2014 +0200
@@ -285,10 +285,10 @@
let
val (name, mash_suggs) =
extract_suggestions mash_line
- ||> weight_facts_steeply
+ ||> (map fst #> weight_facts_steeply)
val (name', mepo_suggs) =
extract_suggestions mepo_line
- ||> weight_facts_steeply
+ ||> (map fst #> weight_facts_steeply)
val _ = if name = name' then () else error "Input files out of sync."
val mess =
[(mepo_weight, (mepo_suggs, [])),
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 12:27:51 2014 +0200
@@ -32,9 +32,9 @@
val decode_str : string -> string
val decode_strs : string -> string list
val encode_features : (string * real) list -> string
- val extract_suggestions : string -> string * string list
+ val extract_suggestions : string -> string * (string * real) list
- datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
+ datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
@@ -48,16 +48,18 @@
val query : Proof.context -> bool -> int ->
(string * string list * string list * string list) list * string list * string list
* (string * real) list ->
- string list
+ (string * real) list
end
structure MaSh_SML :
sig
val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
int -> (int * real) list -> (int * real) list
- val naive_bayes : int -> int -> (int -> int list) -> (int -> Inttab.key list) -> int -> int ->
- (Inttab.key * real) list -> (int * real) list
- val query : Proof.context -> mash_engine -> string list -> int ->
+ val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int ->
+ (int * real) list -> (int * real) list
+ val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
+ (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
+ val query : Proof.context -> bool -> mash_engine -> string list -> int ->
(string * (string * real) list * string list) list * string list * (string * real) list ->
string list
end
@@ -144,7 +146,7 @@
()
end
-datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
+datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
fun mash_engine () =
let val flag1 = Options.default_string @{system_option MaSh} in
@@ -154,6 +156,7 @@
| "sml" => SOME MaSh_SML_NB
| "sml_knn" => SOME MaSh_SML_kNN
| "sml_nb" => SOME MaSh_SML_NB
+ | "sml_nb_py" => SOME MaSh_SML_NB_Py
| _ => NONE)
end
@@ -267,8 +270,8 @@
(* The suggested weights do not make much sense. *)
fun extract_suggestion sugg =
(case space_explode "=" sugg of
- [name, _ (* weight *)] => SOME (decode_str name)
- | [name] => SOME (decode_str name)
+ [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0)
+ | [name] => SOME (decode_str name, 1.0)
| _ => NONE)
fun extract_suggestions line =
@@ -458,7 +461,7 @@
(* TODO: Either use IDF or don't use it. See commented out code portions below. *)
-fun naive_bayes_learn num_facts get_deps get_th_feats num_feats =
+fun naive_bayes_learn num_facts get_deps get_feats num_feats =
let
val tfreq = Array.array (num_facts, 0)
val sfreq = Array.array (num_facts, Inttab.empty)
@@ -483,7 +486,7 @@
end
fun for i =
- if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1))
+ if i = num_facts then () else (learn i (get_feats i) (get_deps i); for (i + 1))
in
for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *))
end
@@ -536,15 +539,34 @@
ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
-fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats =
- naive_bayes_learn num_facts get_deps get_th_feats num_feats
+fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+ naive_bayes_learn num_facts get_deps get_feats num_feats
|> naive_bayes_query num_facts num_visible_facts max_suggs feats
+(* experimental *)
+fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
+ feats =
+ let
+ fun name_of_fact j = "f" ^ string_of_int j
+ fun fact_of_name s = the (Int.fromString (unprefix "f" s))
+ fun name_of_feature j = "F" ^ string_of_int j
+ fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
+
+ val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
+ map name_of_fact (get_deps j))) (0 upto num_facts - 1)
+ val parents' = parents_of num_visible_facts
+ val feats' = map (apfst name_of_feature) feats
+ in
+ MaSh_Py.unlearn ctxt overlord;
+ MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
+ |> map (apfst fact_of_name)
+ end
+
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-fun query ctxt engine visible_facts max_suggs (learns, hints, feats) =
+fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
let
val visible_fact_set = Symtab.make_set visible_facts
@@ -602,8 +624,8 @@
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
in
- naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
- feats'
+ (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord)
+ num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
end)
|> map (curry Vector.sub fact_vec o fst)
end
@@ -1301,6 +1323,7 @@
if engine = MaSh_Py then
let val (parents, hints, feats) = query_args access_G in
MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+ |> map fst
end
else
[]))
@@ -1315,7 +1338,7 @@
val learns =
Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
in
- MaSh_SML.query ctxt engine visible_facts max_facts (learns, hints, feats)
+ MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats)
end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts