extended relevance filter with first-order term matching
authorblanchet
Fri, 27 Aug 2010 15:39:17 +0200
changeset 38827 cf01645cbbce
parent 38826 f42f425edf24
child 38828 91ad85f962c4
extended relevance filter with first-order term matching
src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Fri Aug 27 15:37:03 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Fri Aug 27 15:39:17 2010 +0200
@@ -13,6 +13,7 @@
      only: bool}
 
   val trace : bool Unsynchronized.ref
+  val term_patterns : bool Unsynchronized.ref
   val name_thm_pairs_from_ref :
     Proof.context -> unit Symtab.table -> thm list -> Facts.ref
     -> ((string * locality) * thm) list
@@ -30,6 +31,8 @@
 val trace = Unsynchronized.ref false
 fun trace_msg msg = if !trace then tracing (msg ()) else ()
 
+val term_patterns = Unsynchronized.ref true
+
 val respect_no_atp = true
 
 datatype locality = General | Theory | Local | Chained
@@ -84,18 +87,29 @@
     match_pattern (p, q) andalso match_patterns (ps, qs)
 
 (* Is there a unifiable constant? *)
-fun pconst_mem f const_tab (s, ps) =
+fun pconst_mem f consts (s, ps) =
+  exists (curry (match_patterns o f) ps)
+         (map snd (filter (curry (op =) s o fst) consts))
+fun pconst_hyper_mem f const_tab (s, ps) =
   exists (curry (match_patterns o f) ps) (these (Symtab.lookup const_tab s))
 
-fun pattern_for_type (Type (s, Ts)) = PApp (s, map pattern_for_type Ts)
-  | pattern_for_type (TFree (s, _)) = PApp (s, [])
-  | pattern_for_type (TVar _) = PVar
+fun ptype (Type (s, Ts)) = PApp (s, map ptype Ts)
+  | ptype (TFree (s, _)) = PApp (s, [])
+  | ptype (TVar _) = PVar
+
+fun pterm thy t =
+  case strip_comb t of
+    (Const x, ts) => PApp (pconst thy true x ts)
+  | (Free x, ts) => PApp (pconst thy false x ts)
+  | (Var x, []) => PVar
+  | _ => PApp ("?", [])  (* equivalence class of higher-order constructs *)
 (* Pairs a constant with the list of its type instantiations. *)
-fun pconst_for thy (c, T) =
-  (c, map pattern_for_type (Sign.const_typargs thy (c, T)))
-  handle TYPE _ => (c, [])  (* Variable (locale constant): monomorphic *)
+and pconst_args thy const (s, T) ts =
+  (if const then map ptype (Sign.const_typargs thy (s, T)) else []) @
+  (if !term_patterns then map (pterm thy) ts else [])
+and pconst thy const (s, T) ts = (s, pconst_args thy const (s, T) ts)
 
-fun string_for_super_pconst (s, pss) =
+fun string_for_hyper_pconst (s, pss) =
   s ^ "{" ^ commas (map string_for_patterns pss) ^ "}"
 
 val abs_name = "Sledgehammer.abs"
@@ -124,14 +138,17 @@
     (* We include free variables, as well as constants, to handle locales. For
        each quantifiers that must necessarily be skolemized by the ATP, we
        introduce a fresh constant to simulate the effect of Skolemization. *)
