src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57095 001ec97c3e59
parent 57089 353652f47974
child 57096 e4074b91b2a6
equal deleted inserted replaced
57094:589ec121ce1a 57095:001ec97c3e59
   119 fun mash_engine () =
   119 fun mash_engine () =
   120   let val flag1 = Options.default_string @{system_option MaSh} in
   120   let val flag1 = Options.default_string @{system_option MaSh} in
   121     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
   121     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
   122       "yes" => SOME MaSh_Py
   122       "yes" => SOME MaSh_Py
   123     | "py" => SOME MaSh_Py
   123     | "py" => SOME MaSh_Py
   124     | "sml" => SOME MaSh_SML_kNN
   124     | "sml" => SOME MaSh_SML_NB
   125     | "sml_knn" => SOME MaSh_SML_kNN
   125     | "sml_knn" => SOME MaSh_SML_kNN
   126     | "sml_nb" => SOME MaSh_SML_NB
   126     | "sml_nb" => SOME MaSh_SML_NB
   127     | _ => NONE)
   127     | _ => NONE)
   128   end
   128   end
   129 
   129 
   420     for 0;
   420     for 0;
   421     heap (Real.compare o pairself snd) advno recommends;
   421     heap (Real.compare o pairself snd) advno recommends;
   422     ret [] (Integer.max 0 (adv_max - advno))
   422     ret [] (Integer.max 0 (adv_max - advno))
   423   end
   423   end
   424 
   424 
   425 (* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in
   425 val tau = 0.02
   426    usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
   426 val posWeight = 2.0
   427    prior. *)
   427 val defVal = ~15.0
   428 fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
   428 val defPriWei = 20
       
   429 
       
   430 fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
   429   let
   431   let
   430     val afreq = Unsynchronized.ref 0
   432     val afreq = Unsynchronized.ref 0;
   431     val tfreq = Array.array (avail_num, 0)
   433     val tfreq = Array.array (avail_num, 0);
   432     val sfreq = Array.array (avail_num, Inttab.empty)
   434     val sfreq = Array.array (avail_num, Inttab.empty);
   433 
   435     val dffreq = Array.array (sym_num, 0);
   434     fun nb_learn syms ts =
   436 
       
   437     fun learn th syms deps =
   435       let
   438       let
   436         fun add_sym hpis sym =
   439         fun add_th t =
   437           let
   440           let
   438             val im = Array.sub (sfreq, hpis)
   441             val im = Array.sub (sfreq, t);
   439             val v = the_default 0 (Inttab.lookup im sym)
   442             fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf;
   440           in
   443           in
   441             Array.update (sfreq, hpis, Inttab.update (sym, v + 1) im)
   444             Array.update (tfreq, t, 1 + Array.sub (tfreq, t));
   442           end
   445             Array.update (sfreq, t, fold fold_fn syms im)
   443 
   446           end;
   444         fun add_th t =
   447         fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s));
   445           (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
       
   446       in
   448       in
   447         afreq := !afreq + 1; List.app add_th ts
   449         List.app add_th (replicate defPriWei th);
   448       end
   450         List.app add_th deps;
   449 
   451         List.app add_sym syms;
   450     fun nb_eval syms =
   452         afreq := !afreq + 1
       
   453       end;
       
   454 
       
   455     fun tfidf _ = 1.0;
       
   456     (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*)
       
   457 
       
   458     fun eval syms =
   451       let
   459       let
   452         fun log_posterior i =
   460         fun log_posterior i =
   453           let
   461           let
   454             val symh = fold (Inttab.update o rpair ()) syms Inttab.empty
   462             val tfreq = Real.fromInt (Array.sub (tfreq, i));
   455             val n = Real.fromInt (Array.sub (tfreq, i))
   463             fun fold_syms (f, fw) (res, sfh) =
   456             val sfreqh = Array.sub (sfreq, i)
   464               (case Inttab.lookup sfh f of
   457             val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
   465                 SOME sf =>
   458             val mp = ess * p
   466                 (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq),
   459             val logmp = Math.ln mp
   467                  Inttab.delete f sfh)
   460             val lognmp = Math.ln (n + mp)
   468               | NONE => (res + fw * defVal, sfh));
   461 
   469             val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i));
   462             fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
   470             fun fold_sfh (f, sf) sow =
   463               let val sfreqv = Real.fromInt sfreqv in
   471               sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq));
   464                 if Inttab.defined sfsymh s then
   472             val sumOfWei = Inttab.fold fold_sfh sfh 0.0;
   465                   (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
       
   466                 else
       
   467                   (sofar + Math.ln (n - sfreqv + mp), sfsymh)
       
   468               end
       
   469 
       
   470             val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
       
   471             val len_mem = length (Inttab.keys symh)
       
   472             val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
       
   473           in
   473           in
   474             postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
   474             res + tau * sumOfWei
   475               Real.fromInt sym_num * Math.ln (n + ess)
       
   476           end
   475           end
   477 
   476         val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)));
   478         val posterior = Array.tabulate (adv_max, swap o `log_posterior)
       
   479 
       
   480         fun ret acc at =
   477         fun ret acc at =
   481           if at = Array.length posterior then acc
   478           if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1)
   482           else ret (Array.sub (posterior, at) :: acc) (at + 1)
       
   483       in
   479       in
   484         heap (Real.compare o pairself snd) advno posterior;
   480         heap (Real.compare o pairself snd) advno posterior;
   485         ret [] (Integer.max 0 (adv_max - advno))
   481         ret [] (Integer.max 0 (adv_max - advno))
   486       end
   482       end;
   487 
   483 
   488     fun for i =
   484     fun for i =
   489       if i = avail_num then () else (nb_learn (get_th_syms i) (i :: get_deps i); for (i + 1))
   485       if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
   490   in
   486   in
   491     for 0; nb_eval syms
   487     for 0; eval syms
   492   end
   488   end
   493 
   489 
   494 val knns = 40 (* FUDGE *)
   490 val knns = 40 (* FUDGE *)
   495 val ess = 0.00001 (* FUDGE *)
       
   496 val prior = 0.001 (* FUDGE *)
       
   497 
   491 
   498 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   492 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   499 
   493 
   500 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   494 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   501 
   495 
   530 
   524 
   531     val deps_vec = Vector.fromList (rev rev_depss)
   525     val deps_vec = Vector.fromList (rev rev_depss)
   532 
   526 
   533     val num_visible_facts = length visible_facts
   527     val num_visible_facts = length visible_facts
   534     val get_deps = curry Vector.sub deps_vec
   528     val get_deps = curry Vector.sub deps_vec
       
   529     val syms = map_filter (fn (feat, weight) =>
       
   530       Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   535   in
   531   in
   536     trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
   532     trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
   537       " query " ^ encode_features feats ^ " from {" ^
   533       " query " ^ encode_features feats ^ " from {" ^
   538       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   534       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   539     (if engine = MaSh_SML_kNN then
   535     (if engine = MaSh_SML_kNN then
   546                    feats;
   542                    feats;
   547                  fact'
   543                  fact'
   548                end)
   544                end)
   549              rev_featss num_facts
   545              rev_featss num_facts
   550          val get_facts = curry Array.sub facts_ary
   546          val get_facts = curry Array.sub facts_ary
   551          val syms = map_filter (fn (feat, weight) =>
       
   552            Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
       
   553        in
   547        in
   554          k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
   548          k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
   555        end
   549        end
   556      else
   550      else
   557        let
   551        let
   558          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   552          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   559          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   553          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   560          val unweighted_syms = map_filter (Symtab.lookup feat_tab o fst) feats
       
   561        in
   554        in
   562          naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats ess prior
   555          naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
   563            max_suggs unweighted_syms
   556            syms
   564        end)
   557        end)
   565     |> map (curry Vector.sub fact_vec o fst)
   558     |> map (curry Vector.sub fact_vec o fst)
   566   end
   559   end
   567 
   560 
   568 end;
   561 end;
  1235         val chained_feats = chained
  1228         val chained_feats = chained
  1236           |> map (rpair 1.0)
  1229           |> map (rpair 1.0)
  1237           |> map (chained_or_extra_features_of chained_feature_factor)
  1230           |> map (chained_or_extra_features_of chained_feature_factor)
  1238           |> rpair [] |-> fold (union (eq_fst (op =)))
  1231           |> rpair [] |-> fold (union (eq_fst (op =)))
  1239         val extra_feats =
  1232         val extra_feats =
  1240           (* As long as SML NB does not support weights, it makes little sense to include these
  1233           facts
  1241              extra features. *)
  1234           |> take (Int.max (0, num_extra_feature_facts - length chained))
  1242           if engine = MaSh_SML_NB then
  1235           |> filter fact_has_right_theory
  1243             []
  1236           |> weight_facts_steeply
  1244           else
  1237           |> map (chained_or_extra_features_of extra_feature_factor)
  1245             facts
  1238           |> rpair [] |-> fold (union (eq_fst (op =)))
  1246             |> take (Int.max (0, num_extra_feature_facts - length chained))
       
  1247             |> filter fact_has_right_theory
       
  1248             |> weight_facts_steeply
       
  1249             |> map (chained_or_extra_features_of extra_feature_factor)
       
  1250             |> rpair [] |-> fold (union (eq_fst (op =)))
       
  1251         val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
  1239         val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
  1252           |> debug ? sort (Real.compare o swap o pairself snd)
  1240           |> debug ? sort (Real.compare o swap o pairself snd)
  1253       in
  1241       in
  1254         (parents, hints, feats)
  1242         (parents, hints, feats)
  1255       end
  1243       end