reintroduced 'extra features' but with lower weight than before (to account for tfidf)
authorblanchet
Fri, 27 Jun 2014 11:38:15 +0200
changeset 57401 02f56126b4e4
parent 57400 13b06c626163
child 57402 b532b879acd0
reintroduced 'extra features' but with lower weight than before (to account for tfidf)
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jun 27 10:49:52 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jun 27 11:38:15 2014 +0200
@@ -409,9 +409,9 @@
         Array.update (overlaps_sqr, j, (j, v + ov))
       end
 
-    fun do_feat s =
+    fun do_feat (s, sw0) =
       let
-        val sw = tfidf s
+        val sw = sw0 * tfidf s
         val w2 = sw * sw
         fun do_th j = if j < num_facts then inc_overlap j w2 else ()
       in
@@ -523,10 +523,10 @@
       let
         val tfreq = Real.fromInt (Vector.sub (tfreq, i))
 
-        fun fold_feats f (res, sfh) =
+        fun fold_feats (f, fw0) (res, sfh) =
           (case Inttab.lookup sfh f of
             SOME sf =>
-            (res + tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq),
+            (res + fw0 * tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq),
              Inttab.delete f sfh)
           | NONE => (res + tfidf f * def_val, sfh))
 
@@ -561,11 +561,10 @@
       map name_of_feature (Vector.sub (featss, j)),
       map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
     val parents' = parents_of num_facts
-    val goal_feats' = map (rpair 1.0 o name_of_feature) goal_feats
   in
     MaSh_Py.unlearn ctxt overlord;
     OS.Process.sleep (seconds 2.0); (* hack *)
-    MaSh_Py.query ctxt overlord max_suggs (learns, parents', goal_feats')
+    MaSh_Py.query ctxt overlord max_suggs (learns, parents', goal_feats)
     |> map (apfst fact_of_name)
   end
 
@@ -599,7 +598,7 @@
         |> filter_out (curry (op =) "")
       end
   in
-    (List.app do_learn learns; ol occ (os occ o quote) ", " goal_feats;
+    (List.app do_learn learns; ol occ (os occ o quote) ", " (map fst goal_feats);
      TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
      forkexec max_suggs)
   end
@@ -609,14 +608,14 @@
 val naive_bayes_ext = external_tool "predict/nbayes"
 
 fun query_external ctxt engine max_suggs learns goal_feats =
-  (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_strs goal_feats);
+  (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_features goal_feats);
    (case engine of
      MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
    | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
 
 fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
     (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
-  (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
+  (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_features goal_feats ^ " from {" ^
      elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
    (case engine of
      MaSh_SML_kNN =>
@@ -1276,8 +1275,8 @@
 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
 
 val chained_feature_factor = 0.5 (* FUDGE *)
-val extra_feature_factor = 0.1 (* FUDGE *)
-val num_extra_feature_facts = 0 (* FUDGE *) (* TODO: keep or eliminate? *)
+val extra_feature_factor = 0.05 (* FUDGE *)
+val num_extra_feature_facts = 10 (* FUDGE *)
 
 (* FUDGE *)
 fun weight_of_proximity_fact rank =
@@ -1378,8 +1377,7 @@
         []
       else
         let
-          val (parents, goal_feats0) = query_args access_G
-          val goal_feats = map fst goal_feats0
+          val (parents, goal_feats) = query_args access_G
           val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
         in
           if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then
@@ -1391,7 +1389,8 @@
             end
           else
             let
-              val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
+              val int_goal_feats =
+                map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
             in
               MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
                 max_suggs goal_feats int_goal_feats