added SML implementation of MaSh
authorblanchet
Mon, 19 May 2014 23:43:53 +0200
changeset 57009 8cb6a5f1ae84
parent 57008 10f68b83b474
child 57010 121b63d7bcdb
added SML implementation of MaSh
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
@@ -1,5 +1,6 @@
 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     Author:     Jasmin Blanchette, TU Muenchen
+    Author:     Cezary Kaliszyk, University of Innsbruck
 
 Sledgehammer's machine-learning-based relevance filter (MaSh).
 *)
@@ -253,9 +254,7 @@
 fun learn _ _ _ [] = ()
   | learn ctxt overlord save learns =
     (trace_msg ctxt (fn () =>
-       let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
-         "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
-       end);
+       "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}");
      run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
        (K ()))
 
@@ -280,15 +279,203 @@
 structure MaSh_SML =
 struct
 
-fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
-    max_suggs (query as (_, _, _, feats)) =
-  (trace_msg ctxt (fn () =>
-     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
-       "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
-       "MaSh_SML query " ^ encode_features feats
-     end);
-   (* Implementation missing *)
-   [])
+fun max a b = if a > b then a else b
+
+exception BOTTOM of int
+
+fun heap cmp bnd a =
+  let
+    fun maxson l i =
+      let
+        val i31 = i + i + i + 1
+      in
+        if i31 + 2 < l then
+          let
+            val x = Unsynchronized.ref i31;
+            val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
+            val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
+          in
+            !x
+          end
+        else
+          if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS
+          then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
+      end
+
+    fun trickledown l i e =
+      let
+        val j = maxson l i
+      in
+        if cmp (Array.sub (a, j), e) = GREATER then
+          let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
+        else Array.update (a, i, e)
+      end
+
+    fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
+
+    fun bubbledown l i =
+      let
+        val j = maxson l i
+        val () = Array.update (a, i, Array.sub (a, j))
+      in
+        bubbledown l j
+      end
+
+    fun bubble l i = bubbledown l i handle BOTTOM i => i
+
+    fun trickleup i e =
+      let
+        val father = (i - 1) div 3
+      in
+        if cmp (Array.sub (a, father), e) = LESS then
+          let
+            val () = Array.update (a, i, Array.sub (a, father))
+          in
+            if father > 0 then trickleup father e else Array.update (a, 0, e)
+          end
+        else Array.update (a, i, e)
+      end
+
+    val l = Array.length a
+
+    fun for i =
+      if i < 0 then () else
+      let
+        val _ = trickle l i (Array.sub (a, i))
+      in
+        for (i - 1)
+      end
+
+    val () = for (((l + 1) div 3) - 1)
+
+    fun for2 i =
+      if i < max 2 (l - bnd) then () else
+      let
+        val e = Array.sub (a, i)
+        val () = Array.update (a, i, Array.sub (a, 0))
+        val () = trickleup (bubble i 0) e
+      in
+        for2 (i - 1)
+      end
+
+    val () = for2 (l - 1)
+  in
+    if l > 1 then
+      let
+        val e = Array.sub (a, 1)
+        val () = Array.update (a, 1, Array.sub (a, 0))
+      in
+        Array.update (a, 0, e)
+      end
+    else ()
+  end
+
+(*
+  avail_no = maximum number of theorems to check dependencies and symbols
+  get_deps = returns dependencies of a theorem
+  get_sym_ths = get theorems that have this feature
+  knns    = number of nearest neighbours
+  advno   = number of predictions to return
+  syms    = symbols of the conjecture
+*)
+fun knn avail_no get_deps get_sym_ths knns advno syms =
+  let
+    (* Can be later used for TFIDF *)
+    fun sym_wght _ = 1.0;
+    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)));
+    fun inc_overlap j v =
+      let
+        val ov = snd (Array.sub (overlaps_sqr,j))
+      in
+        Array.update (overlaps_sqr, j, (j, v + ov))
+      end;
+    fun do_sym (s, con_wght) =
+      let
+        val sw = sym_wght s;
+        val w2 = sw * sw * con_wght;
+        fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
+      in
+        ignore (map do_th (get_sym_ths s))
+      end;
+    val () = ignore (map do_sym syms);
+    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
+    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)));
+    fun inc_recommend j v =
+      let
+        val ov = snd (Array.sub (recommends,j))
+      in
+        Array.update (recommends, j, (j, v + ov))
+      end;
+    fun for k =
+      if k = knns then () else
+      if k >= avail_no then () else
+      let
+        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1);
+        val o1 = Math.sqrt o2;
+        val () = inc_recommend j o1;
+        val ds = get_deps j;
+        val l = Real.fromInt (length ds);
+        val _ = map (fn d => inc_recommend d (o1 / l)) ds
+      in
+        for (k + 1)
+      end;
+    val () = for 0;
+    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+    fun ret acc at =
+      if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+  in
+    ret [] (max 0 (avail_no - advno))
+  end
+
+val knns = 40 (* FUDGE *)
+
+fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
+
+fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
+
+(* TODO: take weight components of "feats" into consideration *)
+fun learn_and_query ctxt parents access_G max_suggs hints feats =
+  let
+    val str_of_feat = space_implode "|"
+
+    val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
+      fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+          let
+            val (_, feats, deps) = Graph.get_node access_G fact
+
+            fun add_feat feat (xtab as (n, tab, _)) =
+              (case Symtab.lookup tab feat of
+                SOME i => (i, xtab)
+              | NONE => (n, add_to_xtab feat xtab))
+
+            val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
+          in
+            (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
+             add_to_xtab fact fact_xtab, feat_xtab')
+          end)
+        (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+
+    val facts = rev facts0
+    val fact_ary = Array.fromList facts
+
+    val deps_ary = Array.fromList (rev depss0)
+    val facts_ary = Array.array (num_feats, [])
+    val _ =
+      fold (fn feats => fn fact =>
+          let val fact' = fact - 1 in
+            (List.app (map_array_index facts_ary (cons fact')) feats; fact')
+          end)
+        featss (length featss)
+  in
+    (trace_msg ctxt (fn () =>
+       "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+        elide_string 1000 (space_implode " " facts) ^ "}");
+     knn (Array.length deps_ary) (curry Array.sub deps_ary)
+       (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
+       (map_filter (fn (feat, weight) =>
+          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
+     |> map ((fn i => Array.sub (fact_ary, i)) o fst))
+  end
 
 end;
 
@@ -574,7 +761,7 @@
 
 fun sort_of_type alg T =
   let
-    val graph = Sorts.classes_of alg
+    val G = Sorts.classes_of alg
 
     fun cls_of S [] = S
       | cls_of S (cl :: cls) =
@@ -585,7 +772,7 @@
             cls_of S (union (op =) cls' cls)
           end
   in
-    cls_of [] (Graph.maximals graph)
+    cls_of [] (Graph.maximals G)
   end
 
 val generalize_goal = false
@@ -975,13 +1162,6 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
-fun learn_of_graph graph =
-  let
-    fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
-  in
-    Graph.schedule sched graph
-  end
-
 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
@@ -1025,9 +1205,10 @@
               |> map (nickname_of_thm o snd)
           in
             (access_G,
-             (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
-              else MaSh_Py.query ctxt overlord)
-               max_facts ([], hints, parents, feats))
+             if Config.get ctxt sml then
+               MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
+             else
+               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
           end)
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
@@ -1035,27 +1216,27 @@
     |> pairself (map fact_of_raw_fact)
   end
 
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
   let
-    fun maybe_learn_from from (accum as (parents, graph)) =
-      try_graph ctxt "updating graph" accum (fn () =>
-        (from :: parents, Graph.add_edge_acyclic (from, name) graph))
-    val graph = graph |> Graph.default_node (name, (Isar_Proof, feats, deps))
-    val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
-    val (deps, _) = ([], graph) |> fold maybe_learn_from deps
+    fun maybe_learn_from from (accum as (parents, G)) =
+      try_graph ctxt "updating G" accum (fn () =>
+        (from :: parents, Graph.add_edge_acyclic (from, name) G))
+    val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
+    val (parents, G) = ([], G) |> fold maybe_learn_from parents
+    val (deps, _) = ([], G) |> fold maybe_learn_from deps
   in
-    ((name, parents, feats, deps) :: learns, graph)
+    ((name, parents, feats, deps) :: learns, G)
   end
 
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
   let
-    fun maybe_relearn_from from (accum as (parents, graph)) =
+    fun maybe_relearn_from from (accum as (parents, G)) =
       try_graph ctxt "updating graph" accum (fn () =>
-        (from :: parents, Graph.add_edge_acyclic (from, name) graph))
-    val graph = graph |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
-    val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
+        (from :: parents, Graph.add_edge_acyclic (from, name) G))
+    val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
+    val (deps, _) = ([], G) |> fold maybe_relearn_from deps
   in
-    ((name, deps) :: relearns, graph)
+    ((name, deps) :: relearns, G)
   end
 
 fun flop_wrt_access_graph name =