speed up MaSh a bit
authorblanchet
Tue, 01 Jul 2014 16:47:10 +0200
changeset 57460 9cc802a8ab06
parent 57459 22023ab4df3c
child 57461 29efe682335b
speed up MaSh a bit
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
@@ -187,7 +187,7 @@
       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
     in
       fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
-      |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
+      |> map (`weight_of) |> sort (int_ord o pairself fst o swap)
       |> map snd |> take max_facts
     end
 
@@ -197,16 +197,12 @@
 fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1)
 fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
 
-
-(*** Isabelle-agnostic machine learning ***)
-
-structure MaSh =
-struct
-
-fun heap cmp bnd al a =
+fun rev_sort_array_prefix cmp bnd a =
   let
     exception BOTTOM of int
 
+    val al = Array.length a
+
     fun maxson l i =
       let val i31 = i + i + i + 1 in
         if i31 + 2 < l then
@@ -270,6 +266,18 @@
       ()
   end
 
+fun rev_sort_list_prefix cmp bnd xs =
+  let val ary = Array.fromList xs in
+    rev_sort_array_prefix cmp bnd ary;
+    Array.foldr (op ::) [] ary
+  end
+
+
+(*** Isabelle-agnostic machine learning ***)
+
+structure MaSh =
+struct
+
 fun select_visible_facts big_number recommends =
   List.app (fn at =>
     let val (j, ov) = Array.sub (recommends, at) in
@@ -354,7 +362,7 @@
       if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
   in
     select_visible_facts 100000.0 posterior visible_facts;
-    heap (Real.compare o pairself snd) max_suggs num_facts posterior;
+    rev_sort_array_prefix (Real.compare o pairself snd) max_suggs posterior;
     ret (Integer.max 0 (num_facts - max_suggs)) []
   end
 
@@ -387,7 +395,7 @@
       end
 
     val _ = List.app do_feat goal_feats
-    val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
+    val _ = rev_sort_array_prefix (Real.compare o pairself snd) num_facts overlaps_sqr
     val no_recommends = Unsynchronized.ref 0
     val recommends = Array.tabulate (num_facts, rpair 0.0)
     val age = Unsynchronized.ref 500000000.0
@@ -432,7 +440,7 @@
     while1 ();
     while2 ();
     select_visible_facts 1000000000.0 recommends visible_facts;
-    heap (Real.compare o pairself snd) max_suggs num_facts recommends;
+    rev_sort_array_prefix (Real.compare o pairself snd) max_suggs recommends;
     ret [] (Integer.max 0 (num_facts - max_suggs))
   end
 
@@ -1110,9 +1118,13 @@
     find_maxes Symtab.empty ([], Graph.maximals G)
   end
 
-fun maximal_wrt_access_graph access_G facts =
-  map (nickname_of_thm o snd) facts
-  |> maximal_wrt_graph access_G
+fun maximal_wrt_access_graph _ [] = []
+  | maximal_wrt_access_graph access_G ((fact as (_, th)) :: facts) =
+    let val thy = theory_of_thm th in
+      fact :: filter_out (fn (_, th') => Theory.subthy (theory_of_thm th', thy)) facts
+      |> map (nickname_of_thm o snd)
+      |> maximal_wrt_graph access_G
+    end
 
 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
 
@@ -1144,8 +1156,11 @@
     val thy_name = Context.theory_name thy
     val engine = the_mash_engine ()
 
-    val facts = facts |> sort (crude_thm_ord o pairself snd o swap)
-    val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
+    val facts = facts
+      |> rev_sort_list_prefix (crude_thm_ord o pairself snd)
+        (Int.max (num_extra_feature_facts, max_proximity_facts))
+
+    val chained = filter (fn ((_, (scope, _)), _) => scope = Chained) facts
 
     fun fact_has_right_theory (_, th) =
       thy_name = Context.theory_name (theory_of_thm th)
@@ -1155,53 +1170,44 @@
       |> features_of ctxt (theory_of_thm th) stature
       |> map (rpair (weight * factor))
 
-    fun query_args access_G =
-      let
-        val parents = maximal_wrt_access_graph access_G facts
-
-        val goal_feats = features_of ctxt thy (Local, General) (concl_t :: hyp_ts)
-        val chained_feats = chained
-          |> map (rpair 1.0)
-          |> map (chained_or_extra_features_of chained_feature_factor)
-          |> rpair [] |-> fold (union (eq_fst (op =)))
-        val extra_feats =
-          facts
-          |> take (Int.max (0, num_extra_feature_facts - length chained))
-          |> filter fact_has_right_theory
-          |> weight_facts_steeply
-          |> map (chained_or_extra_features_of extra_feature_factor)
-          |> rpair [] |-> fold (union (eq_fst (op =)))
-        val feats =
-          fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats)
-          |> debug ? sort (Real.compare o swap o pairself snd)
-      in
-        (parents, feats)
-      end
-
     val {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} =
       peek_state ctxt
 
+    val goal_feats0 = features_of ctxt thy (Local, General) (concl_t :: hyp_ts)
+    val chained_feats = chained
+      |> map (rpair 1.0)
+      |> map (chained_or_extra_features_of chained_feature_factor)
+      |> rpair [] |-> fold (union (eq_fst (op =)))
+    val extra_feats = facts
+      |> take (Int.max (0, num_extra_feature_facts - length chained))
+      |> filter fact_has_right_theory
+      |> weight_facts_steeply
+      |> map (chained_or_extra_features_of extra_feature_factor)
+      |> rpair [] |-> fold (union (eq_fst (op =)))
+
+    val goal_feats =
+      fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
+      |> debug ? sort (Real.compare o swap o pairself snd)
+
+    val parents = maximal_wrt_access_graph access_G facts
+    val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
+
     val suggs =
-      let
-        val (parents, goal_feats) = query_args access_G
-        val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
-      in
-        if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then
-          let
-            val learns =
-              Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
-          in
-            MaSh.query_external ctxt engine max_suggs learns goal_feats
-          end
-        else
-          let
-            val int_goal_feats =
-              map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
-          in
-            MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs
-              goal_feats int_goal_feats
-          end
-      end
+      if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then
+        let
+          val learns =
+            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+        in
+          MaSh.query_external ctxt engine max_suggs learns goal_feats
+        end
+      else
+        let
+          val int_goal_feats =
+            map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
+        in
+          MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs
+            goal_feats int_goal_feats
+        end
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
@@ -1264,6 +1270,7 @@
       let
         val thy = Proof_Context.theory_of ctxt
         val feats = features_of ctxt thy (Local, General) [t]
+        val facts = rev_sort_list_prefix (crude_thm_ord o pairself snd) 1 facts
       in
         map_state ctxt
           (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Tue Jul 01 16:47:10 2014 +0200
@@ -383,7 +383,7 @@
                     val atp_proof =
                       atp_proof
                       |> termify_atp_proof ctxt name format type_enc pool lifted sym_tab
-                      |> introduce_spass_skolem
+                      |> spass ? introduce_spass_skolem
                       |> factify_atp_proof (map fst used_from) hyp_ts concl_t
                   in
                     (verbose, (metis_type_enc, metis_lam_trans), preplay_timeout, compress, try0,