# HG changeset patch # User blanchet # Date 1403782585 -7200 # Node ID f40ac83d076c54e2fe1a13fd7aafe2b8429cda7f # Parent b75438e23925ed27d2364d7a2db51cfbaa81c4b7 refactoring diff -r b75438e23925 -r f40ac83d076c src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:22 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:25 2014 +0200 @@ -398,7 +398,7 @@ exception EXIT of unit -fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts conj_feats = +fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats = let val ln_afreq = Math.ln (Real.fromInt num_facts) fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) @@ -419,7 +419,7 @@ List.app do_th (Array.sub (feat_facts, s)) end - val _ = List.app do_feat conj_feats + val _ = List.app do_feat goal_feats val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr val no_recommends = Unsynchronized.ref 0 val recommends = Array.tabulate (num_facts, rpair 0.0) @@ -512,7 +512,7 @@ (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq) end -fun naive_bayes tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats = +fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs visible_facts goal_feats = let val tau = 0.05 (* FUDGE *) val pos_weight = 10.0 (* FUDGE *) @@ -534,7 +534,7 @@ Inttab.delete f sfh) | NONE => (res + tfidf f * def_val, sfh)) - val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Vector.sub (sfreq, i)) + val (res, sfh) = fold fold_feats goal_feats (Math.ln tfreq, Vector.sub (sfreq, i)) fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq) @@ -554,7 +554,7 @@ end (* experimental *) -fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs conj_feats = +fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs goal_feats = let fun name_of_fact j = "f" ^ string_of_int j fun fact_of_name s = the (Int.fromString (unprefix "f" s)) @@ -565,16 +565,16 @@ map name_of_feature (Vector.sub (featss, j)), map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1) val parents' = parents_of num_facts - val conj_feats' = map (rpair 1.0 o name_of_feature) conj_feats + val goal_feats' = map (rpair 1.0 o name_of_feature) goal_feats in MaSh_Py.unlearn ctxt overlord; OS.Process.sleep (seconds 2.0); (* hack *) - MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', conj_feats') + MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', goal_feats') |> map (apfst fact_of_name) end (* experimental *) -fun experimental_external_tool tool max_suggs learns cfeats = +fun external_tool tool max_suggs learns goal_feats = let val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *) val ocs = TextIO.openOut ("adv_syms" ^ ser) @@ -603,67 +603,38 @@ |> filter_out (curry (op =) "") end in - (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats; + (List.app do_learn learns; ol occ (os occ o quote) ", " goal_feats; TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; forkexec max_suggs) end val k_nearest_neighbors_ext = - experimental_external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors) -val naive_bayes_ext = experimental_external_tool "predict/nbayes" - -fun reorder_learns (num_facts, fact_tab) learns0 = - let - val learns = Array.array (num_facts, ("", [], [])) - in - List.app (fn learn as (fact, _, _) => - Array.update (learns, the (Symtab.lookup fact_tab fact), learn)) - learns0; - Array.foldr (op ::) [] learns - end + external_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors) +val naive_bayes_ext = external_tool "predict/nbayes" -fun query ctxt engine (fact_xtab as (num_facts, fact_tab)) (num_feats, feat_tab) visible_facts - max_suggs learns0 conj_feats = - if engine = MaSh_SML_kNN_Ext then - k_nearest_neighbors_ext max_suggs learns0 conj_feats - else if engine = MaSh_SML_NB_Ext then - naive_bayes_ext max_suggs learns0 conj_feats - else - let - val learns = reorder_learns fact_xtab learns0 - - val facts = Vector.fromList (map #1 learns) - val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns) - val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) +fun query_external ctxt engine max_suggs learns goal_feats = + (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_strs goal_feats); + (case engine of + MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats + | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats)) - val tfreq = Vector.tabulate (num_facts, K 0) - val sfreq = Vector.tabulate (num_facts, K Inttab.empty) - val dffreq = Vector.tabulate (num_feats, K 0) - - val (tfreq, sfreq, dffreq) = - learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss - - val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts - val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats - in - trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^ - elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}"); - (case engine of - MaSh_SML_kNN => - let - val feat_facts = Array.array (num_feats, []) - val _ = - Vector.foldl (fn (feats, fact) => - (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) - 0 featss - in - k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs int_visible_facts - int_conj_feats - end - | MaSh_SML_NB => - naive_bayes tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats) - |> map (curry Vector.sub facts o fst) - end +fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq)) + visible_facts max_suggs goal_feats int_goal_feats = + (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^ + elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}"); + (case engine of + MaSh_SML_kNN => + let + val feat_facts = Array.array (num_feats, []) + val _ = + Vector.foldl (fn (feats, fact) => + (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) + 0 featss + in + k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats + end + | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats) + |> map (curry Vector.sub facts o fst)) end; @@ -1312,7 +1283,17 @@ fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t []) -fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts = +fun reorder_learns (num_facts, fact_tab) learns0 = + let + val learns = Array.array (num_facts, ("", [], [])) + in + List.app (fn learn as (fact, _, _) => + Array.update (learns, the (Symtab.lookup fact_tab fact), learn)) + learns0; + Array.foldr (op ::) [] learns + end + +fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt val thy_name = Context.theory_name thy @@ -1360,14 +1341,14 @@ (parents, hints, feats) end - val ((access_G, (fact_xtab, feat_xtab)), py_suggs) = + val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) = peek_state ctxt overlord (fn {access_G, xtabs, ...} => ((access_G, xtabs), if Graph.is_empty access_G then (trace_msg ctxt (K "Nothing has been learned yet"); []) else if engine = MaSh_Py then let val (parents, hints, feats) = query_args access_G in - MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats) + MaSh_Py.query ctxt overlord max_suggs ([], hints, parents, feats) |> map fst end else @@ -1378,19 +1359,46 @@ [] else let - val (parents, hints, feats0) = query_args access_G - val feats = map fst feats0 - val visible_facts = Graph.all_preds access_G parents - val learns = - (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *) - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + val (parents, hints, goal_feats0) = query_args access_G + val goal_feats = map fst goal_feats0 + val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) in - MaSh_SML.query ctxt engine fact_xtab feat_xtab visible_facts max_facts learns feats + if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then + let + val learns = + (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *) + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + in + MaSh_SML.query_external ctxt engine max_suggs learns goal_feats + end + else + let + val learns0 = + (if null hints then [] else [(hintsN, goal_feats, hints)]) @ (* ### FIXME *) + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + val learns = reorder_learns fact_xtab learns0 + + val facts = Vector.fromList (map #1 learns) + val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns) + val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) + + val tfreq = Vector.tabulate (num_facts, K 0) + val sfreq = Vector.tabulate (num_facts, K Inttab.empty) + val dffreq = Vector.tabulate (num_feats, K 0) + + val freqs' = + MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss + + val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats + in + MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs' + visible_facts max_suggs goal_feats int_goal_feats + end end val unknown = filter_out (is_fact_in_graph access_G o snd) facts in - find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown + find_mash_suggestions ctxt max_suggs (py_suggs @ sml_suggs) facts chained unknown |> pairself (map fact_of_raw_fact) end