--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:22 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:25 2014 +0200
@@ -398,7 +398,7 @@
exception EXIT of unit
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts conj_feats =
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
let
val ln_afreq = Math.ln (Real.fromInt num_facts)
fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
@@ -419,7 +419,7 @@
List.app do_th (Array.sub (feat_facts, s))
end
- val _ = List.app do_feat conj_feats
+ val _ = List.app do_feat goal_feats
val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
val no_recommends = Unsynchronized.ref 0
val recommends = Array.tabulate (num_facts, rpair 0.0)
@@ -512,7 +512,7 @@
(Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
end
-fun naive_bayes tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
+fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs visible_facts goal_feats =
let
val tau = 0.05 (* FUDGE *)
val pos_weight = 10.0 (* FUDGE *)
@@ -534,7 +534,7 @@
Inttab.delete f sfh)
| NONE => (res + tfidf f * def_val, sfh))
- val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Vector.sub (sfreq, i))
+ val (res, sfh) = fold fold_feats goal_feats (Math.ln tfreq, Vector.sub (sfreq, i))
fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
@@ -554,7 +554,7 @@
end
(* experimental *)
-fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs conj_feats =
+fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs goal_feats =
let
fun name_of_fact j = "f" ^ string_of_int j
fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -565,16 +565,16 @@
map name_of_feature (Vector.sub (featss, j)),
map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
val parents' = parents_of num_facts
- val conj_feats' = map (rpair 1.0 o name_of_feature) conj_feats
+ val goal_feats' = map (rpair 1.0 o name_of_feature) goal_feats
in
MaSh_Py.unlearn ctxt overlord;
OS.Process.sleep (seconds 2.0); (* hack *)
- MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', conj_feats')
+ MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', goal_feats')
|> map (apfst fact_of_name)
end
(* experimental *)
-fun experimental_external_tool tool max_suggs learns cfeats =
+fun external_tool tool max_suggs learns goal_feats =
let
val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
val ocs = TextIO.openOut ("adv_syms" ^ ser)
@@ -603,67 +603,38 @@
|> filter_out (curry (op =) "")
end
in
- (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats;
+ (List.app do_learn learns; ol occ (os occ o quote) ", " goal_feats;
TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
forkexec max_suggs)
end
val k_nearest_neighbors_ext =
- experimental_external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
-val naive_bayes_ext = experimental_external_tool "predict/nbayes"
-
-fun reorder_learns (num_facts, fact_tab) learns0 =
- let
- val learns = Array.array (num_facts, ("", [], []))
- in
- List.app (fn learn as (fact, _, _) =>
- Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
- learns0;
- Array.foldr (op ::) [] learns
- end
+ external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
+val naive_bayes_ext = external_tool "predict/nbayes"
-fun query ctxt engine (fact_xtab as (num_facts, fact_tab)) (num_feats, feat_tab) visible_facts
- max_suggs learns0 conj_feats =
- if engine = MaSh_SML_kNN_Ext then
- k_nearest_neighbors_ext max_suggs learns0 conj_feats
- else if engine = MaSh_SML_NB_Ext then
- naive_bayes_ext max_suggs learns0 conj_feats
- else
- let
- val learns = reorder_learns fact_xtab learns0
-
- val facts = Vector.fromList (map #1 learns)
- val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
- val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+fun query_external ctxt engine max_suggs learns goal_feats =
+ (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_strs goal_feats);
+ (case engine of
+ MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
+ | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
- val tfreq = Vector.tabulate (num_facts, K 0)
- val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
- val dffreq = Vector.tabulate (num_feats, K 0)
-
- val (tfreq, sfreq, dffreq) =
- learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
-
- val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
- val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
- in
- trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
- elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
- (case engine of
- MaSh_SML_kNN =>
- let
- val feat_facts = Array.array (num_feats, [])
- val _ =
- Vector.foldl (fn (feats, fact) =>
- (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
- 0 featss
- in
- k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs int_visible_facts
- int_conj_feats
- end
- | MaSh_SML_NB =>
- naive_bayes tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
- |> map (curry Vector.sub facts o fst)
- end
+fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
+ visible_facts max_suggs goal_feats int_goal_feats =
+ (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
+ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
+ (case engine of
+ MaSh_SML_kNN =>
+ let
+ val feat_facts = Array.array (num_feats, [])
+ val _ =
+ Vector.foldl (fn (feats, fact) =>
+ (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
+ 0 featss
+ in
+ k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
+ end
+ | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
+ |> map (curry Vector.sub facts o fst))
end;
@@ -1312,7 +1283,17 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
-fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
+fun reorder_learns (num_facts, fact_tab) learns0 =
+ let
+ val learns = Array.array (num_facts, ("", [], []))
+ in
+ List.app (fn learn as (fact, _, _) =>
+ Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
+ learns0;
+ Array.foldr (op ::) [] learns
+ end
+
+fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
val thy_name = Context.theory_name thy
@@ -1360,14 +1341,14 @@
(parents, hints, feats)
end
- val ((access_G, (fact_xtab, feat_xtab)), py_suggs) =
+ val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
((access_G, xtabs),
if Graph.is_empty access_G then
(trace_msg ctxt (K "Nothing has been learned yet"); [])
else if engine = MaSh_Py then
let val (parents, hints, feats) = query_args access_G in
- MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+ MaSh_Py.query ctxt overlord max_suggs ([], hints, parents, feats)
|> map fst
end
else
@@ -1378,19 +1359,46 @@
[]
else
let
- val (parents, hints, feats0) = query_args access_G
- val feats = map fst feats0
- val visible_facts = Graph.all_preds access_G parents
- val learns =
- (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ val (parents, hints, goal_feats0) = query_args access_G
+ val goal_feats = map fst goal_feats0
+ val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
in
- MaSh_SML.query ctxt engine fact_xtab feat_xtab visible_facts max_facts learns feats
+ if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then
+ let
+ val learns =
+ (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *)
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ in
+ MaSh_SML.query_external ctxt engine max_suggs learns goal_feats
+ end
+ else
+ let
+ val learns0 =
+ (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *)
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ val learns = reorder_learns fact_xtab learns0
+
+ val facts = Vector.fromList (map #1 learns)
+ val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+ val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+
+ val tfreq = Vector.tabulate (num_facts, K 0)
+ val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+ val dffreq = Vector.tabulate (num_feats, K 0)
+
+ val freqs' =
+ MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
+
+ val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
+ in
+ MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
+ visible_facts max_suggs goal_feats int_goal_feats
+ end
end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
- find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
+ find_mash_suggestions ctxt max_suggs (py_suggs @ sml_suggs) facts chained unknown
|> pairself (map fact_of_raw_fact)
end