src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57373 e9d47cd3239b
parent 57372 24738b4f8c6b
child 57374 cb6667e7cbc1
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:00 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:06 2014 +0200
@@ -398,16 +398,8 @@
 
 exception EXIT of unit
 
-fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts num_feats feats =
+fun k_nearest_neighbors dffreq num_facts deps_vec get_sym_ths max_suggs visible_facts conj_feats =
   let
-    val dffreq = Array.array (num_feats, 0)
-
-    val add_sym = map_array_at dffreq (Integer.add 1)
-    fun for1 i =
-      if i = num_feats then () else
-      (List.app (fn _ => add_sym i) (get_sym_ths i); for1 (i + 1))
-    val _ = for1 0
-
     val ln_afreq = Math.ln (Real.fromInt num_facts)
     fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat)))
 
@@ -427,7 +419,7 @@
         List.app do_th (get_sym_ths s)
       end
 
-    val _ = List.app do_feat feats
+    val _ = List.app do_feat conj_feats
     val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
     val no_recommends = Unsynchronized.ref 0
     val recommends = Array.tabulate (num_facts, rpair 0.0)
@@ -447,7 +439,7 @@
           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
           val o1 = Math.sqrt o2
           val _ = inc_recommend j o1
-          val ds = get_deps j
+          val ds = Vector.sub (deps_vec, j)
           val l = Real.fromInt (length ds)
         in
           List.app (fn d => inc_recommend d (o1 / l)) ds
@@ -474,7 +466,7 @@
 
 val nb_def_prior_weight = 21 (* FUDGE *)
 
-fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
+fun learn_facts tfreq sfreq dffreq num_facts depss featss =
   let
     fun learn_fact th feats deps =
       let
@@ -495,35 +487,26 @@
       end
 
     fun for i =
-      if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
+      if i = num_facts then ()
+      else (learn_fact i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1))
   in
-    for 0;
-    (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
+    for 0
   end
 
-fun learn num_facts get_deps get_feats num_feats =
-  let
-    val tfreq = Array.array (num_facts, 0)
-    val sfreq = Array.array (num_facts, Inttab.empty)
-    val dffreq = Array.array (num_feats, 0)
-  in
-    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
-  end
-
-fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
+fun naive_bayes_query tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
   let
     val tau = 0.05 (* FUDGE *)
     val pos_weight = 10.0 (* FUDGE *)
     val def_val = ~15.0 (* FUDGE *)
 
     val ln_afreq = Math.ln (Real.fromInt num_facts)
-    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
+    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq)
 
     fun tfidf feat = Vector.sub (idf, feat)
 
     fun log_posterior i =
       let
-        val tfreq = Real.fromInt (Vector.sub (tfreq, i))
+        val tfreq = Real.fromInt (Array.sub (tfreq, i))
 
         fun fold_feats f (res, sfh) =
           (case Inttab.lookup sfh f of
@@ -532,7 +515,7 @@
              Inttab.delete f sfh)
           | NONE => (res + tfidf f * def_val, sfh))
 
-        val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
+        val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Array.sub (sfreq, i))
 
         fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
 
@@ -551,26 +534,23 @@
     ret (Integer.max 0 (num_facts - max_suggs)) []
   end
 
-fun naive_bayes num_facts get_deps get_feats num_feats max_suggs visible_facts feats =
-  learn num_facts get_deps get_feats num_feats
-  |> naive_bayes_query num_facts max_suggs visible_facts feats
-
 (* experimental *)
-fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
+fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs conj_feats =
   let
     fun name_of_fact j = "f" ^ string_of_int j
     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
     fun name_of_feature j = "F" ^ string_of_int j
     fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
 
-    val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
-      map name_of_fact (get_deps j))) (0 upto num_facts - 1)
+    val learns = map (fn j => (name_of_fact j, parents_of j,
+      map name_of_feature (Vector.sub (featss, j)),
+      map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
     val parents' = parents_of num_facts
-    val feats' = map (rpair 1.0 o name_of_feature) feats
+    val conj_feats' = map (rpair 1.0 o name_of_feature) conj_feats
   in
     MaSh_Py.unlearn ctxt overlord;
     OS.Process.sleep (seconds 2.0); (* hack *)
-    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
+    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', conj_feats')
     |> map (apfst fact_of_name)
   end
 
@@ -633,9 +613,14 @@
       val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
 
       val fact_vec = Vector.fromList facts
+      val feats_vec = Vector.fromList featss
       val deps_vec = Vector.fromList depss
 
-      val get_deps = curry Vector.sub deps_vec
+      val tfreq = Array.array (num_facts, 0)
+      val sfreq = Array.array (num_facts, Inttab.empty)
+      val dffreq = Array.array (num_feats, 0)
+
+      val _ = learn_facts tfreq sfreq dffreq num_facts deps_vec feats_vec
 
       val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
       val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
@@ -652,17 +637,11 @@
               featss 0
           val get_facts = curry Array.sub facts_ary
         in
-          k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
+          k_nearest_neighbors dffreq num_facts deps_vec get_facts max_suggs int_visible_facts
             int_conj_feats
         end
       | MaSh_SML_NB =>
-        let
-          val feats_ary = Vector.fromList featss
-          val get_feats = curry Vector.sub feats_ary
-        in
-          naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
-            int_conj_feats
-        end)
+        naive_bayes_query tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
       |> map (curry Vector.sub fact_vec o fst)
     end