added naive Bayes ML implementation, due to Cezary Kaliszyk (like k-NN)
authorblanchet
Tue May 20 22:28:44 2014 +0200 (2014-05-20)
changeset 5702975cc30d2b83f
parent 57028 e5466055e94f
child 57030 b592202a45cc
added naive Bayes ML implementation, due to Cezary Kaliszyk (like k-NN)
NEWS
src/Doc/Sledgehammer/document/root.tex
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/NEWS	Tue May 20 22:28:08 2014 +0200
     1.2 +++ b/NEWS	Tue May 20 22:28:44 2014 +0200
     1.3 @@ -386,8 +386,8 @@
     1.4        - Activation of MaSh now works via the "mash" system option (without
     1.5          requiring restart), instead of former settings variable "MASH".
     1.6          The option can be edited in Isabelle/jEdit menu Plugin
     1.7 -        Options / Isabelle / General. Allowed values include "sml" (for the new
     1.8 -        SML engine), "py" (for the Python engine), and "no".
     1.9 +        Options / Isabelle / General. Allowed values include "sml" (for the
    1.10 +        default SML engine), "py" (for the old Python engine), and "none".
    1.11    - New option:
    1.12        smt_proofs
    1.13    - Renamed options:
     2.1 --- a/src/Doc/Sledgehammer/document/root.tex	Tue May 20 22:28:08 2014 +0200
     2.2 +++ b/src/Doc/Sledgehammer/document/root.tex	Tue May 20 22:28:44 2014 +0200
     2.3 @@ -1070,8 +1070,8 @@
     2.4  The experimental MaSh machine learner. Three learning engines are provided:
     2.5  
     2.6  \begin{enum}
     2.7 -\item[\labelitemi] \textbf{\textit{sml}} (also called
     2.8 -\textbf{\textit{sml\_knn}}) is a Standard ML implementation of $k$-nearest
     2.9 +\item[\labelitemi] \textbf{\textit{sml\_knn}} (also called
    2.10 +\textbf{\textit{sml}}) is a Standard ML implementation of $k$-nearest
    2.11  neighbors.
    2.12  
    2.13  \item[\labelitemi] \textbf{\textit{sml\_nb}} is a Standard ML implementation of
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:08 2014 +0200
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:44 2014 +0200
     3.3 @@ -115,26 +115,21 @@
     3.4      ()
     3.5    end
     3.6  
     3.7 -datatype mash_engine = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
     3.8 +datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
     3.9  
    3.10  fun mash_engine () =
    3.11    let val flag1 = Options.default_string @{system_option maSh} in
    3.12      (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
    3.13        "yes" => SOME MaSh_Py
    3.14      | "py" => SOME MaSh_Py
    3.15 -    | "sml" => SOME MaSh_SML_KNN
    3.16 -    | "sml_knn" => SOME MaSh_SML_KNN
    3.17 +    | "sml" => SOME MaSh_SML_kNN
    3.18 +    | "sml_knn" => SOME MaSh_SML_kNN
    3.19      | "sml_nb" => SOME MaSh_SML_NB
    3.20      | _ => NONE)
    3.21    end
    3.22  
    3.23  val is_mash_enabled = is_some o mash_engine
    3.24 -
    3.25 -fun is_mash_sml_enabled () =
    3.26 -  (case mash_engine () of
    3.27 -    SOME MaSh_SML_KNN => true
    3.28 -  | SOME MaSh_SML_NB => true
    3.29 -  | _ => false)
    3.30 +val the_mash_engine = the_default MaSh_SML_kNN o mash_engine
    3.31  
    3.32  
    3.33  (*** Low-level communication with Python version of MaSh ***)
    3.34 @@ -320,71 +315,55 @@
    3.35        end
    3.36  
    3.37      fun trickledown l i e =
    3.38 -      let
    3.39 -        val j = maxson l i
    3.40 -      in
    3.41 +      let val j = maxson l i in
    3.42          if cmp (Array.sub (a, j), e) = GREATER then
    3.43 -          let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
    3.44 -        else Array.update (a, i, e)
    3.45 +          (Array.update (a, i, Array.sub (a, j)); trickledown l j e)
    3.46 +        else
    3.47 +          Array.update (a, i, e)
    3.48        end
    3.49  
    3.50      fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
    3.51  
    3.52      fun bubbledown l i =
    3.53 -      let
    3.54 -        val j = maxson l i
    3.55 -        val _ = Array.update (a, i, Array.sub (a, j))
    3.56 -      in
    3.57 +      let val j = maxson l i in
    3.58 +        Array.update (a, i, Array.sub (a, j));
    3.59          bubbledown l j
    3.60        end
    3.61  
    3.62      fun bubble l i = bubbledown l i handle BOTTOM i => i
    3.63  
    3.64      fun trickleup i e =
    3.65 -      let
    3.66 -        val father = (i - 1) div 3
    3.67 -      in
    3.68 +      let val father = (i - 1) div 3 in
    3.69          if cmp (Array.sub (a, father), e) = LESS then
    3.70 -          let
    3.71 -            val _ = Array.update (a, i, Array.sub (a, father))
    3.72 -          in
    3.73 -            if father > 0 then trickleup father e else Array.update (a, 0, e)
    3.74 -          end
    3.75 -        else Array.update (a, i, e)
    3.76 +          (Array.update (a, i, Array.sub (a, father));
    3.77 +           if father > 0 then trickleup father e else Array.update (a, 0, e))
    3.78 +        else
    3.79 +          Array.update (a, i, e)
    3.80        end
    3.81  
    3.82      val l = Array.length a
    3.83  
    3.84 -    fun for i =
    3.85 -      if i < 0 then () else
    3.86 -      let
    3.87 -        val _ = trickle l i (Array.sub (a, i))
    3.88 -      in
    3.89 -        for (i - 1)
    3.90 -      end
    3.91 -
    3.92 -    val _ = for (((l + 1) div 3) - 1)
    3.93 +    fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
    3.94  
    3.95      fun for2 i =
    3.96 -      if i < Integer.max 2 (l - bnd) then () else
    3.97 -      let
    3.98 -        val e = Array.sub (a, i)
    3.99 -        val _ = Array.update (a, i, Array.sub (a, 0))
   3.100 -        val _ = trickleup (bubble i 0) e
   3.101 -      in
   3.102 -        for2 (i - 1)
   3.103 -      end
   3.104 -
   3.105 -    val _ = for2 (l - 1)
   3.106 +      if i < Integer.max 2 (l - bnd) then
   3.107 +        ()
   3.108 +      else
   3.109 +        let val e = Array.sub (a, i) in
   3.110 +          Array.update (a, i, Array.sub (a, 0));
   3.111 +          trickleup (bubble i 0) e;
   3.112 +          for2 (i - 1)
   3.113 +        end
   3.114    in
   3.115 +    for (((l + 1) div 3) - 1);
   3.116 +    for2 (l - 1);
   3.117      if l > 1 then
   3.118 -      let
   3.119 -        val e = Array.sub (a, 1)
   3.120 -        val _ = Array.update (a, 1, Array.sub (a, 0))
   3.121 -      in
   3.122 +      let val e = Array.sub (a, 1) in
   3.123 +        Array.update (a, 1, Array.sub (a, 0));
   3.124          Array.update (a, 0, e)
   3.125        end
   3.126 -    else ()
   3.127 +    else
   3.128 +      ()
   3.129    end
   3.130  
   3.131  (*
   3.132 @@ -421,7 +400,7 @@
   3.133        end
   3.134  
   3.135      val _ = List.app do_sym syms
   3.136 -    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
   3.137 +    val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
   3.138      val recommends = Array.tabulate (adv_max, rpair 0.0)
   3.139  
   3.140      fun inc_recommend j v =
   3.141 @@ -438,27 +417,97 @@
   3.142            val _ = inc_recommend j o1
   3.143            val ds = get_deps j
   3.144            val l = Real.fromInt (length ds)
   3.145 -          val _ = map (fn d => inc_recommend d (o1 / l)) ds
   3.146          in
   3.147 -          for (k + 1)
   3.148 +          List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1)
   3.149          end
   3.150  
   3.151 -    val _ = for 0
   3.152 -    val _ = heap (Real.compare o pairself snd) advno recommends
   3.153 -
   3.154      fun ret acc at =
   3.155        if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   3.156    in
   3.157 +    for 0;
   3.158 +    heap (Real.compare o pairself snd) advno recommends;
   3.159      ret [] (Integer.max 0 (adv_max - advno))
   3.160    end
   3.161  
   3.162 +(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in
   3.163 +   usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
   3.164 +   prior. *)
   3.165 +fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
   3.166 +  let
   3.167 +    val afreq = Unsynchronized.ref 0
   3.168 +    val tfreq = Array.array (avail_num, 0)
   3.169 +    val sfreq = Array.array (avail_num, Inttab.empty)
   3.170 +
   3.171 +    fun nb_learn syms ts =
   3.172 +      let
   3.173 +        fun add_sym hpis sym =
   3.174 +          let
   3.175 +            val im = Array.sub (sfreq, hpis)
   3.176 +            val v = the_default 0 (Inttab.lookup im sym)
   3.177 +          in
   3.178 +            Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im)
   3.179 +          end
   3.180 +
   3.181 +        fun add_th t =
   3.182 +          (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
   3.183 +      in
   3.184 +        afreq := !afreq + 1;
   3.185 +        List.app add_th ts
   3.186 +      end
   3.187 +
   3.188 +    fun nb_eval syms =
   3.189 +      let
   3.190 +        fun log_posterior i =
   3.191 +          let
   3.192 +            val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty
   3.193 +            val n = Real.fromInt (Array.sub (tfreq, i))
   3.194 +            val sfreqh = Array.sub (sfreq, i)
   3.195 +            val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
   3.196 +            val mp = ess * p
   3.197 +            val logmp = Math.ln mp
   3.198 +            val lognmp = Math.ln (n + mp)
   3.199 +
   3.200 +            fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
   3.201 +              let val sfreqv = Real.fromInt sfreqv in
   3.202 +                if Inttab.defined sfsymh s then
   3.203 +                  (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
   3.204 +                else
   3.205 +                  (sofar + Math.ln (n - sfreqv + mp), sfsymh)
   3.206 +              end
   3.207 +
   3.208 +            val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
   3.209 +            val len_mem = length (Inttab.keys symh)
   3.210 +            val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
   3.211 +          in
   3.212 +            postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
   3.213 +              Real.fromInt sym_num * Math.ln(n + ess)
   3.214 +          end
   3.215 +
   3.216 +        val posterior = Array.tabulate (adv_max, swap o `log_posterior)
   3.217 +
   3.218 +        fun ret acc at =
   3.219 +          if at = Array.length posterior then acc
   3.220 +          else ret (Array.sub (posterior,at) :: acc) (at + 1)
   3.221 +      in
   3.222 +        heap (Real.compare o pairself snd) advno posterior;
   3.223 +        ret [] (Integer.max 0 (adv_max - advno))
   3.224 +      end
   3.225 +
   3.226 +    fun for i =
   3.227 +      if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1))
   3.228 +  in
   3.229 +    for 0; nb_eval syms
   3.230 +  end
   3.231 +
   3.232  val knns = 40 (* FUDGE *)
   3.233 +val ess = 0.00001 (* FUDGE *)
   3.234 +val prior = 0.001 (* FUDGE *)
   3.235  
   3.236  fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   3.237  
   3.238  fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   3.239  
   3.240 -fun query ctxt parents access_G max_suggs hints feats =
   3.241 +fun query ctxt engine parents access_G max_suggs hints feats =
   3.242    let
   3.243      val str_of_feat = space_implode "|"
   3.244  
   3.245 @@ -470,9 +519,9 @@
   3.246         |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
   3.247        (if null hints then [] else [(".goal", feats, hints)])
   3.248  
   3.249 -    val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
   3.250 +    val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   3.251        fold (fn (fact, feats, deps) =>
   3.252 -            fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   3.253 +            fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   3.254            let
   3.255              fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   3.256                (case Symtab.lookup tab feat of
   3.257 @@ -481,7 +530,7 @@
   3.258  
   3.259              val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
   3.260            in
   3.261 -            (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   3.262 +            (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   3.263               add_to_xtab fact fact_xtab, feat_xtab')
   3.264            end)
   3.265          all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   3.266 @@ -490,22 +539,40 @@
   3.267      val fact_vec = Vector.fromList facts
   3.268  
   3.269      val deps_vec = Vector.fromList (rev rev_depss)
   3.270 -    val facts_ary = Array.array (num_feats, [])
   3.271 -    val _ =
   3.272 -      fold (fn feats => fn fact =>
   3.273 -          let val fact' = fact - 1 in
   3.274 -            List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   3.275 -              feats;
   3.276 -            fact'
   3.277 -          end)
   3.278 -        featss (length featss)
   3.279 +
   3.280 +    val avail_num = Vector.length deps_vec
   3.281 +    val adv_max = length visible_facts
   3.282 +    val get_deps = curry Vector.sub deps_vec
   3.283 +    val advno = max_suggs
   3.284    in
   3.285      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   3.286        elide_string 1000 (space_implode " " facts) ^ "}");
   3.287 -    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
   3.288 -      (curry Array.sub facts_ary) knns max_suggs
   3.289 -      (map_filter (fn (feat, weight) =>
   3.290 -         Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   3.291 +    (if engine = MaSh_SML_kNN then
   3.292 +       let
   3.293 +        val facts_ary = Array.array (num_feats, [])
   3.294 +        val _ =
   3.295 +          fold (fn feats => fn fact =>
   3.296 +              let val fact' = fact - 1 in
   3.297 +                List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   3.298 +                  feats;
   3.299 +                fact'
   3.300 +              end)
   3.301 +            rev_featss num_facts
   3.302 +         val get_sym_ths = curry Array.sub facts_ary
   3.303 +         val syms = map_filter (fn (feat, weight) =>
   3.304 +           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats
   3.305 +       in
   3.306 +         knn avail_num adv_max get_deps get_sym_ths knns advno syms
   3.307 +       end
   3.308 +     else
   3.309 +       let
   3.310 +         val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   3.311 +         val get_th_syms = curry Vector.sub unweighted_feats_ary
   3.312 +         val sym_num = num_feats
   3.313 +         val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats
   3.314 +       in
   3.315 +         nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms
   3.316 +       end)
   3.317      |> map (curry Vector.sub fact_vec o fst)
   3.318    end
   3.319  
   3.320 @@ -596,8 +663,10 @@
   3.321                    fold extract_line_and_add_node node_lines Graph.empty),
   3.322                  length node_lines)
   3.323               | LESS =>
   3.324 -               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
   3.325 -                else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
   3.326 +               (* cannot parse old file *)
   3.327 +               (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
   3.328 +                else wipe_out_mash_state_dir ();
   3.329 +                (Graph.empty, 0))
   3.330               | GREATER => raise FILE_VERSION_TOO_NEW ())
   3.331           in
   3.332             trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
   3.333 @@ -645,7 +714,8 @@
   3.334  fun clear_state ctxt overlord =
   3.335    (* "MaSh_Py.unlearn" also removes the state file *)
   3.336    Synchronized.change global_state (fn _ =>
   3.337 -    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
   3.338 +    (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
   3.339 +     else wipe_out_mash_state_dir ();
   3.340       (false, empty_state)))
   3.341  
   3.342  end
   3.343 @@ -1224,7 +1294,7 @@
   3.344          (parents, hints, feats)
   3.345        end
   3.346  
   3.347 -    val sml = is_mash_sml_enabled ()
   3.348 +    val engine = the_mash_engine ()
   3.349  
   3.350      val (access_G, py_suggs) =
   3.351        peek_state ctxt overlord (fn {access_G, ...} =>
   3.352 @@ -1232,20 +1302,20 @@
   3.353            (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   3.354          else
   3.355            (access_G,
   3.356 -           if sml then
   3.357 -             []
   3.358 -           else
   3.359 +           if engine = MaSh_Py then
   3.360               let val (parents, hints, feats) = query_args access_G in
   3.361                 MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
   3.362 -             end))
   3.363 +             end
   3.364 +           else
   3.365 +             []))
   3.366  
   3.367      val sml_suggs =
   3.368 -      if sml then
   3.369 +      if engine = MaSh_Py then
   3.370 +        []
   3.371 +      else
   3.372          let val (parents, hints, feats) = query_args access_G in
   3.373 -          MaSh_SML.query ctxt parents access_G max_facts hints feats
   3.374 +          MaSh_SML.query ctxt engine parents access_G max_facts hints feats
   3.375          end
   3.376 -      else
   3.377 -        []
   3.378  
   3.379      val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   3.380    in
   3.381 @@ -1309,13 +1379,13 @@
   3.382                |> filter (is_fact_in_graph access_G)
   3.383                |> map nickname_of_thm
   3.384            in
   3.385 -            if is_mash_sml_enabled () then
   3.386 +            if the_mash_engine () = MaSh_Py then
   3.387 +              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   3.388 +            else
   3.389                let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
   3.390                  {access_G = access_G, num_known_facts = num_known_facts + 1,
   3.391                   dirty = Option.map (cons name) dirty}
   3.392                end
   3.393 -            else
   3.394 -              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   3.395            end);
   3.396          (true, "")
   3.397        end)
   3.398 @@ -1334,7 +1404,7 @@
   3.399      val timer = Timer.startRealTimer ()
   3.400      fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
   3.401  
   3.402 -    val sml = is_mash_sml_enabled ()
   3.403 +    val engine = the_mash_engine ()
   3.404      val {access_G, ...} = peek_state ctxt overlord I
   3.405      val is_in_access_G = is_fact_in_graph access_G o snd
   3.406      val no_new_facts = forall is_in_access_G facts
   3.407 @@ -1376,11 +1446,11 @@
   3.408                    (false, SOME names, []) => SOME (map #1 learns @ names)
   3.409                  | _ => NONE)
   3.410              in
   3.411 -              if sml then
   3.412 -                ()
   3.413 +              if engine = MaSh_Py then
   3.414 +                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   3.415 +                 MaSh_Py.relearn ctxt overlord save relearns)
   3.416                else
   3.417 -                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   3.418 -                 MaSh_Py.relearn ctxt overlord save relearns);
   3.419 +                ();
   3.420                {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
   3.421              end
   3.422  
   3.423 @@ -1613,7 +1683,7 @@
   3.424             |> Par_List.map (apsnd (fn f => f ()))
   3.425        val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
   3.426      in
   3.427 -      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
   3.428 +      if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else ();
   3.429        (case (fact_filter, mess) of
   3.430          (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
   3.431          [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
   3.432 @@ -1623,7 +1693,7 @@
   3.433  
   3.434  fun kill_learners ctxt ({overlord, ...} : params) =
   3.435    (Async_Manager.kill_threads MaShN "learner";
   3.436 -   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
   3.437 +   if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ())
   3.438  
   3.439  fun running_learners () = Async_Manager.running_threads MaShN "learner"
   3.440