refactoring
authorblanchet
Thu, 26 Jun 2014 13:33:50 +0200
changeset 57356 9816f692b0ca
parent 57355 a9e0f9d35125
child 57357 30ee18eb23ac
refactoring
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:27 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:50 2014 +0200
@@ -60,12 +60,12 @@
 
   structure MaSh_SML :
   sig
-    val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
+    val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int ->
+      int list -> (int * real) list -> (int * real) list
+    val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
       int -> (int * real) list -> (int * real) list
-    val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) ->
+    val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
       int -> int -> (int * real) list -> (int * real) list
-    val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
-      (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
     val query : Proof.context -> bool -> mash_engine -> string list -> int ->
       (string * (string * real) list * string list) list * string list * (string * real) list ->
       string list
@@ -423,14 +423,12 @@
 
 (*
   num_facts = maximum number of theorems to check dependencies and symbols
-  num_visible_facts = do not return theorems over or equal to this number.
-    Must satisfy: num_visible_facts <= num_facts.
   get_deps = returns dependencies of a theorem
   get_sym_ths = get theorems that have this feature
   max_suggs = number of suggestions to return
   feats = features of the goal
 *)
-fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
+fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats =
   let
     (* Can be later used for TFIDF *)
     fun sym_wght _ = 1.0
@@ -457,7 +455,7 @@
     val _ = List.app do_feat feats
     val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
     val no_recommends = Unsynchronized.ref 0
-    val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
+    val recommends = Array.tabulate (num_facts, rpair 0.0)
     val age = Unsynchronized.ref 1000000000.0
 
     fun inc_recommend j v =
@@ -470,7 +468,7 @@
 
     val k = Unsynchronized.ref 0
     fun do_k k =
-      if k >= num_visible_facts then
+      if k >= num_facts then
         raise EXIT ()
       else
         let
@@ -496,8 +494,8 @@
       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
     while1 (); while2 ();
-    heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
-    ret [] (Integer.max 0 (num_visible_facts - max_suggs))
+    heap (Real.compare o pairself snd) max_suggs num_facts recommends;
+    ret [] (Integer.max 0 (num_facts - max_suggs))
   end
 
 val nb_def_prior_weight = 21 (* FUDGE *)
@@ -541,7 +539,7 @@
     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   end
 
-fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
+fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs feats
     (tfreq, sfreq, idf) =
   let
     val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
@@ -576,22 +574,21 @@
         res + tau * sum_of_weights
       end
 
-    val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
+    val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j)))
 
     fun ret acc at =
-      if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
+      if at = num_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   in
-    heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
-    ret [] (Integer.max 0 (num_visible_facts - max_suggs))
+    heap (Real.compare o pairself snd) max_suggs num_facts posterior;
+    ret [] (Integer.max 0 (num_facts - max_suggs))
   end
 
-fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs feats =
   learn num_facts get_deps get_feats num_feats
-  |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
+  |> naive_bayes_query opts num_facts max_suggs feats
 
 (* experimental *)
-fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
-    feats =
+fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
   let
     fun name_of_fact j = "f" ^ string_of_int j
     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -600,7 +597,7 @@
 
     val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
       map name_of_fact (get_deps j))) (0 upto num_facts - 1)
-    val parents' = parents_of num_visible_facts
+    val parents' = parents_of num_facts
     val feats' = map (apfst name_of_feature) feats
   in
     MaSh_Py.unlearn ctxt overlord;
@@ -655,10 +652,7 @@
 
 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   let
-    val visible_fact_set = Symtab.make_set visible_facts
-    val learns =
-      (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
-      (if null hints then [] else [(".hints", feats, hints)])
+    val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
   in
     if engine = MaSh_SML_kNN_Cpp then
       k_nearest_neighbors_cpp max_suggs learns (map fst feats)
@@ -666,7 +660,7 @@
       naive_bayes_cpp max_suggs learns (map fst feats)
     else
       let
-        val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
+        val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
           fold (fn (fact, feats, deps) =>
                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
               let
@@ -687,11 +681,12 @@
 
         val deps_vec = Vector.fromList (rev rev_depss)
 
-        val num_visible_facts = length visible_facts
         val get_deps = curry Vector.sub deps_vec
+
+        val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
       in
         trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
-          elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
+          elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
         (if engine = MaSh_SML_kNN then
            let
              val facts_ary = Array.array (num_feats, [])
@@ -704,10 +699,10 @@
                    end)
                  rev_featss num_facts
              val get_facts = curry Array.sub facts_ary
-             val feats' = map_filter (fn (feat, weight) =>
+             val int_feats = map_filter (fn (feat, weight) =>
                Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
            in
-             k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
+             k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats
            end
          else
            let
@@ -717,9 +712,9 @@
            in
              (case engine of
                MaSh_SML_NB opts =>
-               naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
-                 max_suggs int_feats
-             | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
+               naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs
+                 int_feats
+             | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
                  get_unweighted_feats num_feats max_suggs int_feats)
            end)
         |> map (curry Vector.sub fact_vec o fst)