tuning
authorblanchet
Tue, 20 May 2014 16:11:37 +0200
changeset 57017 afdf75c0de58
parent 57016 c44ce6f4067d
child 57018 142950e9c7e2
tuning
src/HOL/Tools/ATP/atp_proof.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:00:00 2014 +0200
+++ b/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:11:37 2014 +0200
@@ -183,8 +183,7 @@
   |> find_first (fn (_, pattern) => String.isSubstring pattern output)
   |> Option.map fst
 
-fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures
-                                       output =
+fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures output =
   (case (extract_tstplike_proof proof_delims output,
       extract_known_atp_failure known_failures output) of
     (_, SOME ProofIncomplete) => ("", NONE)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:00:00 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:11:37 2014 +0200
@@ -284,22 +284,16 @@
 structure MaSh_SML =
 struct
 
-fun max a b = if a > b then a else b
-
 exception BOTTOM of int
 
 fun heap cmp bnd a =
   let
     fun maxson l i =
-      let
-        val i31 = i + i + i + 1
-      in
+      let val i31 = i + i + i + 1 in
         if i31 + 2 < l then
-          let
-            val x = Unsynchronized.ref i31;
-            val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
-            val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
-          in
+          let val x = Unsynchronized.ref i31 in
+            if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
+            if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
             !x
           end
         else
@@ -354,7 +348,7 @@
     val _ = for (((l + 1) div 3) - 1)
 
     fun for2 i =
-      if i < max 2 (l - bnd) then () else
+      if i < Integer.max 2 (l - bnd) then () else
       let
         val e = Array.sub (a, i)
         val _ = Array.update (a, i, Array.sub (a, 0))
@@ -387,51 +381,57 @@
 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   let
     (* Can be later used for TFIDF *)
-    fun sym_wght _ = 1.0;
-    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
+    fun sym_wght _ = 1.0
+
+    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+
     fun inc_overlap j v =
       let
-        val ov = snd (Array.sub (overlaps_sqr,j))
+        val ov = snd (Array.sub (overlaps_sqr, j))
       in
         Array.update (overlaps_sqr, j, (j, v + ov))
-      end;
+      end
+
     fun do_sym (s, con_wght) =
       let
-        val sw = sym_wght s;
-        val w2 = sw * sw * con_wght;
+        val sw = sym_wght s
+        val w2 = sw * sw * con_wght
+
         fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
       in
-        ignore (map do_th (get_sym_ths s))
-      end;
-    val () = ignore (map do_sym syms);
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
-    val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
+        List.app do_th (get_sym_ths s)
+      end
+
+    val _ = List.app do_sym syms
+    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+    val recommends = Array.tabulate (adv_max, rpair 0.0)
+
     fun inc_recommend j v =
-      if j >= adv_max then () else
-      let
-        val ov = snd (Array.sub (recommends,j))
-      in
-        Array.update (recommends, j, (j, v + ov))
-      end;
+      if j >= adv_max then ()
+      else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
+
     fun for k =
-      if k = knns then () else
-      if k >= adv_max then () else
-      let
-        val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
-        val o1 = Math.sqrt o2;
-        val () = inc_recommend j o1;
-        val ds = get_deps j;
-        val l = Real.fromInt (length ds);
-        val _ = map (fn d => inc_recommend d (o1 / l)) ds
-      in
-        for (k + 1)
-      end;
-    val () = for 0;
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+      if k = knns orelse k >= adv_max then
+        ()
+      else
+        let
+          val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+          val o1 = Math.sqrt o2
+          val _ = inc_recommend j o1
+          val ds = get_deps j
+          val l = Real.fromInt (length ds)
+          val _ = map (fn d => inc_recommend d (o1 / l)) ds
+        in
+          for (k + 1)
+        end
+
+    val _ = for 0
+    val _ = heap (Real.compare o pairself snd) advno recommends
+
     fun ret acc at =
-      if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+      if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
-    ret [] (max 0 (adv_max - advno))
+    ret [] (Integer.max 0 (adv_max - advno))
   end
 
 val knns = 40 (* FUDGE *)
@@ -440,7 +440,7 @@
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
-fun learn_and_query ctxt parents access_G max_suggs hints feats =
+fun query ctxt parents access_G max_suggs hints feats =
   let
     val str_of_feat = space_implode "|"
 
@@ -469,9 +469,9 @@
         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
 
     val facts = rev rev_facts
-    val fact_ary = Array.fromList facts
+    val fact_vec = Vector.fromList facts
 
-    val deps_ary = Array.fromList (rev rev_depss)
+    val deps_vec = Vector.fromList (rev rev_depss)
     val facts_ary = Array.array (num_feats, [])
     val _ =
       fold (fn feats => fn fact =>
@@ -484,11 +484,11 @@
   in
     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
       elide_string 1000 (space_implode " " facts) ^ "}");
-    knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
+    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
       (curry Array.sub facts_ary) knns max_suggs
       (map_filter (fn (feat, weight) =>
          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
-    |> map ((fn i => Array.sub (fact_ary, i)) o fst)
+    |> map (curry Vector.sub fact_vec o fst)
   end
 
 end;
@@ -625,7 +625,7 @@
   Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
 
 fun clear_state ctxt overlord =
-  (* "unlearn" also removes the state file *)
+  (* "MaSh_Py.unlearn" also removes the state file *)
   Synchronized.change global_state (fn _ =>
     (if Config.get ctxt sml then wipe_out_mash_state_dir ()
      else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
@@ -971,9 +971,6 @@
     | NONE => false)
   | is_size_def _ _ = false
 
-fun no_dependencies_for_status status =
-  status = Non_Rec_Def orelse status = Rec_Def
-
 fun trim_dependencies deps =
   if length deps > max_dependencies then NONE else SOME deps
 
@@ -1022,18 +1019,17 @@
       val num_isar_deps = length isar_deps
     in
       if verbose andalso auto_level = 0 then
-        "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
-        " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
-        |> Output.urgent_message
+        Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
+          string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
+          " facts.")
       else
         ();
       (case run_prover_for_mash ctxt params prover name facts goal of
         {outcome = NONE, used_facts, ...} =>
         (if verbose andalso auto_level = 0 then
            let val num_facts = length used_facts in
-             "Found proof with " ^ string_of_int num_facts ^ " fact" ^
-             plural_s num_facts ^ "."
-             |> Output.urgent_message
+             Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
+               plural_s num_facts ^ ".")
            end
          else
            ();
@@ -1187,40 +1183,57 @@
       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
       |> map (apsnd (fn r => weight * factor * r))
 
-    val (access_G, suggs) =
+    fun query_args access_G =
+      let
+        val parents = maximal_wrt_access_graph access_G facts
+        val hints = chained
+          |> filter (is_fact_in_graph access_G o snd)
+          |> map (nickname_of_thm o snd)
+
+        val goal_feats =
+          features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+        val chained_feats = chained
+          |> map (rpair 1.0)
+          |> map (chained_or_extra_features_of chained_feature_factor)
+          |> rpair [] |-> fold (union (eq_fst (op =)))
+        val extra_feats = facts
+          |> take (Int.max (0, num_extra_feature_facts - length chained))
+          |> filter fact_has_right_theory
+          |> weight_facts_steeply
+          |> map (chained_or_extra_features_of extra_feature_factor)
+          |> rpair [] |-> fold (union (eq_fst (op =)))
+        val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+          |> debug ? sort (Real.compare o swap o pairself snd)
+      in
+        (parents, hints, feats)
+      end
+
+    val sml = Config.get ctxt sml
+
+    val (access_G, py_suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
         if Graph.is_empty access_G then
           (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
         else
-          let
-            val parents = maximal_wrt_access_graph access_G facts
-            val goal_feats =
-              features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
-            val chained_feats = chained
-              |> map (rpair 1.0)
-              |> map (chained_or_extra_features_of chained_feature_factor)
-              |> rpair [] |-> fold (union (eq_fst (op =)))
-            val extra_feats = facts
-              |> take (Int.max (0, num_extra_feature_facts - length chained))
-              |> filter fact_has_right_theory
-              |> weight_facts_steeply
-              |> map (chained_or_extra_features_of extra_feature_factor)
-              |> rpair [] |-> fold (union (eq_fst (op =)))
-            val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
-              |> debug ? sort (Real.compare o swap o pairself snd)
-            val hints = chained
-              |> filter (is_fact_in_graph access_G o snd)
-              |> map (nickname_of_thm o snd)
-          in
-            (access_G,
-             if Config.get ctxt sml then
-               MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
-             else
-               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
-          end)
+          (access_G,
+           if sml then
+             []
+           else
+             let val (parents, hints, feats) = query_args access_G in
+               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+             end))
+
+    val sml_suggs =
+      if sml then
+        let val (parents, hints, feats) = query_args access_G in
+          MaSh_SML.query ctxt parents access_G max_facts hints feats
+        end
+      else
+        []
+
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
-    find_mash_suggestions ctxt max_facts suggs facts chained unknown
+    find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
     |> pairself (map fact_of_raw_fact)
   end
 
@@ -1323,7 +1336,7 @@
         val name_tabs = build_name_tables nickname_of_thm facts
 
         fun deps_of status th =
-          if no_dependencies_for_status status then
+          if status = Non_Rec_Def orelse status = Rec_Def then
             SOME []
           else if run_prover then
             prover_dependencies_of ctxt params prover auto_level facts name_tabs th
@@ -1355,18 +1368,13 @@
             end
 
         fun commit last learns relearns flops =
-          (if debug andalso auto_level = 0 then
-             Output.urgent_message "Committing..."
-           else
-             ();
+          (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
            map_state ctxt overlord (do_commit (rev learns) relearns flops);
            if not last andalso auto_level = 0 then
              let val num_proofs = length learns + length relearns in
-               "Learned " ^ string_of_int num_proofs ^ " " ^
-               (if run_prover then "automatic" else "Isar") ^ " proof" ^
-               plural_s num_proofs ^ " in the last " ^
-               string_of_time commit_timeout ^ "."
-               |> Output.urgent_message
+               Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
+                 (if run_prover then "automatic" else "Isar") ^ " proof" ^
+                 plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
              end
            else
              ())
@@ -1478,14 +1486,12 @@
       |> Output.urgent_message
   in
     if run_prover then
-      ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
-       " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
-       ").\n\nCollecting Isar proofs first..."
-       |> Output.urgent_message;
+      (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+         plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
+         string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
        learn 1 false;
-       "Now collecting automatic proofs. This may take several hours. You can safely stop the \
-       \learning process at any point."
-       |> Output.urgent_message;
+       Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
+         \can safely stop the learning process at any point.";
        learn 0 true)
     else
       (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^