34 val encode_features : (string * real) list -> string |
34 val encode_features : (string * real) list -> string |
35 val extract_suggestions : string -> string * string list |
35 val extract_suggestions : string -> string * string list |
36 |
36 |
37 structure MaSh: |
37 structure MaSh: |
38 sig |
38 sig |
39 val unlearn : Proof.context -> unit |
39 val unlearn : Proof.context -> bool -> unit |
40 val learn : |
40 val learn : |
41 Proof.context -> bool |
41 Proof.context -> bool |
42 -> (string * string list * string list * string list) list -> unit |
42 -> (string * string list * string list * string list) list -> unit |
43 val relearn : |
43 val relearn : |
44 Proof.context -> bool -> (string * string list) list -> unit |
44 Proof.context -> bool -> (string * string list) list -> unit |
47 -> (string * string list * string list * string list) list |
47 -> (string * string list * string list * string list) list |
48 * string list * string list * (string * real) list |
48 * string list * string list * (string * real) list |
49 -> string list |
49 -> string list |
50 end |
50 end |
51 |
51 |
52 val mash_unlearn : Proof.context -> unit |
52 val mash_unlearn : Proof.context -> params -> unit |
53 val nickname_of_thm : thm -> string |
53 val nickname_of_thm : thm -> string |
54 val find_suggested_facts : |
54 val find_suggested_facts : |
55 Proof.context -> ('b * thm) list -> string list -> ('b * thm) list |
55 Proof.context -> ('b * thm) list -> string list -> ('b * thm) list |
56 val mesh_facts : |
56 val mesh_facts : |
57 ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list |
57 ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list |
88 Proof.context -> params -> string -> term -> ('a * thm) list -> thm list |
88 Proof.context -> params -> string -> term -> ('a * thm) list -> thm list |
89 -> unit |
89 -> unit |
90 val mash_learn : |
90 val mash_learn : |
91 Proof.context -> params -> fact_override -> thm list -> bool -> unit |
91 Proof.context -> params -> fact_override -> thm list -> bool -> unit |
92 val is_mash_enabled : unit -> bool |
92 val is_mash_enabled : unit -> bool |
93 val mash_can_suggest_facts : Proof.context -> bool |
93 val mash_can_suggest_facts : Proof.context -> bool -> bool |
94 val generous_max_facts : int -> int |
94 val generous_max_facts : int -> int |
95 val mepo_weight : real |
95 val mepo_weight : real |
96 val mash_weight : real |
96 val mash_weight : real |
97 val relevant_facts : |
97 val relevant_facts : |
98 Proof.context -> params -> string -> int -> fact_override -> term list |
98 Proof.context -> params -> string -> int -> fact_override -> term list |
99 -> term -> raw_fact list -> (string * fact list) list |
99 -> term -> raw_fact list -> (string * fact list) list |
100 val kill_learners : Proof.context -> unit |
100 val kill_learners : Proof.context -> params -> unit |
101 val running_learners : unit -> unit |
101 val running_learners : unit -> unit |
102 end; |
102 end; |
103 |
103 |
104 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH = |
104 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH = |
105 struct |
105 struct |
251 | _ => ("", []) |
251 | _ => ("", []) |
252 |
252 |
253 structure MaSh = |
253 structure MaSh = |
254 struct |
254 struct |
255 |
255 |
256 fun shutdown ctxt = |
256 fun shutdown ctxt overlord = |
257 run_mash_tool ctxt false [shutdown_server_arg] ([], K "") (K ()) |
257 run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ()) |
258 |
258 |
259 fun unlearn ctxt = |
259 fun unlearn ctxt overlord = |
260 let val path = mash_model_dir () in |
260 let val path = mash_model_dir () in |
261 trace_msg ctxt (K "MaSh unlearn"); |
261 trace_msg ctxt (K "MaSh unlearn"); |
262 shutdown ctxt; |
262 shutdown ctxt overlord; |
263 try (File.fold_dir (fn file => fn _ => |
263 try (File.fold_dir (fn file => fn _ => |
264 try File.rm (Path.append path (Path.basic file))) |
264 try File.rm (Path.append path (Path.basic file))) |
265 path) NONE; |
265 path) NONE; |
266 () |
266 () |
267 end |
267 end |
357 SOME (unencode_str name, unencode_strs parents, |
357 SOME (unencode_str name, unencode_strs parents, |
358 try proof_kind_of_str kind |> the_default Isar_Proof) |
358 try proof_kind_of_str kind |> the_default Isar_Proof) |
359 | _ => NONE) |
359 | _ => NONE) |
360 | _ => NONE |
360 | _ => NONE |
361 |
361 |
362 fun load _ (state as (true, _)) = state |
362 fun load _ _ (state as (true, _)) = state |
363 | load ctxt _ = |
363 | load ctxt overlord _ = |
364 let val path = mash_state_file () in |
364 let val path = mash_state_file () in |
365 (true, |
365 (true, |
366 case try File.read_lines path of |
366 case try File.read_lines path of |
367 SOME (version' :: node_lines) => |
367 SOME (version' :: node_lines) => |
368 let |
368 let |
381 (try_graph ctxt "loading state" Graph.empty (fn () => |
381 (try_graph ctxt "loading state" Graph.empty (fn () => |
382 fold add_node node_lines Graph.empty), |
382 fold add_node node_lines Graph.empty), |
383 length node_lines) |
383 length node_lines) |
384 | LESS => |
384 | LESS => |
385 (* can't parse old file *) |
385 (* can't parse old file *) |
386 (MaSh.unlearn ctxt; (Graph.empty, 0)) |
386 (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) |
387 | GREATER => raise FILE_VERSION_TOO_NEW () |
387 | GREATER => raise FILE_VERSION_TOO_NEW () |
388 in |
388 in |
389 trace_msg ctxt (fn () => |
389 trace_msg ctxt (fn () => |
390 "Loaded fact graph (" ^ graph_info access_G ^ ")"); |
390 "Loaded fact graph (" ^ graph_info access_G ^ ")"); |
391 {access_G = access_G, num_known_facts = num_known_facts, |
391 {access_G = access_G, num_known_facts = num_known_facts, |
421 val global_state = |
421 val global_state = |
422 Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state) |
422 Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state) |
423 |
423 |
424 in |
424 in |
425 |
425 |
426 fun map_state ctxt f = |
426 fun map_state ctxt overlord f = |
427 Synchronized.change global_state (load ctxt ##> (f #> save ctxt)) |
427 Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt)) |
428 handle FILE_VERSION_TOO_NEW () => () |
428 handle FILE_VERSION_TOO_NEW () => () |
429 |
429 |
430 fun peek_state ctxt f = |
430 fun peek_state ctxt overlord f = |
431 Synchronized.change_result global_state |
431 Synchronized.change_result global_state |
432 (perhaps (try (load ctxt)) #> `snd #>> f) |
432 (perhaps (try (load ctxt overlord)) #> `snd #>> f) |
433 |
433 |
434 fun clear_state ctxt = |
434 fun clear_state ctxt overlord = |
435 Synchronized.change global_state (fn _ => |
435 Synchronized.change global_state (fn _ => |
436 (MaSh.unlearn ctxt; (* also removes the state file *) |
436 (MaSh.unlearn ctxt overlord; (* also removes the state file *) |
437 (false, empty_state))) |
437 (false, empty_state))) |
438 |
438 |
439 end |
439 end |
440 |
440 |
441 val mash_unlearn = clear_state |
441 fun mash_unlearn ctxt ({overlord, ...} : params) = clear_state ctxt overlord |
442 |
442 |
443 |
443 |
444 (*** Isabelle helpers ***) |
444 (*** Isabelle helpers ***) |
445 |
445 |
446 val local_prefix = "local" ^ Long_Name.separator |
446 val local_prefix = "local" ^ Long_Name.separator |
587 | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S |
587 | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S |
588 |
588 |
589 fun maybe_singleton_str _ "" = [] |
589 fun maybe_singleton_str _ "" = [] |
590 | maybe_singleton_str pref s = [pref ^ s] |
590 | maybe_singleton_str pref s = [pref ^ s] |
591 |
591 |
592 val max_pat_breadth = 10 |
592 val max_pat_breadth = 10 (* FUDGE *) |
593 |
593 |
594 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth |
594 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth |
595 type_max_depth ts = |
595 type_max_depth ts = |
596 let |
596 let |
597 val thy = Proof_Context.theory_of ctxt |
597 val thy = Proof_Context.theory_of ctxt |
639 16.0 + |
639 16.0 + |
640 (if num_facts = 0 then |
640 (if num_facts = 0 then |
641 0.0 |
641 0.0 |
642 else |
642 else |
643 let val count = Symtab.lookup const_tab s |> the_default 1 in |
643 let val count = Symtab.lookup const_tab s |> the_default 1 in |
644 (Real.fromInt num_facts / Real.fromInt count) (* FUDGE *) |
644 Real.fromInt num_facts / Real.fromInt count (* FUDGE *) |
645 end) |
645 end) |
646 fun pattify_term _ _ 0 _ = [] |
646 fun pattify_term _ _ 0 _ = [] |
647 | pattify_term _ args _ (Const (x as (s, _))) = |
647 | pattify_term _ args _ (Const (x as (s, _))) = |
648 if fst (is_built_in x args) then [] |
648 if fst (is_built_in x args) then [] |
649 else [(massage_long_name s, weight_of_const s)] |
649 else [(massage_long_name s, weight_of_const s)] |
904 |
904 |
905 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm |
905 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm |
906 |
906 |
907 val chained_feature_factor = 0.5 |
907 val chained_feature_factor = 0.5 |
908 val extra_feature_factor = 0.1 |
908 val extra_feature_factor = 0.1 |
909 val num_extra_feature_facts = 10 (* FUDGE *) |
909 val num_extra_feature_facts = 0 (* FUDGE *) |
910 |
910 |
911 (* FUDGE *) |
911 (* FUDGE *) |
912 fun weight_of_proximity_fact rank = |
912 fun weight_of_proximity_fact rank = |
913 Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 |
913 Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 |
914 |
914 |
956 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |
956 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |
957 [prop_of th] |
957 [prop_of th] |
958 |> features_of ctxt prover (theory_of_thm th) num_facts const_tab stature |
958 |> features_of ctxt prover (theory_of_thm th) num_facts const_tab stature |
959 |> map (apsnd (fn r => weight * factor * r)) |
959 |> map (apsnd (fn r => weight * factor * r)) |
960 val (access_G, suggs) = |
960 val (access_G, suggs) = |
961 peek_state ctxt (fn {access_G, ...} => |
961 peek_state ctxt overlord (fn {access_G, ...} => |
962 if Graph.is_empty access_G then |
962 if Graph.is_empty access_G then |
963 (access_G, []) |
963 (access_G, []) |
964 else |
964 else |
965 let |
965 let |
966 val parents = maximal_wrt_access_graph access_G facts |
966 val parents = maximal_wrt_access_graph access_G facts |
972 |> map (rpair 1.0) |
972 |> map (rpair 1.0) |
973 |> map (chained_or_extra_features_of chained_feature_factor) |
973 |> map (chained_or_extra_features_of chained_feature_factor) |
974 |> rpair [] |-> fold (union (op = o pairself fst)) |
974 |> rpair [] |-> fold (union (op = o pairself fst)) |
975 val extra_feats = |
975 val extra_feats = |
976 facts |
976 facts |
977 |> take (num_extra_feature_facts - length chained) |
977 |> take (Int.max (0, num_extra_feature_facts - length chained)) |
978 |> weight_facts_steeply |
978 |> weight_facts_steeply |
979 |> map (chained_or_extra_features_of extra_feature_factor) |
979 |> map (chained_or_extra_features_of extra_feature_factor) |
980 |> rpair [] |-> fold (union (op = o pairself fst)) |
980 |> rpair [] |-> fold (union (op = o pairself fst)) |
981 val feats = |
981 val feats = |
982 fold (union (op = o pairself fst)) [chained_feats, extra_feats] |
982 fold (union (op = o pairself fst)) [chained_feats, extra_feats] |
1034 val name = freshish_name () |
1034 val name = freshish_name () |
1035 val feats = |
1035 val feats = |
1036 features_of ctxt prover thy 0 Symtab.empty (Local, General) [t] |
1036 features_of ctxt prover thy 0 Symtab.empty (Local, General) [t] |
1037 |> map fst |
1037 |> map fst |
1038 in |
1038 in |
1039 peek_state ctxt (fn {access_G, ...} => |
1039 peek_state ctxt overlord (fn {access_G, ...} => |
1040 let |
1040 let |
1041 val parents = maximal_wrt_access_graph access_G facts |
1041 val parents = maximal_wrt_access_graph access_G facts |
1042 val deps = |
1042 val deps = |
1043 used_ths |> filter (is_fact_in_graph access_G) |
1043 used_ths |> filter (is_fact_in_graph access_G) |
1044 |> map nickname_of_thm |
1044 |> map nickname_of_thm |
1058 auto_level run_prover learn_timeout facts = |
1058 auto_level run_prover learn_timeout facts = |
1059 let |
1059 let |
1060 val timer = Timer.startRealTimer () |
1060 val timer = Timer.startRealTimer () |
1061 fun next_commit_time () = |
1061 fun next_commit_time () = |
1062 Time.+ (Timer.checkRealTimer timer, commit_timeout) |
1062 Time.+ (Timer.checkRealTimer timer, commit_timeout) |
1063 val {access_G, ...} = peek_state ctxt I |
1063 val {access_G, ...} = peek_state ctxt overlord I |
1064 val is_in_access_G = is_fact_in_graph access_G o snd |
1064 val is_in_access_G = is_fact_in_graph access_G o snd |
1065 val no_new_facts = forall is_in_access_G facts |
1065 val no_new_facts = forall is_in_access_G facts |
1066 in |
1066 in |
1067 if no_new_facts andalso not run_prover then |
1067 if no_new_facts andalso not run_prover then |
1068 if auto_level < 2 then |
1068 if auto_level < 2 then |
1112 fun commit last learns relearns flops = |
1112 fun commit last learns relearns flops = |
1113 (if debug andalso auto_level = 0 then |
1113 (if debug andalso auto_level = 0 then |
1114 Output.urgent_message "Committing..." |
1114 Output.urgent_message "Committing..." |
1115 else |
1115 else |
1116 (); |
1116 (); |
1117 map_state ctxt (do_commit (rev learns) relearns flops); |
1117 map_state ctxt overlord (do_commit (rev learns) relearns flops); |
1118 if not last andalso auto_level = 0 then |
1118 if not last andalso auto_level = 0 then |
1119 let val num_proofs = length learns + length relearns in |
1119 let val num_proofs = length learns + length relearns in |
1120 "Learned " ^ string_of_int num_proofs ^ " " ^ |
1120 "Learned " ^ string_of_int num_proofs ^ " " ^ |
1121 (if run_prover then "automatic" else "Isar") ^ " proof" ^ |
1121 (if run_prover then "automatic" else "Isar") ^ " proof" ^ |
1122 plural_s num_proofs ^ " in the last " ^ |
1122 plural_s num_proofs ^ " in the last " ^ |
1248 |> Output.urgent_message; |
1248 |> Output.urgent_message; |
1249 learn 0 false) |
1249 learn 0 false) |
1250 end |
1250 end |
1251 |
1251 |
1252 fun is_mash_enabled () = (getenv "MASH" = "yes") |
1252 fun is_mash_enabled () = (getenv "MASH" = "yes") |
1253 fun mash_can_suggest_facts ctxt = |
1253 fun mash_can_suggest_facts ctxt overlord = |
1254 not (Graph.is_empty (#access_G (peek_state ctxt I))) |
1254 not (Graph.is_empty (#access_G (peek_state ctxt overlord I))) |
1255 |
1255 |
1256 (* Generate more suggestions than requested, because some might be thrown out |
1256 (* Generate more suggestions than requested, because some might be thrown out |
1257 later for various reasons. *) |
1257 later for various reasons. *) |
1258 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts) |
1258 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts) |
1259 |
1259 |
1262 |
1262 |
1263 (* The threshold should be large enough so that MaSh doesn't kick in for Auto |
1263 (* The threshold should be large enough so that MaSh doesn't kick in for Auto |
1264 Sledgehammer and Try. *) |
1264 Sledgehammer and Try. *) |
1265 val min_secs_for_learning = 15 |
1265 val min_secs_for_learning = 15 |
1266 |
1266 |
1267 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover |
1267 fun relevant_facts ctxt (params as {overlord, learn, fact_filter, timeout, ...}) |
1268 max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts = |
1268 prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t |
|
1269 facts = |
1269 if not (subset (op =) (the_list fact_filter, fact_filters)) then |
1270 if not (subset (op =) (the_list fact_filter, fact_filters)) then |
1270 error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".") |
1271 error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".") |
1271 else if only then |
1272 else if only then |
1272 let val facts = facts |> map fact_of_raw_fact in |
1273 let val facts = facts |> map fact_of_raw_fact in |
1273 [("", facts)] |
1274 [("", facts)] |
1293 case fact_filter of |
1294 case fact_filter of |
1294 SOME ff => (() |> ff <> mepoN ? maybe_learn; ff) |
1295 SOME ff => (() |> ff <> mepoN ? maybe_learn; ff) |
1295 | NONE => |
1296 | NONE => |
1296 if is_mash_enabled () then |
1297 if is_mash_enabled () then |
1297 (maybe_learn (); |
1298 (maybe_learn (); |
1298 if mash_can_suggest_facts ctxt then meshN else mepoN) |
1299 if mash_can_suggest_facts ctxt overlord then meshN else mepoN) |
1299 else |
1300 else |
1300 mepoN |
1301 mepoN |
1301 val add_ths = Attrib.eval_thms ctxt add |
1302 val add_ths = Attrib.eval_thms ctxt add |
1302 fun in_add (_, th) = member Thm.eq_thm_prop add_ths th |
1303 fun in_add (_, th) = member Thm.eq_thm_prop add_ths th |
1303 fun add_and_take accepts = |
1304 fun add_and_take accepts = |
1333 [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), |
1334 [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), |
1334 (mashN, mash |> map fst |> add_and_take)] |
1335 (mashN, mash |> map fst |> add_and_take)] |
1335 | _ => [("", mesh)] |
1336 | _ => [("", mesh)] |
1336 end |
1337 end |
1337 |
1338 |
1338 fun kill_learners ctxt = |
1339 fun kill_learners ctxt ({overlord, ...} : params) = |
1339 (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt) |
1340 (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord) |
1340 fun running_learners () = Async_Manager.running_threads MaShN "learner" |
1341 fun running_learners () = Async_Manager.running_threads MaShN "learner" |
1341 |
1342 |
1342 end; |
1343 end; |