refactoring
authorblanchet
Thu, 26 Jun 2014 13:36:25 +0200
changeset 57376 f40ac83d076c
parent 57375 b75438e23925
child 57377 73e9b858ec8d
child 57389 eb96243a25c5
refactoring
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:22 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:25 2014 +0200
@@ -398,7 +398,7 @@
 
 exception EXIT of unit
 
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts conj_feats =
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
   let
     val ln_afreq = Math.ln (Real.fromInt num_facts)
     fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
@@ -419,7 +419,7 @@
         List.app do_th (Array.sub (feat_facts, s))
       end
 
-    val _ = List.app do_feat conj_feats
+    val _ = List.app do_feat goal_feats
     val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
     val no_recommends = Unsynchronized.ref 0
     val recommends = Array.tabulate (num_facts, rpair 0.0)
@@ -512,7 +512,7 @@
     (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
   end
 
-fun naive_bayes tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
+fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs visible_facts goal_feats =
   let
     val tau = 0.05 (* FUDGE *)
     val pos_weight = 10.0 (* FUDGE *)
@@ -534,7 +534,7 @@
              Inttab.delete f sfh)
           | NONE => (res + tfidf f * def_val, sfh))
 
-        val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Vector.sub (sfreq, i))
+        val (res, sfh) = fold fold_feats goal_feats (Math.ln tfreq, Vector.sub (sfreq, i))
 
         fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
 
@@ -554,7 +554,7 @@
   end
 
 (* experimental *)
-fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs conj_feats =
+fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs goal_feats =
   let
     fun name_of_fact j = "f" ^ string_of_int j
     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -565,16 +565,16 @@
       map name_of_feature (Vector.sub (featss, j)),
       map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
     val parents' = parents_of num_facts
-    val conj_feats' = map (rpair 1.0 o name_of_feature) conj_feats
+    val goal_feats' = map (rpair 1.0 o name_of_feature) goal_feats
   in
     MaSh_Py.unlearn ctxt overlord;
     OS.Process.sleep (seconds 2.0); (* hack *)
-    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', conj_feats')
+    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', goal_feats')
     |> map (apfst fact_of_name)
   end
 
 (* experimental *)
-fun experimental_external_tool tool max_suggs learns cfeats =
+fun external_tool tool max_suggs learns goal_feats =
   let
     val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
     val ocs = TextIO.openOut ("adv_syms" ^ ser)
@@ -603,67 +603,38 @@
         |> filter_out (curry (op =) "")
       end
   in
-    (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats;
+    (List.app do_learn learns; ol occ (os occ o quote) ", " goal_feats;
      TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
      forkexec max_suggs)
   end
 
 val k_nearest_neighbors_ext =
-  experimental_external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
-val naive_bayes_ext = experimental_external_tool "predict/nbayes"
-
-fun reorder_learns (num_facts, fact_tab) learns0 =
-  let
-    val learns = Array.array (num_facts, ("", [], []))
-  in
-    List.app (fn learn as (fact, _, _) =>
-        Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
-      learns0;
-    Array.foldr (op ::) [] learns
-  end
+  external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
+val naive_bayes_ext = external_tool "predict/nbayes"
 
-fun query ctxt engine (fact_xtab as (num_facts, fact_tab)) (num_feats, feat_tab) visible_facts
-    max_suggs learns0 conj_feats =
-  if engine = MaSh_SML_kNN_Ext then
-    k_nearest_neighbors_ext max_suggs learns0 conj_feats
-  else if engine = MaSh_SML_NB_Ext then
-    naive_bayes_ext max_suggs learns0 conj_feats
-  else
-    let
-      val learns = reorder_learns fact_xtab learns0
-
-      val facts = Vector.fromList (map #1 learns)
-      val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
-      val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+fun query_external ctxt engine max_suggs learns goal_feats =
+  (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_strs goal_feats);
+   (case engine of
+     MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
+   | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
 
-      val tfreq = Vector.tabulate (num_facts, K 0)
-      val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
-      val dffreq = Vector.tabulate (num_feats, K 0)
-
-      val (tfreq, sfreq, dffreq) =
-        learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
-
-      val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
-      val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
-    in
-      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
-        elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
-      (case engine of
-        MaSh_SML_kNN =>
-        let
-          val feat_facts = Array.array (num_feats, [])
-          val _ =
-            Vector.foldl (fn (feats, fact) =>
-                (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
-              0 featss
-        in
-          k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs int_visible_facts
-            int_conj_feats
-        end
-      | MaSh_SML_NB =>
-        naive_bayes tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
-      |> map (curry Vector.sub facts o fst)
-    end
+fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
+    visible_facts max_suggs goal_feats int_goal_feats =
+  (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
+     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
+   (case engine of
+     MaSh_SML_kNN =>
+     let
+       val feat_facts = Array.array (num_feats, [])
+       val _ =
+         Vector.foldl (fn (feats, fact) =>
+             (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
+           0 featss
+     in
+       k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
+     end
+   | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
+   |> map (curry Vector.sub facts o fst))
 
 end;
 
@@ -1312,7 +1283,17 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
-fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
+fun reorder_learns (num_facts, fact_tab) learns0 =
+  let
+    val learns = Array.array (num_facts, ("", [], []))
+  in
+    List.app (fn learn as (fact, _, _) =>
+        Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
+      learns0;
+    Array.foldr (op ::) [] learns
+  end
+
+fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
     val thy_name = Context.theory_name thy
@@ -1360,14 +1341,14 @@
         (parents, hints, feats)
       end
 
-    val ((access_G, (fact_xtab, feat_xtab)), py_suggs) =
+    val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
       peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
         ((access_G, xtabs),
          if Graph.is_empty access_G then
            (trace_msg ctxt (K "Nothing has been learned yet"); [])
          else if engine = MaSh_Py then
            let val (parents, hints, feats) = query_args access_G in
-             MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+             MaSh_Py.query ctxt overlord max_suggs ([], hints, parents, feats)
              |> map fst
            end
          else
@@ -1378,19 +1359,46 @@
         []
       else
         let
-          val (parents, hints, feats0) = query_args access_G
-          val feats = map fst feats0
-          val visible_facts = Graph.all_preds access_G parents
-          val learns =
-            (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
-            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+          val (parents, hints, goal_feats0) = query_args access_G
+          val goal_feats = map fst goal_feats0
+          val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
         in
-          MaSh_SML.query ctxt engine fact_xtab feat_xtab visible_facts max_facts learns feats
+          if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then
+            let
+              val learns =
+                (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *)
+                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+            in
+              MaSh_SML.query_external ctxt engine max_suggs learns goal_feats
+            end
+          else
+            let
+              val learns0 =
+                (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *)
+                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+              val learns = reorder_learns fact_xtab learns0
+
+              val facts = Vector.fromList (map #1 learns)
+              val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+              val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+
+              val tfreq = Vector.tabulate (num_facts, K 0)
+              val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+              val dffreq = Vector.tabulate (num_feats, K 0)
+
+              val freqs' =
+                MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
+
+              val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
+            in
+              MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
+                visible_facts max_suggs goal_feats int_goal_feats
+            end
         end
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
-    find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
+    find_mash_suggestions ctxt max_suggs (py_suggs @ sml_suggs) facts chained unknown
     |> pairself (map fact_of_raw_fact)
   end