105 val unlearnN = "unlearn" |
105 val unlearnN = "unlearn" |
106 val learn_isarN = "learn_isar" |
106 val learn_isarN = "learn_isar" |
107 val learn_proverN = "learn_prover" |
107 val learn_proverN = "learn_prover" |
108 val relearn_isarN = "relearn_isar" |
108 val relearn_isarN = "relearn_isar" |
109 val relearn_proverN = "relearn_prover" |
109 val relearn_proverN = "relearn_prover" |
110 |
|
111 val learned_proof_prefix = ".." |
|
112 |
|
113 fun learned_proof_name () = |
|
114 learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ |
|
115 serial_string () |
|
116 |
110 |
117 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir |
111 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir |
118 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state") |
112 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state") |
119 |
113 |
120 fun wipe_out_mash_state_dir () = |
114 fun wipe_out_mash_state_dir () = |
380 end |
374 end |
381 else () |
375 else () |
382 end |
376 end |
383 |
377 |
384 (* |
378 (* |
385 avail_no = maximum number of theorems to check dependencies and symbols |
379 avail_num = maximum number of theorems to check dependencies and symbols |
|
380 adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num |
386 get_deps = returns dependencies of a theorem |
381 get_deps = returns dependencies of a theorem |
387 get_sym_ths = get theorems that have this feature |
382 get_sym_ths = get theorems that have this feature |
388 knns = number of nearest neighbours |
383 knns = number of nearest neighbours |
389 advno = number of predictions to return |
384 advno = number of predictions to return |
390 syms = symbols of the conjecture |
385 syms = symbols of the conjecture |
391 *) |
386 *) |
392 fun knn avail_no get_deps get_sym_ths knns advno syms = |
387 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms = |
393 let |
388 let |
394 (* Can be later used for TFIDF *) |
389 (* Can be later used for TFIDF *) |
395 fun sym_wght _ = 1.0 |
390 fun sym_wght _ = 1.0; |
396 |
391 val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0))); |
397 val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0))) |
|
398 |
|
399 fun inc_overlap j v = |
392 fun inc_overlap j v = |
400 let |
393 let |
401 val ov = snd (Array.sub (overlaps_sqr,j)) |
394 val ov = snd (Array.sub (overlaps_sqr,j)) |
402 in |
395 in |
403 Array.update (overlaps_sqr, j, (j, v + ov)) |
396 Array.update (overlaps_sqr, j, (j, v + ov)) |
404 end |
397 end; |
405 |
|
406 fun do_sym (s, con_wght) = |
398 fun do_sym (s, con_wght) = |
407 let |
399 let |
408 val sw = sym_wght s |
400 val sw = sym_wght s; |
409 val w2 = sw * sw * con_wght |
401 val w2 = sw * sw * con_wght; |
410 fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else () |
402 fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else () |
411 in |
403 in |
412 ignore (map do_th (get_sym_ths s)) |
404 ignore (map do_th (get_sym_ths s)) |
413 end |
405 end; |
414 |
406 val () = ignore (map do_sym syms); |
415 val _ = ignore (map do_sym syms) |
407 val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr; |
416 val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr |
408 val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0))); |
417 val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0))) |
|
418 |
|
419 fun inc_recommend j v = |
409 fun inc_recommend j v = |
|
410 if j >= adv_max then () else |
420 let |
411 let |
421 val ov = snd (Array.sub (recommends,j)) |
412 val ov = snd (Array.sub (recommends,j)) |
422 in |
413 in |
423 Array.update (recommends, j, (j, v + ov)) |
414 Array.update (recommends, j, (j, v + ov)) |
424 end |
415 end; |
425 |
|
426 fun for k = |
416 fun for k = |
427 if k = knns then () else |
417 if k = knns then () else |
428 if k >= avail_no then () else |
418 if k >= adv_max then () else |
429 let |
419 let |
430 val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1) |
420 val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1); |
431 val o1 = Math.sqrt o2 |
421 val o1 = Math.sqrt o2; |
432 val _ = inc_recommend j o1 |
422 val () = inc_recommend j o1; |
433 val ds = get_deps j |
423 val ds = get_deps j; |
434 val l = Real.fromInt (length ds) |
424 val l = Real.fromInt (length ds); |
435 val _ = map (fn d => inc_recommend d (o1 / l)) ds |
425 val _ = map (fn d => inc_recommend d (o1 / l)) ds |
436 in |
426 in |
437 for (k + 1) |
427 for (k + 1) |
438 end |
428 end; |
439 |
429 val () = for 0; |
440 val _ = for 0 |
430 val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends; |
441 val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends |
|
442 |
|
443 fun ret acc at = |
431 fun ret acc at = |
444 if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) |
432 if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) |
445 in |
433 in |
446 ret [] (max 0 (avail_no - advno)) |
434 ret [] (max 0 (adv_max - advno)) |
447 end |
435 end |
448 |
436 |
449 val knns = 40 (* FUDGE *) |
437 val knns = 40 (* FUDGE *) |
450 |
438 |
451 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
454 |
442 |
455 fun learn_and_query ctxt parents access_G max_suggs hints feats = |
443 fun learn_and_query ctxt parents access_G max_suggs hints feats = |
456 let |
444 let |
457 val str_of_feat = space_implode "|" |
445 val str_of_feat = space_implode "|" |
458 |
446 |
459 val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) = |
447 val visible_facts = Graph.all_preds access_G parents |
460 fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
448 val visible_fact_set = Symtab.make_set visible_facts |
|
449 |
|
450 val all_nodes = |
|
451 Graph.schedule (K I) access_G |
|
452 |> List.partition (Symtab.defined visible_fact_set o fst) |
|
453 |> op @ |
|
454 |
|
455 val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) = |
|
456 fold (fn (fact, (_, feats, deps)) => |
|
457 fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
461 let |
458 let |
462 val (_, feats, deps) = Graph.get_node access_G fact |
|
463 |
|
464 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
459 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
465 (case Symtab.lookup tab feat of |
460 (case Symtab.lookup tab feat of |
466 SOME i => ((i, weight), xtab) |
461 SOME i => ((i, weight), xtab) |
467 | NONE => ((n, weight), add_to_xtab feat xtab)) |
462 | NONE => ((n, weight), add_to_xtab feat xtab)) |
468 |
463 |
469 val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab |
464 val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab |
470 in |
465 in |
471 (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, |
466 (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, |
472 add_to_xtab fact fact_xtab, feat_xtab') |
467 add_to_xtab fact fact_xtab, feat_xtab') |
473 end) |
468 end) |
474 (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
469 all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
475 |
470 |
476 val facts = rev facts0 |
471 val facts = rev rev_facts |
477 val fact_ary = Array.fromList facts |
472 val fact_ary = Array.fromList facts |
478 |
473 |
479 val deps_ary = Array.fromList (rev depss0) |
474 val deps_ary = Array.fromList (rev rev_depss) |
480 val facts_ary = Array.array (num_feats, []) |
475 val facts_ary = Array.array (num_feats, []) |
481 val _ = |
476 val _ = |
482 fold (fn feats => fn fact => |
477 fold (fn feats => fn fact => |
483 let val fact' = fact - 1 in |
478 let val fact' = fact - 1 in |
484 List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) |
479 List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) |
485 feats; |
480 feats; |
486 fact' |
481 fact' |
487 end) |
482 end) |
488 featss (length featss) |
483 featss (length featss) |
489 in |
484 in |
490 trace_msg ctxt (fn () => |
485 trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ |
491 "MaSh_SML query " ^ encode_features feats ^ " from {" ^ |
486 elide_string 1000 (space_implode " " facts) ^ "}"); |
492 elide_string 1000 (space_implode " " facts) ^ "}"); |
487 knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary) |
493 knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns |
488 (curry Array.sub facts_ary) knns max_suggs |
494 max_suggs |
|
495 (map_filter (fn (feat, weight) => |
489 (map_filter (fn (feat, weight) => |
496 Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) |
490 Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) |
497 |> map ((fn i => Array.sub (fact_ary, i)) o fst) |
491 |> map ((fn i => Array.sub (fact_ary, i)) o fst) |
498 |> filter_out (String.isPrefix learned_proof_prefix) |
|
499 end |
492 end |
500 |
493 |
501 end; |
494 end; |
502 |
495 |
503 |
496 |
515 |
508 |
516 fun add_edge_to name parent = |
509 fun add_edge_to name parent = |
517 Graph.default_node (parent, (Isar_Proof, [], [])) |
510 Graph.default_node (parent, (Isar_Proof, [], [])) |
518 #> Graph.add_edge (parent, name) |
511 #> Graph.add_edge (parent, name) |
519 |
512 |
520 fun add_node kind name feats deps = |
513 fun add_node kind name parents feats deps = |
521 Graph.default_node (name, (kind, feats, deps)) |
514 Graph.default_node (name, (kind, feats, deps)) |
522 #> Graph.map_node name (K (kind, feats, deps)) |
515 #> Graph.map_node name (K (kind, feats, deps)) |
|
516 #> fold (add_edge_to name) parents |
523 |
517 |
524 fun try_graph ctxt when def f = |
518 fun try_graph ctxt when def f = |
525 f () |
519 f () |
526 handle |
520 handle |
527 Graph.CYCLES (cycle :: _) => |
521 Graph.CYCLES (cycle :: _) => |
573 SOME (version' :: node_lines) => |
567 SOME (version' :: node_lines) => |
574 let |
568 let |
575 fun extract_line_and_add_node line = |
569 fun extract_line_and_add_node line = |
576 (case extract_node line of |
570 (case extract_node line of |
577 NONE => I (* should not happen *) |
571 NONE => I (* should not happen *) |
578 | SOME (kind, name, parents, feats, deps) => |
572 | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps) |
579 add_node kind name feats deps |
|
580 #> fold (add_edge_to name) parents) |
|
581 |
573 |
582 val (access_G, num_known_facts) = |
574 val (access_G, num_known_facts) = |
583 (case string_ord (version', version) of |
575 (case string_ord (version', version) of |
584 EQUAL => |
576 EQUAL => |
585 (try_graph ctxt "loading state" Graph.empty (fn () => |
577 (try_graph ctxt "loading state" Graph.empty (fn () => |
1130 (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news)) |
1122 (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news)) |
1131 in |
1123 in |
1132 find_maxes Symtab.empty ([], Graph.maximals G) |
1124 find_maxes Symtab.empty ([], Graph.maximals G) |
1133 end |
1125 end |
1134 |
1126 |
1135 fun graph_islands G = |
|
1136 Graph.fold (fn (m, (_, (preds, succs))) => |
|
1137 (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G []; |
|
1138 |
|
1139 (* islands represent learned proofs associated with no facts *) |
|
1140 fun maximal_wrt_access_graph access_G facts = |
1127 fun maximal_wrt_access_graph access_G facts = |
1141 map (nickname_of_thm o snd) facts @ graph_islands access_G |
1128 map (nickname_of_thm o snd) facts |
1142 |> maximal_wrt_graph access_G |
1129 |> maximal_wrt_graph access_G |
1143 |
1130 |
1144 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm |
1131 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm |
1145 |
1132 |
1146 val chained_feature_factor = 0.5 (* FUDGE *) |
1133 val chained_feature_factor = 0.5 (* FUDGE *) |
1238 end |
1225 end |
1239 |
1226 |
1240 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = |
1227 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = |
1241 let |
1228 let |
1242 fun maybe_learn_from from (accum as (parents, G)) = |
1229 fun maybe_learn_from from (accum as (parents, G)) = |
1243 try_graph ctxt "updating G" accum (fn () => |
1230 try_graph ctxt "updating graph" accum (fn () => |
1244 (from :: parents, Graph.add_edge_acyclic (from, name) G)) |
1231 (from :: parents, Graph.add_edge_acyclic (from, name) G)) |
1245 val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps)) |
1232 val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps)) |
1246 val (parents, G) = ([], G) |> fold maybe_learn_from parents |
1233 val (parents, G) = ([], G) |> fold maybe_learn_from parents |
1247 val (deps, _) = ([], G) |> fold maybe_learn_from deps |
1234 val (deps, _) = ([], G) |> fold maybe_learn_from deps |
1248 in |
1235 in |
1283 val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |
1273 val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |
1284 in |
1274 in |
1285 map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} => |
1275 map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} => |
1286 let |
1276 let |
1287 val name = learned_proof_name () |
1277 val name = learned_proof_name () |
|
1278 val parents = maximal_wrt_access_graph access_G facts |
1288 val deps = used_ths |
1279 val deps = used_ths |
1289 |> filter (is_fact_in_graph access_G) |
1280 |> filter (is_fact_in_graph access_G) |
1290 |> map nickname_of_thm |
1281 |> map nickname_of_thm |
1291 in |
1282 in |
1292 if Config.get ctxt sml then |
1283 if Config.get ctxt sml then |
1293 let val access_G = access_G |> add_node Automatic_Proof name feats deps in |
1284 let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in |
1294 {access_G = access_G, num_known_facts = num_known_facts + 1, |
1285 {access_G = access_G, num_known_facts = num_known_facts + 1, |
1295 dirty = Option.map (cons name) dirty} |
1286 dirty = Option.map (cons name) dirty} |
1296 end |
1287 end |
1297 else |
1288 else |
1298 let val parents = maximal_wrt_access_graph access_G facts in |
1289 (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) |
1299 (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) |
|
1300 end |
|
1301 end); |
1290 end); |
1302 (true, "") |
1291 (true, "") |
1303 end) |
1292 end) |
1304 else |
1293 else |
1305 () |
1294 () |