take weights into consideration in knn
authorblanchet
Mon, 19 May 2014 23:43:53 +0200
changeset 57010 121b63d7bcdb
parent 57009 8cb6a5f1ae84
child 57011 a4428f517f46
take weights into consideration in knn
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
@@ -207,10 +207,15 @@
   encode_unweighted_feature names ^
   (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
 
+fun decode_feature s =
+  (case space_explode "=" s of
+    [feat, weight] => (decode_unweighted_feature feat, Real.fromString weight |> the_default 1.0)
+  | _ => (decode_unweighted_feature s, 1.0))
+
 val encode_unweighted_features = map encode_unweighted_feature #> space_implode " "
-val decode_unweighted_features = space_explode " " #> map decode_unweighted_feature
 
 val encode_features = map encode_feature #> space_implode " "
+val decode_features = space_explode " " #> map decode_feature
 
 fun str_of_learn (name, parents, feats, deps) =
   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
@@ -431,9 +436,8 @@
 
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
-fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
+fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
-(* TODO: take weight components of "feats" into consideration *)
 fun learn_and_query ctxt parents access_G max_suggs hints feats =
   let
     val str_of_feat = space_implode "|"
@@ -443,12 +447,12 @@
           let
             val (_, feats, deps) = Graph.get_node access_G fact
 
-            fun add_feat feat (xtab as (n, tab, _)) =
+            fun add_feat (feat, weight) (xtab as (n, tab, _)) =
               (case Symtab.lookup tab feat of
-                SOME i => (i, xtab)
-              | NONE => (n, add_to_xtab feat xtab))
+                SOME i => ((i, weight), xtab)
+              | NONE => ((n, weight), add_to_xtab feat xtab))
 
-            val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
+            val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
           in
             (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
              add_to_xtab fact fact_xtab, feat_xtab')
@@ -463,15 +467,17 @@
     val _ =
       fold (fn feats => fn fact =>
           let val fact' = fact - 1 in
-            (List.app (map_array_index facts_ary (cons fact')) feats; fact')
+            List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
+              feats;
+            fact'
           end)
         featss (length featss)
   in
     (trace_msg ctxt (fn () =>
        "MaSh_SML query " ^ encode_features feats ^ " from {" ^
         elide_string 1000 (space_implode " " facts) ^ "}");
-     knn (Array.length deps_ary) (curry Array.sub deps_ary)
-       (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
+     knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
+       max_suggs
        (map_filter (fn (feat, weight) =>
           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
      |> map ((fn i => Array.sub (fact_ary, i)) o fst))
@@ -515,7 +521,7 @@
   string_of_int (length (Graph.maximals G)) ^ " maximal"
 
 type mash_state =
-  {access_G : (proof_kind * string list list * string list) Graph.T,
+  {access_G : (proof_kind * (string list * real) list * string list) Graph.T,
    num_known_facts : int,
    dirty : string list option}
 
@@ -523,7 +529,7 @@
 
 local
 
-val version = "*** MaSh version 20140516 ***"
+val version = "*** MaSh version 20140519 ***"
 
 exception FILE_VERSION_TOO_NEW of unit
 
@@ -532,8 +538,8 @@
     [head, tail] =>
     (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of
       ([kind, name], [parents, feats, deps]) =>
-      SOME (proof_kind_of_str kind, decode_str name, decode_strs parents,
-        decode_unweighted_features feats, decode_strs deps)
+      SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_features feats,
+        decode_strs deps)
     | _ => NONE)
   | _ => NONE)
 
@@ -573,7 +579,7 @@
 
 fun str_of_entry (kind, name, parents, feats, deps) =
   str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
-  encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n"
+  encode_features feats ^ "; " ^ encode_strs deps ^ "\n"
 
 fun save_state _ (state as {dirty = SOME [], ...}) = state
   | save_state ctxt {access_G, num_known_facts, dirty} =
@@ -1225,7 +1231,7 @@
     val (parents, G) = ([], G) |> fold maybe_learn_from parents
     val (deps, _) = ([], G) |> fold maybe_learn_from deps
   in
-    ((name, parents, feats, deps) :: learns, G)
+    ((name, parents, map fst feats, deps) :: learns, G)
   end
 
 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
@@ -1259,7 +1265,7 @@
     launch_thread timeout (fn () =>
       let
         val thy = Proof_Context.theory_of ctxt
-        val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
+        val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t]
       in
         peek_state ctxt overlord (fn {access_G, ...} =>
           let
@@ -1269,7 +1275,7 @@
                        |> map nickname_of_thm
           in
             if Config.get ctxt sml then () (* TODO: implement *)
-            else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
+            else MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]
           end);
         (true, "")
       end)
@@ -1361,7 +1367,6 @@
               val name = nickname_of_thm th
               val feats =
                 features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
-                |> map fst
               val deps = deps_of status th |> these
               val n = n |> not (null deps) ? Integer.add 1
               val learns = (name, parents, feats, deps) :: learns