36 val extract_suggestions : string -> string * (string * real) list |
36 val extract_suggestions : string -> string * (string * real) list |
37 |
37 |
38 datatype mash_engine = |
38 datatype mash_engine = |
39 MaSh_Py |
39 MaSh_Py |
40 | MaSh_SML_kNN |
40 | MaSh_SML_kNN |
|
41 | MaSh_SML_kNN_Cpp |
41 | MaSh_SML_NB of bool * bool |
42 | MaSh_SML_NB of bool * bool |
42 | MaSh_SML_NB_Py |
43 | MaSh_SML_NB_Py |
43 |
44 |
44 val is_mash_enabled : unit -> bool |
45 val is_mash_enabled : unit -> bool |
45 val the_mash_engine : unit -> mash_engine |
46 val the_mash_engine : unit -> mash_engine |
167 (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of |
169 (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of |
168 "yes" => SOME default_MaSh_SML_NB |
170 "yes" => SOME default_MaSh_SML_NB |
169 | "py" => SOME MaSh_Py |
171 | "py" => SOME MaSh_Py |
170 | "sml" => SOME default_MaSh_SML_NB |
172 | "sml" => SOME default_MaSh_SML_NB |
171 | "sml_knn" => SOME MaSh_SML_kNN |
173 | "sml_knn" => SOME MaSh_SML_kNN |
|
174 | "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp |
172 | "sml_nb" => SOME default_MaSh_SML_NB |
175 | "sml_nb" => SOME default_MaSh_SML_NB |
173 | "sml_nbCC" => SOME (MaSh_SML_NB (false, false)) |
176 | "sml_nbCC" => SOME (MaSh_SML_NB (false, false)) |
174 | "sml_nbCD" => SOME (MaSh_SML_NB (false, true)) |
177 | "sml_nbCD" => SOME (MaSh_SML_NB (false, true)) |
175 | "sml_nbDC" => SOME (MaSh_SML_NB (true, false)) |
178 | "sml_nbDC" => SOME (MaSh_SML_NB (true, false)) |
176 | "sml_nbDD" => SOME (MaSh_SML_NB (true, true)) |
179 | "sml_nbDD" => SOME (MaSh_SML_NB (true, true)) |
581 OS.Process.sleep (seconds 2.0); (* hack *) |
584 OS.Process.sleep (seconds 2.0); (* hack *) |
582 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
585 MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') |
583 |> map (apfst fact_of_name) |
586 |> map (apfst fact_of_name) |
584 end |
587 end |
585 |
588 |
|
589 (* experimental *) |
|
590 fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms = |
|
591 let |
|
592 val ocs = TextIO.openOut "adv_syms" |
|
593 val ocd = TextIO.openOut "adv_deps" |
|
594 val ocq = TextIO.openOut "adv_seq" |
|
595 val occ = TextIO.openOut "adv_conj" |
|
596 fun os oc s = TextIO.output (oc, s) |
|
597 fun oi oc i = os oc (Int.toString i) |
|
598 fun ol _ _ _ [] = () |
|
599 | ol _ f _ [e] = f e |
|
600 | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) |
|
601 fun do_n n = |
|
602 (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n"; |
|
603 oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n"; |
|
604 oi ocq n; os ocq "\n") |
|
605 fun for n = if n = avail_num then () else (do_n n; for (n + 1)) |
|
606 fun forkexec no = |
|
607 let |
|
608 val cmd = |
|
609 "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^ |
|
610 " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj" |
|
611 in |
|
612 fst (Isabelle_System.bash_output cmd) |
|
613 |> space_explode " " |
|
614 |> map_filter (Option.map (rpair 1.0) o Int.fromString) |
|
615 end |
|
616 in |
|
617 (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs; |
|
618 TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; |
|
619 forkexec (advno + avail_num - adv_max)) |
|
620 end |
|
621 |
586 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
622 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
587 |
623 |
588 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
624 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
589 |
625 |
590 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = |
626 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = |
617 val deps_vec = Vector.fromList (rev rev_depss) |
653 val deps_vec = Vector.fromList (rev rev_depss) |
618 |
654 |
619 val num_visible_facts = length visible_facts |
655 val num_visible_facts = length visible_facts |
620 val get_deps = curry Vector.sub deps_vec |
656 val get_deps = curry Vector.sub deps_vec |
621 in |
657 in |
622 trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^ |
658 trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^ |
623 " query " ^ encode_features feats ^ " from {" ^ |
|
624 elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); |
659 elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); |
625 (if engine = MaSh_SML_kNN then |
660 (if engine = MaSh_SML_kNN then |
626 let |
661 let |
627 val facts_ary = Array.array (num_feats, []) |
662 val facts_ary = Array.array (num_feats, []) |
628 val _ = |
663 val _ = |
644 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
679 val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) |
645 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
680 val get_unweighted_feats = curry Vector.sub unweighted_feats_ary |
646 val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
681 val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats |
647 in |
682 in |
648 (case engine of |
683 (case engine of |
649 MaSh_SML_NB opts => naive_bayes opts |
684 MaSh_SML_kNN_Cpp => |
650 | _ => naive_bayes_py ctxt overlord) |
685 k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats |
651 num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats' |
686 max_suggs (map fst feats') |
|
687 | MaSh_SML_NB opts => |
|
688 naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats |
|
689 max_suggs feats' |
|
690 | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps |
|
691 get_unweighted_feats num_feats max_suggs feats') |
652 end) |
692 end) |
653 |> map (curry Vector.sub fact_vec o fst) |
693 |> map (curry Vector.sub fact_vec o fst) |
654 end |
694 end |
655 |
695 |
656 end; |
696 end; |