src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57029 75cc30d2b83f
parent 57028 e5466055e94f
child 57039 1ddd1f75fb40
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:08 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:44 2014 +0200
     1.3 @@ -115,26 +115,21 @@
     1.4      ()
     1.5    end
     1.6  
     1.7 -datatype mash_engine = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
     1.8 +datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
     1.9  
    1.10  fun mash_engine () =
    1.11    let val flag1 = Options.default_string @{system_option maSh} in
    1.12      (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
    1.13        "yes" => SOME MaSh_Py
    1.14      | "py" => SOME MaSh_Py
    1.15 -    | "sml" => SOME MaSh_SML_KNN
    1.16 -    | "sml_knn" => SOME MaSh_SML_KNN
    1.17 +    | "sml" => SOME MaSh_SML_kNN
    1.18 +    | "sml_knn" => SOME MaSh_SML_kNN
    1.19      | "sml_nb" => SOME MaSh_SML_NB
    1.20      | _ => NONE)
    1.21    end
    1.22  
    1.23  val is_mash_enabled = is_some o mash_engine
    1.24 -
    1.25 -fun is_mash_sml_enabled () =
    1.26 -  (case mash_engine () of
    1.27 -    SOME MaSh_SML_KNN => true
    1.28 -  | SOME MaSh_SML_NB => true
    1.29 -  | _ => false)
    1.30 +val the_mash_engine = the_default MaSh_SML_kNN o mash_engine
    1.31  
    1.32  
    1.33  (*** Low-level communication with Python version of MaSh ***)
    1.34 @@ -320,71 +315,55 @@
    1.35        end
    1.36  
    1.37      fun trickledown l i e =
    1.38 -      let
    1.39 -        val j = maxson l i
    1.40 -      in
    1.41 +      let val j = maxson l i in
    1.42          if cmp (Array.sub (a, j), e) = GREATER then
    1.43 -          let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
    1.44 -        else Array.update (a, i, e)
    1.45 +          (Array.update (a, i, Array.sub (a, j)); trickledown l j e)
    1.46 +        else
    1.47 +          Array.update (a, i, e)
    1.48        end
    1.49  
    1.50      fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
    1.51  
    1.52      fun bubbledown l i =
    1.53 -      let
    1.54 -        val j = maxson l i
    1.55 -        val _ = Array.update (a, i, Array.sub (a, j))
    1.56 -      in
    1.57 +      let val j = maxson l i in
    1.58 +        Array.update (a, i, Array.sub (a, j));
    1.59          bubbledown l j
    1.60        end
    1.61  
    1.62      fun bubble l i = bubbledown l i handle BOTTOM i => i
    1.63  
    1.64      fun trickleup i e =
    1.65 -      let
    1.66 -        val father = (i - 1) div 3
    1.67 -      in
    1.68 +      let val father = (i - 1) div 3 in
    1.69          if cmp (Array.sub (a, father), e) = LESS then
    1.70 -          let
    1.71 -            val _ = Array.update (a, i, Array.sub (a, father))
    1.72 -          in
    1.73 -            if father > 0 then trickleup father e else Array.update (a, 0, e)
    1.74 -          end
    1.75 -        else Array.update (a, i, e)
    1.76 +          (Array.update (a, i, Array.sub (a, father));
    1.77 +           if father > 0 then trickleup father e else Array.update (a, 0, e))
    1.78 +        else
    1.79 +          Array.update (a, i, e)
    1.80        end
    1.81  
    1.82      val l = Array.length a
    1.83  
    1.84 -    fun for i =
    1.85 -      if i < 0 then () else
    1.86 -      let
    1.87 -        val _ = trickle l i (Array.sub (a, i))
    1.88 -      in
    1.89 -        for (i - 1)
    1.90 -      end
    1.91 -
    1.92 -    val _ = for (((l + 1) div 3) - 1)
    1.93 +    fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
    1.94  
    1.95      fun for2 i =
    1.96 -      if i < Integer.max 2 (l - bnd) then () else
    1.97 -      let
    1.98 -        val e = Array.sub (a, i)
    1.99 -        val _ = Array.update (a, i, Array.sub (a, 0))
   1.100 -        val _ = trickleup (bubble i 0) e
   1.101 -      in
   1.102 -        for2 (i - 1)
   1.103 -      end
   1.104 -
   1.105 -    val _ = for2 (l - 1)
   1.106 +      if i < Integer.max 2 (l - bnd) then
   1.107 +        ()
   1.108 +      else
   1.109 +        let val e = Array.sub (a, i) in
   1.110 +          Array.update (a, i, Array.sub (a, 0));
   1.111 +          trickleup (bubble i 0) e;
   1.112 +          for2 (i - 1)
   1.113 +        end
   1.114    in
   1.115 +    for (((l + 1) div 3) - 1);
   1.116 +    for2 (l - 1);
   1.117      if l > 1 then
   1.118 -      let
   1.119 -        val e = Array.sub (a, 1)
   1.120 -        val _ = Array.update (a, 1, Array.sub (a, 0))
   1.121 -      in
   1.122 +      let val e = Array.sub (a, 1) in
   1.123 +        Array.update (a, 1, Array.sub (a, 0));
   1.124          Array.update (a, 0, e)
   1.125        end
   1.126 -    else ()
   1.127 +    else
   1.128 +      ()
   1.129    end
   1.130  
   1.131  (*
   1.132 @@ -421,7 +400,7 @@
   1.133        end
   1.134  
   1.135      val _ = List.app do_sym syms
   1.136 -    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
   1.137 +    val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
   1.138      val recommends = Array.tabulate (adv_max, rpair 0.0)
   1.139  
   1.140      fun inc_recommend j v =
   1.141 @@ -438,27 +417,97 @@
   1.142            val _ = inc_recommend j o1
   1.143            val ds = get_deps j
   1.144            val l = Real.fromInt (length ds)
   1.145 -          val _ = map (fn d => inc_recommend d (o1 / l)) ds
   1.146          in
   1.147 -          for (k + 1)
   1.148 +          List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1)
   1.149          end
   1.150  
   1.151 -    val _ = for 0
   1.152 -    val _ = heap (Real.compare o pairself snd) advno recommends
   1.153 -
   1.154      fun ret acc at =
   1.155        if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   1.156    in
   1.157 +    for 0;
   1.158 +    heap (Real.compare o pairself snd) advno recommends;
   1.159      ret [] (Integer.max 0 (adv_max - advno))
   1.160    end
   1.161  
   1.162 +(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in
   1.163 +   usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
   1.164 +   prior. *)
   1.165 +fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
   1.166 +  let
   1.167 +    val afreq = Unsynchronized.ref 0
   1.168 +    val tfreq = Array.array (avail_num, 0)
   1.169 +    val sfreq = Array.array (avail_num, Inttab.empty)
   1.170 +
   1.171 +    fun nb_learn syms ts =
   1.172 +      let
   1.173 +        fun add_sym hpis sym =
   1.174 +          let
   1.175 +            val im = Array.sub (sfreq, hpis)
   1.176 +            val v = the_default 0 (Inttab.lookup im sym)
   1.177 +          in
   1.178 +            Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im)
   1.179 +          end
   1.180 +
   1.181 +        fun add_th t =
   1.182 +          (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
   1.183 +      in
   1.184 +        afreq := !afreq + 1;
   1.185 +        List.app add_th ts
   1.186 +      end
   1.187 +
   1.188 +    fun nb_eval syms =
   1.189 +      let
   1.190 +        fun log_posterior i =
   1.191 +          let
   1.192 +            val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty
   1.193 +            val n = Real.fromInt (Array.sub (tfreq, i))
   1.194 +            val sfreqh = Array.sub (sfreq, i)
   1.195 +            val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
   1.196 +            val mp = ess * p
   1.197 +            val logmp = Math.ln mp
   1.198 +            val lognmp = Math.ln (n + mp)
   1.199 +
   1.200 +            fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
   1.201 +              let val sfreqv = Real.fromInt sfreqv in
   1.202 +                if Inttab.defined sfsymh s then
   1.203 +                  (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
   1.204 +                else
   1.205 +                  (sofar + Math.ln (n - sfreqv + mp), sfsymh)
   1.206 +              end
   1.207 +
   1.208 +            val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
   1.209 +            val len_mem = length (Inttab.keys symh)
   1.210 +            val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
   1.211 +          in
   1.212 +            postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
   1.213 +              Real.fromInt sym_num * Math.ln(n + ess)
   1.214 +          end
   1.215 +
   1.216 +        val posterior = Array.tabulate (adv_max, swap o `log_posterior)
   1.217 +
   1.218 +        fun ret acc at =
   1.219 +          if at = Array.length posterior then acc
   1.220 +          else ret (Array.sub (posterior,at) :: acc) (at + 1)
   1.221 +      in
   1.222 +        heap (Real.compare o pairself snd) advno posterior;
   1.223 +        ret [] (Integer.max 0 (adv_max - advno))
   1.224 +      end
   1.225 +
   1.226 +    fun for i =
   1.227 +      if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1))
   1.228 +  in
   1.229 +    for 0; nb_eval syms
   1.230 +  end
   1.231 +
   1.232  val knns = 40 (* FUDGE *)
   1.233 +val ess = 0.00001 (* FUDGE *)
   1.234 +val prior = 0.001 (* FUDGE *)
   1.235  
   1.236  fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   1.237  
   1.238  fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   1.239  
   1.240 -fun query ctxt parents access_G max_suggs hints feats =
   1.241 +fun query ctxt engine parents access_G max_suggs hints feats =
   1.242    let
   1.243      val str_of_feat = space_implode "|"
   1.244  
   1.245 @@ -470,9 +519,9 @@
   1.246         |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
   1.247        (if null hints then [] else [(".goal", feats, hints)])
   1.248  
   1.249 -    val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
   1.250 +    val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   1.251        fold (fn (fact, feats, deps) =>
   1.252 -            fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   1.253 +            fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   1.254            let
   1.255              fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   1.256                (case Symtab.lookup tab feat of
   1.257 @@ -481,7 +530,7 @@
   1.258  
   1.259              val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
   1.260            in
   1.261 -            (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   1.262 +            (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   1.263               add_to_xtab fact fact_xtab, feat_xtab')
   1.264            end)
   1.265          all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   1.266 @@ -490,22 +539,40 @@
   1.267      val fact_vec = Vector.fromList facts
   1.268  
   1.269      val deps_vec = Vector.fromList (rev rev_depss)
   1.270 -    val facts_ary = Array.array (num_feats, [])
   1.271 -    val _ =
   1.272 -      fold (fn feats => fn fact =>
   1.273 -          let val fact' = fact - 1 in
   1.274 -            List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   1.275 -              feats;
   1.276 -            fact'
   1.277 -          end)
   1.278 -        featss (length featss)
   1.279 +
   1.280 +    val avail_num = Vector.length deps_vec
   1.281 +    val adv_max = length visible_facts
   1.282 +    val get_deps = curry Vector.sub deps_vec
   1.283 +    val advno = max_suggs
   1.284    in
   1.285      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.286        elide_string 1000 (space_implode " " facts) ^ "}");
   1.287 -    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
   1.288 -      (curry Array.sub facts_ary) knns max_suggs
   1.289 -      (map_filter (fn (feat, weight) =>
   1.290 -         Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   1.291 +    (if engine = MaSh_SML_kNN then
   1.292 +       let
   1.293 +        val facts_ary = Array.array (num_feats, [])
   1.294 +        val _ =
   1.295 +          fold (fn feats => fn fact =>
   1.296 +              let val fact' = fact - 1 in
   1.297 +                List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   1.298 +                  feats;
   1.299 +                fact'
   1.300 +              end)
   1.301 +            rev_featss num_facts
   1.302 +         val get_sym_ths = curry Array.sub facts_ary
   1.303 +         val syms = map_filter (fn (feat, weight) =>
   1.304 +           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats
   1.305 +       in
   1.306 +         knn avail_num adv_max get_deps get_sym_ths knns advno syms
   1.307 +       end
   1.308 +     else
   1.309 +       let
   1.310 +         val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   1.311 +         val get_th_syms = curry Vector.sub unweighted_feats_ary
   1.312 +         val sym_num = num_feats
   1.313 +         val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats
   1.314 +       in
   1.315 +         nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms
   1.316 +       end)
   1.317      |> map (curry Vector.sub fact_vec o fst)
   1.318    end
   1.319  
   1.320 @@ -596,8 +663,10 @@
   1.321                    fold extract_line_and_add_node node_lines Graph.empty),
   1.322                  length node_lines)
   1.323               | LESS =>
   1.324 -               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
   1.325 -                else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
   1.326 +               (* cannot parse old file *)
   1.327 +               (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
   1.328 +                else wipe_out_mash_state_dir ();
   1.329 +                (Graph.empty, 0))
   1.330               | GREATER => raise FILE_VERSION_TOO_NEW ())
   1.331           in
   1.332             trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
   1.333 @@ -645,7 +714,8 @@
   1.334  fun clear_state ctxt overlord =
   1.335    (* "MaSh_Py.unlearn" also removes the state file *)
   1.336    Synchronized.change global_state (fn _ =>
   1.337 -    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
   1.338 +    (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
   1.339 +     else wipe_out_mash_state_dir ();
   1.340       (false, empty_state)))
   1.341  
   1.342  end
   1.343 @@ -1224,7 +1294,7 @@
   1.344          (parents, hints, feats)
   1.345        end
   1.346  
   1.347 -    val sml = is_mash_sml_enabled ()
   1.348 +    val engine = the_mash_engine ()
   1.349  
   1.350      val (access_G, py_suggs) =
   1.351        peek_state ctxt overlord (fn {access_G, ...} =>
   1.352 @@ -1232,20 +1302,20 @@
   1.353            (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   1.354          else
   1.355            (access_G,
   1.356 -           if sml then
   1.357 -             []
   1.358 -           else
   1.359 +           if engine = MaSh_Py then
   1.360               let val (parents, hints, feats) = query_args access_G in
   1.361                 MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
   1.362 -             end))
   1.363 +             end
   1.364 +           else
   1.365 +             []))
   1.366  
   1.367      val sml_suggs =
   1.368 -      if sml then
   1.369 +      if engine = MaSh_Py then
   1.370 +        []
   1.371 +      else
   1.372          let val (parents, hints, feats) = query_args access_G in
   1.373 -          MaSh_SML.query ctxt parents access_G max_facts hints feats
   1.374 +          MaSh_SML.query ctxt engine parents access_G max_facts hints feats
   1.375          end
   1.376 -      else
   1.377 -        []
   1.378  
   1.379      val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   1.380    in
   1.381 @@ -1309,13 +1379,13 @@
   1.382                |> filter (is_fact_in_graph access_G)
   1.383                |> map nickname_of_thm
   1.384            in
   1.385 -            if is_mash_sml_enabled () then
   1.386 +            if the_mash_engine () = MaSh_Py then
   1.387 +              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.388 +            else
   1.389                let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
   1.390                  {access_G = access_G, num_known_facts = num_known_facts + 1,
   1.391                   dirty = Option.map (cons name) dirty}
   1.392                end
   1.393 -            else
   1.394 -              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.395            end);
   1.396          (true, "")
   1.397        end)
   1.398 @@ -1334,7 +1404,7 @@
   1.399      val timer = Timer.startRealTimer ()
   1.400      fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
   1.401  
   1.402 -    val sml = is_mash_sml_enabled ()
   1.403 +    val engine = the_mash_engine ()
   1.404      val {access_G, ...} = peek_state ctxt overlord I
   1.405      val is_in_access_G = is_fact_in_graph access_G o snd
   1.406      val no_new_facts = forall is_in_access_G facts
   1.407 @@ -1376,11 +1446,11 @@
   1.408                    (false, SOME names, []) => SOME (map #1 learns @ names)
   1.409                  | _ => NONE)
   1.410              in
   1.411 -              if sml then
   1.412 -                ()
   1.413 +              if engine = MaSh_Py then
   1.414 +                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.415 +                 MaSh_Py.relearn ctxt overlord save relearns)
   1.416                else
   1.417 -                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.418 -                 MaSh_Py.relearn ctxt overlord save relearns);
   1.419 +                ();
   1.420                {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
   1.421              end
   1.422  
   1.423 @@ -1613,7 +1683,7 @@
   1.424             |> Par_List.map (apsnd (fn f => f ()))
   1.425        val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
   1.426      in
   1.427 -      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
   1.428 +      if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else ();
   1.429        (case (fact_filter, mess) of
   1.430          (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
   1.431          [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
   1.432 @@ -1623,7 +1693,7 @@
   1.433  
   1.434  fun kill_learners ctxt ({overlord, ...} : params) =
   1.435    (Async_Manager.kill_threads MaShN "learner";
   1.436 -   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
   1.437 +   if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ())
   1.438  
   1.439  fun running_learners () = Async_Manager.running_threads MaShN "learner"
   1.440