-    fun do_term t =
-      case t of
-        Const x => add_pconst_to_table also_skolems (pconst_for thy x)
-      | Free (s, _) => add_pconst_to_table also_skolems (s, [])
-      | t1 $ t2 => fold do_term [t1, t2]
-      | Abs (_, _, t') =>
-        do_term t' #> add_pconst_to_table true (abs_name, [])
-      | _ => I
+    fun do_const const (s, T) ts =
+      add_pconst_to_table also_skolems (pconst thy const (s, T) ts)
+      #> fold do_term ts
+    and do_term t =
+      case strip_comb t of
+        (Const x, ts) => do_const true x ts
+      | (Free x, ts) => do_const false x ts
+      | (Abs (_, _, t'), ts) =>
+        null ts ? add_pconst_to_table true (abs_name, [])
+        #> fold do_term (t' :: ts)
+      | (_, ts) => fold do_term ts
     fun do_quantifier will_surely_be_skolemized body_t =
       do_formula pos body_t
       #> (if also_skolems andalso will_surely_be_skolemized then
@@ -204,21 +221,20 @@
 structure CTtab =
   Table(type key = pattern list val ord = dict_ord pattern_ord)
 
-fun count_axiom_consts theory_relevant thy (_, th) =
+fun count_axiom_consts theory_relevant thy =
   let
-    fun do_const (a, T) =
-      let val (c, cts) = pconst_for thy (a, T) in
-        (* Two-dimensional table update. Constant maps to types maps to
-           count. *)
-        CTtab.map_default (cts, 0) (Integer.add 1)
-        |> Symtab.map_default (c, CTtab.empty)
-      end
-    fun do_term (Const x) = do_const x
-      | do_term (Free x) = do_const x
-      | do_term (t $ u) = do_term t #> do_term u
-      | do_term (Abs (_, _, t)) = do_term t
-      | do_term _ = I
-  in th |> theory_const_prop_of theory_relevant |> do_term end
+    fun do_const const (s, T) ts =
+      (* Two-dimensional table update. Constant maps to types maps to count. *)
+      CTtab.map_default (pconst_args thy const (s, T) ts, 0) (Integer.add 1)
+      |> Symtab.map_default (s, CTtab.empty)
+      #> fold do_term ts
+    and do_term t =
+      case strip_comb t of
+        (Const x, ts) => do_const true x ts
+      | (Free x, ts) => do_const false x ts
+      | (Abs (_, _, t'), ts) => fold do_term (t' :: ts)
+      | (_, ts) => fold do_term ts
+  in do_term o theory_const_prop_of theory_relevant o snd end
 
 
 (**** Actual Filtering Code ****)
@@ -235,9 +251,6 @@
 (* "log" seems best in practice. A constant function of one ignores the constant
    frequencies. *)
 fun rel_log n = 1.0 + 2.0 / Math.ln (Real.fromInt n + 1.0)
-(* TODO: experiment
-fun irrel_log n = 0.5 + 1.0 / Math.ln (Real.fromInt n + 1.0)
-*)
 fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
 
 (* FUDGE *)
@@ -263,40 +276,30 @@
   | locality_multiplier Chained = 2.0
 
 fun axiom_weight loc const_tab relevant_consts axiom_consts =
-  case axiom_consts |> List.partition (pconst_mem I relevant_consts)
-                    ||> filter_out (pconst_mem swap relevant_consts) of
-    ([], []) => 0.0
-  | (_, []) => 1.0
+  case axiom_consts |> List.partition (pconst_hyper_mem I relevant_consts)
+                    ||> filter_out (pconst_hyper_mem swap relevant_consts) of
+    ([], _) => 0.0
   | (rel, irrel) =>
-    let
-      val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
-                       |> curry Real.* (locality_multiplier loc)
-      val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
-      val res = rel_weight / (rel_weight + irrel_weight)
-    in if Real.isFinite res then res else 0.0 end
-
-(* TODO: experiment
-fun debug_axiom_weight const_tab relevant_consts axiom_consts =
-  case axiom_consts |> List.partition (pconst_mem I relevant_consts)
-                    ||> filter_out (pconst_mem swap relevant_consts) of
-    ([], []) => 0.0
-  | (_, []) => 1.0
-  | (rel, irrel) =>
-    let
-val _ = tracing (PolyML.makestring ("REL: ", rel))
-val _ = tracing (PolyML.makestring ("IRREL: ", irrel))
-      val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
-      val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
-      val res = rel_weight / (rel_weight + irrel_weight)
-    in if Real.isFinite res then res else 0.0 end
-*)
+    case irrel |> filter_out (pconst_mem swap rel) of
+      [] => 1.0
+    | irrel =>
+      let
+        val rel_weight =
+          fold (curry Real.+ o rel_weight const_tab) rel 0.0
+          |> curry Real.* (locality_multiplier loc)
+        val irrel_weight =
+          fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
+        val res = rel_weight / (rel_weight + irrel_weight)
+      in if Real.isFinite res then res else 0.0 end
 
 fun pconsts_in_axiom thy t =
   Symtab.fold (fn (s, pss) => fold (cons o pair s) pss)
               (get_pconsts thy true (SOME true) [t]) []
 fun pair_consts_axiom theory_relevant thy axiom =
-  (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
-                |> pconsts_in_axiom thy)
+  case axiom |> snd |> theory_const_prop_of theory_relevant
+             |> pconsts_in_axiom thy of
+    [] => NONE
+  | consts => SOME ((axiom, consts), NONE)
 
 type annotated_thm =
   (((unit -> string) * locality) * thm) * (string * pattern list) list
@@ -314,12 +317,8 @@
     val ((accepts, more_rejects), rejects) =
       chop max_imperfect imperfect |>> append perfect |>> chop remaining_max
   in
-    trace_msg (fn () => "Number of candidates: " ^
-                        string_of_int (length candidates));
-    trace_msg (fn () => "Effective threshold: " ^
-                        Real.toString (#2 (List.last accepts)));
     trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
-        "): " ^ (accepts
+        " of " ^ Int.toString (length candidates) ^ "): " ^ (accepts
                  |> map (fn ((((name, _), _), _), weight) =>
                             name () ^ " [" ^ Real.toString weight ^ "]")
                  |> commas));
@@ -401,7 +400,7 @@
               trace_msg (fn () => "New or updated constants: " ^
                   commas (rel_const_tab' |> Symtab.dest
                           |> subtract (op =) (rel_const_tab |> Symtab.dest)
-                          |> map string_for_super_pconst));
+                          |> map string_for_hyper_pconst));
               map (fst o fst) accepts @
               (if remaining_max = 0 then
                  game_over (hopeful_rejects @ map (apsnd SOME) hopeless_rejects)
@@ -436,12 +435,12 @@
               Real.toString threshold ^ ", constants: " ^
               commas (rel_const_tab |> Symtab.dest
                       |> filter (curry (op <>) [] o snd)
-                      |> map string_for_super_pconst));
+                      |> map string_for_hyper_pconst));
           relevant [] [] hopeless hopeful
         end
   in
     axioms |> filter_out (member Thm.eq_thm del_thms o snd)
-           |> map (rpair NONE o pair_consts_axiom theory_relevant thy)
+           |> map_filter (pair_consts_axiom theory_relevant thy)
            |> iter 0 max_relevant threshold0 goal_const_tab []
            |> tap (fn res => trace_msg (fn () =>
                                 "Total relevant: " ^ Int.toString (length res)))