# HG changeset patch # User blanchet # Date 1400535833 -7200 # Node ID 121b63d7bcdbe139c7b98d341e648659e4ce0ffe # Parent 8cb6a5f1ae8435712b4b2dfae9acdbb212e09584 take weights into consideration in knn diff -r 8cb6a5f1ae84 -r 121b63d7bcdb src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 @@ -207,10 +207,15 @@ encode_unweighted_feature names ^ (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight) +fun decode_feature s = + (case space_explode "=" s of + [feat, weight] => (decode_unweighted_feature feat, Real.fromString weight |> the_default 1.0) + | _ => (decode_unweighted_feature s, 1.0)) + val encode_unweighted_features = map encode_unweighted_feature #> space_implode " " -val decode_unweighted_features = space_explode " " #> map decode_unweighted_feature val encode_features = map encode_feature #> space_implode " " +val decode_features = space_explode " " #> map decode_feature fun str_of_learn (name, parents, feats, deps) = "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ @@ -431,9 +436,8 @@ fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) -fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i))) +fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) -(* TODO: take weight components of "feats" into consideration *) fun learn_and_query ctxt parents access_G max_suggs hints feats = let val str_of_feat = space_implode "|" @@ -443,12 +447,12 @@ let val (_, feats, deps) = Graph.get_node access_G fact - fun add_feat feat (xtab as (n, tab, _)) = + fun add_feat (feat, weight) (xtab as (n, tab, _)) = (case Symtab.lookup tab feat of - SOME i => (i, xtab) - | NONE => (n, add_to_xtab feat xtab)) + SOME i => ((i, weight), xtab) + | NONE => ((n, weight), add_to_xtab feat xtab)) - val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab + val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab in (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, add_to_xtab fact fact_xtab, feat_xtab') @@ -463,15 +467,17 @@ val _ = fold (fn feats => fn fact => let val fact' = fact - 1 in - (List.app (map_array_index facts_ary (cons fact')) feats; fact') + List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) + feats; + fact' end) featss (length featss) in (trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ elide_string 1000 (space_implode " " facts) ^ "}"); - knn (Array.length deps_ary) (curry Array.sub deps_ary) - (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs + knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns + max_suggs (map_filter (fn (feat, weight) => Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) |> map ((fn i => Array.sub (fact_ary, i)) o fst)) @@ -515,7 +521,7 @@ string_of_int (length (Graph.maximals G)) ^ " maximal" type mash_state = - {access_G : (proof_kind * string list list * string list) Graph.T, + {access_G : (proof_kind * (string list * real) list * string list) Graph.T, num_known_facts : int, dirty : string list option} @@ -523,7 +529,7 @@ local -val version = "*** MaSh version 20140516 ***" +val version = "*** MaSh version 20140519 ***" exception FILE_VERSION_TOO_NEW of unit @@ -532,8 +538,8 @@ [head, tail] => (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of ([kind, name], [parents, feats, deps]) => - SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, - decode_unweighted_features feats, decode_strs deps) + SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_features feats, + decode_strs deps) | _ => NONE) | _ => NONE) @@ -573,7 +579,7 @@ fun str_of_entry (kind, name, parents, feats, deps) = str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ - encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n" + encode_features feats ^ "; " ^ encode_strs deps ^ "\n" fun save_state _ (state as {dirty = SOME [], ...}) = state | save_state ctxt {access_G, num_known_facts, dirty} = @@ -1225,7 +1231,7 @@ val (parents, G) = ([], G) |> fold maybe_learn_from parents val (deps, _) = ([], G) |> fold maybe_learn_from deps in - ((name, parents, feats, deps) :: learns, G) + ((name, parents, map fst feats, deps) :: learns, G) end fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) = @@ -1259,7 +1265,7 @@ launch_thread timeout (fn () => let val thy = Proof_Context.theory_of ctxt - val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst + val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] in peek_state ctxt overlord (fn {access_G, ...} => let @@ -1269,7 +1275,7 @@ |> map nickname_of_thm in if Config.get ctxt sml then () (* TODO: implement *) - else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)] + else MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)] end); (true, "") end) @@ -1361,7 +1367,6 @@ val name = nickname_of_thm th val feats = features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th] - |> map fst val deps = deps_of status th |> these val n = n |> not (null deps) ? Integer.add 1 val learns = (name, parents, feats, deps) :: learns