427 val nb_tau = 0.02 (* FUDGE *) |
427 val nb_tau = 0.02 (* FUDGE *) |
428 val nb_pos_weight = 2.0 (* FUDGE *) |
428 val nb_pos_weight = 2.0 (* FUDGE *) |
429 val nb_def_val = ~15.0 (* FUDGE *) |
429 val nb_def_val = ~15.0 (* FUDGE *) |
430 val nb_def_prior_weight = 20 (* FUDGE *) |
430 val nb_def_prior_weight = 20 (* FUDGE *) |
431 |
431 |
|
432 (* TODO: Either use IDF or don't use it. See commented out code portions below. *) |
|
433 |
432 fun naive_bayes_learn num_facts get_deps get_th_feats num_feats = |
434 fun naive_bayes_learn num_facts get_deps get_th_feats num_feats = |
433 let |
435 let |
434 val afreq = Unsynchronized.ref 0 |
|
435 val tfreq = Array.array (num_facts, 0) |
436 val tfreq = Array.array (num_facts, 0) |
436 val sfreq = Array.array (num_facts, Inttab.empty) |
437 val sfreq = Array.array (num_facts, Inttab.empty) |
437 val dffreq = Array.array (num_feats, 0) |
438 (* val dffreq = Array.array (num_feats, 0) *) |
438 |
439 |
439 fun learn th feats deps = |
440 fun learn th feats deps = |
440 let |
441 let |
441 fun add_th weight t = |
442 fun add_th weight t = |
442 let |
443 let |
445 in |
446 in |
446 Array.update (tfreq, t, weight + Array.sub (tfreq, t)); |
447 Array.update (tfreq, t, weight + Array.sub (tfreq, t)); |
447 Array.update (sfreq, t, fold fold_fn feats im) |
448 Array.update (sfreq, t, fold fold_fn feats im) |
448 end |
449 end |
449 |
450 |
450 fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)) |
451 (* fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)) *) |
451 in |
452 in |
452 add_th nb_def_prior_weight th; |
453 add_th nb_def_prior_weight th; |
453 List.app (add_th 1) deps; |
454 List.app (add_th 1) deps |
454 List.app add_sym feats; |
455 (* ; List.app add_sym feats *) |
455 afreq := !afreq + 1 |
|
456 end |
456 end |
457 |
457 |
458 fun for i = |
458 fun for i = |
459 if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1)) |
459 if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1)) |
460 in |
460 in |
461 for 0; (Real.fromInt (!afreq), Array.vector tfreq, Array.vector sfreq, Array.vector dffreq) |
461 for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *)) |
462 end |
462 end |
463 |
463 |
464 fun naive_bayes_query num_visible_facts max_suggs feats (afreq, tfreq, sfreq, dffreq) = |
464 fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) = |
465 let |
465 let |
|
466 (* |
|
467 val afreq = Real.fromInt num_facts |
466 fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) |
468 fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) |
|
469 *) |
|
470 fun tfidf _ = 1.0 |
467 |
471 |
468 fun log_posterior i = |
472 fun log_posterior i = |
469 let |
473 let |
470 val tfreq = Real.fromInt (Vector.sub (tfreq, i)) |
474 val tfreq = Real.fromInt (Vector.sub (tfreq, i)) |
471 |
475 |
495 ret [] (Integer.max 0 (num_visible_facts - max_suggs)) |
499 ret [] (Integer.max 0 (num_visible_facts - max_suggs)) |
496 end |
500 end |
497 |
501 |
498 fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats = |
502 fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats = |
499 naive_bayes_learn num_facts get_deps get_th_feats num_feats |
503 naive_bayes_learn num_facts get_deps get_th_feats num_feats |
500 |> naive_bayes_query num_visible_facts max_suggs feats |
504 |> naive_bayes_query num_facts num_visible_facts max_suggs feats |
501 |
505 |
502 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
506 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
503 |
507 |
504 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
508 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
505 |
509 |
1212 val facts = facts |> sort (crude_thm_ord o pairself snd o swap) |
1216 val facts = facts |> sort (crude_thm_ord o pairself snd o swap) |
1213 val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) |
1217 val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) |
1214 val num_facts = length facts |
1218 val num_facts = length facts |
1215 |
1219 |
1216 (* Weights appear to hurt kNN more than they help. *) |
1220 (* Weights appear to hurt kNN more than they help. *) |
1217 val const_tab = Symtab.empty |> engine = MaSh_Py ? fold (add_const_counts o prop_of o snd) facts |
1221 val const_tab = Symtab.empty |> engine <> MaSh_SML_kNN |
|
1222 ? fold (add_const_counts o prop_of o snd) facts |
1218 |
1223 |
1219 fun fact_has_right_theory (_, th) = |
1224 fun fact_has_right_theory (_, th) = |
1220 thy_name = Context.theory_name (theory_of_thm th) |
1225 thy_name = Context.theory_name (theory_of_thm th) |
1221 |
1226 |
1222 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |
1227 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |