added SML implementation of MaSh
authorblanchet
Mon May 19 23:43:53 2014 +0200 (2014-05-19)
changeset 570098cb6a5f1ae84
parent 57008 10f68b83b474
child 57010 121b63d7bcdb
added SML implementation of MaSh
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.3 @@ -1,5 +1,6 @@
     1.4  (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.5      Author:     Jasmin Blanchette, TU Muenchen
     1.6 +    Author:     Cezary Kaliszyk, University of Innsbruck
     1.7  
     1.8  Sledgehammer's machine-learning-based relevance filter (MaSh).
     1.9  *)
    1.10 @@ -253,9 +254,7 @@
    1.11  fun learn _ _ _ [] = ()
    1.12    | learn ctxt overlord save learns =
    1.13      (trace_msg ctxt (fn () =>
    1.14 -       let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
    1.15 -         "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
    1.16 -       end);
    1.17 +       "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}");
    1.18       run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
    1.19         (K ()))
    1.20  
    1.21 @@ -280,15 +279,203 @@
    1.22  structure MaSh_SML =
    1.23  struct
    1.24  
    1.25 -fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
    1.26 -    max_suggs (query as (_, _, _, feats)) =
    1.27 -  (trace_msg ctxt (fn () =>
    1.28 -     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
    1.29 -       "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
    1.30 -       "MaSh_SML query " ^ encode_features feats
    1.31 -     end);
    1.32 -   (* Implementation missing *)
    1.33 -   [])
    1.34 +fun max a b = if a > b then a else b
    1.35 +
    1.36 +exception BOTTOM of int
    1.37 +
    1.38 +fun heap cmp bnd a =
    1.39 +  let
    1.40 +    fun maxson l i =
    1.41 +      let
    1.42 +        val i31 = i + i + i + 1
    1.43 +      in
    1.44 +        if i31 + 2 < l then
    1.45 +          let
    1.46 +            val x = Unsynchronized.ref i31;
    1.47 +            val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
    1.48 +            val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
    1.49 +          in
    1.50 +            !x
    1.51 +          end
    1.52 +        else
    1.53 +          if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS
    1.54 +          then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
    1.55 +      end
    1.56 +
    1.57 +    fun trickledown l i e =
    1.58 +      let
    1.59 +        val j = maxson l i
    1.60 +      in
    1.61 +        if cmp (Array.sub (a, j), e) = GREATER then
    1.62 +          let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
    1.63 +        else Array.update (a, i, e)
    1.64 +      end
    1.65 +
    1.66 +    fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
    1.67 +
    1.68 +    fun bubbledown l i =
    1.69 +      let
    1.70 +        val j = maxson l i
    1.71 +        val () = Array.update (a, i, Array.sub (a, j))
    1.72 +      in
    1.73 +        bubbledown l j
    1.74 +      end
    1.75 +
    1.76 +    fun bubble l i = bubbledown l i handle BOTTOM i => i
    1.77 +
    1.78 +    fun trickleup i e =
    1.79 +      let
    1.80 +        val father = (i - 1) div 3
    1.81 +      in
    1.82 +        if cmp (Array.sub (a, father), e) = LESS then
    1.83 +          let
    1.84 +            val () = Array.update (a, i, Array.sub (a, father))
    1.85 +          in
    1.86 +            if father > 0 then trickleup father e else Array.update (a, 0, e)
    1.87 +          end
    1.88 +        else Array.update (a, i, e)
    1.89 +      end
    1.90 +
    1.91 +    val l = Array.length a
    1.92 +
    1.93 +    fun for i =
    1.94 +      if i < 0 then () else
    1.95 +      let
    1.96 +        val _ = trickle l i (Array.sub (a, i))
    1.97 +      in
    1.98 +        for (i - 1)
    1.99 +      end
   1.100 +
   1.101 +    val () = for (((l + 1) div 3) - 1)
   1.102 +
   1.103 +    fun for2 i =
   1.104 +      if i < max 2 (l - bnd) then () else
   1.105 +      let
   1.106 +        val e = Array.sub (a, i)
   1.107 +        val () = Array.update (a, i, Array.sub (a, 0))
   1.108 +        val () = trickleup (bubble i 0) e
   1.109 +      in
   1.110 +        for2 (i - 1)
   1.111 +      end
   1.112 +
   1.113 +    val () = for2 (l - 1)
   1.114 +  in
   1.115 +    if l > 1 then
   1.116 +      let
   1.117 +        val e = Array.sub (a, 1)
   1.118 +        val () = Array.update (a, 1, Array.sub (a, 0))
   1.119 +      in
   1.120 +        Array.update (a, 0, e)
   1.121 +      end
   1.122 +    else ()
   1.123 +  end
   1.124 +
   1.125 +(*
   1.126 +  avail_no = maximum number of theorems to check dependencies and symbols
   1.127 +  get_deps = returns dependencies of a theorem
   1.128 +  get_sym_ths = get theorems that have this feature
   1.129 +  knns    = number of nearest neighbours
   1.130 +  advno   = number of predictions to return
   1.131 +  syms    = symbols of the conjecture
   1.132 +*)
   1.133 +fun knn avail_no get_deps get_sym_ths knns advno syms =
   1.134 +  let
   1.135 +    (* Can be later used for TFIDF *)
   1.136 +    fun sym_wght _ = 1.0;
   1.137 +    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)));
   1.138 +    fun inc_overlap j v =
   1.139 +      let
   1.140 +        val ov = snd (Array.sub (overlaps_sqr,j))
   1.141 +      in
   1.142 +        Array.update (overlaps_sqr, j, (j, v + ov))
   1.143 +      end;
   1.144 +    fun do_sym (s, con_wght) =
   1.145 +      let
   1.146 +        val sw = sym_wght s;
   1.147 +        val w2 = sw * sw * con_wght;
   1.148 +        fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
   1.149 +      in
   1.150 +        ignore (map do_th (get_sym_ths s))
   1.151 +      end;
   1.152 +    val () = ignore (map do_sym syms);
   1.153 +    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
   1.154 +    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)));
   1.155 +    fun inc_recommend j v =
   1.156 +      let
   1.157 +        val ov = snd (Array.sub (recommends,j))
   1.158 +      in
   1.159 +        Array.update (recommends, j, (j, v + ov))
   1.160 +      end;
   1.161 +    fun for k =
   1.162 +      if k = knns then () else
   1.163 +      if k >= avail_no then () else
   1.164 +      let
   1.165 +        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1);
   1.166 +        val o1 = Math.sqrt o2;
   1.167 +        val () = inc_recommend j o1;
   1.168 +        val ds = get_deps j;
   1.169 +        val l = Real.fromInt (length ds);
   1.170 +        val _ = map (fn d => inc_recommend d (o1 / l)) ds
   1.171 +      in
   1.172 +        for (k + 1)
   1.173 +      end;
   1.174 +    val () = for 0;
   1.175 +    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   1.176 +    fun ret acc at =
   1.177 +      if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   1.178 +  in
   1.179 +    ret [] (max 0 (avail_no - advno))
   1.180 +  end
   1.181 +
   1.182 +val knns = 40 (* FUDGE *)
   1.183 +
   1.184 +fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   1.185 +
   1.186 +fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   1.187 +
   1.188 +(* TODO: take weight components of "feats" into consideration *)
   1.189 +fun learn_and_query ctxt parents access_G max_suggs hints feats =
   1.190 +  let
   1.191 +    val str_of_feat = space_implode "|"
   1.192 +
   1.193 +    val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
   1.194 +      fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   1.195 +          let
   1.196 +            val (_, feats, deps) = Graph.get_node access_G fact
   1.197 +
   1.198 +            fun add_feat feat (xtab as (n, tab, _)) =
   1.199 +              (case Symtab.lookup tab feat of
   1.200 +                SOME i => (i, xtab)
   1.201 +              | NONE => (n, add_to_xtab feat xtab))
   1.202 +
   1.203 +            val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
   1.204 +          in
   1.205 +            (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   1.206 +             add_to_xtab fact fact_xtab, feat_xtab')
   1.207 +          end)
   1.208 +        (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   1.209 +
   1.210 +    val facts = rev facts0
   1.211 +    val fact_ary = Array.fromList facts
   1.212 +
   1.213 +    val deps_ary = Array.fromList (rev depss0)
   1.214 +    val facts_ary = Array.array (num_feats, [])
   1.215 +    val _ =
   1.216 +      fold (fn feats => fn fact =>
   1.217 +          let val fact' = fact - 1 in
   1.218 +            (List.app (map_array_index facts_ary (cons fact')) feats; fact')
   1.219 +          end)
   1.220 +        featss (length featss)
   1.221 +  in
   1.222 +    (trace_msg ctxt (fn () =>
   1.223 +       "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.224 +        elide_string 1000 (space_implode " " facts) ^ "}");
   1.225 +     knn (Array.length deps_ary) (curry Array.sub deps_ary)
   1.226 +       (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
   1.227 +       (map_filter (fn (feat, weight) =>
   1.228 +          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   1.229 +     |> map ((fn i => Array.sub (fact_ary, i)) o fst))
   1.230 +  end
   1.231  
   1.232  end;
   1.233  
   1.234 @@ -574,7 +761,7 @@
   1.235  
   1.236  fun sort_of_type alg T =
   1.237    let
   1.238 -    val graph = Sorts.classes_of alg
   1.239 +    val G = Sorts.classes_of alg
   1.240  
   1.241      fun cls_of S [] = S
   1.242        | cls_of S (cl :: cls) =
   1.243 @@ -585,7 +772,7 @@
   1.244              cls_of S (union (op =) cls' cls)
   1.245            end
   1.246    in
   1.247 -    cls_of [] (Graph.maximals graph)
   1.248 +    cls_of [] (Graph.maximals G)
   1.249    end
   1.250  
   1.251  val generalize_goal = false
   1.252 @@ -975,13 +1162,6 @@
   1.253  fun add_const_counts t =
   1.254    fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
   1.255  
   1.256 -fun learn_of_graph graph =
   1.257 -  let
   1.258 -    fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
   1.259 -  in
   1.260 -    Graph.schedule sched graph
   1.261 -  end
   1.262 -
   1.263  fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   1.264    let
   1.265      val thy = Proof_Context.theory_of ctxt
   1.266 @@ -1025,9 +1205,10 @@
   1.267                |> map (nickname_of_thm o snd)
   1.268            in
   1.269              (access_G,
   1.270 -             (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
   1.271 -              else MaSh_Py.query ctxt overlord)
   1.272 -               max_facts ([], hints, parents, feats))
   1.273 +             if Config.get ctxt sml then
   1.274 +               MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
   1.275 +             else
   1.276 +               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
   1.277            end)
   1.278      val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   1.279    in
   1.280 @@ -1035,27 +1216,27 @@
   1.281      |> pairself (map fact_of_raw_fact)
   1.282    end
   1.283  
   1.284 -fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
   1.285 +fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
   1.286    let
   1.287 -    fun maybe_learn_from from (accum as (parents, graph)) =
   1.288 -      try_graph ctxt "updating graph" accum (fn () =>
   1.289 -        (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   1.290 -    val graph = graph |> Graph.default_node (name, (Isar_Proof, feats, deps))
   1.291 -    val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
   1.292 -    val (deps, _) = ([], graph) |> fold maybe_learn_from deps
   1.293 +    fun maybe_learn_from from (accum as (parents, G)) =
   1.294 +      try_graph ctxt "updating G" accum (fn () =>
   1.295 +        (from :: parents, Graph.add_edge_acyclic (from, name) G))
   1.296 +    val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
   1.297 +    val (parents, G) = ([], G) |> fold maybe_learn_from parents
   1.298 +    val (deps, _) = ([], G) |> fold maybe_learn_from deps
   1.299    in
   1.300 -    ((name, parents, feats, deps) :: learns, graph)
   1.301 +    ((name, parents, feats, deps) :: learns, G)
   1.302    end
   1.303  
   1.304 -fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
   1.305 +fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
   1.306    let
   1.307 -    fun maybe_relearn_from from (accum as (parents, graph)) =
   1.308 +    fun maybe_relearn_from from (accum as (parents, G)) =
   1.309        try_graph ctxt "updating graph" accum (fn () =>
   1.310 -        (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   1.311 -    val graph = graph |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
   1.312 -    val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
   1.313 +        (from :: parents, Graph.add_edge_acyclic (from, name) G))
   1.314 +    val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
   1.315 +    val (deps, _) = ([], G) |> fold maybe_relearn_from deps
   1.316    in
   1.317 -    ((name, deps) :: relearns, graph)
   1.318 +    ((name, deps) :: relearns, G)
   1.319    end
   1.320  
   1.321  fun flop_wrt_access_graph name =