src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
changeset 38939 f0aa0c49fdbf
parent 38938 2b93dbc07778
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Tue Aug 31 13:12:56 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Tue Aug 31 20:19:58 2010 +0200
@@ -14,6 +14,7 @@
 
   val trace : bool Unsynchronized.ref
   val worse_irrel_freq : real Unsynchronized.ref
+  val higher_order_irrel_weight : real Unsynchronized.ref
   val abs_rel_weight : real Unsynchronized.ref
   val abs_irrel_weight : real Unsynchronized.ref
   val skolem_irrel_weight : real Unsynchronized.ref
@@ -80,13 +81,22 @@
 
 (*** constants with types ***)
 
+fun order_of_type (Type (@{type_name fun}, [T1, @{typ bool}])) =
+    order_of_type T1 (* cheat: pretend sets are first-order *)
+  | order_of_type (Type (@{type_name fun}, [T1, T2])) =
+    Int.max (order_of_type T1 + 1, order_of_type T2)
+  | order_of_type (Type (_, Ts)) = fold (Integer.max o order_of_type) Ts 0
+  | order_of_type _ = 0
+
 (* An abstraction of Isabelle types and first-order terms *)
 datatype pattern = PVar | PApp of string * pattern list
+datatype ptype = PType of int * pattern list
 
 fun string_for_pattern PVar = "_"
   | string_for_pattern (PApp (s, ps)) =
     if null ps then s else s ^ string_for_patterns ps
 and string_for_patterns ps = "(" ^ commas (map string_for_pattern ps) ^ ")"
+fun string_for_ptype (PType (_, ps)) = string_for_patterns ps
 
 (*Is the second type an instance of the first one?*)
 fun match_pattern (PVar, _) = true
@@ -97,17 +107,18 @@
   | match_patterns ([], _) = false
   | match_patterns (p :: ps, q :: qs) =
     match_pattern (p, q) andalso match_patterns (ps, qs)
+fun match_ptype (PType (_, ps), PType (_, qs)) = match_patterns (ps, qs)
 
 (* Is there a unifiable constant? *)
 fun pconst_mem f consts (s, ps) =
-  exists (curry (match_patterns o f) ps)
+  exists (curry (match_ptype o f) ps)
          (map snd (filter (curry (op =) s o fst) consts))
 fun pconst_hyper_mem f const_tab (s, ps) =
-  exists (curry (match_patterns o f) ps) (these (Symtab.lookup const_tab s))
+  exists (curry (match_ptype o f) ps) (these (Symtab.lookup const_tab s))
 
-fun ptype (Type (s, Ts)) = PApp (s, map ptype Ts)
-  | ptype (TFree (s, _)) = PApp (s, [])
-  | ptype (TVar _) = PVar
+fun pattern_for_type (Type (s, Ts)) = PApp (s, map pattern_for_type Ts)
+  | pattern_for_type (TFree (s, _)) = PApp (s, [])
+  | pattern_for_type (TVar _) = PVar
 
 fun pterm thy t =
   case strip_comb t of
@@ -116,14 +127,17 @@
   | (Var x, []) => PVar
   | _ => PApp ("?", [])  (* equivalence class of higher-order constructs *)
 (* Pairs a constant with the list of its type instantiations. *)
-and pconst_args thy const (s, T) ts =
-  (if const then map ptype (these (try (Sign.const_typargs thy) (s, T)))
+and ptype thy const x ts =
+  (if const then map pattern_for_type (these (try (Sign.const_typargs thy) x))
    else []) @
   (if term_patterns then map (pterm thy) ts else [])
-and pconst thy const (s, T) ts = (s, pconst_args thy const (s, T) ts)
+and pconst thy const (s, T) ts = (s, ptype thy const (s, T) ts)
+and rich_ptype thy const (s, T) ts =
+  PType (order_of_type T, ptype thy const (s, T) ts)
+and rich_pconst thy const (s, T) ts = (s, rich_ptype thy const (s, T) ts)
 
-fun string_for_hyper_pconst (s, pss) =
-  s ^ "{" ^ commas (map string_for_patterns pss) ^ "}"
+fun string_for_hyper_pconst (s, ps) =
+  s ^ "{" ^ commas (map string_for_ptype ps) ^ "}"
 
 val abs_name = "Sledgehammer.abs"
 val skolem_prefix = "Sledgehammer.sko"
@@ -136,12 +150,12 @@
 
 (* Add a pconstant to the table, but a [] entry means a standard
    connective, which we ignore.*)
-fun add_pconst_to_table also_skolem (c, ps) =
+fun add_pconst_to_table also_skolem (c, p) =
   if member (op =) boring_consts c orelse
      (not also_skolem andalso String.isPrefix skolem_prefix c) then
     I
   else
-    Symtab.map_default (c, [ps]) (insert (op =) ps)
+    Symtab.map_default (c, [p]) (insert (op =) p)
 
 fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
 
@@ -152,38 +166,40 @@
        each quantifiers that must necessarily be skolemized by the ATP, we
        introduce a fresh constant to simulate the effect of Skolemization. *)
     fun do_const const (s, T) ts =
