627 |
627 |
628 val k_nearest_neighbors_cpp = |
628 val k_nearest_neighbors_cpp = |
629 c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors) |
629 c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors) |
630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes" |
630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes" |
631 |
631 |
|
632 val empty_xtab = (0, Symtab.empty) |
|
633 |
632 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab) |
634 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab) |
|
635 fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key)) |
633 |
636 |
634 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
637 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
635 |
638 |
636 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = |
639 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = |
637 let |
640 let |
641 k_nearest_neighbors_cpp max_suggs learns feats |
644 k_nearest_neighbors_cpp max_suggs learns feats |
642 else if engine = MaSh_SML_NB_Cpp then |
645 else if engine = MaSh_SML_NB_Cpp then |
643 naive_bayes_cpp max_suggs learns feats |
646 naive_bayes_cpp max_suggs learns feats |
644 else |
647 else |
645 let |
648 let |
646 val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) = |
649 val facts = map #1 learns |
647 fold (fn (fact, feats, deps) => |
|
648 fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) => |
|
649 let |
|
650 fun add_feat feat (xtab as (n, tab)) = |
|
651 (case Symtab.lookup tab feat of |
|
652 SOME i => (i, xtab) |
|
653 | NONE => (n, add_to_xtab feat xtab)) |
|
654 |
|
655 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab |
|
656 in |
|
657 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, |
|
658 (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab') |
|
659 end) |
|
660 learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty)) |
|
661 |
|
662 val facts = rev rev_facts |
|
663 val fact_vec = Vector.fromList facts |
650 val fact_vec = Vector.fromList facts |
|
651 |
|
652 val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab |
|
653 val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab |
|
654 |
|
655 val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns |
|
656 |
|
657 val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) |
|
658 |
664 val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts |
659 val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts |
665 |
|
666 val deps_vec = Vector.fromList (rev rev_depss) |
|
667 |
660 |
668 val get_deps = curry Vector.sub deps_vec |
661 val get_deps = curry Vector.sub deps_vec |
669 |
662 |
670 val int_feats = map_filter (Symtab.lookup feat_tab) feats |
663 val int_feats = map_filter (Symtab.lookup feat_tab) feats |
671 in |
664 in |
674 (if engine = MaSh_SML_kNN then |
667 (if engine = MaSh_SML_kNN then |
675 let |
668 let |
676 val facts_ary = Array.array (num_feats, []) |
669 val facts_ary = Array.array (num_feats, []) |
677 val _ = |
670 val _ = |
678 fold (fn feats => fn fact => |
671 fold (fn feats => fn fact => |
679 let val fact' = fact - 1 in |
672 (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1)) |
680 List.app (map_array_at facts_ary (cons fact')) feats; fact' |
673 featss 0 |
681 end) |
|
682 rev_featss num_facts |
|
683 val get_facts = curry Array.sub facts_ary |
674 val get_facts = curry Array.sub facts_ary |
684 in |
675 in |
685 k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats |
676 k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats |
686 int_feats |
677 int_feats |
687 end |
678 end |
688 else |
679 else |
689 let |
680 let |
690 val unweighted_feats_ary = Vector.fromList (rev rev_featss) |
681 val unweighted_feats_ary = Vector.fromList featss |
691 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
682 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
692 in |
683 in |
693 (case engine of |
684 (case engine of |
694 MaSh_SML_NB => |
685 MaSh_SML_NB => |
695 naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs |
686 naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs |