58 (string * real) list |
58 (string * real) list |
59 end |
59 end |
60 |
60 |
61 structure MaSh_SML : |
61 structure MaSh_SML : |
62 sig |
62 sig |
63 val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) -> |
63 val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int -> |
|
64 int list -> (int * real) list -> (int * real) list |
|
65 val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int -> |
64 int -> (int * real) list -> (int * real) list |
66 int -> (int * real) list -> (int * real) list |
65 val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) -> |
67 val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) -> |
66 int -> int -> (int * real) list -> (int * real) list |
68 int -> int -> (int * real) list -> (int * real) list |
67 val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) -> |
|
68 (int -> int list) -> int -> int -> (int * real) list -> (int * real) list |
|
69 val query : Proof.context -> bool -> mash_engine -> string list -> int -> |
69 val query : Proof.context -> bool -> mash_engine -> string list -> int -> |
70 (string * (string * real) list * string list) list * string list * (string * real) list -> |
70 (string * (string * real) list * string list) list * string list * (string * real) list -> |
71 string list |
71 string list |
72 end |
72 end |
73 |
73 |
421 |
421 |
422 exception EXIT of unit |
422 exception EXIT of unit |
423 |
423 |
424 (* |
424 (* |
425 num_facts = maximum number of theorems to check dependencies and symbols |
425 num_facts = maximum number of theorems to check dependencies and symbols |
426 num_visible_facts = do not return theorems over or equal to this number. |
|
427 Must satisfy: num_visible_facts <= num_facts. |
|
428 get_deps = returns dependencies of a theorem |
426 get_deps = returns dependencies of a theorem |
429 get_sym_ths = get theorems that have this feature |
427 get_sym_ths = get theorems that have this feature |
430 max_suggs = number of suggestions to return |
428 max_suggs = number of suggestions to return |
431 feats = features of the goal |
429 feats = features of the goal |
432 *) |
430 *) |
433 fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats = |
431 fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats = |
434 let |
432 let |
435 (* Can be later used for TFIDF *) |
433 (* Can be later used for TFIDF *) |
436 fun sym_wght _ = 1.0 |
434 fun sym_wght _ = 1.0 |
437 |
435 |
438 val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) |
436 val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) |
455 end |
453 end |
456 |
454 |
457 val _ = List.app do_feat feats |
455 val _ = List.app do_feat feats |
458 val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr |
456 val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr |
459 val no_recommends = Unsynchronized.ref 0 |
457 val no_recommends = Unsynchronized.ref 0 |
460 val recommends = Array.tabulate (num_visible_facts, rpair 0.0) |
458 val recommends = Array.tabulate (num_facts, rpair 0.0) |
461 val age = Unsynchronized.ref 1000000000.0 |
459 val age = Unsynchronized.ref 1000000000.0 |
462 |
460 |
463 fun inc_recommend j v = |
461 fun inc_recommend j v = |
464 let val ov = snd (Array.sub (recommends, j)) in |
462 let val ov = snd (Array.sub (recommends, j)) in |
465 if ov <= 0.0 then |
463 if ov <= 0.0 then |
468 (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) |
466 (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) |
469 end |
467 end |
470 |
468 |
471 val k = Unsynchronized.ref 0 |
469 val k = Unsynchronized.ref 0 |
472 fun do_k k = |
470 fun do_k k = |
473 if k >= num_visible_facts then |
471 if k >= num_facts then |
474 raise EXIT () |
472 raise EXIT () |
475 else |
473 else |
476 let |
474 let |
477 val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) |
475 val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) |
478 val o1 = Math.sqrt o2 |
476 val o1 = Math.sqrt o2 |
494 |
492 |
495 fun ret acc at = |
493 fun ret acc at = |
496 if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) |
494 if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) |
497 in |
495 in |
498 while1 (); while2 (); |
496 while1 (); while2 (); |
499 heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends; |
497 heap (Real.compare o pairself snd) max_suggs num_facts recommends; |
500 ret [] (Integer.max 0 (num_visible_facts - max_suggs)) |
498 ret [] (Integer.max 0 (num_facts - max_suggs)) |
501 end |
499 end |
502 |
500 |
503 val nb_def_prior_weight = 21 (* FUDGE *) |
501 val nb_def_prior_weight = 21 (* FUDGE *) |
504 |
502 |
505 fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats = |
503 fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats = |
539 val dffreq = Array.array (num_feats, 0) |
537 val dffreq = Array.array (num_feats, 0) |
540 in |
538 in |
541 learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats |
539 learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats |
542 end |
540 end |
543 |
541 |
544 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats |
542 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs feats |
545 (tfreq, sfreq, idf) = |
543 (tfreq, sfreq, idf) = |
546 let |
544 let |
547 val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *) |
545 val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *) |
548 val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *) |
546 val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *) |
549 val def_val = ~15.0 (* FUDGE *) |
547 val def_val = ~15.0 (* FUDGE *) |
574 val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 |
572 val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 |
575 in |
573 in |
576 res + tau * sum_of_weights |
574 res + tau * sum_of_weights |
577 end |
575 end |
578 |
576 |
579 val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j))) |
577 val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j))) |
580 |
578 |
581 fun ret acc at = |
579 fun ret acc at = |
582 if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) |
580 if at = num_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) |
583 in |
581 in |
584 heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior; |
582 heap (Real.compare o pairself snd) max_suggs num_facts posterior; |
585 ret [] (Integer.max 0 (num_visible_facts - max_suggs)) |
583 ret [] (Integer.max 0 (num_facts - max_suggs)) |
586 end |
584 end |
587 |
585 |
588 fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = |
586 fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs feats = |
589 learn num_facts get_deps get_feats num_feats |
587 learn num_facts get_deps get_feats num_feats |
590 |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats |
588 |> naive_bayes_query opts num_facts max_suggs feats |
591 |
589 |
592 (* experimental *) |
590 (* experimental *) |
593 fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs |
591 fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats = |
594 feats = |
|
595 let |
592 let |
596 fun name_of_fact j = "f" ^ string_of_int j |
593 fun name_of_fact j = "f" ^ string_of_int j |
597 fun fact_of_name s = the (Int.fromString (unprefix "f" s)) |
594 fun fact_of_name s = the (Int.fromString (unprefix "f" s)) |
598 fun name_of_feature j = "F" ^ string_of_int j |
595 fun name_of_feature j = "F" ^ string_of_int j |
599 fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)] |
596 fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)] |
600 |
597 |
601 val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), |
598 val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), |
602 map name_of_fact (get_deps j))) (0 upto num_facts - 1) |
599 map name_of_fact (get_deps j))) (0 upto num_facts - 1) |
603 val parents' = parents_of num_visible_facts |
600 val parents' = parents_of num_facts |
604 val feats' = map (apfst name_of_feature) feats |
601 val feats' = map (apfst name_of_feature) feats |
605 in |
602 in |
606 MaSh_Py.unlearn ctxt overlord; |
603 MaSh_Py.unlearn ctxt overlord; |
607 OS.Process.sleep (seconds 2.0); (* hack *) |
604 OS.Process.sleep (seconds 2.0); (* hack *) |
608 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
605 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
653 |
650 |
654 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
651 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
655 |
652 |
656 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = |
653 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = |
657 let |
654 let |
658 val visible_fact_set = Symtab.make_set visible_facts |
655 val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)]) |
659 val learns = |
|
660 (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ |
|
661 (if null hints then [] else [(".hints", feats, hints)]) |
|
662 in |
656 in |
663 if engine = MaSh_SML_kNN_Cpp then |
657 if engine = MaSh_SML_kNN_Cpp then |
664 k_nearest_neighbors_cpp max_suggs learns (map fst feats) |
658 k_nearest_neighbors_cpp max_suggs learns (map fst feats) |
665 else if engine = MaSh_SML_NB_Cpp then |
659 else if engine = MaSh_SML_NB_Cpp then |
666 naive_bayes_cpp max_suggs learns (map fst feats) |
660 naive_bayes_cpp max_suggs learns (map fst feats) |
667 else |
661 else |
668 let |
662 let |
669 val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = |
663 val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) = |
670 fold (fn (fact, feats, deps) => |
664 fold (fn (fact, feats, deps) => |
671 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
665 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
672 let |
666 let |
673 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
667 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
674 (case Symtab.lookup tab feat of |
668 (case Symtab.lookup tab feat of |
685 val facts = rev rev_facts |
679 val facts = rev rev_facts |
686 val fact_vec = Vector.fromList facts |
680 val fact_vec = Vector.fromList facts |
687 |
681 |
688 val deps_vec = Vector.fromList (rev rev_depss) |
682 val deps_vec = Vector.fromList (rev rev_depss) |
689 |
683 |
690 val num_visible_facts = length visible_facts |
|
691 val get_deps = curry Vector.sub deps_vec |
684 val get_deps = curry Vector.sub deps_vec |
|
685 |
|
686 val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts |
692 in |
687 in |
693 trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ |
688 trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ |
694 elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); |
689 elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}"); |
695 (if engine = MaSh_SML_kNN then |
690 (if engine = MaSh_SML_kNN then |
696 let |
691 let |
697 val facts_ary = Array.array (num_feats, []) |
692 val facts_ary = Array.array (num_feats, []) |
698 val _ = |
693 val _ = |
699 fold (fn feats => fn fact => |
694 fold (fn feats => fn fact => |
702 map_array_at facts_ary (cons (fact', weight)) feat) feats; |
697 map_array_at facts_ary (cons (fact', weight)) feat) feats; |
703 fact' |
698 fact' |
704 end) |
699 end) |
705 rev_featss num_facts |
700 rev_featss num_facts |
706 val get_facts = curry Array.sub facts_ary |
701 val get_facts = curry Array.sub facts_ary |
707 val feats' = map_filter (fn (feat, weight) => |
702 val int_feats = map_filter (fn (feat, weight) => |
708 Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats |
703 Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats |
709 in |
704 in |
710 k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' |
705 k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats |
711 end |
706 end |
712 else |
707 else |
713 let |
708 let |
714 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
709 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
715 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
710 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
716 val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
711 val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
717 in |
712 in |
718 (case engine of |
713 (case engine of |
719 MaSh_SML_NB opts => |
714 MaSh_SML_NB opts => |
720 naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats |
715 naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs |
721 max_suggs int_feats |
716 int_feats |
722 | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps |
717 | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps |
723 get_unweighted_feats num_feats max_suggs int_feats) |
718 get_unweighted_feats num_feats max_suggs int_feats) |
724 end) |
719 end) |
725 |> map (curry Vector.sub fact_vec o fst) |
720 |> map (curry Vector.sub fact_vec o fst) |
726 end |
721 end |
727 end |
722 end |