# HG changeset patch # User blanchet # Date 1401445671 -7200 # Node ID 2f620ef839ee9fced439160f0b88237c8a654f57 # Parent e4c2c792226fe4505cb122e4cf16d8e38b9caf83 added another way of invoking Python code, for experiments diff -r e4c2c792226f -r 2f620ef839ee src/HOL/TPTP/mash_eval.ML --- a/src/HOL/TPTP/mash_eval.ML Fri May 30 12:27:51 2014 +0200 +++ b/src/HOL/TPTP/mash_eval.ML Fri May 30 12:27:51 2014 +0200 @@ -97,7 +97,7 @@ mesh_isar_line), mesh_prover_line)) = if in_range range j then let - val get_suggs = extract_suggestions ##> take max_suggs + val get_suggs = extract_suggestions ##> (take max_suggs #> map fst) val (name1, mepo_suggs) = get_suggs mepo_line val (name2, mash_isar_suggs) = get_suggs mash_isar_line val (name3, mash_prover_suggs) = get_suggs mash_prover_line diff -r e4c2c792226f -r 2f620ef839ee src/HOL/TPTP/mash_export.ML --- a/src/HOL/TPTP/mash_export.ML Fri May 30 12:27:51 2014 +0200 +++ b/src/HOL/TPTP/mash_export.ML Fri May 30 12:27:51 2014 +0200 @@ -285,10 +285,10 @@ let val (name, mash_suggs) = extract_suggestions mash_line - ||> weight_facts_steeply + ||> (map fst #> weight_facts_steeply) val (name', mepo_suggs) = extract_suggestions mepo_line - ||> weight_facts_steeply + ||> (map fst #> weight_facts_steeply) val _ = if name = name' then () else error "Input files out of sync." val mess = [(mepo_weight, (mepo_suggs, [])), diff -r e4c2c792226f -r 2f620ef839ee src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 12:27:51 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 12:27:51 2014 +0200 @@ -32,9 +32,9 @@ val decode_str : string -> string val decode_strs : string -> string list val encode_features : (string * real) list -> string - val extract_suggestions : string -> string * string list + val extract_suggestions : string -> string * (string * real) list - datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB + datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine @@ -48,16 +48,18 @@ val query : Proof.context -> bool -> int -> (string * string list * string list * string list) list * string list * string list * (string * real) list -> - string list + (string * real) list end structure MaSh_SML : sig val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) -> int -> (int * real) list -> (int * real) list - val naive_bayes : int -> int -> (int -> int list) -> (int -> Inttab.key list) -> int -> int -> - (Inttab.key * real) list -> (int * real) list - val query : Proof.context -> mash_engine -> string list -> int -> + val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int -> + (int * real) list -> (int * real) list + val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) -> + (int -> int list) -> int -> int -> (int * real) list -> (int * real) list + val query : Proof.context -> bool -> mash_engine -> string list -> int -> (string * (string * real) list * string list) list * string list * (string * real) list -> string list end @@ -144,7 +146,7 @@ () end -datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB +datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py fun mash_engine () = let val flag1 = Options.default_string @{system_option MaSh} in @@ -154,6 +156,7 @@ | "sml" => SOME MaSh_SML_NB | "sml_knn" => SOME MaSh_SML_kNN | "sml_nb" => SOME MaSh_SML_NB + | "sml_nb_py" => SOME MaSh_SML_NB_Py | _ => NONE) end @@ -267,8 +270,8 @@ (* The suggested weights do not make much sense. *) fun extract_suggestion sugg = (case space_explode "=" sugg of - [name, _ (* weight *)] => SOME (decode_str name) - | [name] => SOME (decode_str name) + [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0) + | [name] => SOME (decode_str name, 1.0) | _ => NONE) fun extract_suggestions line = @@ -458,7 +461,7 @@ (* TODO: Either use IDF or don't use it. See commented out code portions below. *) -fun naive_bayes_learn num_facts get_deps get_th_feats num_feats = +fun naive_bayes_learn num_facts get_deps get_feats num_feats = let val tfreq = Array.array (num_facts, 0) val sfreq = Array.array (num_facts, Inttab.empty) @@ -483,7 +486,7 @@ end fun for i = - if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1)) + if i = num_facts then () else (learn i (get_feats i) (get_deps i); for (i + 1)) in for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *)) end @@ -536,15 +539,34 @@ ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end -fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats = - naive_bayes_learn num_facts get_deps get_th_feats num_feats +fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = + naive_bayes_learn num_facts get_deps get_feats num_feats |> naive_bayes_query num_facts num_visible_facts max_suggs feats +(* experimental *) +fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs + feats = + let + fun name_of_fact j = "f" ^ string_of_int j + fun fact_of_name s = the (Int.fromString (unprefix "f" s)) + fun name_of_feature j = "F" ^ string_of_int j + fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)] + + val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), + map name_of_fact (get_deps j))) (0 upto num_facts - 1) + val parents' = parents_of num_visible_facts + val feats' = map (apfst name_of_feature) feats + in + MaSh_Py.unlearn ctxt overlord; + MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') + |> map (apfst fact_of_name) + end + fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) -fun query ctxt engine visible_facts max_suggs (learns, hints, feats) = +fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = let val visible_fact_set = Symtab.make_set visible_facts @@ -602,8 +624,8 @@ val get_unweighted_feats = curry Vector.sub unweighted_feats_ary val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in - naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs - feats' + (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord) + num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats' end) |> map (curry Vector.sub fact_vec o fst) end @@ -1301,6 +1323,7 @@ if engine = MaSh_Py then let val (parents, hints, feats) = query_args access_G in MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats) + |> map fst end else [])) @@ -1315,7 +1338,7 @@ val learns = Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G in - MaSh_SML.query ctxt engine visible_facts max_facts (learns, hints, feats) + MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats) end val unknown = filter_out (is_fact_in_graph access_G o snd) facts