# HG changeset patch # User blanchet # Date 1400535833 -7200 # Node ID 8cb6a5f1ae8435712b4b2dfae9acdbb212e09584 # Parent 10f68b83b474ea3cc7f7d02cdc3fe6d47829416e added SML implementation of MaSh diff -r 10f68b83b474 -r 8cb6a5f1ae84 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Mon May 19 23:43:53 2014 +0200 @@ -1,5 +1,6 @@ (* Title: HOL/Tools/Sledgehammer/sledgehammer_mash.ML Author: Jasmin Blanchette, TU Muenchen + Author: Cezary Kaliszyk, University of Innsbruck Sledgehammer's machine-learning-based relevance filter (MaSh). *) @@ -253,9 +254,7 @@ fun learn _ _ _ [] = () | learn ctxt overlord save learns = (trace_msg ctxt (fn () => - let val names = elide_string 1000 (space_implode " " (map #1 learns)) in - "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names) - end); + "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}"); run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn) (K ())) @@ -280,15 +279,203 @@ structure MaSh_SML = struct -fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list) - max_suggs (query as (_, _, _, feats)) = - (trace_msg ctxt (fn () => - let val names = elide_string 1000 (space_implode " " (map #1 learns)) in - "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^ - "MaSh_SML query " ^ encode_features feats - end); - (* Implementation missing *) - []) +fun max a b = if a > b then a else b + +exception BOTTOM of int + +fun heap cmp bnd a = + let + fun maxson l i = + let + val i31 = i + i + i + 1 + in + if i31 + 2 < l then + let + val x = Unsynchronized.ref i31; + val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else (); + val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else () + in + !x + end + else + if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS + then i31 + 1 else if i31 < l then i31 else raise BOTTOM i + end + + fun trickledown l i e = + let + val j = maxson l i + in + if cmp (Array.sub (a, j), e) = GREATER then + let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end + else Array.update (a, i, e) + end + + fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e) + + fun bubbledown l i = + let + val j = maxson l i + val () = Array.update (a, i, Array.sub (a, j)) + in + bubbledown l j + end + + fun bubble l i = bubbledown l i handle BOTTOM i => i + + fun trickleup i e = + let + val father = (i - 1) div 3 + in + if cmp (Array.sub (a, father), e) = LESS then + let + val () = Array.update (a, i, Array.sub (a, father)) + in + if father > 0 then trickleup father e else Array.update (a, 0, e) + end + else Array.update (a, i, e) + end + + val l = Array.length a + + fun for i = + if i < 0 then () else + let + val _ = trickle l i (Array.sub (a, i)) + in + for (i - 1) + end + + val () = for (((l + 1) div 3) - 1) + + fun for2 i = + if i < max 2 (l - bnd) then () else + let + val e = Array.sub (a, i) + val () = Array.update (a, i, Array.sub (a, 0)) + val () = trickleup (bubble i 0) e + in + for2 (i - 1) + end + + val () = for2 (l - 1) + in + if l > 1 then + let + val e = Array.sub (a, 1) + val () = Array.update (a, 1, Array.sub (a, 0)) + in + Array.update (a, 0, e) + end + else () + end + +(* + avail_no = maximum number of theorems to check dependencies and symbols + get_deps = returns dependencies of a theorem + get_sym_ths = get theorems that have this feature + knns = number of nearest neighbours + advno = number of predictions to return + syms = symbols of the conjecture +*) +fun knn avail_no get_deps get_sym_ths knns advno syms = + let + (* Can be later used for TFIDF *) + fun sym_wght _ = 1.0; + val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0))); + fun inc_overlap j v = + let + val ov = snd (Array.sub (overlaps_sqr,j)) + in + Array.update (overlaps_sqr, j, (j, v + ov)) + end; + fun do_sym (s, con_wght) = + let + val sw = sym_wght s; + val w2 = sw * sw * con_wght; + fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else () + in + ignore (map do_th (get_sym_ths s)) + end; + val () = ignore (map do_sym syms); + val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr; + val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0))); + fun inc_recommend j v = + let + val ov = snd (Array.sub (recommends,j)) + in + Array.update (recommends, j, (j, v + ov)) + end; + fun for k = + if k = knns then () else + if k >= avail_no then () else + let + val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1); + val o1 = Math.sqrt o2; + val () = inc_recommend j o1; + val ds = get_deps j; + val l = Real.fromInt (length ds); + val _ = map (fn d => inc_recommend d (o1 / l)) ds + in + for (k + 1) + end; + val () = for 0; + val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends; + fun ret acc at = + if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) + in + ret [] (max 0 (avail_no - advno)) + end + +val knns = 40 (* FUDGE *) + +fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) + +fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i))) + +(* TODO: take weight components of "feats" into consideration *) +fun learn_and_query ctxt parents access_G max_suggs hints feats = + let + val str_of_feat = space_implode "|" + + val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) = + fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => + let + val (_, feats, deps) = Graph.get_node access_G fact + + fun add_feat feat (xtab as (n, tab, _)) = + (case Symtab.lookup tab feat of + SOME i => (i, xtab) + | NONE => (n, add_to_xtab feat xtab)) + + val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab + in + (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, + add_to_xtab fact fact_xtab, feat_xtab') + end) + (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) + + val facts = rev facts0 + val fact_ary = Array.fromList facts + + val deps_ary = Array.fromList (rev depss0) + val facts_ary = Array.array (num_feats, []) + val _ = + fold (fn feats => fn fact => + let val fact' = fact - 1 in + (List.app (map_array_index facts_ary (cons fact')) feats; fact') + end) + featss (length featss) + in + (trace_msg ctxt (fn () => + "MaSh_SML query " ^ encode_features feats ^ " from {" ^ + elide_string 1000 (space_implode " " facts) ^ "}"); + knn (Array.length deps_ary) (curry Array.sub deps_ary) + (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs + (map_filter (fn (feat, weight) => + Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) + |> map ((fn i => Array.sub (fact_ary, i)) o fst)) + end end; @@ -574,7 +761,7 @@ fun sort_of_type alg T = let - val graph = Sorts.classes_of alg + val G = Sorts.classes_of alg fun cls_of S [] = S | cls_of S (cl :: cls) = @@ -585,7 +772,7 @@ cls_of S (union (op =) cls' cls) end in - cls_of [] (Graph.maximals graph) + cls_of [] (Graph.maximals G) end val generalize_goal = false @@ -975,13 +1162,6 @@ fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t []) -fun learn_of_graph graph = - let - fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps) - in - Graph.schedule sched graph - end - fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt @@ -1025,9 +1205,10 @@ |> map (nickname_of_thm o snd) in (access_G, - (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G) - else MaSh_Py.query ctxt overlord) - max_facts ([], hints, parents, feats)) + if Config.get ctxt sml then + MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats + else + MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)) end) val unknown = filter_out (is_fact_in_graph access_G o snd) facts in @@ -1035,27 +1216,27 @@ |> pairself (map fact_of_raw_fact) end -fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) = +fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = let - fun maybe_learn_from from (accum as (parents, graph)) = - try_graph ctxt "updating graph" accum (fn () => - (from :: parents, Graph.add_edge_acyclic (from, name) graph)) - val graph = graph |> Graph.default_node (name, (Isar_Proof, feats, deps)) - val (parents, graph) = ([], graph) |> fold maybe_learn_from parents - val (deps, _) = ([], graph) |> fold maybe_learn_from deps + fun maybe_learn_from from (accum as (parents, G)) = + try_graph ctxt "updating G" accum (fn () => + (from :: parents, Graph.add_edge_acyclic (from, name) G)) + val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps)) + val (parents, G) = ([], G) |> fold maybe_learn_from parents + val (deps, _) = ([], G) |> fold maybe_learn_from deps in - ((name, parents, feats, deps) :: learns, graph) + ((name, parents, feats, deps) :: learns, G) end -fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) = +fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) = let - fun maybe_relearn_from from (accum as (parents, graph)) = + fun maybe_relearn_from from (accum as (parents, G)) = try_graph ctxt "updating graph" accum (fn () => - (from :: parents, Graph.add_edge_acyclic (from, name) graph)) - val graph = graph |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps)) - val (deps, _) = ([], graph) |> fold maybe_relearn_from deps + (from :: parents, Graph.add_edge_acyclic (from, name) G)) + val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps)) + val (deps, _) = ([], G) |> fold maybe_relearn_from deps in - ((name, deps) :: relearns, graph) + ((name, deps) :: relearns, G) end fun flop_wrt_access_graph name =