get rid of visibility in MaSh -- it slows it down more than it helps
authorblanchet
Mon, 21 Jan 2019 22:29:41 +0100
changeset 69706 6d6235b828fc
parent 69705 c9ea1e9916fb
child 69707 920fe0a2fd22
get rid of visibility in MaSh -- it slows it down more than it helps
src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Mon Jan 21 21:28:16 2019 +0100
+++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Mon Jan 21 22:29:41 2019 +0100
@@ -308,9 +308,9 @@
     | NONE => default_prover_name ())
   end
 
-fun get_prover ctxt name params goal all_facts =
+fun get_prover ctxt name params goal =
   let
-    val learn = Sledgehammer_MaSh.mash_learn_proof ctxt params (Thm.prop_of goal) all_facts
+    val learn = Sledgehammer_MaSh.mash_learn_proof ctxt params (Thm.prop_of goal)
   in
     Sledgehammer_Prover_Minimize.get_minimizing_prover ctxt Sledgehammer_Prover.Normal learn name
   end
@@ -429,7 +429,7 @@
                      "Line " ^ str0 (Position.line_of pos) ^ ": " ^
                      Sledgehammer.string_of_factss factss
                      |> writeln)
-        val prover = get_prover ctxt prover_name params goal facts
+        val prover = get_prover ctxt prover_name params goal
         val problem =
           {comment = "", state = st', goal = goal, subgoal = i,
            subgoal_count = Sledgehammer_Util.subgoal_count st, factss = factss, found_proof = I}
--- a/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Mon Jan 21 21:28:16 2019 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Mon Jan 21 22:29:41 2019 +0100
@@ -302,7 +302,7 @@
             val problem =
               {comment = "", state = state, goal = goal, subgoal = i, subgoal_count = n,
                factss = factss, found_proof = found_proof}
-            val learn = mash_learn_proof ctxt params (Thm.prop_of goal) all_facts
+            val learn = mash_learn_proof ctxt params (Thm.prop_of goal)
             val launch = launch_prover params mode writeln_result only learn
           in
             if mode = Auto_Try then
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Jan 21 21:28:16 2019 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Jan 21 22:29:41 2019 +0100
@@ -70,7 +70,7 @@
     raw_fact list -> fact list * fact list
 
   val mash_unlearn : Proof.context -> unit
-  val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
+  val mash_learn_proof : Proof.context -> params -> term -> thm list -> unit
   val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
     raw_fact list -> string
   val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
@@ -291,7 +291,7 @@
 structure MaSh =
 struct
 
-fun select_visible_facts (big_number : real) recommends =
+fun select_fact_idxs (big_number : real) recommends =
   List.app (fn at =>
     let val (j, ov) = Array.sub (recommends, at) in
       Array.update (recommends, at, (j, big_number + ov))
@@ -337,7 +337,7 @@
     (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
   end
 
-fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs visible_facts goal_feats =
+fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs fact_idxs goal_feats =
   let
     val tau = 0.2 (* FUDGE *)
     val pos_weight = 5.0 (* FUDGE *)
@@ -375,14 +375,14 @@
     fun ret at acc =
       if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
   in
-    select_visible_facts 100000.0 posterior visible_facts;
+    select_fact_idxs 100000.0 posterior fact_idxs;
     sort_array_suffix (Real.compare o apply2 snd) max_suggs posterior;
     ret (Integer.max 0 (num_facts - max_suggs)) []
   end
 
 val initial_k = 0
 
-fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts goal_feats =
+fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs goal_feats =
   let
     exception EXIT of unit
 
@@ -451,7 +451,7 @@
   in
     while1 ();
     while2 ();
-    select_visible_facts 1000000000.0 recommends visible_facts;
+    select_fact_idxs 1000000000.0 recommends fact_idxs;
     sort_array_suffix (Real.compare o apply2 snd) max_suggs recommends;
     ret [] (Integer.max 0 (num_facts - max_suggs))
   end
@@ -502,14 +502,13 @@
    | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats))
 
 fun query_internal ctxt algorithm num_facts num_feats (fact_names, featss, depss)
