420 for 0; |
420 for 0; |
421 heap (Real.compare o pairself snd) advno recommends; |
421 heap (Real.compare o pairself snd) advno recommends; |
422 ret [] (Integer.max 0 (adv_max - advno)) |
422 ret [] (Integer.max 0 (adv_max - advno)) |
423 end |
423 end |
424 |
424 |
425 (* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in |
425 val tau = 0.02 |
426 usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the |
426 val posWeight = 2.0 |
427 prior. *) |
427 val defVal = ~15.0 |
428 fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms = |
428 val defPriWei = 20 |
|
429 |
|
430 fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms = |
429 let |
431 let |
430 val afreq = Unsynchronized.ref 0 |
432 val afreq = Unsynchronized.ref 0; |
431 val tfreq = Array.array (avail_num, 0) |
433 val tfreq = Array.array (avail_num, 0); |
432 val sfreq = Array.array (avail_num, Inttab.empty) |
434 val sfreq = Array.array (avail_num, Inttab.empty); |
433 |
435 val dffreq = Array.array (sym_num, 0); |
434 fun nb_learn syms ts = |
436 |
|
437 fun learn th syms deps = |
435 let |
438 let |
436 fun add_sym hpis sym = |
439 fun add_th t = |
437 let |
440 let |
438 val im = Array.sub (sfreq, hpis) |
441 val im = Array.sub (sfreq, t); |
439 val v = the_default 0 (Inttab.lookup im sym) |
442 fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf; |
440 in |
443 in |
441 Array.update (sfreq, hpis, Inttab.update (sym, v + 1) im) |
444 Array.update (tfreq, t, 1 + Array.sub (tfreq, t)); |
442 end |
445 Array.update (sfreq, t, fold fold_fn syms im) |
443 |
446 end; |
444 fun add_th t = |
447 fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)); |
445 (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms) |
|
446 in |
448 in |
447 afreq := !afreq + 1; List.app add_th ts |
449 List.app add_th (replicate defPriWei th); |
448 end |
450 List.app add_th deps; |
449 |
451 List.app add_sym syms; |
450 fun nb_eval syms = |
452 afreq := !afreq + 1 |
|
453 end; |
|
454 |
|
455 fun tfidf _ = 1.0; |
|
456 (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*) |
|
457 |
|
458 fun eval syms = |
451 let |
459 let |
452 fun log_posterior i = |
460 fun log_posterior i = |
453 let |
461 let |
454 val symh = fold (Inttab.update o rpair ()) syms Inttab.empty |
462 val tfreq = Real.fromInt (Array.sub (tfreq, i)); |
455 val n = Real.fromInt (Array.sub (tfreq, i)) |
463 fun fold_syms (f, fw) (res, sfh) = |
456 val sfreqh = Array.sub (sfreq, i) |
464 (case Inttab.lookup sfh f of |
457 val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq) |
465 SOME sf => |
458 val mp = ess * p |
466 (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq), |
459 val logmp = Math.ln mp |
467 Inttab.delete f sfh) |
460 val lognmp = Math.ln (n + mp) |
468 | NONE => (res + fw * defVal, sfh)); |
461 |
469 val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i)); |
462 fun in_sfreqh (s, sfreqv) (sofar, sfsymh) = |
470 fun fold_sfh (f, sf) sow = |
463 let val sfreqv = Real.fromInt sfreqv in |
471 sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)); |
464 if Inttab.defined sfsymh s then |
472 val sumOfWei = Inttab.fold fold_sfh sfh 0.0; |
465 (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh) |
|
466 else |
|
467 (sofar + Math.ln (n - sfreqv + mp), sfsymh) |
|
468 end |
|
469 |
|
470 val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh) |
|
471 val len_mem = length (Inttab.keys symh) |
|
472 val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh) |
|
473 in |
473 in |
474 postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp - |
474 res + tau * sumOfWei |
475 Real.fromInt sym_num * Math.ln (n + ess) |
|
476 end |
475 end |
477 |
476 val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j))); |
478 val posterior = Array.tabulate (adv_max, swap o `log_posterior) |
|
479 |
|
480 fun ret acc at = |
477 fun ret acc at = |
481 if at = Array.length posterior then acc |
478 if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1) |
482 else ret (Array.sub (posterior, at) :: acc) (at + 1) |
|
483 in |
479 in |
484 heap (Real.compare o pairself snd) advno posterior; |
480 heap (Real.compare o pairself snd) advno posterior; |
485 ret [] (Integer.max 0 (adv_max - advno)) |
481 ret [] (Integer.max 0 (adv_max - advno)) |
486 end |
482 end; |
487 |
483 |
488 fun for i = |
484 fun for i = |
489 if i = avail_num then () else (nb_learn (get_th_syms i) (i :: get_deps i); for (i + 1)) |
485 if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1)) |
490 in |
486 in |
491 for 0; nb_eval syms |
487 for 0; eval syms |
492 end |
488 end |
493 |
489 |
494 val knns = 40 (* FUDGE *) |
490 val knns = 40 (* FUDGE *) |
495 val ess = 0.00001 (* FUDGE *) |
|
496 val prior = 0.001 (* FUDGE *) |
|
497 |
491 |
498 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
492 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
499 |
493 |
500 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
494 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
501 |
495 |
1235 val chained_feats = chained |
1228 val chained_feats = chained |
1236 |> map (rpair 1.0) |
1229 |> map (rpair 1.0) |
1237 |> map (chained_or_extra_features_of chained_feature_factor) |
1230 |> map (chained_or_extra_features_of chained_feature_factor) |
1238 |> rpair [] |-> fold (union (eq_fst (op =))) |
1231 |> rpair [] |-> fold (union (eq_fst (op =))) |
1239 val extra_feats = |
1232 val extra_feats = |
1240 (* As long as SML NB does not support weights, it makes little sense to include these |
1233 facts |
1241 extra features. *) |
1234 |> take (Int.max (0, num_extra_feature_facts - length chained)) |
1242 if engine = MaSh_SML_NB then |
1235 |> filter fact_has_right_theory |
1243 [] |
1236 |> weight_facts_steeply |
1244 else |
1237 |> map (chained_or_extra_features_of extra_feature_factor) |
1245 facts |
1238 |> rpair [] |-> fold (union (eq_fst (op =))) |
1246 |> take (Int.max (0, num_extra_feature_facts - length chained)) |
|
1247 |> filter fact_has_right_theory |
|
1248 |> weight_facts_steeply |
|
1249 |> map (chained_or_extra_features_of extra_feature_factor) |
|
1250 |> rpair [] |-> fold (union (eq_fst (op =))) |
|
1251 val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats |
1239 val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats |
1252 |> debug ? sort (Real.compare o swap o pairself snd) |
1240 |> debug ? sort (Real.compare o swap o pairself snd) |
1253 in |
1241 in |
1254 (parents, hints, feats) |
1242 (parents, hints, feats) |
1255 end |
1243 end |