-      add_pconst_to_table also_skolems (pconst thy const (s, T) ts)
+      add_pconst_to_table also_skolems (rich_pconst thy const (s, T) ts)
       #> fold do_term ts
     and do_term t =
       case strip_comb t of
         (Const x, ts) => do_const true x ts
       | (Free x, ts) => do_const false x ts
-      | (Abs (_, _, t'), ts) =>
-        (null ts ? add_pconst_to_table true (abs_name, []))
+      | (Abs (_, T, t'), ts) =>
+        (null ts
+         ? add_pconst_to_table true (abs_name, PType (order_of_type T + 1, [])))
         #> fold do_term (t' :: ts)
       | (_, ts) => fold do_term ts
-    fun do_quantifier will_surely_be_skolemized body_t =
+    fun do_quantifier will_surely_be_skolemized abs_T body_t =
       do_formula pos body_t
       #> (if also_skolems andalso will_surely_be_skolemized then
-            add_pconst_to_table true (gensym skolem_prefix, [])
+            add_pconst_to_table true
+                         (gensym skolem_prefix, PType (order_of_type abs_T, []))
           else
             I)
     and do_term_or_formula T =
       if is_formula_type T then do_formula NONE else do_term
     and do_formula pos t =
       case t of
-        Const (@{const_name all}, _) $ Abs (_, _, body_t) =>
-        do_quantifier (pos = SOME false) body_t
+        Const (@{const_name all}, _) $ Abs (_, T, t') =>
+        do_quantifier (pos = SOME false) T t'
       | @{const "==>"} $ t1 $ t2 =>
         do_formula (flip pos) t1 #> do_formula pos t2
       | Const (@{const_name "=="}, Type (_, [T, _])) $ t1 $ t2 =>
         fold (do_term_or_formula T) [t1, t2]
       | @{const Trueprop} $ t1 => do_formula pos t1
       | @{const Not} $ t1 => do_formula (flip pos) t1
-      | Const (@{const_name All}, _) $ Abs (_, _, body_t) =>
-        do_quantifier (pos = SOME false) body_t
-      | Const (@{const_name Ex}, _) $ Abs (_, _, body_t) =>
-        do_quantifier (pos = SOME true) body_t
+      | Const (@{const_name All}, _) $ Abs (_, T, t') =>
+        do_quantifier (pos = SOME false) T t'
+      | Const (@{const_name Ex}, _) $ Abs (_, T, t') =>
+        do_quantifier (pos = SOME true) T t'
       | @{const HOL.conj} $ t1 $ t2 => fold (do_formula pos) [t1, t2]
       | @{const HOL.disj} $ t1 $ t2 => fold (do_formula pos) [t1, t2]
       | @{const HOL.implies} $ t1 $ t2 =>
@@ -193,14 +209,14 @@
       | Const (@{const_name If}, Type (_, [_, Type (_, [T, _])]))
         $ t1 $ t2 $ t3 =>
         do_formula NONE t1 #> fold (do_term_or_formula T) [t2, t3]
-      | Const (@{const_name Ex1}, _) $ Abs (_, _, body_t) =>
-        do_quantifier (is_some pos) body_t
-      | Const (@{const_name Ball}, _) $ t1 $ Abs (_, _, body_t) =>
-        do_quantifier (pos = SOME false)
-                      (HOLogic.mk_imp (incr_boundvars 1 t1 $ Bound 0, body_t))
-      | Const (@{const_name Bex}, _) $ t1 $ Abs (_, _, body_t) =>
-        do_quantifier (pos = SOME true)
-                      (HOLogic.mk_conj (incr_boundvars 1 t1 $ Bound 0, body_t))
+      | Const (@{const_name Ex1}, _) $ Abs (_, T, t') =>
+        do_quantifier (is_some pos) T t'
+      | Const (@{const_name Ball}, _) $ t1 $ Abs (_, T, t') =>
+        do_quantifier (pos = SOME false) T
+                      (HOLogic.mk_imp (incr_boundvars 1 t1 $ Bound 0, t'))
+      | Const (@{const_name Bex}, _) $ t1 $ Abs (_, T, t') =>
+        do_quantifier (pos = SOME true) T
+                      (HOLogic.mk_conj (incr_boundvars 1 t1 $ Bound 0, t'))
       | (t0 as Const (_, @{typ bool})) $ t1 =>
         do_term t0 #> do_formula pos t1  (* theory constant *)
       | _ => do_term t
@@ -230,16 +246,17 @@
   | (PApp _, PVar) => GREATER
   | (PApp q1, PApp q2) =>
     prod_ord fast_string_ord (dict_ord pattern_ord) (q1, q2)
+fun ptype_ord (PType p, PType q) =
+  prod_ord (dict_ord pattern_ord) int_ord (swap p, swap q)
 
-structure CTtab =
-  Table(type key = pattern list val ord = dict_ord pattern_ord)
+structure PType_Tab = Table(type key = ptype val ord = ptype_ord)
 
 fun count_axiom_consts theory_relevant thy =
   let
     fun do_const const (s, T) ts =
       (* Two-dimensional table update. Constant maps to types maps to count. *)
-      CTtab.map_default (pconst_args thy const (s, T) ts, 0) (Integer.add 1)
-      |> Symtab.map_default (s, CTtab.empty)
+      PType_Tab.map_default (rich_ptype thy const (s, T) ts, 0) (Integer.add 1)
+      |> Symtab.map_default (s, PType_Tab.empty)
       #> fold do_term ts
     and do_term t =
       case strip_comb t of
@@ -252,10 +269,14 @@
 
 (**** Actual Filtering Code ****)
 
+fun pow_int x 0 = 1.0
+  | pow_int x 1 = x
+  | pow_int x n = if n > 0 then x * pow_int x (n - 1) else pow_int x (n + 1) / x
+
 (*The frequency of a constant is the sum of those of all instances of its type.*)
 fun pconst_freq match const_tab (c, ps) =
-  CTtab.fold (fn (qs, m) => match (ps, qs) ? Integer.add m)
-             (the (Symtab.lookup const_tab c)) 0
+  PType_Tab.fold (fn (qs, m) => match (ps, qs) ? Integer.add m)
+                 (the (Symtab.lookup const_tab c)) 0
 
 
 (* A surprising number of theorems contain only a few significant constants.
@@ -264,19 +285,21 @@
 (* "log" seems best in practice. A constant function of one ignores the constant
    frequencies. Rare constants give more points if they are relevant than less
    rare ones. *)
-fun rel_weight_for_freq n = 1.0 + 2.0 / Math.ln (Real.fromInt n + 1.0)
+fun rel_weight_for order freq = 1.0 + 2.0 / Math.ln (Real.fromInt freq + 1.0)
 
 (* FUDGE *)
 val worse_irrel_freq = Unsynchronized.ref 100.0
+val higher_order_irrel_weight = Unsynchronized.ref 1.05
 
 (* Irrelevant constants are treated differently. We associate lower penalties to
    very rare constants and very common ones -- the former because they can't
    lead to the inclusion of too many new facts, and the latter because they are
    so common as to be of little interest. *)
-fun irrel_weight_for_freq n =
+fun irrel_weight_for order freq =
   let val (k, x) = !worse_irrel_freq |> `Real.ceil in
-    if n < k then Math.ln (Real.fromInt (n + 1)) / Math.ln x
-    else rel_weight_for_freq n / rel_weight_for_freq k
+    (if freq < k then Math.ln (Real.fromInt (freq + 1)) / Math.ln x
+     else rel_weight_for order freq / rel_weight_for order k)
+    * pow_int (!higher_order_irrel_weight) (order - 1)
   end
 
 (* FUDGE *)
@@ -285,17 +308,17 @@
 val skolem_irrel_weight = Unsynchronized.ref 0.75
 
 (* Computes a constant's weight, as determined by its frequency. *)
-fun generic_pconst_weight abs_weight skolem_weight weight_for_freq f const_tab
-                          (c as (s, _)) =
+fun generic_pconst_weight abs_weight skolem_weight weight_for f const_tab
+                          (c as (s, PType (m, _))) =
   if s = abs_name then abs_weight
   else if String.isPrefix skolem_prefix s then skolem_weight
-  else weight_for_freq (pconst_freq (match_patterns o f) const_tab c)
+  else weight_for m (pconst_freq (match_ptype o f) const_tab c)
 
 fun rel_pconst_weight const_tab =
-  generic_pconst_weight (!abs_rel_weight) 0.0 rel_weight_for_freq I const_tab
+  generic_pconst_weight (!abs_rel_weight) 0.0 rel_weight_for I const_tab
 fun irrel_pconst_weight const_tab =
   generic_pconst_weight (!abs_irrel_weight) (!skolem_irrel_weight)
-                        irrel_weight_for_freq swap const_tab
+                        irrel_weight_for swap const_tab
 
 (* FUDGE *)
 val intro_bonus = Unsynchronized.ref 0.15
@@ -340,7 +363,7 @@
         ~ (locality_bonus loc)
         |> fold (curry (op +) o irrel_pconst_weight const_tab) irrel
 val _ = tracing (PolyML.makestring ("REL: ", map (`(rel_pconst_weight const_tab)) rel))
-val _ = tracing (PolyML.makestring ("IRREL: ", map (`(irrel_pconst_weight const_tab)) irrel))(*###*)
+val _ = tracing (PolyML.makestring ("IRREL: ", map (`(irrel_pconst_weight const_tab)) irrel))
       val res = rels_weight / (rels_weight + irrels_weight)
     in if Real.isFinite res then res else 0.0 end
 *)
@@ -355,7 +378,7 @@
   | consts => SOME ((axiom, consts), NONE)
 
 type annotated_thm =
-  (((unit -> string) * locality) * thm) * (string * pattern list) list
+  (((unit -> string) * locality) * thm) * (string * ptype) list
 
 (* FUDGE *)
 val max_imperfect = Unsynchronized.ref 11.5