-    (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
+    (freqs as (_, _, dffreq)) fact_idxs max_suggs goal_feats int_goal_feats =
   let
     fun nb () =
-      naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+      naive_bayes freqs num_facts max_suggs fact_idxs int_goal_feats
       |> map fst
     fun knn () =
-      k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts
-        int_goal_feats
+      k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs int_goal_feats
       |> map fst
   in
     (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
@@ -652,7 +651,7 @@
 
 local
 
-val version = "*** MaSh version 20161123 ***"
+val version = "*** MaSh version 20190121 ***"
 
 exception FILE_VERSION_TOO_NEW of unit
 
@@ -904,7 +903,7 @@
           #> union (op =)) S
 
     fun pattify_type 0 _ = []
-      | pattify_type depth (Type (s, [])) = if member (op =) bad_types s then [] else [s]
+      | pattify_type _ (Type (s, [])) = if member (op =) bad_types s then [] else [s]
       | pattify_type depth (Type (s, U :: Ts)) =
         let
           val T = Type (s, Ts)
@@ -930,7 +929,7 @@
       | add_subtypes T = add_type T
 
     fun pattify_term _ 0 _ = []
-      | pattify_term _ depth (Const (s, _)) =
+      | pattify_term _ _ (Const (s, _)) =
         if is_widely_irrelevant_const s then [] else [s]
       | pattify_term _ _ (Free (s, T)) =
         maybe_singleton_str (crude_str_of_typ T)
@@ -1136,64 +1135,6 @@
       |> drop (length old_facts)
     end
 
-fun maximal_wrt_graph _ [] = []
-  | maximal_wrt_graph G keys =
-    if can (Graph.get_node G o the_single) keys then
-      keys
-    else
-      let
-        val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
-
-        fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
-
-        fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
-
-        fun find_maxes _ (maxs, []) = map snd maxs
-          | find_maxes seen (maxs, new :: news) =
-            find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
-              (if Symtab.defined tab new then
-                 let
-                   val newp = Graph.all_preds G [new]
-                   fun is_ancestor x yp = member (op =) yp x
-                   val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
-                 in
-                   if exists (is_ancestor new o fst) maxs then (maxs, news)
-                   else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
-                 end
-               else
-                 (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
-      in
-        find_maxes Symtab.empty ([],
-          G |> Graph.restrict (not o String.isPrefix anonymous_proof_prefix) |> Graph.maximals)
-      end
-
-val max_facts_for_shuffle_cleanup = 20
-
-fun maximal_wrt_access_graph _ [] = []
-  | maximal_wrt_access_graph access_G (fact :: facts) =
-    let
-      fun cleanup_wrt (_, th) =
-        let val thy_id = Thm.theory_id th in
-          filter_out (fn (_, th') =>
-            Context.proper_subthy_id (Thm.theory_id th', thy_id))
-        end
-
-      fun shuffle_cleanup accum [] = accum
-        | shuffle_cleanup accum (fact :: facts) =
-          let
-            val accum' = accum |> cleanup_wrt fact
-            val facts' = facts |> cleanup_wrt fact
-          in
-            shuffle_cleanup accum' facts'
-          end
-    in
-      fact :: cleanup_wrt fact facts
-      |> (fn facts => facts
-        |> length facts <= max_facts_for_shuffle_cleanup ? shuffle_cleanup [])
-      |> map (nickname_of_thm o snd)
-      |> maximal_wrt_graph access_G
-    end
-
 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
 
 val chained_feature_factor = 0.5 (* FUDGE *)
@@ -1257,8 +1198,7 @@
           fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
           |> debug ? sort (Real.compare o swap o apply2 snd)
 
-        val parents = maximal_wrt_access_graph access_G facts
-        val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
+        val fact_idxs = map_filter (Symtab.lookup fact_tab o nickname_of_thm o snd) facts
 
         val suggs =
           if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then
@@ -1274,8 +1214,8 @@
               val int_goal_feats =
                 map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
             in
-              MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs visible_facts
-                max_suggs goal_feats int_goal_feats
+              MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs fact_idxs max_suggs
+                goal_feats int_goal_feats
             end
 
         val unknown = filter_out (is_fact_in_graph access_G o snd) facts
@@ -1336,25 +1276,24 @@
   Date.fmt (anonymous_proof_prefix ^ "%Y%m%d.%H%M%S.") (Date.fromTimeLocal (Time.now ())) ^
   serial_string ()
 
-fun mash_learn_proof ctxt ({timeout, ...} : params) t facts used_ths =
+fun mash_learn_proof ctxt ({timeout, ...} : params) t used_ths =
   if not (null used_ths) andalso is_mash_enabled () then
     launch_thread timeout (fn () =>
       let
         val thy = Proof_Context.theory_of ctxt
         val feats = features_of ctxt (Context.theory_name thy) (Local, General) [t]
-        val facts = rev_sort_list_prefix (crude_thm_ord ctxt o apply2 snd) 1 facts
       in
         map_state ctxt
           (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
              let
-               val parents = maximal_wrt_access_graph access_G facts
                val deps = used_ths
                  |> filter (is_fact_in_graph access_G)
                  |> map nickname_of_thm
 
                val name = anonymous_proof_name ()
                val (access_G', xtabs', rev_learns) =
-                 add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])
+                 add_node Automatic_Proof name [] (* ignore parents *) feats deps
+                   (access_G, xtabs, [])
 
                val (ffds', freqs') =
                  recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs
@@ -1475,7 +1414,7 @@
                 let
                   val new_facts = facts
                     |> sort (crude_thm_ord ctxt o apply2 snd)
-                    |> attach_parents_to_facts []
+                    |> map (pair []) (* ignore parents *)
                     |> filter_out (is_in_access_G o snd)
                   val (learns, (num_nontrivial, _, _)) =
                     ([], (0, next_commit_time (), false))
@@ -1582,7 +1521,7 @@
 
 (* Generate more suggestions than requested, because some might be thrown out later for various
    reasons (e.g., duplicates). *)
-fun generous_max_suggestions max_facts = 3 * max_facts div 2 + 25
+fun generous_max_suggestions max_facts = 2 * max_facts + 25 (* FUDGE *)
 
 val mepo_weight = 0.5 (* FUDGE *)
 val mash_weight = 0.5 (* FUDGE *)