new version of adaptive k-NN with TFIDF
authorblanchet
Thu, 26 Jun 2014 13:34:28 +0200
changeset 57357 30ee18eb23ac
parent 57356 9816f692b0ca
child 57358 545d02691b32
new version of adaptive k-NN with TFIDF
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:50 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:34:28 2014 +0200
@@ -60,15 +60,14 @@
 
   structure MaSh_SML :
   sig
-    val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int ->
-      int list -> (int * real) list -> (int * real) list
+    val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
+      int -> int list -> (int * real) list
     val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
-      int -> (int * real) list -> (int * real) list
+      int -> int list -> (int * real) list
     val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
-      int -> int -> (int * real) list -> (int * real) list
+      int -> int -> int list -> (int * real) list
     val query : Proof.context -> bool -> mash_engine -> string list -> int ->
-      (string * (string * real) list * string list) list * string list * (string * real) list ->
-      string list
+      (string * string list * string list) list * string list * string list -> string list
   end
 
   val mash_unlearn : Proof.context -> params -> unit
@@ -428,10 +427,18 @@
   max_suggs = number of suggestions to return
   feats = features of the goal
 *)
-fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats =
+fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts num_feats feats =
   let
-    (* Can be later used for TFIDF *)
-    fun sym_wght _ = 1.0
+    val dffreq = Array.array (num_feats, 0)
+
+    fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
+    fun for1 i =
+      if i = num_feats then () else
+      (List.app (fn _ => add_sym i) (get_sym_ths i); for1 (i + 1))
+    val _ = for1 0
+
+    val ln_afreq = Math.ln (Real.fromInt num_facts)
+    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat))) handle Subscript => ln_afreq
 
     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
 
@@ -442,12 +449,11 @@
         Array.update (overlaps_sqr, j, (j, v + ov))
       end
 
-    fun do_feat (s, con_wght) =
+    fun do_feat s =
       let
-        val sw = sym_wght s
-        val w2 = sw * sw * con_wght
-
-        fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else ()
+        val sw = tfidf s
+        val w2 = sw * sw
+        fun do_th j = if j < num_facts then inc_overlap j w2 else ()
       in
         List.app do_th (get_sym_ths s)
       end
@@ -460,11 +466,8 @@
 
     fun inc_recommend j v =
       let val ov = snd (Array.sub (recommends, j)) in
-        if ov <= 0.0 then
-          (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
-        else
-          (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ())
-      end
+      if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+      else (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) end
 
     val k = Unsynchronized.ref 0
     fun do_k k =
@@ -482,12 +485,13 @@
         end
 
     fun while1 () =
-      if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
+      (if !k = number_of_nearest_neighbors then () else
+      (do_k (!k); k := !k + 1; while1 ()))
       handle EXIT () => ()
 
     fun while2 () =
-      if !no_recommends >= max_suggs then ()
-      else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
+      (if !no_recommends >= max_suggs then () else
+      (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()))
       handle EXIT () => ()
 
     fun ret acc at =
@@ -553,7 +557,7 @@
       let
         val tfreq = Real.fromInt (Vector.sub (tfreq, i))
 
-        fun fold_feats (f, _) (res, sfh) =
+        fun fold_feats f (res, sfh) =
           (case Inttab.lookup sfh f of
             SOME sf =>
             (res + tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq),
@@ -598,7 +602,7 @@
     val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
       map name_of_fact (get_deps j))) (0 upto num_facts - 1)
     val parents' = parents_of num_facts
-    val feats' = map (apfst name_of_feature) feats
+    val feats' = map (rpair 1.0 o name_of_feature) feats
   in
     MaSh_Py.unlearn ctxt overlord;
     OS.Process.sleep (seconds 2.0); (* hack *)
@@ -622,8 +626,7 @@
       | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
 
     fun do_learn (name, feats, deps) =
-      (os ocs name; os ocs ":";
-       ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
+      (os ocs name; os ocs ":"; ol ocs (os ocs o quote) ", " feats; os ocs "\n";
        os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n")
 
     fun forkexec no =
@@ -637,7 +640,7 @@
         |> filter_out (curry (op =) "")
       end
   in
-    (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
+    (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats;
      TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
      forkexec max_suggs)
   end
@@ -655,19 +658,19 @@
     val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
   in
     if engine = MaSh_SML_kNN_Cpp then
-      k_nearest_neighbors_cpp max_suggs learns (map fst feats)
+      k_nearest_neighbors_cpp max_suggs learns feats
     else if engine = MaSh_SML_NB_Cpp then
-      naive_bayes_cpp max_suggs learns (map fst feats)
+      naive_bayes_cpp max_suggs learns feats
     else
       let
         val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
           fold (fn (fact, feats, deps) =>
                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
               let
-                fun add_feat (feat, weight) (xtab as (n, tab, _)) =
+                fun add_feat feat (xtab as (n, tab, _)) =
                   (case Symtab.lookup tab feat of
-                    SOME i => ((i, weight), xtab)
-                  | NONE => ((n, weight), add_to_xtab feat xtab))
+                    SOME i => (i, xtab)
+                  | NONE => (n, add_to_xtab feat xtab))
 
                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
               in
@@ -678,14 +681,15 @@
 
         val facts = rev rev_facts
         val fact_vec = Vector.fromList facts
+        val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
 
         val deps_vec = Vector.fromList (rev rev_depss)
 
         val get_deps = curry Vector.sub deps_vec
 
-        val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
+        val int_feats = map (the_default ~1 o Symtab.lookup feat_tab) feats
       in
-        trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+        trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
           elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
         (if engine = MaSh_SML_kNN then
            let
@@ -693,22 +697,18 @@
              val _ =
                fold (fn feats => fn fact =>
                    let val fact' = fact - 1 in
-                     List.app (fn (feat, weight) =>
-                       map_array_at facts_ary (cons (fact', weight)) feat) feats;
-                     fact'
+                     List.app (map_array_at facts_ary (cons fact')) feats; fact'
                    end)
                  rev_featss num_facts
              val get_facts = curry Array.sub facts_ary
-             val int_feats = map_filter (fn (feat, weight) =>
-               Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
            in
-             k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats
+             k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
+               int_feats
            end
          else
            let
-             val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
+             val unweighted_feats_ary = Vector.fromList (rev rev_featss)
              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
-             val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
            in
              (case engine of
                MaSh_SML_NB opts =>
@@ -773,6 +773,8 @@
 
 val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []} : mash_state
 
+(* TODO: get rid of weights in data structure *)
+
 local
 
 val version = "*** MaSh version 20140519 ***"
@@ -1428,9 +1430,10 @@
           val (parents, hints, feats) = query_args access_G
           val visible_facts = Graph.all_preds access_G parents
           val learns =
-            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, map fst feats, deps))
+              access_G
         in
-          MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats)
+          MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
         end
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts