src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57366 d01d1befe4a3
parent 57365 d2090a01e920
child 57367 e64c1b174f4b
equal deleted inserted replaced
57365:d2090a01e920 57366:d01d1befe4a3
   510         List.app add_sym feats
   510         List.app add_sym feats
   511       end
   511       end
   512 
   512 
   513     fun for i =
   513     fun for i =
   514       if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
   514       if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
   515 
       
   516     val ln_afreq = Math.ln (Real.fromInt num_facts)
       
   517   in
   515   in
   518     for 0;
   516     for 0;
   519     (Array.vector tfreq, Array.vector sfreq,
   517     (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
   520      Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
       
   521   end
   518   end
   522 
   519 
   523 fun learn num_facts get_deps get_feats num_feats =
   520 fun learn num_facts get_deps get_feats num_feats =
   524   let
   521   let
   525     val tfreq = Array.array (num_facts, 0)
   522     val tfreq = Array.array (num_facts, 0)
   527     val dffreq = Array.array (num_feats, 0)
   524     val dffreq = Array.array (num_feats, 0)
   528   in
   525   in
   529     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   526     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   530   end
   527   end
   531 
   528 
   532 fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, idf) =
   529 fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
   533   let
   530   let
   534     val tau = 0.05 (* FUDGE *)
   531     val tau = 0.05 (* FUDGE *)
   535     val pos_weight = 10.0 (* FUDGE *)
   532     val pos_weight = 10.0 (* FUDGE *)
   536     val def_val = ~15.0 (* FUDGE *)
   533     val def_val = ~15.0 (* FUDGE *)
       
   534 
       
   535     val ln_afreq = Math.ln (Real.fromInt num_facts)
       
   536     val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
   537 
   537 
   538     fun tfidf feat = Vector.sub (idf, feat)
   538     fun tfidf feat = Vector.sub (idf, feat)
   539 
   539 
   540     fun log_posterior i =
   540     fun log_posterior i =
   541       let
   541       let
   627 
   627 
   628 val k_nearest_neighbors_cpp =
   628 val k_nearest_neighbors_cpp =
   629   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
   629   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
   630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   630 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   631 
   631 
   632 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   632 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
   633 
   633 
   634 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   634 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   635 
   635 
   636 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   636 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   637   let
   637   let
   641       k_nearest_neighbors_cpp max_suggs learns feats
   641       k_nearest_neighbors_cpp max_suggs learns feats
   642     else if engine = MaSh_SML_NB_Cpp then
   642     else if engine = MaSh_SML_NB_Cpp then
   643       naive_bayes_cpp max_suggs learns feats
   643       naive_bayes_cpp max_suggs learns feats
   644     else
   644     else
   645       let
   645       let
   646         val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
   646         val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) =
   647           fold (fn (fact, feats, deps) =>
   647           fold (fn (fact, feats, deps) =>
   648                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   648                 fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) =>
   649               let
   649               let
   650                 fun add_feat feat (xtab as (n, tab, _)) =
   650                 fun add_feat feat (xtab as (n, tab)) =
   651                   (case Symtab.lookup tab feat of
   651                   (case Symtab.lookup tab feat of
   652                     SOME i => (i, xtab)
   652                     SOME i => (i, xtab)
   653                   | NONE => (n, add_to_xtab feat xtab))
   653                   | NONE => (n, add_to_xtab feat xtab))
   654 
   654 
   655                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   655                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   656               in
   656               in
   657                 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   657                 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   658                  add_to_xtab fact fact_xtab, feat_xtab')
   658                  (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab')
   659               end)
   659               end)
   660             learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   660             learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty))
   661 
   661 
   662         val facts = rev rev_facts
   662         val facts = rev rev_facts
   663         val fact_vec = Vector.fromList facts
   663         val fact_vec = Vector.fromList facts
   664         val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   664         val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   665 
   665