fiddle with relevance filter
authorblanchet
Thu, 26 Aug 2010 00:49:04 +0200
changeset 38747 b264ae66cede
parent 38746 9b465a288c62
child 38748 69fea359d3f8
fiddle with relevance filter
src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Wed Aug 25 19:47:25 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Thu Aug 26 00:49:04 2010 +0200
@@ -101,36 +101,40 @@
   | string_for_super_pseudoconst (s, Tss) =
     s ^ "{" ^ commas (map string_for_pseudotypes Tss) ^ "}"
 
-(*Add a const/type pair to the table, but a [] entry means a standard connective,
-  which we ignore.*)
-fun add_const_to_table (c, ctyps) =
-  Symtab.map_default (c, [ctyps])
-                     (fn [] => [] | ctypss => insert (op =) ctyps ctypss)
+val skolem_prefix = "Sledgehammer."
+
+(* Add a pseudoconstant to the table, but a [] entry means a standard
+   connective, which we ignore.*)
+fun add_pseudoconst_to_table also_skolem (c, ctyps) =
+  if also_skolem orelse not (String.isPrefix skolem_prefix c) then
+    Symtab.map_default (c, [ctyps])
+                       (fn [] => [] | ctypss => insert (op =) ctyps ctypss)
+  else
+    I
 
 fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
 
-val fresh_prefix = "Sledgehammer.skolem."
 val flip = Option.map not
 (* These are typically simplified away by "Meson.presimplify". *)
 val boring_consts =
   [@{const_name False}, @{const_name True}, @{const_name If}, @{const_name Let}]
 
-fun get_consts thy pos ts =
+fun get_pseudoconsts thy also_skolems pos ts =
   let
     (* We include free variables, as well as constants, to handle locales. For
        each quantifiers that must necessarily be skolemized by the ATP, we
        introduce a fresh constant to simulate the effect of Skolemization. *)
     fun do_term t =
       case t of
-        Const x => add_const_to_table (pseudoconst_for thy x)
-      | Free (s, _) => add_const_to_table (s, [])
+        Const x => add_pseudoconst_to_table also_skolems (pseudoconst_for thy x)
+      | Free (s, _) => add_pseudoconst_to_table also_skolems (s, [])
       | t1 $ t2 => fold do_term [t1, t2]
