src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57369 6d422f19cefb
parent 57368 b89937ed6099
child 57370 9d420da6c7e2
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:39 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:46 2014 +0200
@@ -46,30 +46,6 @@
   val is_mash_enabled : unit -> bool
   val the_mash_engine : unit -> mash_engine
 
-  structure MaSh_Py :
-  sig
-    val unlearn : Proof.context -> bool -> unit
-    val learn : Proof.context -> bool -> bool ->
-      (string * string list * string list * string list) list -> unit
-    val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
-    val query : Proof.context -> bool -> int ->
-      (string * string list * string list * string list) list * string list * string list
-        * (string * real) list ->
-      (string * real) list
-  end
-
-  structure MaSh_SML :
-  sig
-    val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
-      int -> int list -> (int * real) list
-    val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list ->
-      int list -> (int * real) list
-    val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
-      int -> int -> int list -> (int * real) list
-    val query : Proof.context -> bool -> mash_engine -> string list -> int ->
-      (string * string list * string list) list * string list * string list -> string list
-  end
-
   val mash_unlearn : Proof.context -> params -> unit
   val nickname_of_thm : thm -> string
   val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
@@ -492,7 +468,7 @@
 
 val nb_def_prior_weight = 21 (* FUDGE *)
 
-fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
+fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
   let
     fun learn_fact th feats deps =
       let
@@ -525,7 +501,7 @@
     val sfreq = Array.array (num_facts, Inttab.empty)
     val dffreq = Array.array (num_feats, 0)
   in
-    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
+    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
   end
 
 fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
@@ -574,7 +550,7 @@
   |> naive_bayes_query num_facts max_suggs visible_facts feats
 
 (* experimental *)
-fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
   let
     fun name_of_fact j = "f" ^ string_of_int j
     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -631,66 +607,54 @@
   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
 
-val empty_xtab = (0, Symtab.empty)
-
-fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
-fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
+    learns conj_feats =
+  if engine = MaSh_SML_kNN_Cpp then
+    k_nearest_neighbors_cpp max_suggs learns conj_feats
+  else if engine = MaSh_SML_NB_Cpp then
+    naive_bayes_cpp max_suggs learns conj_feats
+  else
+    let
+      val facts = map #1 learns
+      val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
+      val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
 
-fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
-  let
-    val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
-  in
-    if engine = MaSh_SML_kNN_Cpp then
-      k_nearest_neighbors_cpp max_suggs learns feats
-    else if engine = MaSh_SML_NB_Cpp then
-      naive_bayes_cpp max_suggs learns feats
-    else
-      let
-        val facts = map #1 learns
-        val fact_vec = Vector.fromList facts
+      val fact_vec = Vector.fromList facts
+      val deps_vec = Vector.fromList depss
 
-        val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
-        val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
-
-        val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
-
-        val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
-
-        val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
-
-        val get_deps = curry Vector.sub deps_vec
+      val get_deps = curry Vector.sub deps_vec
 
-        val int_feats = map_filter (Symtab.lookup feat_tab) feats
-      in
-        trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
-          elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
-        (if engine = MaSh_SML_kNN then
-           let
-             val facts_ary = Array.array (num_feats, [])
-             val _ =
-               fold (fn feats => fn fact =>
-                   (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
-                 featss 0
-             val get_facts = curry Array.sub facts_ary
-           in
-             k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
-               int_feats
-           end
-         else
-           let
-             val unweighted_feats_ary = Vector.fromList featss
-             val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
-           in
-             (case engine of
-               MaSh_SML_NB =>
-               naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
-                 int_visible_facts int_feats
-             | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
-                 get_unweighted_feats num_feats max_suggs int_feats)
-           end)
-        |> map (curry Vector.sub fact_vec o fst)
-      end
-  end
+      val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
+      val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
+    in
+      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
+        elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
+      (if engine = MaSh_SML_kNN then
+         let
+           val facts_ary = Array.array (num_feats, [])
+           val _ =
+             fold (fn feats => fn fact =>
+                 (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
+               featss 0
+           val get_facts = curry Array.sub facts_ary
+         in
+           k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
+             int_conj_feats
+         end
+       else
+         let
+           val feats_ary = Vector.fromList featss
+           val get_feats = curry Vector.sub feats_ary
+         in
+           (case engine of
+             MaSh_SML_NB =>
+             naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
+               int_conj_feats
+           | MaSh_SML_NB_Py =>
+             naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs int_conj_feats)
+         end)
+      |> map (curry Vector.sub fact_vec o fst)
+    end
 
 end;
 
@@ -1328,6 +1292,11 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
+val empty_xtab = (0, Symtab.empty)
+
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
+fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+
 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
@@ -1395,12 +1364,18 @@
         []
       else
         let
-          val (parents, hints, feats) = query_args access_G
+          val (parents, hints, feats0) = query_args access_G
+          val feats = map fst feats0
           val visible_facts = Graph.all_preds access_G parents
           val learns =
-            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
+            (if null hints then [] else [(".hints", feats, hints)])
+
+          val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
+          val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
         in
-          MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
+          MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
+            feats
         end
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts