42 val mash_unlearn : Proof.context -> unit |
42 val mash_unlearn : Proof.context -> unit |
43 val nickname_of : thm -> string |
43 val nickname_of : thm -> string |
44 val suggested_facts : |
44 val suggested_facts : |
45 (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list |
45 (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list |
46 val mesh_facts : |
46 val mesh_facts : |
47 int -> ((('a * thm) * real) list * ('a * thm) list) list -> ('a * thm) list |
47 int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list |
|
48 -> ('a * thm) list |
48 val theory_ord : theory * theory -> order |
49 val theory_ord : theory * theory -> order |
49 val thm_ord : thm * thm -> order |
50 val thm_ord : thm * thm -> order |
50 val goal_of_thm : theory -> thm -> thm |
51 val goal_of_thm : theory -> thm -> thm |
51 val run_prover_for_mash : |
52 val run_prover_for_mash : |
52 Proof.context -> params -> string -> fact list -> thm -> prover_result |
53 Proof.context -> params -> string -> fact list -> thm -> prover_result |
57 val atp_dependencies_of : |
58 val atp_dependencies_of : |
58 Proof.context -> params -> string -> int -> fact list -> unit Symtab.table |
59 Proof.context -> params -> string -> int -> fact list -> unit Symtab.table |
59 -> thm -> bool * string list option |
60 -> thm -> bool * string list option |
60 val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list |
61 val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list |
61 val mash_suggested_facts : |
62 val mash_suggested_facts : |
62 Proof.context -> params -> string -> int -> term list -> term |
63 Proof.context -> params -> string -> int -> term list -> term -> fact list |
63 -> fact list -> fact list * fact list |
64 -> fact list |
64 val mash_learn_proof : |
65 val mash_learn_proof : |
65 Proof.context -> params -> string -> term -> ('a * thm) list -> thm list |
66 Proof.context -> params -> string -> term -> ('a * thm) list -> thm list |
66 -> unit |
67 -> unit |
67 val mash_learn : |
68 val mash_learn : |
68 Proof.context -> params -> fact_override -> thm list -> bool -> unit |
69 Proof.context -> params -> fact_override -> thm list -> bool -> unit |
423 val tab = Symtab.empty |> fold add_fact facts |
424 val tab = Symtab.empty |> fold add_fact facts |
424 fun find_sugg (name, weight) = |
425 fun find_sugg (name, weight) = |
425 Symtab.lookup tab name |> Option.map (rpair weight) |
426 Symtab.lookup tab name |> Option.map (rpair weight) |
426 in map_filter find_sugg suggs end |
427 in map_filter find_sugg suggs end |
427 |
428 |
428 fun sum_avg [] = 0 |
429 fun scaled_avg [] = 0 |
429 | sum_avg xs = |
430 | scaled_avg xs = |
430 Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs |
431 Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs |
431 |
432 |
432 fun normalize_scores [] = [] |
433 fun avg [] = 0.0 |
433 | normalize_scores ((fact, score) :: tail) = |
434 | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) |
434 (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail |
435 |
435 |
436 fun normalize_scores _ [] = [] |
436 fun mesh_facts max_facts [(sels, unks)] = |
437 | normalize_scores max_facts xs = |
|
438 let val avg = avg (map snd (take max_facts xs)) in |
|
439 map (apsnd (curry Real.* (1.0 / avg))) xs |
|
440 end |
|
441 |
|
442 fun mesh_facts max_facts [(_, (sels, unks))] = |
437 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
443 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
438 | mesh_facts max_facts mess = |
444 | mesh_facts max_facts mess = |
439 let |
445 let |
440 val mess = mess |> map (apfst (normalize_scores #> `length)) |
446 val mess = |
|
447 mess |> map (apsnd (apfst (normalize_scores max_facts #> `length))) |
441 val fact_eq = Thm.eq_thm o pairself snd |
448 val fact_eq = Thm.eq_thm o pairself snd |
442 fun score_at sels = try (nth sels) #> Option.map snd |
449 fun score_in fact (global_weight, ((sel_len, sels), unks)) = |
443 fun score_in fact ((sel_len, sels), unks) = |
450 let |
444 case find_index (curry fact_eq fact o fst) sels of |
451 fun score_at j = |
445 ~1 => (case find_index (curry fact_eq fact) unks of |
452 case try (nth sels) j of |
446 ~1 => score_at sels sel_len |
453 SOME (_, score) => SOME (global_weight * score) |
447 | _ => NONE) |
454 | NONE => NONE |
448 | rank => score_at sels rank |
455 in |
449 fun weight_of fact = mess |> map_filter (score_in fact) |> sum_avg |
456 case find_index (curry fact_eq fact o fst) sels of |
|
457 ~1 => (case find_index (curry fact_eq fact) unks of |
|
458 ~1 => score_at sel_len |
|
459 | _ => NONE) |
|
460 | rank => score_at rank |
|
461 end |
|
462 fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg |
450 val facts = |
463 val facts = |
451 fold (union fact_eq o map fst o take max_facts o snd o fst) mess [] |
464 fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess |
|
465 [] |
452 in |
466 in |
453 facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst) |
467 facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst) |
454 |> map snd |> take max_facts |
468 |> map snd |> take max_facts |
455 end |
469 end |
456 |
470 |
457 fun thy_feature_of s = ("y" ^ s, 1.0 (* FUDGE *)) |
471 fun thy_feature_of s = ("y" ^ s, 1.0 (* FUDGE *)) |
458 fun term_feature_of s = ("c" ^ s, 1.0 (* FUDGE *)) |
472 fun term_feature_of s = ("c" ^ s, 1.0 (* FUDGE *)) |
459 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *)) |
473 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *)) |
460 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *)) |
474 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *)) |
461 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *)) |
475 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *)) |
462 val local_feature = ("local", 20.0 (* FUDGE *)) |
476 val local_feature = ("local", 1.0 (* FUDGE *)) |
463 val lams_feature = ("lams", 1.0 (* FUDGE *)) |
477 val lams_feature = ("lams", 1.0 (* FUDGE *)) |
464 val skos_feature = ("skos", 1.0 (* FUDGE *)) |
478 val skos_feature = ("skos", 1.0 (* FUDGE *)) |
465 |
479 |
466 fun theory_ord p = |
480 fun theory_ord p = |
467 if Theory.eq_thy p then |
481 if Theory.eq_thy p then |
529 | patternify_term _ 0 _ = [] |
543 | patternify_term _ 0 _ = [] |
530 | patternify_term args depth (t $ u) = |
544 | patternify_term args depth (t $ u) = |
531 let |
545 let |
532 val ps = patternify_term (u :: args) depth t |
546 val ps = patternify_term (u :: args) depth t |
533 val qs = "" :: patternify_term [] (depth - 1) u |
547 val qs = "" :: patternify_term [] (depth - 1) u |
534 in map_product (fn p => fn "" => p | q => "(" ^ q ^ ")") ps qs end |
548 in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end |
535 | patternify_term _ _ _ = [] |
549 | patternify_term _ _ _ = [] |
536 val add_term_pattern = |
550 val add_term_pattern = |
537 union (op = o pairself fst) o map term_feature_of oo patternify_term [] |
551 union (op = o pairself fst) o map term_feature_of oo patternify_term [] |
538 fun add_term_patterns ~1 _ = I |
552 fun add_term_patterns ~1 _ = I |
539 | add_term_patterns depth t = |
553 | add_term_patterns depth t = |
690 else |
704 else |
691 (maxs, Graph.Keys.fold (insert_new seen) |
705 (maxs, Graph.Keys.fold (insert_new seen) |
692 (Graph.imm_preds fact_G new) news)) |
706 (Graph.imm_preds fact_G new) news)) |
693 in find_maxes Symtab.empty ([], Graph.maximals fact_G) end |
707 in find_maxes Symtab.empty ([], Graph.maximals fact_G) end |
694 |
708 |
695 (* Generate more suggestions than requested, because some might be thrown out |
|
696 later for various reasons and "meshing" gives better results with some |
|
697 slack. *) |
|
698 fun max_suggs_of max_facts = max_facts + Int.min (50, max_facts) |
|
699 |
|
700 fun is_fact_in_graph fact_G (_, th) = |
709 fun is_fact_in_graph fact_G (_, th) = |
701 can (Graph.get_node fact_G) (nickname_of th) |
710 can (Graph.get_node fact_G) (nickname_of th) |
702 |
711 |
703 fun interleave 0 _ _ = [] |
|
704 | interleave n [] ys = take n ys |
|
705 | interleave n xs [] = take n xs |
|
706 | interleave 1 (x :: _) _ = [x] |
|
707 | interleave n (x :: xs) (y :: ys) = x :: y :: interleave (n - 2) xs ys |
|
708 |
|
709 (* factor that controls whether unknown global facts should be included *) |
712 (* factor that controls whether unknown global facts should be included *) |
710 val include_unk_global_factor = 15 |
713 val include_unk_global_factor = 15 |
711 |
714 |
712 val weight_mash_facts = weight_mepo_facts (* use MePo weights for now *) |
715 (* use MePo weights for now *) |
|
716 val weight_raw_mash_facts = weight_mepo_facts |
|
717 val weight_mash_facts = weight_raw_mash_facts |
|
718 |
|
719 (* FUDGE *) |
|
720 fun weight_of_proximity_fact rank = |
|
721 Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0 |
|
722 |
|
723 fun weight_proximity_facts facts = |
|
724 facts ~~ map weight_of_proximity_fact (0 upto length facts - 1) |
713 |
725 |
714 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts |
726 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts |
715 concl_t facts = |
727 concl_t facts = |
716 let |
728 let |
717 val thy = Proof_Context.theory_of ctxt |
729 val thy = Proof_Context.theory_of ctxt |
723 let |
735 let |
724 val parents = maximal_in_graph fact_G facts |
736 val parents = maximal_in_graph fact_G facts |
725 val feats = |
737 val feats = |
726 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) |
738 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) |
727 in |
739 in |
728 (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts) |
740 (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats)) |
729 (parents, feats)) |
|
730 end) |
741 end) |
731 val (chained, unchained) = |
742 val (chained, unchained) = |
732 List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts |
743 List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts |
733 val sels = |
744 val raw_mash = |
734 facts |> suggested_facts suggs |
745 facts |> suggested_facts suggs |
735 (* The weights currently returned by "mash.py" are too spaced out to |
746 (* The weights currently returned by "mash.py" are too spaced out to |
736 make any sense. *) |
747 make any sense. *) |
737 |> map fst |
748 |> map fst |
738 |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained) |
749 val proximity = |
739 val (unk_global, unk_local) = |
750 chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained |
740 unchained |> filter_out (is_fact_in_graph fact_G) |
751 |> sort (thm_ord o pairself snd o swap)) |
741 |> List.partition (fn ((_, (scope, _)), _) => scope = Global) |
752 val unknown = facts |> filter_out (is_fact_in_graph fact_G) |
742 val (small_unk_global, big_unk_global) = |
753 val mess = |
743 ([], unk_global) |
754 [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)), |
744 |> include_unk_global_factor * length unk_global <= max_facts ? swap |
755 (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))] |
745 in |
756 in mesh_facts max_facts mess end |
746 (interleave max_facts (chained @ unk_local @ small_unk_global) sels, |
|
747 big_unk_global) |
|
748 end |
|
749 |
757 |
750 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) = |
758 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) = |
751 let |
759 let |
752 fun maybe_add_from from (accum as (parents, graph)) = |
760 fun maybe_add_from from (accum as (parents, graph)) = |
753 try_graph ctxt "updating graph" accum (fn () => |
761 try_graph ctxt "updating graph" accum (fn () => |
993 end |
1001 end |
994 |
1002 |
995 fun is_mash_enabled () = (getenv "MASH" = "yes") |
1003 fun is_mash_enabled () = (getenv "MASH" = "yes") |
996 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt))) |
1004 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt))) |
997 |
1005 |
|
1006 (* Generate more suggestions than requested, because some might be thrown out |
|
1007 later for various reasons. *) |
|
1008 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts) |
|
1009 |
998 (* The threshold should be large enough so that MaSh doesn't kick in for Auto |
1010 (* The threshold should be large enough so that MaSh doesn't kick in for Auto |
999 Sledgehammer and Try. *) |
1011 Sledgehammer and Try. *) |
1000 val min_secs_for_learning = 15 |
1012 val min_secs_for_learning = 15 |
1001 |
1013 |
1002 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover |
1014 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover |
1038 fun mepo () = |
1050 fun mepo () = |
1039 mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t |
1051 mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t |
1040 facts |
1052 facts |
1041 |> weight_mepo_facts |
1053 |> weight_mepo_facts |
1042 fun mash () = |
1054 fun mash () = |
1043 mash_suggested_facts ctxt params prover max_facts hyp_ts concl_t facts |
1055 mash_suggested_facts ctxt params prover (generous_max_facts max_facts) |
1044 |>> weight_mash_facts |
1056 hyp_ts concl_t facts |
|
1057 |> weight_mash_facts |
1045 val mess = |
1058 val mess = |
1046 [] |> (if fact_filter <> mashN then cons (mepo (), []) else I) |
1059 [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I) |
1047 |> (if fact_filter <> mepoN then cons (mash ()) else I) |
1060 |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I) |
1048 in |
1061 in |
1049 mesh_facts max_facts mess |
1062 mesh_facts max_facts mess |
1050 |> not (null add_ths) ? prepend_facts add_ths |
1063 |> not (null add_ths) ? prepend_facts add_ths |
1051 end |
1064 end |
1052 |
1065 |