-      | Abs (_, _, t') => do_term t'
+      | Abs (_, _, t') => do_term t'  (* FIXME: add penalty? *)
       | _ => I
     fun do_quantifier will_surely_be_skolemized body_t =
       do_formula pos body_t
-      #> (if will_surely_be_skolemized then
-            add_const_to_table (gensym fresh_prefix, [])
+      #> (if also_skolems andalso will_surely_be_skolemized then
+            add_pseudoconst_to_table true (gensym skolem_prefix, [])
           else
             I)
     and do_term_or_formula T =
@@ -233,14 +237,20 @@
 
 (* "log" seems best in practice. A constant function of one ignores the constant
    frequencies. *)
-fun rel_log (x : real) = 1.0 + 2.0 / Math.ln (x + 1.0)
-fun irrel_log (x : real) = Math.ln (x + 19.0) / 6.4
+fun rel_log n = 1.0 + 2.0 / Math.ln (Real.fromInt n + 1.0)
+(* TODO: experiment
+fun irrel_log n = 0.5 + 1.0 / Math.ln (Real.fromInt n + 1.0)
+*)
+fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
 
 (* Computes a constant's weight, as determined by its frequency. *)
-val rel_weight = rel_log o real oo pseudoconst_freq match_pseudotypes
-val irrel_weight =
-  irrel_log o real oo pseudoconst_freq (match_pseudotypes o swap)
-(* fun irrel_weight _ _ = 1.0  FIXME: OLD CODE *)
+val rel_weight = rel_log oo pseudoconst_freq match_pseudotypes
+fun irrel_weight const_tab (c as (s, _)) =
+  if String.isPrefix skolem_prefix s then 1.0
+  else irrel_log (pseudoconst_freq (match_pseudotypes o swap) const_tab c)
+(* TODO: experiment
+fun irrel_weight _ _ = 1.0
+*)
 
 fun axiom_weight const_tab relevant_consts axiom_consts =
   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
@@ -254,40 +264,60 @@
       val res = rel_weight / (rel_weight + irrel_weight)
     in if Real.isFinite res then res else 0.0 end
 
-fun consts_of_term thy t =
+(* TODO: experiment
+fun debug_axiom_weight const_tab relevant_consts axiom_consts =
+  case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
+                    ||> filter_out (pseudoconst_mem swap relevant_consts) of
+    ([], []) => 0.0
+  | (_, []) => 1.0
+  | (rel, irrel) =>
+    let
+val _ = tracing (PolyML.makestring ("REL: ", rel))
+val _ = tracing (PolyML.makestring ("IRREL: ", irrel))
+      val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
+      val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
+      val res = rel_weight / (rel_weight + irrel_weight)
+    in if Real.isFinite res then res else 0.0 end
+*)
+
+fun pseudoconsts_of_term thy t =
   Symtab.fold (fn (x, ys) => fold (fn y => cons (x, y)) ys)
-              (get_consts thy (SOME true) [t]) []
-
+              (get_pseudoconsts thy true (SOME true) [t]) []
 fun pair_consts_axiom theory_relevant thy axiom =
   (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
-                |> consts_of_term thy)
+                |> pseudoconsts_of_term thy)
 
 type annotated_thm =
   ((unit -> string * bool) * thm) * (string * pseudotype list) list
 
-fun take_best max (candidates : (annotated_thm * real) list) =
+fun take_most_relevant max_max_imperfect max_relevant remaining_max
+                       (candidates : (annotated_thm * real) list) =
   let
-    val ((perfect, more_perfect), imperfect) =
-      candidates |> List.partition (fn (_, w) => w > 0.99999) |>> chop (max - 1)
+    val max_imperfect =
+      Real.ceil (Math.pow (max_max_imperfect,
+                           Real.fromInt remaining_max
+                           / Real.fromInt max_relevant))
+    val (perfect, imperfect) =
+      candidates |> List.partition (fn (_, w) => w > 0.99999)
                  ||> sort (Real.compare o swap o pairself snd)
-    val (accepts, rejects) =
-      case more_perfect @ imperfect of
-        [] => (perfect, [])
-      | (q :: qs) => (q :: perfect, qs)
+    val ((accepts, more_rejects), rejects) =
+      chop max_imperfect imperfect |>> append perfect |>> chop remaining_max
   in
     trace_msg (fn () => "Number of candidates: " ^
                         string_of_int (length candidates));
     trace_msg (fn () => "Effective threshold: " ^
                         Real.toString (#2 (hd accepts)));
-    trace_msg (fn () => "Actually passed: " ^
-        (accepts |> map (fn (((name, _), _), weight) =>
+    trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
+        "): " ^ (accepts
+                 |> map (fn (((name, _), _), weight) =>
                             fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
                  |> commas));
-    (accepts, rejects)
+    (accepts, more_rejects @ rejects)
   end
 
 val threshold_divisor = 2.0
 val ridiculous_threshold = 0.1
+val max_max_imperfect_fudge_factor = 0.66
 
 fun relevance_filter ctxt threshold0 decay max_relevant theory_relevant
                      ({add, del, ...} : relevance_override) axioms goal_ts =
@@ -297,30 +327,35 @@
                          Symtab.empty
     val add_thms = maps (ProofContext.get_fact ctxt) add
     val del_thms = maps (ProofContext.get_fact ctxt) del
-    fun iter j max threshold rel_const_tab hopeless hopeful =
+    val max_max_imperfect =
+      Math.sqrt (Real.fromInt max_relevant * max_max_imperfect_fudge_factor)
+    fun iter j remaining_max threshold rel_const_tab hopeless hopeful =
       let
         fun game_over rejects =
-          if j = 0 andalso threshold >= ridiculous_threshold then
-            (* First iteration? Try again. *)
-            iter 0 max (threshold / threshold_divisor) rel_const_tab hopeless
-                 hopeful
+          (* Add "add:" facts. *)
+          if null add_thms then
+            []
           else
-            (* Add "add:" facts. *)
-            if null add_thms then
-              []
+            map_filter (fn ((p as (_, th), _), _) =>
+                           if member Thm.eq_thm add_thms th then SOME p
+                           else NONE) rejects
+        fun relevant [] rejects hopeless [] =
+            (* Nothing has been added this iteration. *)
+            if j = 0 andalso threshold >= ridiculous_threshold then
+              (* First iteration? Try again. *)
+              iter 0 max_relevant (threshold / threshold_divisor) rel_const_tab
+                   hopeless hopeful
             else
-              map_filter (fn ((p as (_, th), _), _) =>
-                             if member Thm.eq_thm add_thms th then SOME p
-                             else NONE) rejects
-        fun relevant [] rejects [] hopeless =
-            (* Nothing has been added this iteration. *)
-            game_over (map (apsnd SOME) (rejects @ hopeless))
-          | relevant candidates rejects [] hopeless =
+              game_over (rejects @ hopeless)
+          | relevant candidates rejects hopeless [] =
             let
-              val (accepts, more_rejects) = take_best max candidates
+              val (accepts, more_rejects) =
+                take_most_relevant max_max_imperfect max_relevant remaining_max
+                                   candidates
               val rel_const_tab' =
                 rel_const_tab
-                |> fold add_const_to_table (maps (snd o fst) accepts)
+                |> fold (add_pseudoconst_to_table false)
+                        (maps (snd o fst) accepts)
               fun is_dirty (c, _) =
                 Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c
               val (hopeful_rejects, hopeless_rejects) =
@@ -334,33 +369,41 @@
                              |> map (fn (ax as (_, consts), old_weight) =>
                                         (ax, if exists is_dirty consts then NONE
                                              else SOME old_weight)))
-              val threshold = threshold + (1.0 - threshold) * decay
-              val max = max - length accepts
+              val threshold =
+                threshold + (1.0 - threshold)
+                * Math.pow (decay, Real.fromInt (length accepts))
+              val remaining_max = remaining_max - length accepts
             in
               trace_msg (fn () => "New or updated constants: " ^
                   commas (rel_const_tab' |> Symtab.dest
                           |> subtract (op =) (Symtab.dest rel_const_tab)
                           |> map string_for_super_pseudoconst));
               map (fst o fst) accepts @
-              (if max = 0 then
+              (if remaining_max = 0 then
                  game_over (hopeful_rejects @ map (apsnd SOME) hopeless_rejects)
                else
-                 iter (j + 1) max threshold rel_const_tab' hopeless_rejects
-                      hopeful_rejects)
+                 iter (j + 1) remaining_max threshold rel_const_tab'
+                      hopeless_rejects hopeful_rejects)
             end
-          | relevant candidates rejects
+          | relevant candidates rejects hopeless
                      (((ax as ((name, th), axiom_consts)), cached_weight)
-                      :: hopeful) hopeless =
+                      :: hopeful) =
             let
               val weight =
                 case cached_weight of
                   SOME w => w
                 | NONE => axiom_weight const_tab rel_const_tab axiom_consts
+(* TODO: experiment
+val _ = if String.isPrefix "lift.simps(3" (fst (name ())) then
+tracing ("*** " ^ (fst (name ())) ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
+else
+()
+*)
             in
               if weight >= threshold then
-                relevant ((ax, weight) :: candidates) rejects hopeful hopeless
+                relevant ((ax, weight) :: candidates) rejects hopeless hopeful
               else
-                relevant candidates ((ax, weight) :: rejects) hopeful hopeless
+                relevant candidates ((ax, weight) :: rejects) hopeless hopeful
             end
         in
           trace_msg (fn () =>
@@ -369,13 +412,13 @@
               commas (rel_const_tab |> Symtab.dest
                       |> filter (curry (op <>) [] o snd)
                       |> map string_for_super_pseudoconst));
-          relevant [] [] hopeful hopeless
+          relevant [] [] hopeless hopeful
         end
   in
     axioms |> filter_out (member Thm.eq_thm del_thms o snd)
            |> map (rpair NONE o pair_consts_axiom theory_relevant thy)
            |> iter 0 max_relevant threshold0
-                   (get_consts thy (SOME false) goal_ts) []
+                   (get_pseudoconsts thy false (SOME false) goal_ts) []
            |> tap (fn res => trace_msg (fn () =>
                                 "Total relevant: " ^ Int.toString (length res)))
   end