--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200
@@ -1,5 +1,6 @@
(* Title: HOL/Tools/Sledgehammer/sledgehammer_mash.ML
Author: Jasmin Blanchette, TU Muenchen
+ Author: Cezary Kaliszyk, University of Innsbruck
Sledgehammer's machine-learning-based relevance filter (MaSh).
*)
@@ -253,9 +254,7 @@
fun learn _ _ _ [] = ()
| learn ctxt overlord save learns =
(trace_msg ctxt (fn () =>
- let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
- "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
- end);
+ "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}");
run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
(K ()))
@@ -280,15 +279,203 @@
structure MaSh_SML =
struct
-fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
- max_suggs (query as (_, _, _, feats)) =
- (trace_msg ctxt (fn () =>
- let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
- "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
- "MaSh_SML query " ^ encode_features feats
- end);
- (* Implementation missing *)
- [])
+fun max a b = if a > b then a else b
+
+exception BOTTOM of int
+
+fun heap cmp bnd a =
+ let
+ fun maxson l i =
+ let
+ val i31 = i + i + i + 1
+ in
+ if i31 + 2 < l then
+ let
+ val x = Unsynchronized.ref i31;
+ val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
+ val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
+ in
+ !x
+ end
+ else
+ if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS
+ then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
+ end
+
+ fun trickledown l i e =
+ let
+ val j = maxson l i
+ in
+ if cmp (Array.sub (a, j), e) = GREATER then
+ let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
+ else Array.update (a, i, e)
+ end
+
+ fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
+
+ fun bubbledown l i =
+ let
+ val j = maxson l i
+ val () = Array.update (a, i, Array.sub (a, j))
+ in
+ bubbledown l j
+ end
+
+ fun bubble l i = bubbledown l i handle BOTTOM i => i
+
+ fun trickleup i e =
+ let
+ val father = (i - 1) div 3
+ in
+ if cmp (Array.sub (a, father), e) = LESS then
+ let
+ val () = Array.update (a, i, Array.sub (a, father))
+ in
+ if father > 0 then trickleup father e else Array.update (a, 0, e)
+ end
+ else Array.update (a, i, e)
+ end
+
+ val l = Array.length a
+
+ fun for i =
+ if i < 0 then () else
+ let
+ val _ = trickle l i (Array.sub (a, i))
+ in
+ for (i - 1)
+ end
+
+ val () = for (((l + 1) div 3) - 1)
+
+ fun for2 i =
+ if i < max 2 (l - bnd) then () else
+ let
+ val e = Array.sub (a, i)
+ val () = Array.update (a, i, Array.sub (a, 0))
+ val () = trickleup (bubble i 0) e
+ in
+ for2 (i - 1)
+ end
+
+ val () = for2 (l - 1)
+ in
+ if l > 1 then
+ let
+ val e = Array.sub (a, 1)
+ val () = Array.update (a, 1, Array.sub (a, 0))
+ in
+ Array.update (a, 0, e)
+ end
+ else ()
+ end
+
+(*
+ avail_no = maximum number of theorems to check dependencies and symbols
+ get_deps = returns dependencies of a theorem
+ get_sym_ths = get theorems that have this feature
+ knns = number of nearest neighbours
+ advno = number of predictions to return
+ syms = symbols of the conjecture
+*)
+fun knn avail_no get_deps get_sym_ths knns advno syms =
+ let
+ (* Can be later used for TFIDF *)
+ fun sym_wght _ = 1.0;
+ val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)));
+ fun inc_overlap j v =
+ let
+ val ov = snd (Array.sub (overlaps_sqr,j))
+ in
+ Array.update (overlaps_sqr, j, (j, v + ov))
+ end;
+ fun do_sym (s, con_wght) =
+ let
+ val sw = sym_wght s;
+ val w2 = sw * sw * con_wght;
+ fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
+ in
+ ignore (map do_th (get_sym_ths s))
+ end;
+ val () = ignore (map do_sym syms);
+ val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
+ val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)));
+ fun inc_recommend j v =
+ let
+ val ov = snd (Array.sub (recommends,j))
+ in
+ Array.update (recommends, j, (j, v + ov))
+ end;
+ fun for k =
+ if k = knns then () else
+ if k >= avail_no then () else
+ let
+ val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1);
+ val o1 = Math.sqrt o2;
+ val () = inc_recommend j o1;
+ val ds = get_deps j;
+ val l = Real.fromInt (length ds);
+ val _ = map (fn d => inc_recommend d (o1 / l)) ds
+ in
+ for (k + 1)
+ end;
+ val () = for 0;
+ val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+ fun ret acc at =
+ if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+ in
+ ret [] (max 0 (avail_no - advno))
+ end
+
+val knns = 40 (* FUDGE *)
+
+fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
+
+fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
+
+(* TODO: take weight components of "feats" into consideration *)
+fun learn_and_query ctxt parents access_G max_suggs hints feats =
+ let
+ val str_of_feat = space_implode "|"
+
+ val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
+ fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+ let
+ val (_, feats, deps) = Graph.get_node access_G fact
+
+ fun add_feat feat (xtab as (n, tab, _)) =
+ (case Symtab.lookup tab feat of
+ SOME i => (i, xtab)
+ | NONE => (n, add_to_xtab feat xtab))
+
+ val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
+ in
+ (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
+ add_to_xtab fact fact_xtab, feat_xtab')
+ end)
+ (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+
+ val facts = rev facts0
+ val fact_ary = Array.fromList facts
+
+ val deps_ary = Array.fromList (rev depss0)
+ val facts_ary = Array.array (num_feats, [])
+ val _ =
+ fold (fn feats => fn fact =>
+ let val fact' = fact - 1 in
+ (List.app (map_array_index facts_ary (cons fact')) feats; fact')
+ end)
+ featss (length featss)
+ in
+ (trace_msg ctxt (fn () =>
+ "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+ elide_string 1000 (space_implode " " facts) ^ "}");
+ knn (Array.length deps_ary) (curry Array.sub deps_ary)
+ (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
+ (map_filter (fn (feat, weight) =>
+ Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
+ |> map ((fn i => Array.sub (fact_ary, i)) o fst))
+ end
end;
@@ -574,7 +761,7 @@
fun sort_of_type alg T =
let
- val graph = Sorts.classes_of alg
+ val G = Sorts.classes_of alg
fun cls_of S [] = S
| cls_of S (cl :: cls) =
@@ -585,7 +772,7 @@
cls_of S (union (op =) cls' cls)
end
in
- cls_of [] (Graph.maximals graph)
+ cls_of [] (Graph.maximals G)
end
val generalize_goal = false
@@ -975,13 +1162,6 @@
fun add_const_counts t =
fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
-fun learn_of_graph graph =
- let
- fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
- in
- Graph.schedule sched graph
- end
-
fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
let
val thy = Proof_Context.theory_of ctxt
@@ -1025,9 +1205,10 @@
|> map (nickname_of_thm o snd)
in
(access_G,
- (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
- else MaSh_Py.query ctxt overlord)
- max_facts ([], hints, parents, feats))
+ if Config.get ctxt sml then
+ MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
+ else
+ MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
end)
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
@@ -1035,27 +1216,27 @@
|> pairself (map fact_of_raw_fact)
end
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
let
- fun maybe_learn_from from (accum as (parents, graph)) =
- try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) graph))
- val graph = graph |> Graph.default_node (name, (Isar_Proof, feats, deps))
- val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
- val (deps, _) = ([], graph) |> fold maybe_learn_from deps
+ fun maybe_learn_from from (accum as (parents, G)) =
+ try_graph ctxt "updating G" accum (fn () =>
+ (from :: parents, Graph.add_edge_acyclic (from, name) G))
+ val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
+ val (parents, G) = ([], G) |> fold maybe_learn_from parents
+ val (deps, _) = ([], G) |> fold maybe_learn_from deps
in
- ((name, parents, feats, deps) :: learns, graph)
+ ((name, parents, feats, deps) :: learns, G)
end
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
let
- fun maybe_relearn_from from (accum as (parents, graph)) =
+ fun maybe_relearn_from from (accum as (parents, G)) =
try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) graph))
- val graph = graph |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
- val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
+ (from :: parents, Graph.add_edge_acyclic (from, name) G))
+ val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
+ val (deps, _) = ([], G) |> fold maybe_relearn_from deps
in
- ((name, deps) :: relearns, graph)
+ ((name, deps) :: relearns, G)
end
fun flop_wrt_access_graph name =