# HG changeset patch # User blanchet # Date 1400546843 -7200 # Node ID 43fd82a537a3f36b974fa419654a3cfb56350de1 # Parent a4428f517f468a7a2b0e171c464421d1e978bb2e cleaner handling of learned proofs diff -r a4428f517f46 -r 43fd82a537a3 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 00:13:31 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 02:47:23 2014 +0200 @@ -108,6 +108,12 @@ val relearn_isarN = "relearn_isar" val relearn_proverN = "relearn_prover" +val learned_proof_prefix = ".." + +fun learned_proof_name () = + learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ + serial_string () + fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state") @@ -297,8 +303,8 @@ if i31 + 2 < l then let val x = Unsynchronized.ref i31; - val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else (); - val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else () + val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else (); + val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else () in !x end @@ -312,7 +318,7 @@ val j = maxson l i in if cmp (Array.sub (a, j), e) = GREATER then - let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end + let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end else Array.update (a, i, e) end @@ -321,7 +327,7 @@ fun bubbledown l i = let val j = maxson l i - val () = Array.update (a, i, Array.sub (a, j)) + val _ = Array.update (a, i, Array.sub (a, j)) in bubbledown l j end @@ -334,7 +340,7 @@ in if cmp (Array.sub (a, father), e) = LESS then let - val () = Array.update (a, i, Array.sub (a, father)) + val _ = Array.update (a, i, Array.sub (a, father)) in if father > 0 then trickleup father e else Array.update (a, 0, e) end @@ -351,24 +357,24 @@ for (i - 1) end - val () = for (((l + 1) div 3) - 1) + val _ = for (((l + 1) div 3) - 1) fun for2 i = if i < max 2 (l - bnd) then () else let val e = Array.sub (a, i) - val () = Array.update (a, i, Array.sub (a, 0)) - val () = trickleup (bubble i 0) e + val _ = Array.update (a, i, Array.sub (a, 0)) + val _ = trickleup (bubble i 0) e in for2 (i - 1) end - val () = for2 (l - 1) + val _ = for2 (l - 1) in if l > 1 then let val e = Array.sub (a, 1) - val () = Array.update (a, 1, Array.sub (a, 0)) + val _ = Array.update (a, 1, Array.sub (a, 0)) in Array.update (a, 0, e) end @@ -386,46 +392,54 @@ fun knn avail_no get_deps get_sym_ths knns advno syms = let (* Can be later used for TFIDF *) - fun sym_wght _ = 1.0; - val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0))); + fun sym_wght _ = 1.0 + + val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0))) + fun inc_overlap j v = let val ov = snd (Array.sub (overlaps_sqr,j)) in Array.update (overlaps_sqr, j, (j, v + ov)) - end; + end + fun do_sym (s, con_wght) = let - val sw = sym_wght s; - val w2 = sw * sw * con_wght; + val sw = sym_wght s + val w2 = sw * sw * con_wght fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else () in ignore (map do_th (get_sym_ths s)) - end; - val () = ignore (map do_sym syms); - val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr; - val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0))); + end + + val _ = ignore (map do_sym syms) + val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr + val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0))) + fun inc_recommend j v = let val ov = snd (Array.sub (recommends,j)) in Array.update (recommends, j, (j, v + ov)) - end; + end + fun for k = if k = knns then () else if k >= avail_no then () else let - val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1); - val o1 = Math.sqrt o2; - val () = inc_recommend j o1; - val ds = get_deps j; - val l = Real.fromInt (length ds); + val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1) + val o1 = Math.sqrt o2 + val _ = inc_recommend j o1 + val ds = get_deps j + val l = Real.fromInt (length ds) val _ = map (fn d => inc_recommend d (o1 / l)) ds in for (k + 1) - end; - val () = for 0; - val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends; + end + + val _ = for 0 + val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends + fun ret acc at = if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) in @@ -473,14 +487,15 @@ end) featss (length featss) in - (trace_msg ctxt (fn () => - "MaSh_SML query " ^ encode_features feats ^ " from {" ^ - elide_string 1000 (space_implode " " facts) ^ "}"); - knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns - max_suggs - (map_filter (fn (feat, weight) => - Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) - |> map ((fn i => Array.sub (fact_ary, i)) o fst)) + trace_msg ctxt (fn () => + "MaSh_SML query " ^ encode_features feats ^ " from {" ^ + elide_string 1000 (space_implode " " facts) ^ "}"); + knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns + max_suggs + (map_filter (fn (feat, weight) => + Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) + |> map ((fn i => Array.sub (fact_ary, i)) o fst) + |> filter_out (String.isPrefix learned_proof_prefix) end end; @@ -502,10 +517,9 @@ Graph.default_node (parent, (Isar_Proof, [], [])) #> Graph.add_edge (parent, name) -fun add_node kind name parents feats deps = +fun add_node kind name feats deps = Graph.default_node (name, (kind, feats, deps)) #> Graph.map_node name (K (kind, feats, deps)) - #> fold (add_edge_to name) parents; fun try_graph ctxt when def f = f () @@ -526,7 +540,6 @@ fun graph_info G = string_of_int (length (Graph.keys G)) ^ " node(s), " ^ string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^ " edge(s), " ^ - string_of_int (length (Graph.minimals G)) ^ " minimal, " ^ string_of_int (length (Graph.maximals G)) ^ " maximal" type mash_state = @@ -562,7 +575,9 @@ fun extract_line_and_add_node line = (case extract_node line of NONE => I (* should not happen *) - | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps) + | SOME (kind, name, parents, feats, deps) => + add_node kind name feats deps + #> fold (add_edge_to name) parents) val (access_G, num_known_facts) = (case string_ord (version', version) of @@ -1095,40 +1110,36 @@ let val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys - fun insert_new seen name = - not (Symtab.defined seen name) ? insert (op =) name + fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0 fun find_maxes _ (maxs, []) = map snd maxs | find_maxes seen (maxs, new :: news) = - find_maxes - (seen |> num_keys (Graph.imm_succs G new) > 1 - ? Symtab.default (new, ())) - (if Symtab.defined tab new then - let - val newp = Graph.all_preds G [new] - fun is_ancestor x yp = member (op =) yp x - val maxs = - maxs |> filter (fn (_, max) => not (is_ancestor max newp)) - in - if exists (is_ancestor new o fst) maxs then - (maxs, news) - else - ((newp, new) - :: filter_out (fn (_, max) => is_ancestor max newp) maxs, - news) - end - else - (maxs, Graph.Keys.fold (insert_new seen) - (Graph.imm_preds G new) news)) + find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ())) + (if Symtab.defined tab new then + let + val newp = Graph.all_preds G [new] + fun is_ancestor x yp = member (op =) yp x + val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp)) + in + if exists (is_ancestor new o fst) maxs then (maxs, news) + else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news) + end + else + (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news)) in find_maxes Symtab.empty ([], Graph.maximals G) end -fun maximal_wrt_access_graph access_G = - map (nickname_of_thm o snd) - #> maximal_wrt_graph access_G +fun graph_islands G = + Graph.fold (fn (m, (_, (preds, succs))) => + (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G []; + +(* islands represent learned proofs associated with no facts *) +fun maximal_wrt_access_graph access_G facts = + map (nickname_of_thm o snd) facts @ graph_islands access_G + |> maximal_wrt_graph access_G fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm @@ -1264,9 +1275,6 @@ Async_Manager.thread MaShN birth_time death_time desc task end -fun fresh_enough_name () = - Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ serial_string () - fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths = if is_mash_enabled () then launch_thread timeout (fn () => @@ -1276,18 +1284,20 @@ in map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} => let - val name = fresh_enough_name () - val parents = maximal_wrt_access_graph access_G facts + val name = learned_proof_name () val deps = used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in if Config.get ctxt sml then - {access_G = add_node Automatic_Proof name parents feats deps access_G, - num_known_facts = num_known_facts + 1, - dirty = Option.map (cons name) dirty} + let val access_G = access_G |> add_node Automatic_Proof name feats deps in + {access_G = access_G, num_known_facts = num_known_facts + 1, + dirty = Option.map (cons name) dirty} + end else - (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) + let val parents = maximal_wrt_access_graph access_G facts in + (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) + end end); (true, "") end)