--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Aug 20 11:42:51 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Aug 20 11:42:52 2013 +0200
@@ -43,7 +43,7 @@
val relearn :
Proof.context -> bool -> (string * string list) list -> unit
val query :
- Proof.context -> bool -> bool -> int
+ Proof.context -> bool -> int
-> (string * string list * (string * real) list * string list) list
* string list * string list * (string * real) list
-> string list
@@ -71,6 +71,8 @@
Proof.context -> params -> string -> int -> raw_fact list
-> string Symtab.table * string Symtab.table -> thm
-> bool * string list
+ val attach_parents_to_facts :
+ ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
val weight_mepo_facts : 'a list -> ('a * real) list
val weight_mash_facts : 'a list -> ('a * real) list
val find_mash_suggestions :
@@ -82,8 +84,6 @@
val mash_learn_proof :
Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
-> unit
- val attach_parents_to_facts :
- ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
val mash_learn :
Proof.context -> params -> fact_override -> thm list -> bool -> unit
val is_mash_enabled : unit -> bool
@@ -227,16 +227,10 @@
fun str_of_relearn (name, deps) =
"p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
-fun str_of_query learn (learns, hints, parents, feats) =
- (if learn then
- implode (map str_of_learn learns) ^
- (if null hints then ""
- else str_of_learn (freshish_name (), parents, feats, hints))
- else
- "") ^
+fun str_of_query (learns, hints, parents, feats) =
+ implode (map str_of_learn learns) ^
"? " ^ encode_strs parents ^ "; " ^ encode_features feats ^
- (if learn orelse null hints then "" else "; " ^ encode_strs hints) ^
- "\n"
+ (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
(* The weights currently returned by "mash.py" are too spaced out to make any
sense. *)
@@ -277,11 +271,9 @@
elide_string 1000 (space_implode " " (map #1 relearns)));
run_mash_tool ctxt overlord true 0 (relearns, str_of_relearn) (K ()))
-fun query ctxt overlord learn max_suggs (query as (learns, hints, _, feats)) =
+fun query ctxt overlord max_suggs (query as (learns, hints, _, feats)) =
(trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
- run_mash_tool ctxt overlord
- (learn andalso not (null learns) andalso not (null hints))
- max_suggs ([query], str_of_query learn)
+ run_mash_tool ctxt overlord false max_suggs ([query], str_of_query)
(fn suggs =>
case suggs () of
[] => []
@@ -335,15 +327,18 @@
string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
string_of_int (length (Graph.maximals G)) ^ " maximal"
-type mash_state = {access_G : unit Graph.T, dirty : string list option}
+type mash_state =
+ {access_G : unit Graph.T,
+ num_known_facts : int,
+ dirty : string list option}
-val empty_state = {access_G = Graph.empty, dirty = SOME []}
+val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []}
local
-val version = "*** MaSh version 20130819a ***"
+val version = "*** MaSh version 20130820 ***"
-exception Too_New of unit
+exception FILE_VERSION_TOO_NEW of unit
fun extract_node line =
case space_explode ":" line of
@@ -371,24 +366,27 @@
| SOME (name, parents, kind) =>
update_access_graph_node (name, kind)
#> fold (add_edge_to name) parents
- val access_G =
+ val (access_G, num_known_facts) =
case string_ord (version', version) of
EQUAL =>
- try_graph ctxt "loading state" Graph.empty (fn () =>
- fold add_node node_lines Graph.empty)
+ (try_graph ctxt "loading state" Graph.empty (fn () =>
+ fold add_node node_lines Graph.empty),
+ length node_lines)
| LESS =>
- (MaSh.unlearn ctxt; Graph.empty) (* can't parse old file *)
- | GREATER => raise Too_New ()
+ (* can't parse old file *)
+ (MaSh.unlearn ctxt; (Graph.empty, 0))
+ | GREATER => raise FILE_VERSION_TOO_NEW ()
in
trace_msg ctxt (fn () =>
"Loaded fact graph (" ^ graph_info access_G ^ ")");
- {access_G = access_G, dirty = SOME []}
+ {access_G = access_G, num_known_facts = num_known_facts,
+ dirty = SOME []}
end
| _ => empty_state)
end
fun save _ (state as {dirty = SOME [], ...}) = state
- | save ctxt {access_G, dirty} =
+ | save ctxt {access_G, num_known_facts, dirty} =
let
fun str_of_entry (name, parents, kind) =
str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
@@ -408,7 +406,7 @@
SOME dirty =>
"; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
| _ => "") ^ ")");
- {access_G = access_G, dirty = SOME []}
+ {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
end
val global_state =
@@ -418,7 +416,7 @@
fun map_state ctxt f =
Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
- handle Too_New () => ()
+ handle FILE_VERSION_TOO_NEW () => ()
fun peek_state ctxt f =
Synchronized.change_result global_state
@@ -723,6 +721,9 @@
| NONE => false)
| is_size_def _ _ = false
+fun no_dependencies_for_status status =
+ status = Non_Rec_Def orelse status = Rec_Def
+
fun trim_dependencies deps =
if length deps > max_dependencies then NONE else SOME deps
@@ -790,159 +791,9 @@
(*** High-level communication with MaSh ***)
-fun maximal_wrt_graph G keys =
- let
- val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
- fun insert_new seen name =
- not (Symtab.defined seen name) ? insert (op =) name
- fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
- fun find_maxes _ (maxs, []) = map snd maxs
- | find_maxes seen (maxs, new :: news) =
- find_maxes
- (seen |> num_keys (Graph.imm_succs G new) > 1
- ? Symtab.default (new, ()))
- (if Symtab.defined tab new then
- let
- val newp = Graph.all_preds G [new]
- fun is_ancestor x yp = member (op =) yp x
- val maxs =
- maxs |> filter (fn (_, max) => not (is_ancestor max newp))
- in
- if exists (is_ancestor new o fst) maxs then
- (maxs, news)
- else
- ((newp, new)
- :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
- news)
- end
- else
- (maxs, Graph.Keys.fold (insert_new seen)
- (Graph.imm_preds G new) news))
- in find_maxes Symtab.empty ([], Graph.maximals G) end
-
-fun maximal_wrt_access_graph access_G =
- map (nickname_of_thm o snd)
- #> maximal_wrt_graph access_G
-
-fun is_fact_in_graph access_G get_th fact =
- can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
-
-(* FUDGE *)
-fun weight_of_mepo_fact rank =
- Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
-
-fun weight_mepo_facts facts =
- facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
-
-val weight_raw_mash_facts = weight_mepo_facts
-val weight_mash_facts = weight_raw_mash_facts
-
-(* FUDGE *)
-fun weight_of_proximity_fact rank =
- Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
-
-fun weight_proximity_facts facts =
- facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
-
-val max_proximity_facts = 100
-
-fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
- | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
- let
- val raw_mash = find_suggested_facts ctxt facts suggs
- val unknown_chained =
- inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
- val proximity =
- facts |> sort (crude_thm_ord o pairself snd o swap)
- |> take max_proximity_facts
- val mess =
- [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
- (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
- (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
- val unknown =
- raw_unknown
- |> fold (subtract (Thm.eq_thm_prop o pairself snd))
- [unknown_chained, proximity]
- in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
-
-fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
- hyp_ts concl_t facts =
- let
- val thy = Proof_Context.theory_of ctxt
- val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
- val (access_G, suggs) =
- peek_state ctxt (fn {access_G, ...} =>
- if Graph.is_empty access_G then
- (access_G, [])
- else
- let
- val parents = maximal_wrt_access_graph access_G facts
- val feats =
- features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
- val hints =
- chained |> filter (is_fact_in_graph access_G snd)
- |> map (nickname_of_thm o snd)
- in
- (access_G, MaSh.query ctxt overlord learn max_facts
- ([], hints, parents, feats))
- end)
- val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
- in
- find_mash_suggestions ctxt max_facts suggs facts chained unknown
- |> pairself (map fact_of_raw_fact)
- end
-
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
- let
- fun maybe_learn_from from (accum as (parents, graph)) =
- try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) graph))
- val graph = graph |> Graph.default_node (name, Isar_Proof)
- val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
- val (deps, _) = ([], graph) |> fold maybe_learn_from deps
- in ((name, parents, feats, deps) :: learns, graph) end
-
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
- let
- fun maybe_relearn_from from (accum as (parents, graph)) =
- try_graph ctxt "updating graph" accum (fn () =>
- (from :: parents, Graph.add_edge_acyclic (from, name) graph))
- val graph = graph |> update_access_graph_node (name, Automatic_Proof)
- val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
- in ((name, deps) :: relearns, graph) end
-
-fun flop_wrt_access_graph name =
- update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
-
-val learn_timeout_slack = 2.0
-
-fun launch_thread timeout task =
- let
- val hard_timeout = time_mult learn_timeout_slack timeout
- val birth_time = Time.now ()
- val death_time = Time.+ (birth_time, hard_timeout)
- val desc = ("Machine learner for Sledgehammer", "")
- in Async_Manager.thread MaShN birth_time death_time desc task end
-
-fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
- used_ths =
- launch_thread (timeout |> the_default one_day) (fn () =>
- let
- val thy = Proof_Context.theory_of ctxt
- val name = freshish_name ()
- val feats = features_of ctxt prover thy (Local, General) [t]
- in
- peek_state ctxt (fn {access_G, ...} =>
- let
- val parents = maximal_wrt_access_graph access_G facts
- val deps =
- used_ths |> filter (is_fact_in_graph access_G I)
- |> map nickname_of_thm
- in
- MaSh.learn ctxt overlord [(name, parents, feats, deps)]
- end);
- (true, "")
- end)
+fun attach_crude_parents_to_facts _ [] = []
+ | attach_crude_parents_to_facts parents ((fact as (_, th)) :: facts) =
+ (parents, fact) :: attach_crude_parents_to_facts [nickname_of_thm th] facts
(* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
@@ -994,6 +845,191 @@
|> drop (length old_facts)
end
+fun maximal_wrt_graph G keys =
+ let
+ val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
+ fun insert_new seen name =
+ not (Symtab.defined seen name) ? insert (op =) name
+ fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
+ fun find_maxes _ (maxs, []) = map snd maxs
+ | find_maxes seen (maxs, new :: news) =
+ find_maxes
+ (seen |> num_keys (Graph.imm_succs G new) > 1
+ ? Symtab.default (new, ()))
+ (if Symtab.defined tab new then
+ let
+ val newp = Graph.all_preds G [new]
+ fun is_ancestor x yp = member (op =) yp x
+ val maxs =
+ maxs |> filter (fn (_, max) => not (is_ancestor max newp))
+ in
+ if exists (is_ancestor new o fst) maxs then
+ (maxs, news)
+ else
+ ((newp, new)
+ :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
+ news)
+ end
+ else
+ (maxs, Graph.Keys.fold (insert_new seen)
+ (Graph.imm_preds G new) news))
+ in find_maxes Symtab.empty ([], Graph.maximals G) end
+
+fun maximal_wrt_access_graph access_G =
+ map (nickname_of_thm o snd)
+ #> maximal_wrt_graph access_G
+
+fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
+
+(* FUDGE *)
+fun weight_of_mepo_fact rank =
+ Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
+
+fun weight_mepo_facts facts =
+ facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
+
+val weight_raw_mash_facts = weight_mepo_facts
+val weight_mash_facts = weight_raw_mash_facts
+
+(* FUDGE *)
+fun weight_of_proximity_fact rank =
+ Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
+
+fun weight_proximity_facts facts =
+ facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
+
+val max_proximity_facts = 100
+
+fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
+ | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
+ let
+ val raw_mash = find_suggested_facts ctxt facts suggs
+ val unknown_chained =
+ inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
+ val proximity =
+ facts |> sort (crude_thm_ord o pairself snd o swap)
+ |> take max_proximity_facts
+ val mess =
+ [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
+ (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
+ (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
+ val unknown =
+ raw_unknown
+ |> fold (subtract (Thm.eq_thm_prop o pairself snd))
+ [unknown_chained, proximity]
+ in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
+
+val max_learn_on_query = 500
+
+fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
+ hyp_ts concl_t facts =
+ let
+ val thy = Proof_Context.theory_of ctxt
+ val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
+ val (access_G, suggs) =
+ peek_state ctxt (fn {access_G, num_known_facts, ...} =>
+ if Graph.is_empty access_G then
+ (access_G, [])
+ else
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val feats =
+ features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
+ val hints =
+ chained |> filter (is_fact_in_graph access_G o snd)
+ |> map (nickname_of_thm o snd)
+ val (learns, parents) =
+ if length facts - num_known_facts <= max_learn_on_query then
+ let
+ val name_tabs = build_name_tables nickname_of_thm facts
+ fun deps_of status th =
+ if no_dependencies_for_status status then
+ SOME []
+ else
+ isar_dependencies_of name_tabs th
+ |> trim_dependencies
+ fun learn_new_fact (parents,
+ ((_, stature as (_, status)), th)) =
+ let
+ val name = nickname_of_thm th
+ val feats =
+ features_of ctxt prover (theory_of_thm th) stature
+ [prop_of th]
+ val deps = deps_of status th |> these
+ in (name, parents, feats, deps) end
+ val new_facts =
+ facts |> filter_out (is_fact_in_graph access_G o snd)
+ |> sort (crude_thm_ord o pairself snd)
+ |> attach_crude_parents_to_facts parents
+ val learns = new_facts |> map learn_new_fact
+ val parents =
+ if null new_facts then parents
+ else [#1 (List.last learns)]
+ in (learns, parents) end
+ else
+ ([], parents)
+ in
+ (access_G, MaSh.query ctxt overlord max_facts
+ (learns, hints, parents, feats))
+ end)
+ val unknown = facts |> filter_out (is_fact_in_graph access_G o snd)
+ in
+ find_mash_suggestions ctxt max_facts suggs facts chained unknown
+ |> pairself (map fact_of_raw_fact)
+ end
+
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
+ let
+ fun maybe_learn_from from (accum as (parents, graph)) =
+ try_graph ctxt "updating graph" accum (fn () =>
+ (from :: parents, Graph.add_edge_acyclic (from, name) graph))
+ val graph = graph |> Graph.default_node (name, Isar_Proof)
+ val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
+ val (deps, _) = ([], graph) |> fold maybe_learn_from deps
+ in ((name, parents, feats, deps) :: learns, graph) end
+
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
+ let
+ fun maybe_relearn_from from (accum as (parents, graph)) =
+ try_graph ctxt "updating graph" accum (fn () =>
+ (from :: parents, Graph.add_edge_acyclic (from, name) graph))
+ val graph = graph |> update_access_graph_node (name, Automatic_Proof)
+ val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
+ in ((name, deps) :: relearns, graph) end
+
+fun flop_wrt_access_graph name =
+ update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
+
+val learn_timeout_slack = 2.0
+
+fun launch_thread timeout task =
+ let
+ val hard_timeout = time_mult learn_timeout_slack timeout
+ val birth_time = Time.now ()
+ val death_time = Time.+ (birth_time, hard_timeout)
+ val desc = ("Machine learner for Sledgehammer", "")
+ in Async_Manager.thread MaShN birth_time death_time desc task end
+
+fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
+ used_ths =
+ launch_thread (timeout |> the_default one_day) (fn () =>
+ let
+ val thy = Proof_Context.theory_of ctxt
+ val name = freshish_name ()
+ val feats = features_of ctxt prover thy (Local, General) [t]
+ in
+ peek_state ctxt (fn {access_G, ...} =>
+ let
+ val parents = maximal_wrt_access_graph access_G facts
+ val deps =
+ used_ths |> filter (is_fact_in_graph access_G)
+ |> map nickname_of_thm
+ in
+ MaSh.learn ctxt overlord [(name, parents, feats, deps)]
+ end);
+ (true, "")
+ end)
+
fun sendback sub =
Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
@@ -1007,7 +1043,7 @@
fun next_commit_time () =
Time.+ (Timer.checkRealTimer timer, commit_timeout)
val {access_G, ...} = peek_state ctxt I
- val is_in_access_G = is_fact_in_graph access_G snd
+ val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
in
if no_new_facts andalso not run_prover then
@@ -1025,7 +1061,7 @@
let
val name_tabs = build_name_tables nickname_of_thm facts
fun deps_of status th =
- if status = Non_Rec_Def orelse status = Rec_Def then
+ if no_dependencies_for_status status then
SOME []
else if run_prover then
prover_dependencies_of ctxt params prover auto_level facts name_tabs
@@ -1036,7 +1072,7 @@
isar_dependencies_of name_tabs th
|> trim_dependencies
fun do_commit [] [] [] state = state
- | do_commit learns relearns flops {access_G, dirty} =
+ | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
let
val was_empty = Graph.is_empty access_G
val (learns, access_G) =
@@ -1044,6 +1080,7 @@
val (relearns, access_G) =
([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
val access_G = access_G |> fold flop_wrt_access_graph flops
+ val num_known_facts = num_known_facts + length learns
val dirty =
case (was_empty, dirty, relearns) of
(false, SOME names, []) => SOME (map #1 learns @ names)
@@ -1051,7 +1088,8 @@
in
MaSh.learn ctxt overlord (rev learns);
MaSh.relearn ctxt overlord relearns;
- {access_G = access_G, dirty = dirty}
+ {access_G = access_G, num_known_facts = num_known_facts,
+ dirty = dirty}
end
fun commit last learns relearns flops =
(if debug andalso auto_level = 0 then