added another way of invoking Python code, for experiments
authorblanchet
Fri, 30 May 2014 12:27:51 +0200
changeset 57125 2f620ef839ee
parent 57124 e4c2c792226f
child 57128 4874411752fe
added another way of invoking Python code, for experiments
src/HOL/TPTP/mash_eval.ML
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/TPTP/mash_eval.ML	Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/TPTP/mash_eval.ML	Fri May 30 12:27:51 2014 +0200
@@ -97,7 +97,7 @@
                          mesh_isar_line), mesh_prover_line)) =
       if in_range range j then
         let
-          val get_suggs = extract_suggestions ##> take max_suggs
+          val get_suggs = extract_suggestions ##> (take max_suggs #> map fst)
           val (name1, mepo_suggs) = get_suggs mepo_line
           val (name2, mash_isar_suggs) = get_suggs mash_isar_line
           val (name3, mash_prover_suggs) = get_suggs mash_prover_line
--- a/src/HOL/TPTP/mash_export.ML	Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/TPTP/mash_export.ML	Fri May 30 12:27:51 2014 +0200
@@ -285,10 +285,10 @@
       let
         val (name, mash_suggs) =
           extract_suggestions mash_line
-          ||> weight_facts_steeply
+          ||> (map fst #> weight_facts_steeply)
         val (name', mepo_suggs) =
           extract_suggestions mepo_line
-          ||> weight_facts_steeply
+          ||> (map fst #> weight_facts_steeply)
         val _ = if name = name' then () else error "Input files out of sync."
         val mess =
           [(mepo_weight, (mepo_suggs, [])),
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri May 30 12:27:51 2014 +0200
@@ -32,9 +32,9 @@
   val decode_str : string -> string
   val decode_strs : string -> string list
   val encode_features : (string * real) list -> string
-  val extract_suggestions : string -> string * string list
+  val extract_suggestions : string -> string * (string * real) list
 
-  datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
+  datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
 
   val is_mash_enabled : unit -> bool
   val the_mash_engine : unit -> mash_engine
@@ -48,16 +48,18 @@
     val query : Proof.context -> bool -> int ->
       (string * string list * string list * string list) list * string list * string list
         * (string * real) list ->
-      string list
+      (string * real) list
   end
 
   structure MaSh_SML :
   sig
     val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
       int -> (int * real) list -> (int * real) list
-    val naive_bayes : int -> int -> (int -> int list) -> (int -> Inttab.key list) -> int -> int ->
-      (Inttab.key * real) list -> (int * real) list
-    val query : Proof.context -> mash_engine -> string list -> int ->
+    val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int ->
+      (int * real) list -> (int * real) list
+    val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
+      (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
+    val query : Proof.context -> bool -> mash_engine -> string list -> int ->
       (string * (string * real) list * string list) list * string list * (string * real) list ->
       string list
   end
@@ -144,7 +146,7 @@
     ()
   end
 
-datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
+datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
 
 fun mash_engine () =
   let val flag1 = Options.default_string @{system_option MaSh} in
@@ -154,6 +156,7 @@
     | "sml" => SOME MaSh_SML_NB
     | "sml_knn" => SOME MaSh_SML_kNN
     | "sml_nb" => SOME MaSh_SML_NB
+    | "sml_nb_py" => SOME MaSh_SML_NB_Py
     | _ => NONE)
   end
 
@@ -267,8 +270,8 @@
 (* The suggested weights do not make much sense. *)
 fun extract_suggestion sugg =
   (case space_explode "=" sugg of
-    [name, _ (* weight *)] => SOME (decode_str name)
-  | [name] => SOME (decode_str name)
+    [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0)
+  | [name] => SOME (decode_str name, 1.0)
   | _ => NONE)
 
 fun extract_suggestions line =
@@ -458,7 +461,7 @@
 
 (* TODO: Either use IDF or don't use it. See commented out code portions below. *)
 
-fun naive_bayes_learn num_facts get_deps get_th_feats num_feats =
+fun naive_bayes_learn num_facts get_deps get_feats num_feats =
   let
     val tfreq = Array.array (num_facts, 0)
     val sfreq = Array.array (num_facts, Inttab.empty)
@@ -483,7 +486,7 @@
       end
 
     fun for i =
-      if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1))
+      if i = num_facts then () else (learn i (get_feats i) (get_deps i); for (i + 1))
   in
     for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *))
   end
@@ -536,15 +539,34 @@
     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
-fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats =
-  naive_bayes_learn num_facts get_deps get_th_feats num_feats
+fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+  naive_bayes_learn num_facts get_deps get_feats num_feats
   |> naive_bayes_query num_facts num_visible_facts max_suggs feats
 
+(* experimental *)
+fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
+    feats =
+  let
+    fun name_of_fact j = "f" ^ string_of_int j
+    fun fact_of_name s = the (Int.fromString (unprefix "f" s))
+    fun name_of_feature j = "F" ^ string_of_int j
+    fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
+
+    val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
+      map name_of_fact (get_deps j))) (0 upto num_facts - 1)
+    val parents' = parents_of num_visible_facts
+    val feats' = map (apfst name_of_feature) feats
+  in
+    MaSh_Py.unlearn ctxt overlord;
+    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
+    |> map (apfst fact_of_name)
+  end
+
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
-fun query ctxt engine visible_facts max_suggs (learns, hints, feats) =
+fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
   let
     val visible_fact_set = Symtab.make_set visible_facts
 
@@ -602,8 +624,8 @@
          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
        in
-         naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
-           feats'
+         (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord)
+           num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
        end)
     |> map (curry Vector.sub fact_vec o fst)
   end
@@ -1301,6 +1323,7 @@
            if engine = MaSh_Py then
              let val (parents, hints, feats) = query_args access_G in
                MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+               |> map fst
              end
            else
              []))
@@ -1315,7 +1338,7 @@
           val learns =
             Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
         in
-          MaSh_SML.query ctxt engine visible_facts max_facts (learns, hints, feats)
+          MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats)
         end
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts