src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57012 43fd82a537a3
parent 57011 a4428f517f46
child 57013 ed95456499e6
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 00:13:31 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 02:47:23 2014 +0200
@@ -108,6 +108,12 @@
 val relearn_isarN = "relearn_isar"
 val relearn_proverN = "relearn_prover"
 
+val learned_proof_prefix = ".."
+
+fun learned_proof_name () =
+  learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
+  serial_string ()
+
 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
 
@@ -297,8 +303,8 @@
         if i31 + 2 < l then
           let
             val x = Unsynchronized.ref i31;
-            val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
-            val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
+            val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
+            val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
           in
             !x
           end
@@ -312,7 +318,7 @@
         val j = maxson l i
       in
         if cmp (Array.sub (a, j), e) = GREATER then
-          let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
+          let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
         else Array.update (a, i, e)
       end
 
@@ -321,7 +327,7 @@
     fun bubbledown l i =
       let
         val j = maxson l i
-        val () = Array.update (a, i, Array.sub (a, j))
+        val _ = Array.update (a, i, Array.sub (a, j))
       in
         bubbledown l j
       end
@@ -334,7 +340,7 @@
       in
         if cmp (Array.sub (a, father), e) = LESS then
           let
-            val () = Array.update (a, i, Array.sub (a, father))
+            val _ = Array.update (a, i, Array.sub (a, father))
           in
             if father > 0 then trickleup father e else Array.update (a, 0, e)
           end
@@ -351,24 +357,24 @@
         for (i - 1)
       end
 
-    val () = for (((l + 1) div 3) - 1)
+    val _ = for (((l + 1) div 3) - 1)
 
     fun for2 i =
       if i < max 2 (l - bnd) then () else
       let
         val e = Array.sub (a, i)
-        val () = Array.update (a, i, Array.sub (a, 0))
-        val () = trickleup (bubble i 0) e
+        val _ = Array.update (a, i, Array.sub (a, 0))
+        val _ = trickleup (bubble i 0) e
       in
         for2 (i - 1)
       end
 
-    val () = for2 (l - 1)
+    val _ = for2 (l - 1)
   in
     if l > 1 then
       let
         val e = Array.sub (a, 1)
-        val () = Array.update (a, 1, Array.sub (a, 0))
+        val _ = Array.update (a, 1, Array.sub (a, 0))
       in
         Array.update (a, 0, e)
       end
@@ -386,46 +392,54 @@
 fun knn avail_no get_deps get_sym_ths knns advno syms =
   let
     (* Can be later used for TFIDF *)
-    fun sym_wght _ = 1.0;
-    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)));
+    fun sym_wght _ = 1.0
+
+    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)))
+
     fun inc_overlap j v =
       let
         val ov = snd (Array.sub (overlaps_sqr,j))
       in
         Array.update (overlaps_sqr, j, (j, v + ov))
-      end;
+      end
+
     fun do_sym (s, con_wght) =
       let
-        val sw = sym_wght s;
-        val w2 = sw * sw * con_wght;
+        val sw = sym_wght s
+        val w2 = sw * sw * con_wght
         fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
       in
         ignore (map do_th (get_sym_ths s))
-      end;
-    val () = ignore (map do_sym syms);
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
-    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)));
+      end
+
+    val _ = ignore (map do_sym syms)
+    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)))
+
     fun inc_recommend j v =
       let
         val ov = snd (Array.sub (recommends,j))
       in
         Array.update (recommends, j, (j, v + ov))
-      end;
+      end
+
     fun for k =
       if k = knns then () else
       if k >= avail_no then () else
       let
-        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1);
-        val o1 = Math.sqrt o2;
-        val () = inc_recommend j o1;
-        val ds = get_deps j;
-        val l = Real.fromInt (length ds);
+        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1)
+        val o1 = Math.sqrt o2
+        val _ = inc_recommend j o1
+        val ds = get_deps j
+        val l = Real.fromInt (length ds)
         val _ = map (fn d => inc_recommend d (o1 / l)) ds
       in
         for (k + 1)
-      end;
-    val () = for 0;
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+      end
+
+    val _ = for 0
+    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends
+
     fun ret acc at =
       if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   in
@@ -473,14 +487,15 @@
           end)
         featss (length featss)
   in
