# HG changeset patch # User blanchet # Date 1376991772 -7200 # Node ID 667717a5ad806b0b5140aa7d956e5d3098021339 # Parent e33d77814a9212125c17f0e38990f876f6b0e0bd learn MaSh facts on the fly diff -r e33d77814a92 -r 667717a5ad80 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Aug 20 11:42:51 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Aug 20 11:42:52 2013 +0200 @@ -43,7 +43,7 @@ val relearn : Proof.context -> bool -> (string * string list) list -> unit val query : - Proof.context -> bool -> bool -> int + Proof.context -> bool -> int -> (string * string list * (string * real) list * string list) list * string list * string list * (string * real) list -> string list @@ -71,6 +71,8 @@ Proof.context -> params -> string -> int -> raw_fact list -> string Symtab.table * string Symtab.table -> thm -> bool * string list + val attach_parents_to_facts : + ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list val weight_mepo_facts : 'a list -> ('a * real) list val weight_mash_facts : 'a list -> ('a * real) list val find_mash_suggestions : @@ -82,8 +84,6 @@ val mash_learn_proof : Proof.context -> params -> string -> term -> ('a * thm) list -> thm list -> unit - val attach_parents_to_facts : - ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit val is_mash_enabled : unit -> bool @@ -227,16 +227,10 @@ fun str_of_relearn (name, deps) = "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n" -fun str_of_query learn (learns, hints, parents, feats) = - (if learn then - implode (map str_of_learn learns) ^ - (if null hints then "" - else str_of_learn (freshish_name (), parents, feats, hints)) - else - "") ^ +fun str_of_query (learns, hints, parents, feats) = + implode (map str_of_learn learns) ^ "? " ^ encode_strs parents ^ "; " ^ encode_features feats ^ - (if learn orelse null hints then "" else "; " ^ encode_strs hints) ^ - "\n" + (if null hints then "" else "; " ^ encode_strs hints) ^ "\n" (* The weights currently returned by "mash.py" are too spaced out to make any sense. *) @@ -277,11 +271,9 @@ elide_string 1000 (space_implode " " (map #1 relearns))); run_mash_tool ctxt overlord true 0 (relearns, str_of_relearn) (K ())) -fun query ctxt overlord learn max_suggs (query as (learns, hints, _, feats)) = +fun query ctxt overlord max_suggs (query as (learns, hints, _, feats)) = (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats); - run_mash_tool ctxt overlord - (learn andalso not (null learns) andalso not (null hints)) - max_suggs ([query], str_of_query learn) + run_mash_tool ctxt overlord false max_suggs ([query], str_of_query) (fn suggs => case suggs () of [] => [] @@ -335,15 +327,18 @@ string_of_int (length (Graph.minimals G)) ^ " minimal, " ^ string_of_int (length (Graph.maximals G)) ^ " maximal" -type mash_state = {access_G : unit Graph.T, dirty : string list option} +type mash_state = + {access_G : unit Graph.T, + num_known_facts : int, + dirty : string list option} -val empty_state = {access_G = Graph.empty, dirty = SOME []} +val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []} local -val version = "*** MaSh version 20130819a ***" +val version = "*** MaSh version 20130820 ***" -exception Too_New of unit +exception FILE_VERSION_TOO_NEW of unit fun extract_node line = case space_explode ":" line of @@ -371,24 +366,27 @@ | SOME (name, parents, kind) => update_access_graph_node (name, kind) #> fold (add_edge_to name) parents - val access_G = + val (access_G, num_known_facts) = case string_ord (version', version) of EQUAL => - try_graph ctxt "loading state" Graph.empty (fn () => - fold add_node node_lines Graph.empty) + (try_graph ctxt "loading state" Graph.empty (fn () => + fold add_node node_lines Graph.empty), + length node_lines) | LESS => - (MaSh.unlearn ctxt; Graph.empty) (* can't parse old file *) - | GREATER => raise Too_New () + (* can't parse old file *) + (MaSh.unlearn ctxt; (Graph.empty, 0)) + | GREATER => raise FILE_VERSION_TOO_NEW () in trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")"); - {access_G = access_G, dirty = SOME []} + {access_G = access_G, num_known_facts = num_known_facts, + dirty = SOME []} end | _ => empty_state) end fun save _ (state as {dirty = SOME [], ...}) = state - | save ctxt {access_G, dirty} = + | save ctxt {access_G, num_known_facts, dirty} = let fun str_of_entry (name, parents, kind) = str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ @@ -408,7 +406,7 @@ SOME dirty => "; " ^ string_of_int (length dirty) ^ " dirty fact(s)" | _ => "") ^ ")"); - {access_G = access_G, dirty = SOME []} + {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []} end val global_state = @@ -418,7 +416,7 @@ fun map_state ctxt f = Synchronized.change global_state (load ctxt ##> (f #> save ctxt)) - handle Too_New () => () + handle FILE_VERSION_TOO_NEW () => () fun peek_state ctxt f = Synchronized.change_result global_state @@ -723,6 +721,9 @@ | NONE => false) | is_size_def _ _ = false +fun no_dependencies_for_status status = + status = Non_Rec_Def orelse status = Rec_Def + fun trim_dependencies deps = if length deps > max_dependencies then NONE else SOME deps @@ -790,159 +791,9 @@ (*** High-level communication with MaSh ***) -fun maximal_wrt_graph G keys = - let - val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys - fun insert_new seen name = - not (Symtab.defined seen name) ? insert (op =) name - fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0 - fun find_maxes _ (maxs, []) = map snd maxs - | find_maxes seen (maxs, new :: news) = - find_maxes - (seen |> num_keys (Graph.imm_succs G new) > 1 - ? Symtab.default (new, ())) - (if Symtab.defined tab new then - let - val newp = Graph.all_preds G [new] - fun is_ancestor x yp = member (op =) yp x - val maxs = - maxs |> filter (fn (_, max) => not (is_ancestor max newp)) - in - if exists (is_ancestor new o fst) maxs then - (maxs, news) - else - ((newp, new) - :: filter_out (fn (_, max) => is_ancestor max newp) maxs, - news) - end - else - (maxs, Graph.Keys.fold (insert_new seen) - (Graph.imm_preds G new) news)) - in find_maxes Symtab.empty ([], Graph.maximals G) end - -fun maximal_wrt_access_graph access_G = - map (nickname_of_thm o snd) - #> maximal_wrt_graph access_G - -fun is_fact_in_graph access_G get_th fact = - can (Graph.get_node access_G) (nickname_of_thm (get_th fact)) - -(* FUDGE *) -fun weight_of_mepo_fact rank = - Math.pow (0.62, log2 (Real.fromInt (rank + 1))) - -fun weight_mepo_facts facts = - facts ~~ map weight_of_mepo_fact (0 upto length facts - 1) - -val weight_raw_mash_facts = weight_mepo_facts -val weight_mash_facts = weight_raw_mash_facts - -(* FUDGE *) -fun weight_of_proximity_fact rank = - Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 - -fun weight_proximity_facts facts = - facts ~~ map weight_of_proximity_fact (0 upto length facts - 1) - -val max_proximity_facts = 100 - -fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown) - | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown = - let - val raw_mash = find_suggested_facts ctxt facts suggs - val unknown_chained = - inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown - val proximity = - facts |> sort (crude_thm_ord o pairself snd o swap) - |> take max_proximity_facts - val mess = - [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])), - (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)), - (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))] - val unknown = - raw_unknown - |> fold (subtract (Thm.eq_thm_prop o pairself snd)) - [unknown_chained, proximity] - in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end - -fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts - hyp_ts concl_t facts = - let - val thy = Proof_Context.theory_of ctxt - val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) - val (access_G, suggs) = - peek_state ctxt (fn {access_G, ...} => - if Graph.is_empty access_G then - (access_G, []) - else - let - val parents = maximal_wrt_access_graph access_G facts - val feats = - features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) - val hints = - chained |> filter (is_fact_in_graph access_G snd) - |> map (nickname_of_thm o snd) - in - (access_G, MaSh.query ctxt overlord learn max_facts - ([], hints, parents, feats)) - end) - val unknown = facts |> filter_out (is_fact_in_graph access_G snd) - in - find_mash_suggestions ctxt max_facts suggs facts chained unknown - |> pairself (map fact_of_raw_fact) - end - -fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) = - let - fun maybe_learn_from from (accum as (parents, graph)) = - try_graph ctxt "updating graph" accum (fn () => - (from :: parents, Graph.add_edge_acyclic (from, name) graph)) - val graph = graph |> Graph.default_node (name, Isar_Proof) - val (parents, graph) = ([], graph) |> fold maybe_learn_from parents - val (deps, _) = ([], graph) |> fold maybe_learn_from deps - in ((name, parents, feats, deps) :: learns, graph) end - -fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) = - let - fun maybe_relearn_from from (accum as (parents, graph)) = - try_graph ctxt "updating graph" accum (fn () => - (from :: parents, Graph.add_edge_acyclic (from, name) graph)) - val graph = graph |> update_access_graph_node (name, Automatic_Proof) - val (deps, _) = ([], graph) |> fold maybe_relearn_from deps - in ((name, deps) :: relearns, graph) end - -fun flop_wrt_access_graph name = - update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop) - -val learn_timeout_slack = 2.0 - -fun launch_thread timeout task = - let - val hard_timeout = time_mult learn_timeout_slack timeout - val birth_time = Time.now () - val death_time = Time.+ (birth_time, hard_timeout) - val desc = ("Machine learner for Sledgehammer", "") - in Async_Manager.thread MaShN birth_time death_time desc task end - -fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts - used_ths = - launch_thread (timeout |> the_default one_day) (fn () => - let - val thy = Proof_Context.theory_of ctxt - val name = freshish_name () - val feats = features_of ctxt prover thy (Local, General) [t] - in - peek_state ctxt (fn {access_G, ...} => - let - val parents = maximal_wrt_access_graph access_G facts - val deps = - used_ths |> filter (is_fact_in_graph access_G I) - |> map nickname_of_thm - in - MaSh.learn ctxt overlord [(name, parents, feats, deps)] - end); - (true, "") - end) +fun attach_crude_parents_to_facts _ [] = [] + | attach_crude_parents_to_facts parents ((fact as (_, th)) :: facts) = + (parents, fact) :: attach_crude_parents_to_facts [nickname_of_thm th] facts (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *) @@ -994,6 +845,191 @@ |> drop (length old_facts) end +fun maximal_wrt_graph G keys = + let + val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys + fun insert_new seen name = + not (Symtab.defined seen name) ? insert (op =) name + fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0 + fun find_maxes _ (maxs, []) = map snd maxs + | find_maxes seen (maxs, new :: news) = + find_maxes + (seen |> num_keys (Graph.imm_succs G new) > 1 + ? Symtab.default (new, ())) + (if Symtab.defined tab new then + let + val newp = Graph.all_preds G [new] + fun is_ancestor x yp = member (op =) yp x + val maxs = + maxs |> filter (fn (_, max) => not (is_ancestor max newp)) + in + if exists (is_ancestor new o fst) maxs then + (maxs, news) + else + ((newp, new) + :: filter_out (fn (_, max) => is_ancestor max newp) maxs, + news) + end + else + (maxs, Graph.Keys.fold (insert_new seen) + (Graph.imm_preds G new) news)) + in find_maxes Symtab.empty ([], Graph.maximals G) end + +fun maximal_wrt_access_graph access_G = + map (nickname_of_thm o snd) + #> maximal_wrt_graph access_G + +fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm + +(* FUDGE *) +fun weight_of_mepo_fact rank = + Math.pow (0.62, log2 (Real.fromInt (rank + 1))) + +fun weight_mepo_facts facts = + facts ~~ map weight_of_mepo_fact (0 upto length facts - 1) + +val weight_raw_mash_facts = weight_mepo_facts +val weight_mash_facts = weight_raw_mash_facts + +(* FUDGE *) +fun weight_of_proximity_fact rank = + Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 + +fun weight_proximity_facts facts = + facts ~~ map weight_of_proximity_fact (0 upto length facts - 1) + +val max_proximity_facts = 100 + +fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown) + | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown = + let + val raw_mash = find_suggested_facts ctxt facts suggs + val unknown_chained = + inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown + val proximity = + facts |> sort (crude_thm_ord o pairself snd o swap) + |> take max_proximity_facts + val mess = + [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])), + (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)), + (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))] + val unknown = + raw_unknown + |> fold (subtract (Thm.eq_thm_prop o pairself snd)) + [unknown_chained, proximity] + in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end + +val max_learn_on_query = 500 + +fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts + hyp_ts concl_t facts = + let + val thy = Proof_Context.theory_of ctxt + val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) + val (access_G, suggs) = + peek_state ctxt (fn {access_G, num_known_facts, ...} => + if Graph.is_empty access_G then + (access_G, []) + else + let + val parents = maximal_wrt_access_graph access_G facts + val feats = + features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) + val hints = + chained |> filter (is_fact_in_graph access_G o snd) + |> map (nickname_of_thm o snd) + val (learns, parents) = + if length facts - num_known_facts <= max_learn_on_query then + let + val name_tabs = build_name_tables nickname_of_thm facts + fun deps_of status th = + if no_dependencies_for_status status then + SOME [] + else + isar_dependencies_of name_tabs th + |> trim_dependencies + fun learn_new_fact (parents, + ((_, stature as (_, status)), th)) = + let + val name = nickname_of_thm th + val feats = + features_of ctxt prover (theory_of_thm th) stature + [prop_of th] + val deps = deps_of status th |> these + in (name, parents, feats, deps) end + val new_facts = + facts |> filter_out (is_fact_in_graph access_G o snd) + |> sort (crude_thm_ord o pairself snd) + |> attach_crude_parents_to_facts parents + val learns = new_facts |> map learn_new_fact + val parents = + if null new_facts then parents + else [#1 (List.last learns)] + in (learns, parents) end + else + ([], parents) + in + (access_G, MaSh.query ctxt overlord max_facts + (learns, hints, parents, feats)) + end) + val unknown = facts |> filter_out (is_fact_in_graph access_G o snd) + in + find_mash_suggestions ctxt max_facts suggs facts chained unknown + |> pairself (map fact_of_raw_fact) + end + +fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) = + let + fun maybe_learn_from from (accum as (parents, graph)) = + try_graph ctxt "updating graph" accum (fn () => + (from :: parents, Graph.add_edge_acyclic (from, name) graph)) + val graph = graph |> Graph.default_node (name, Isar_Proof) + val (parents, graph) = ([], graph) |> fold maybe_learn_from parents + val (deps, _) = ([], graph) |> fold maybe_learn_from deps + in ((name, parents, feats, deps) :: learns, graph) end + +fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) = + let + fun maybe_relearn_from from (accum as (parents, graph)) = + try_graph ctxt "updating graph" accum (fn () => + (from :: parents, Graph.add_edge_acyclic (from, name) graph)) + val graph = graph |> update_access_graph_node (name, Automatic_Proof) + val (deps, _) = ([], graph) |> fold maybe_relearn_from deps + in ((name, deps) :: relearns, graph) end + +fun flop_wrt_access_graph name = + update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop) + +val learn_timeout_slack = 2.0 + +fun launch_thread timeout task = + let + val hard_timeout = time_mult learn_timeout_slack timeout + val birth_time = Time.now () + val death_time = Time.+ (birth_time, hard_timeout) + val desc = ("Machine learner for Sledgehammer", "") + in Async_Manager.thread MaShN birth_time death_time desc task end + +fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts + used_ths = + launch_thread (timeout |> the_default one_day) (fn () => + let + val thy = Proof_Context.theory_of ctxt + val name = freshish_name () + val feats = features_of ctxt prover thy (Local, General) [t] + in + peek_state ctxt (fn {access_G, ...} => + let + val parents = maximal_wrt_access_graph access_G facts + val deps = + used_ths |> filter (is_fact_in_graph access_G) + |> map nickname_of_thm + in + MaSh.learn ctxt overlord [(name, parents, feats, deps)] + end); + (true, "") + end) + fun sendback sub = Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub) @@ -1007,7 +1043,7 @@ fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout) val {access_G, ...} = peek_state ctxt I - val is_in_access_G = is_fact_in_graph access_G snd + val is_in_access_G = is_fact_in_graph access_G o snd val no_new_facts = forall is_in_access_G facts in if no_new_facts andalso not run_prover then @@ -1025,7 +1061,7 @@ let val name_tabs = build_name_tables nickname_of_thm facts fun deps_of status th = - if status = Non_Rec_Def orelse status = Rec_Def then + if no_dependencies_for_status status then SOME [] else if run_prover then prover_dependencies_of ctxt params prover auto_level facts name_tabs @@ -1036,7 +1072,7 @@ isar_dependencies_of name_tabs th |> trim_dependencies fun do_commit [] [] [] state = state - | do_commit learns relearns flops {access_G, dirty} = + | do_commit learns relearns flops {access_G, num_known_facts, dirty} = let val was_empty = Graph.is_empty access_G val (learns, access_G) = @@ -1044,6 +1080,7 @@ val (relearns, access_G) = ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns val access_G = access_G |> fold flop_wrt_access_graph flops + val num_known_facts = num_known_facts + length learns val dirty = case (was_empty, dirty, relearns) of (false, SOME names, []) => SOME (map #1 learns @ names) @@ -1051,7 +1088,8 @@ in MaSh.learn ctxt overlord (rev learns); MaSh.relearn ctxt overlord relearns; - {access_G = access_G, dirty = dirty} + {access_G = access_G, num_known_facts = num_known_facts, + dirty = dirty} end fun commit last learns relearns flops = (if debug andalso auto_level = 0 then