--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
@@ -207,10 +207,15 @@
encode_unweighted_feature names ^
(if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
+fun decode_feature s =
+ (case space_explode "=" s of
+ [feat, weight] => (decode_unweighted_feature feat, Real.fromString weight |> the_default 1.0)
+ | _ => (decode_unweighted_feature s, 1.0))
+
val encode_unweighted_features = map encode_unweighted_feature #> space_implode " "
-val decode_unweighted_features = space_explode " " #> map decode_unweighted_feature
val encode_features = map encode_feature #> space_implode " "
+val decode_features = space_explode " " #> map decode_feature
fun str_of_learn (name, parents, feats, deps) =
"! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
@@ -431,9 +436,8 @@
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
-fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
+fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-(* TODO: take weight components of "feats" into consideration *)
fun learn_and_query ctxt parents access_G max_suggs hints feats =
let
val str_of_feat = space_implode "|"
@@ -443,12 +447,12 @@
let
val (_, feats, deps) = Graph.get_node access_G fact
- fun add_feat feat (xtab as (n, tab, _)) =
+ fun add_feat (feat, weight) (xtab as (n, tab, _)) =
(case Symtab.lookup tab feat of
- SOME i => (i, xtab)
- | NONE => (n, add_to_xtab feat xtab))
+ SOME i => ((i, weight), xtab)
+ | NONE => ((n, weight), add_to_xtab feat xtab))
- val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
+ val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
in
(map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
add_to_xtab fact fact_xtab, feat_xtab')
@@ -463,15 +467,17 @@
val _ =
fold (fn feats => fn fact =>
let val fact' = fact - 1 in
- (List.app (map_array_index facts_ary (cons fact')) feats; fact')
+ List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
+ feats;
+ fact'
end)
featss (length featss)
in
(trace_msg ctxt (fn () =>
"MaSh_SML query " ^ encode_features feats ^ " from {" ^
elide_string 1000 (space_implode " " facts) ^ "}");
- knn (Array.length deps_ary) (curry Array.sub deps_ary)
- (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
+ knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
+ max_suggs
(map_filter (fn (feat, weight) =>
Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
|> map ((fn i => Array.sub (fact_ary, i)) o fst))
@@ -515,7 +521,7 @@
string_of_int (length (Graph.maximals G)) ^ " maximal"
type mash_state =
- {access_G : (proof_kind * string list list * string list) Graph.T,
+ {access_G : (proof_kind * (string list * real) list * string list) Graph.T,
num_known_facts : int,
dirty : string list option}
@@ -523,7 +529,7 @@
local
-val version = "*** MaSh version 20140516 ***"
+val version = "*** MaSh version 20140519 ***"
exception FILE_VERSION_TOO_NEW of unit
@@ -532,8 +538,8 @@
[head, tail] =>
(case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of
([kind, name], [parents, feats, deps]) =>
- SOME (proof_kind_of_str kind, decode_str name, decode_strs parents,
- decode_unweighted_features feats, decode_strs deps)
+ SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_features feats,
+ decode_strs deps)
| _ => NONE)
| _ => NONE)
@@ -573,7 +579,7 @@
fun str_of_entry (kind, name, parents, feats, deps) =
str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
- encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n"
+ encode_features feats ^ "; " ^ encode_strs deps ^ "\n"
fun save_state _ (state as {dirty = SOME [], ...}) = state
| save_state ctxt {access_G, num_known_facts, dirty} =
@@ -1225,7 +1231,7 @@
val (parents, G) = ([], G) |> fold maybe_learn_from parents
val (deps, _) = ([], G) |> fold maybe_learn_from deps
in
- ((name, parents, feats, deps) :: learns, G)
+ ((name, parents, map fst feats, deps) :: learns, G)
end
fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
@@ -1259,7 +1265,7 @@
launch_thread timeout (fn () =>
let
val thy = Proof_Context.theory_of ctxt
- val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
+ val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t]
in
peek_state ctxt overlord (fn {access_G, ...} =>
let
@@ -1269,7 +1275,7 @@
|> map nickname_of_thm
in
if Config.get ctxt sml then () (* TODO: implement *)
- else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
+ else MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]
end);
(true, "")
end)
@@ -1361,7 +1367,6 @@
val name = nickname_of_thm th
val feats =
features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
- |> map fst
val deps = deps_of status th |> these
val n = n |> not (null deps) ? Integer.add 1
val learns = (name, parents, feats, deps) :: learns