# HG changeset patch # User blanchet # Date 1404226030 -7200 # Node ID 419180c354c04091b622d45cdea517112e29360f # Parent b2bafc09b7e7fa2497fdd2d0f3ad1d8145b24c68 tuned (reordered) code diff -r b2bafc09b7e7 -r 419180c354c0 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 @@ -34,10 +34,10 @@ val decode_strs : string -> string list datatype mash_engine = - MaSh_kNN + MaSh_NB + | MaSh_NB_Ext + | MaSh_kNN | MaSh_kNN_Ext - | MaSh_NB - | MaSh_NB_Ext val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine @@ -138,20 +138,20 @@ end datatype mash_engine = - MaSh_kNN + MaSh_NB +| MaSh_NB_Ext +| MaSh_kNN | MaSh_kNN_Ext -| MaSh_NB -| MaSh_NB_Ext fun mash_engine () = let val flag1 = Options.default_string @{system_option MaSh} in (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of "yes" => SOME MaSh_NB | "sml" => SOME MaSh_NB + | "nb" => SOME MaSh_NB + | "nb_ext" => SOME MaSh_NB_Ext | "knn" => SOME MaSh_kNN | "knn_ext" => SOME MaSh_kNN_Ext - | "nb" => SOME MaSh_NB - | "nb_ext" => SOME MaSh_NB_Ext | _ => NONE) end @@ -231,86 +231,12 @@ () end -val number_of_nearest_neighbors = 10 (* FUDGE *) - fun select_visible_facts big_number recommends = List.app (fn at => let val (j, ov) = Array.sub (recommends, at) in Array.update (recommends, at, (j, big_number + ov)) end) -fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats = - let - exception EXIT of unit - - val ln_afreq = Math.ln (Real.fromInt num_facts) - fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) - - val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) - - fun do_feat (s, sw0) = - let - val sw = sw0 * tfidf s - val w2 = sw * sw - - fun inc_overlap j = - let val (_, ov) = Array.sub (overlaps_sqr, j) in - Array.update (overlaps_sqr, j, (j, w2 + ov)) - end - in - List.app inc_overlap (Array.sub (feat_facts, s)) - end - - val _ = List.app do_feat goal_feats - val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr - val no_recommends = Unsynchronized.ref 0 - val recommends = Array.tabulate (num_facts, rpair 0.0) - val age = Unsynchronized.ref 500000000.0 - - fun inc_recommend j v = - let val (_, ov) = Array.sub (recommends, j) in - if ov <= 0.0 then - (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov))) - else if ov < !age + 1000.0 then - Array.update (recommends, j, (j, v + ov)) - else - () - end - - val k = Unsynchronized.ref 0 - fun do_k k = - if k >= num_facts then - raise EXIT () - else - let - val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) - val o1 = Math.sqrt o2 - val _ = inc_recommend j o1 - val ds = Vector.sub (depss, j) - val l = Real.fromInt (length ds) - in - List.app (fn d => inc_recommend d (o1 / l)) ds - end - - fun while1 () = - if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ()) - handle EXIT () => () - - fun while2 () = - if !no_recommends >= max_suggs then () - else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()) - handle EXIT () => () - - fun ret acc at = - if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) - in - while1 (); - while2 (); - select_visible_facts 1000000000.0 recommends visible_facts; - heap (Real.compare o pairself snd) max_suggs num_facts recommends; - ret [] (Integer.max 0 (num_facts - max_suggs)) - end - fun wider_array_of_vector init vec = let val ary = Array.array init in Array.copyVec {src = vec, dst = ary, di = 0}; @@ -393,6 +319,80 @@ ret (Integer.max 0 (num_facts - max_suggs)) [] end +val number_of_nearest_neighbors = 10 (* FUDGE *) + +fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats = + let + exception EXIT of unit + + val ln_afreq = Math.ln (Real.fromInt num_facts) + fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) + + val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) + + fun do_feat (s, sw0) = + let + val sw = sw0 * tfidf s + val w2 = sw * sw + + fun inc_overlap j = + let val (_, ov) = Array.sub (overlaps_sqr, j) in + Array.update (overlaps_sqr, j, (j, w2 + ov)) + end + in + List.app inc_overlap (Array.sub (feat_facts, s)) + end + + val _ = List.app do_feat goal_feats + val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr + val no_recommends = Unsynchronized.ref 0 + val recommends = Array.tabulate (num_facts, rpair 0.0) + val age = Unsynchronized.ref 500000000.0 + + fun inc_recommend v j = + let val (_, ov) = Array.sub (recommends, j) in + if ov <= 0.0 then + (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov))) + else if ov < !age + 1000.0 then + Array.update (recommends, j, (j, v + ov)) + else + () + end + + val k = Unsynchronized.ref 0 + fun do_k k = + if k >= num_facts then + raise EXIT () + else + let + val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) + val o1 = Math.sqrt o2 + val _ = inc_recommend o1 j + val ds = Vector.sub (depss, j) + val l = Real.fromInt (length ds) + in + List.app (inc_recommend (o1 / l)) ds + end + + fun while1 () = + if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ()) + handle EXIT () => () + + fun while2 () = + if !no_recommends >= max_suggs then () + else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()) + handle EXIT () => () + + fun ret acc at = + if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) + in + while1 (); + while2 (); + select_visible_facts 1000000000.0 recommends visible_facts; + heap (Real.compare o pairself snd) max_suggs num_facts recommends; + ret [] (Integer.max 0 (num_facts - max_suggs)) + end + (* experimental *) fun external_tool tool max_suggs learns goal_feats = let @@ -435,15 +435,16 @@ fun query_external ctxt engine max_suggs learns goal_feats = (trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats)); (case engine of - MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats - | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats)) + MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats + | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats)) fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss) (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats = (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); (case engine of - MaSh_kNN => + MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats + | MaSh_kNN => let val feat_facts = Array.array (num_feats, []) val _ = @@ -452,8 +453,7 @@ 0 featss in k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats - end - | MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats) + end) |> map (curry Vector.sub fact_names o fst)) end; @@ -1178,7 +1178,7 @@ val (parents, goal_feats) = query_args access_G val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) in - if engine = MaSh_kNN_Ext orelse engine = MaSh_NB_Ext then + if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then let val learns = Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G