src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57367 e64c1b174f4b
parent 57366 d01d1befe4a3
child 57368 b89937ed6099
equal deleted inserted replaced
57366:d01d1befe4a3 57367:e64c1b174f4b
   627 
   627 
   628 val k_nearest_neighbors_cpp =
   628 val k_nearest_neighbors_cpp =
   629   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
   629   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
   630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   631 
   631 
       
   632 val empty_xtab = (0, Symtab.empty)
       
   633 
   632 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
   634 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
       
   635 fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
   633 
   636 
   634 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   637 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   635 
   638 
   636 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   639 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   637   let
   640   let
   641       k_nearest_neighbors_cpp max_suggs learns feats
   644       k_nearest_neighbors_cpp max_suggs learns feats
   642     else if engine = MaSh_SML_NB_Cpp then
   645     else if engine = MaSh_SML_NB_Cpp then
   643       naive_bayes_cpp max_suggs learns feats
   646       naive_bayes_cpp max_suggs learns feats
   644     else
   647     else
   645       let
   648       let
   646         val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) =
   649         val facts = map #1 learns
   647           fold (fn (fact, feats, deps) =>
       
   648                 fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) =>
       
   649               let
       
   650                 fun add_feat feat (xtab as (n, tab)) =
       
   651                   (case Symtab.lookup tab feat of
       
   652                     SOME i => (i, xtab)
       
   653                   | NONE => (n, add_to_xtab feat xtab))
       
   654 
       
   655                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
       
   656               in
       
   657                 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
       
   658                  (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab')
       
   659               end)
       
   660             learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty))
       
   661 
       
   662         val facts = rev rev_facts
       
   663         val fact_vec = Vector.fromList facts
   650         val fact_vec = Vector.fromList facts
       
   651 
       
   652         val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
       
   653         val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
       
   654 
       
   655         val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
       
   656 
       
   657         val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
       
   658 
   664         val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   659         val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   665 
       
   666         val deps_vec = Vector.fromList (rev rev_depss)
       
   667 
   660 
   668         val get_deps = curry Vector.sub deps_vec
   661         val get_deps = curry Vector.sub deps_vec
   669 
   662 
   670         val int_feats = map_filter (Symtab.lookup feat_tab) feats
   663         val int_feats = map_filter (Symtab.lookup feat_tab) feats
   671       in
   664       in
   674         (if engine = MaSh_SML_kNN then
   667         (if engine = MaSh_SML_kNN then
   675            let
   668            let
   676              val facts_ary = Array.array (num_feats, [])
   669              val facts_ary = Array.array (num_feats, [])
   677              val _ =
   670              val _ =
   678                fold (fn feats => fn fact =>
   671                fold (fn feats => fn fact =>
   679                    let val fact' = fact - 1 in
   672                    (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
   680                      List.app (map_array_at facts_ary (cons fact')) feats; fact'
   673                  featss 0
   681                    end)
       
   682                  rev_featss num_facts
       
   683              val get_facts = curry Array.sub facts_ary
   674              val get_facts = curry Array.sub facts_ary
   684            in
   675            in
   685              k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
   676              k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
   686                int_feats
   677                int_feats
   687            end
   678            end
   688          else
   679          else
   689            let
   680            let
   690              val unweighted_feats_ary = Vector.fromList (rev rev_featss)
   681              val unweighted_feats_ary = Vector.fromList featss
   691              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   682              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   692            in
   683            in
   693              (case engine of
   684              (case engine of
   694                MaSh_SML_NB =>
   685                MaSh_SML_NB =>
   695                naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
   686                naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs