# HG changeset patch # User blanchet # Date 1400571519 -7200 # Node ID ed95456499e6714fad842aaefa228dd1fb25115c # Parent 43fd82a537a3f36b974fa419654a3cfb56350de1 better way to take invisible facts into account than 'island' business diff -r 43fd82a537a3 -r ed95456499e6 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 02:47:23 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 09:38:39 2014 +0200 @@ -108,12 +108,6 @@ val relearn_isarN = "relearn_isar" val relearn_proverN = "relearn_prover" -val learned_proof_prefix = ".." - -fun learned_proof_name () = - learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ - serial_string () - fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state") @@ -382,68 +376,62 @@ end (* - avail_no = maximum number of theorems to check dependencies and symbols + avail_num = maximum number of theorems to check dependencies and symbols + adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num get_deps = returns dependencies of a theorem get_sym_ths = get theorems that have this feature - knns = number of nearest neighbours - advno = number of predictions to return - syms = symbols of the conjecture + knns = number of nearest neighbours + advno = number of predictions to return + syms = symbols of the conjecture *) -fun knn avail_no get_deps get_sym_ths knns advno syms = +fun knn avail_num adv_max get_deps get_sym_ths knns advno syms = let (* Can be later used for TFIDF *) - fun sym_wght _ = 1.0 - - val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0))) - + fun sym_wght _ = 1.0; + val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0))); fun inc_overlap j v = let val ov = snd (Array.sub (overlaps_sqr,j)) in Array.update (overlaps_sqr, j, (j, v + ov)) - end - + end; fun do_sym (s, con_wght) = let - val sw = sym_wght s - val w2 = sw * sw * con_wght - fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else () + val sw = sym_wght s; + val w2 = sw * sw * con_wght; + fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else () in ignore (map do_th (get_sym_ths s)) - end - - val _ = ignore (map do_sym syms) - val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr - val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0))) - + end; + val () = ignore (map do_sym syms); + val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr; + val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0))); fun inc_recommend j v = + if j >= adv_max then () else let val ov = snd (Array.sub (recommends,j)) in Array.update (recommends, j, (j, v + ov)) - end - + end; fun for k = if k = knns then () else - if k >= avail_no then () else + if k >= adv_max then () else let - val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1) - val o1 = Math.sqrt o2 - val _ = inc_recommend j o1 - val ds = get_deps j - val l = Real.fromInt (length ds) + val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1); + val o1 = Math.sqrt o2; + val () = inc_recommend j o1; + val ds = get_deps j; + val l = Real.fromInt (length ds); val _ = map (fn d => inc_recommend d (o1 / l)) ds in for (k + 1) - end - - val _ = for 0 - val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends - + end; + val () = for 0; + val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends; fun ret acc at = if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) in - ret [] (max 0 (avail_no - advno)) + ret [] (max 0 (adv_max - advno)) end val knns = 40 (* FUDGE *) @@ -456,11 +444,18 @@ let val str_of_feat = space_implode "|" - val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) = - fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => + val visible_facts = Graph.all_preds access_G parents + val visible_fact_set = Symtab.make_set visible_facts + + val all_nodes = + Graph.schedule (K I) access_G + |> List.partition (Symtab.defined visible_fact_set o fst) + |> op @ + + val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) = + fold (fn (fact, (_, feats, deps)) => + fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => let - val (_, feats, deps) = Graph.get_node access_G fact - fun add_feat (feat, weight) (xtab as (n, tab, _)) = (case Symtab.lookup tab feat of SOME i => ((i, weight), xtab) @@ -471,12 +466,12 @@ (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, add_to_xtab fact fact_xtab, feat_xtab') end) - (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) + all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) - val facts = rev facts0 + val facts = rev rev_facts val fact_ary = Array.fromList facts - val deps_ary = Array.fromList (rev depss0) + val deps_ary = Array.fromList (rev rev_depss) val facts_ary = Array.array (num_feats, []) val _ = fold (fn feats => fn fact => @@ -487,15 +482,13 @@ end) featss (length featss) in - trace_msg ctxt (fn () => - "MaSh_SML query " ^ encode_features feats ^ " from {" ^ - elide_string 1000 (space_implode " " facts) ^ "}"); - knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns - max_suggs + trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ + elide_string 1000 (space_implode " " facts) ^ "}"); + knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary) + (curry Array.sub facts_ary) knns max_suggs (map_filter (fn (feat, weight) => Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) |> map ((fn i => Array.sub (fact_ary, i)) o fst) - |> filter_out (String.isPrefix learned_proof_prefix) end end; @@ -517,9 +510,10 @@ Graph.default_node (parent, (Isar_Proof, [], [])) #> Graph.add_edge (parent, name) -fun add_node kind name feats deps = +fun add_node kind name parents feats deps = Graph.default_node (name, (kind, feats, deps)) #> Graph.map_node name (K (kind, feats, deps)) + #> fold (add_edge_to name) parents fun try_graph ctxt when def f = f () @@ -575,9 +569,7 @@ fun extract_line_and_add_node line = (case extract_node line of NONE => I (* should not happen *) - | SOME (kind, name, parents, feats, deps) => - add_node kind name feats deps - #> fold (add_edge_to name) parents) + | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps) val (access_G, num_known_facts) = (case string_ord (version', version) of @@ -1132,13 +1124,8 @@ find_maxes Symtab.empty ([], Graph.maximals G) end -fun graph_islands G = - Graph.fold (fn (m, (_, (preds, succs))) => - (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G []; - -(* islands represent learned proofs associated with no facts *) fun maximal_wrt_access_graph access_G facts = - map (nickname_of_thm o snd) facts @ graph_islands access_G + map (nickname_of_thm o snd) facts |> maximal_wrt_graph access_G fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm @@ -1240,7 +1227,7 @@ fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = let fun maybe_learn_from from (accum as (parents, G)) = - try_graph ctxt "updating G" accum (fn () => + try_graph ctxt "updating graph" accum (fn () => (from :: parents, Graph.add_edge_acyclic (from, name) G)) val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps)) val (parents, G) = ([], G) |> fold maybe_learn_from parents @@ -1275,6 +1262,9 @@ Async_Manager.thread MaShN birth_time death_time desc task end +fun learned_proof_name () = + Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string () + fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths = if is_mash_enabled () then launch_thread timeout (fn () => @@ -1285,19 +1275,18 @@ map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} => let val name = learned_proof_name () + val parents = maximal_wrt_access_graph access_G facts val deps = used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in if Config.get ctxt sml then - let val access_G = access_G |> add_node Automatic_Proof name feats deps in + let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in {access_G = access_G, num_known_facts = num_known_facts + 1, dirty = Option.map (cons name) dirty} end else - let val parents = maximal_wrt_access_graph access_G facts in - (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) - end + (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) end); (true, "") end)