-    (trace_msg ctxt (fn () =>
-       "MaSh_SML query " ^ encode_features feats ^ " from {" ^
-        elide_string 1000 (space_implode " " facts) ^ "}");
-     knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
-       max_suggs
-       (map_filter (fn (feat, weight) =>
-          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
-     |> map ((fn i => Array.sub (fact_ary, i)) o fst))
+    trace_msg ctxt (fn () =>
+      "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+       elide_string 1000 (space_implode " " facts) ^ "}");
+    knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
+      max_suggs
+      (map_filter (fn (feat, weight) =>
+         Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
+    |> map ((fn i => Array.sub (fact_ary, i)) o fst)
+    |> filter_out (String.isPrefix learned_proof_prefix)
   end
 
 end;
@@ -502,10 +517,9 @@
   Graph.default_node (parent, (Isar_Proof, [], []))
   #> Graph.add_edge (parent, name)
 
-fun add_node kind name parents feats deps =
+fun add_node kind name feats deps =
   Graph.default_node (name, (kind, feats, deps))
   #> Graph.map_node name (K (kind, feats, deps))
-  #> fold (add_edge_to name) parents;
 
 fun try_graph ctxt when def f =
   f ()
@@ -526,7 +540,6 @@
 fun graph_info G =
   string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^ " edge(s), " ^
-  string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   string_of_int (length (Graph.maximals G)) ^ " maximal"
 
 type mash_state =
@@ -562,7 +575,9 @@
            fun extract_line_and_add_node line =
              (case extract_node line of
                NONE => I (* should not happen *)
-             | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
+             | SOME (kind, name, parents, feats, deps) =>
+               add_node kind name feats deps
+               #> fold (add_edge_to name) parents)
 
            val (access_G, num_known_facts) =
              (case string_ord (version', version) of
@@ -1095,40 +1110,36 @@
   let
     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
 
-    fun insert_new seen name =
-      not (Symtab.defined seen name) ? insert (op =) name
+    fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
 
     fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
 
     fun find_maxes _ (maxs, []) = map snd maxs
       | find_maxes seen (maxs, new :: news) =
-        find_maxes
-            (seen |> num_keys (Graph.imm_succs G new) > 1
-                     ? Symtab.default (new, ()))
-            (if Symtab.defined tab new then
-               let
-                 val newp = Graph.all_preds G [new]
-                 fun is_ancestor x yp = member (op =) yp x
-                 val maxs =
-                   maxs |> filter (fn (_, max) => not (is_ancestor max newp))
-               in
-                 if exists (is_ancestor new o fst) maxs then
-                   (maxs, news)
-                 else
-                   ((newp, new)
-                    :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
-                    news)
-               end
-             else
-               (maxs, Graph.Keys.fold (insert_new seen)
-                                      (Graph.imm_preds G new) news))
+        find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
+          (if Symtab.defined tab new then
+             let
+               val newp = Graph.all_preds G [new]
+               fun is_ancestor x yp = member (op =) yp x
+               val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
+             in
+               if exists (is_ancestor new o fst) maxs then (maxs, news)
+               else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
+             end
+           else
+             (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
   in
     find_maxes Symtab.empty ([], Graph.maximals G)
   end
 
-fun maximal_wrt_access_graph access_G =
-  map (nickname_of_thm o snd)
-  #> maximal_wrt_graph access_G
+fun graph_islands G =
+  Graph.fold (fn (m, (_, (preds, succs))) =>
+    (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G [];
+
+(* islands represent learned proofs associated with no facts *)
+fun maximal_wrt_access_graph access_G facts =
+  map (nickname_of_thm o snd) facts @ graph_islands access_G
+  |> maximal_wrt_graph access_G
 
 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
 
@@ -1264,9 +1275,6 @@
     Async_Manager.thread MaShN birth_time death_time desc task
   end
 
-fun fresh_enough_name () =
-  Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
-
 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
   if is_mash_enabled () then
     launch_thread timeout (fn () =>
@@ -1276,18 +1284,20 @@
       in
         map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
           let
-            val name = fresh_enough_name ()
-            val parents = maximal_wrt_access_graph access_G facts
+            val name = learned_proof_name ()
             val deps = used_ths
               |> filter (is_fact_in_graph access_G)
               |> map nickname_of_thm
           in
             if Config.get ctxt sml then
-              {access_G = add_node Automatic_Proof name parents feats deps access_G,
-               num_known_facts = num_known_facts + 1,
-               dirty = Option.map (cons name) dirty}
+              let val access_G = access_G |> add_node Automatic_Proof name feats deps in
+                {access_G = access_G, num_known_facts = num_known_facts + 1,
+                 dirty = Option.map (cons name) dirty}
+              end
             else
-              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
+              let val parents = maximal_wrt_access_graph access_G facts in
+                (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
+              end
           end);
         (true, "")
       end)