585 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
585 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
586 |> map (apfst fact_of_name) |
586 |> map (apfst fact_of_name) |
587 end |
587 end |
588 |
588 |
589 (* experimental *) |
589 (* experimental *) |
590 fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms = |
590 fun k_nearest_neighbors_cpp max_suggs learns cfeats = |
591 let |
591 let |
592 val ocs = TextIO.openOut "adv_syms" |
592 val ocs = TextIO.openOut "adv_syms" |
593 val ocd = TextIO.openOut "adv_deps" |
593 val ocd = TextIO.openOut "adv_deps" |
594 val ocq = TextIO.openOut "adv_seq" |
594 val ocq = TextIO.openOut "adv_seq" |
595 val occ = TextIO.openOut "adv_conj" |
595 val occ = TextIO.openOut "adv_conj" |
596 fun os oc s = TextIO.output (oc, s) |
596 fun os oc s = TextIO.output (oc, s) |
597 fun oi oc i = os oc (Int.toString i) |
597 fun ol _ _ _ [] = () |
598 fun ol _ _ _ [] = () |
598 | ol _ f _ [e] = f e |
599 | ol _ f _ [e] = f e |
|
600 | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) |
599 | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) |
601 fun do_n n = |
600 fun do_learn (name, feats, deps) = |
602 (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n"; |
601 (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n"; |
603 oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n"; |
602 os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; |
604 oi ocq n; os ocq "\n") |
603 os ocq name; os ocq "\n") |
605 fun for n = if n = avail_num then () else (do_n n; for (n + 1)) |
|
606 fun forkexec no = |
604 fun forkexec no = |
607 let |
605 let |
608 val cmd = |
606 val cmd = |
609 "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^ |
607 "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^ |
610 " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj" |
608 " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj" |
611 in |
609 in |
612 fst (Isabelle_System.bash_output cmd) |
610 fst (Isabelle_System.bash_output cmd) |
613 |> space_explode " " |
611 |> space_explode " " |
614 |> map_filter (Option.map (rpair 1.0) o Int.fromString) |
612 |> filter_out (curry (op =) "") |
615 end |
613 end |
616 in |
614 in |
617 (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs; |
615 (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs; |
618 TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; |
616 TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; |
619 forkexec (advno + avail_num - adv_max)) |
617 forkexec max_suggs) |
620 end |
618 end |
621 |
619 |
622 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
620 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
623 |
621 |
624 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
622 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
625 |
623 |
626 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = |
624 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = |
627 let |
625 let |
628 val visible_fact_set = Symtab.make_set visible_facts |
626 val visible_fact_set = Symtab.make_set visible_facts |
629 |
627 val learns = |
630 val learns' = |
628 (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ |
631 (learns |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ |
|
632 (if null hints then [] else [(".goal", feats, hints)]) |
629 (if null hints then [] else [(".goal", feats, hints)]) |
633 |
630 in |
634 val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = |
631 if engine = MaSh_SML_kNN_Cpp then |
635 fold (fn (fact, feats, deps) => |
632 k_nearest_neighbors_cpp max_suggs learns (map fst feats) |
636 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
633 else |
637 let |
634 let |
638 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
635 val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = |
639 (case Symtab.lookup tab feat of |
636 fold (fn (fact, feats, deps) => |
640 SOME i => ((i, weight), xtab) |
637 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => |
641 | NONE => ((n, weight), add_to_xtab feat xtab)) |
638 let |
642 |
639 fun add_feat (feat, weight) (xtab as (n, tab, _)) = |
643 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab |
640 (case Symtab.lookup tab feat of |
644 in |
641 SOME i => ((i, weight), xtab) |
645 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, |
642 | NONE => ((n, weight), add_to_xtab feat xtab)) |
646 add_to_xtab fact fact_xtab, feat_xtab') |
643 |
647 end) |
644 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab |
648 learns' ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
645 in |
649 |
646 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, |
650 val facts = rev rev_facts |
647 add_to_xtab fact fact_xtab, feat_xtab') |
651 val fact_vec = Vector.fromList facts |
648 end) |
652 |
649 learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
653 val deps_vec = Vector.fromList (rev rev_depss) |
650 |
654 |
651 val facts = rev rev_facts |
655 val num_visible_facts = length visible_facts |
652 val fact_vec = Vector.fromList facts |
656 val get_deps = curry Vector.sub deps_vec |
653 |
657 in |
654 val deps_vec = Vector.fromList (rev rev_depss) |
658 trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^ |
655 |
659 elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); |
656 val num_visible_facts = length visible_facts |
660 (if engine = MaSh_SML_kNN then |
657 val get_deps = curry Vector.sub deps_vec |
661 let |
658 in |
662 val facts_ary = Array.array (num_feats, []) |
659 trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ |
663 val _ = |
660 elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); |
664 fold (fn feats => fn fact => |
661 (if engine = MaSh_SML_kNN then |
665 let val fact' = fact - 1 in |
662 let |
666 List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) |
663 val facts_ary = Array.array (num_feats, []) |
667 feats; |
664 val _ = |
668 fact' |
665 fold (fn feats => fn fact => |
669 end) |
666 let val fact' = fact - 1 in |
670 rev_featss num_facts |
667 List.app (fn (feat, weight) => |
671 val get_facts = curry Array.sub facts_ary |
668 map_array_at facts_ary (cons (fact', weight)) feat) feats; |
672 val feats' = map_filter (fn (feat, weight) => |
669 fact' |
673 Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats |
670 end) |
674 in |
671 rev_featss num_facts |
675 k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' |
672 val get_facts = curry Array.sub facts_ary |
676 end |
673 val feats' = map_filter (fn (feat, weight) => |
677 else |
674 Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats |
678 let |
675 in |
679 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
676 k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' |
680 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
677 end |
681 val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
678 else |
682 in |
679 let |
683 (case engine of |
680 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
684 MaSh_SML_kNN_Cpp => |
681 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
685 k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats |
682 val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
686 max_suggs (map fst feats') |
683 in |
687 | MaSh_SML_NB opts => |
684 (case engine of |
688 naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats |
685 MaSh_SML_NB opts => |
689 max_suggs feats' |
686 naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats |
690 | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps |
687 max_suggs int_feats |
691 get_unweighted_feats num_feats max_suggs feats') |
688 | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps |
692 end) |
689 get_unweighted_feats num_feats max_suggs int_feats) |
693 |> map (curry Vector.sub fact_vec o fst) |
690 end) |
|
691 |> map (curry Vector.sub fact_vec o fst) |
|
692 end |
694 end |
693 end |
695 |
694 |
696 end; |
695 end; |
697 |
696 |
698 |
697 |