--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML Wed Aug 25 09:42:28 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML Wed Aug 25 17:49:52 2010 +0200
@@ -11,11 +11,11 @@
only: bool}
val trace : bool Unsynchronized.ref
- val name_thms_pair_from_ref :
+ val name_thm_pairs_from_ref :
Proof.context -> unit Symtab.table -> thm list -> Facts.ref
- -> (unit -> string * bool) * thm list
+ -> ((unit -> string * bool) * (bool * thm)) list
val relevant_facts :
- bool -> real -> real -> int -> bool -> relevance_override
+ bool -> real -> real option -> int -> bool -> relevance_override
-> Proof.context * (thm list * 'a) -> term list -> term
-> ((string * bool) * thm) list
end;
@@ -37,13 +37,22 @@
val sledgehammer_prefix = "Sledgehammer" ^ Long_Name.separator
-fun name_thms_pair_from_ref ctxt reserved chained_ths xref =
- let val ths = ProofContext.get_fact ctxt xref in
- (fn () => let
- val name = Facts.string_of_ref xref
- val name = name |> Symtab.defined reserved name ? quote
- val chained = forall (member Thm.eq_thm chained_ths) ths
- in (name, chained) end, ths)
+fun repair_name reserved multi j name =
+ (name |> Symtab.defined reserved name ? quote) ^
+ (if multi then "(" ^ Int.toString j ^ ")" else "")
+
+fun name_thm_pairs_from_ref ctxt reserved chained_ths xref =
+ let
+ val ths = ProofContext.get_fact ctxt xref
+ val name = Facts.string_of_ref xref
+ val multi = length ths > 1
+ in
+ fold (fn th => fn (j, rest) =>
+ (j + 1, (fn () => (repair_name reserved multi j name,
+ member Thm.eq_thm chained_ths th),
+ (multi, th)) :: rest))
+ ths (1, [])
+ |> snd
end
(***************************************************************)
@@ -53,30 +62,44 @@
(*** constants with types ***)
(*An abstraction of Isabelle types*)
-datatype const_typ = CTVar | CType of string * const_typ list
+datatype pseudotype = PVar | PType of string * pseudotype list
+
+fun string_for_pseudotype PVar = "?"
+ | string_for_pseudotype (PType (s, Ts)) =
+ (case Ts of
+ [] => ""
+ | [T] => string_for_pseudotype T
+ | Ts => string_for_pseudotypes Ts ^ " ") ^ s
+and string_for_pseudotypes Ts =
+ "(" ^ commas (map string_for_pseudotype Ts) ^ ")"
(*Is the second type an instance of the first one?*)
-fun match_type (CType(con1,args1)) (CType(con2,args2)) =
- con1=con2 andalso match_types args1 args2
- | match_type CTVar _ = true
- | match_type _ CTVar = false
-and match_types [] [] = true
- | match_types (a1::as1) (a2::as2) = match_type a1 a2 andalso match_types as1 as2;
+fun match_pseudotype (PType (a, T), PType (b, U)) =
+ a = b andalso match_pseudotypes (T, U)
+ | match_pseudotype (PVar, _) = true
+ | match_pseudotype (_, PVar) = false
+and match_pseudotypes ([], []) = true
+ | match_pseudotypes (T :: Ts, U :: Us) =
+ match_pseudotype (T, U) andalso match_pseudotypes (Ts, Us)
(*Is there a unifiable constant?*)
-fun const_mem const_tab (c, c_typ) =
- exists (match_types c_typ) (these (Symtab.lookup const_tab c))
+fun pseudoconst_mem f const_tab (c, c_typ) =
+ exists (curry (match_pseudotypes o f) c_typ)
+ (these (Symtab.lookup const_tab c))
-(*Maps a "real" type to a const_typ*)
-fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs)
- | const_typ_of (TFree _) = CTVar
- | const_typ_of (TVar _) = CTVar
+fun pseudotype_for (Type (c,typs)) = PType (c, map pseudotype_for typs)
+ | pseudotype_for (TFree _) = PVar
+ | pseudotype_for (TVar _) = PVar
+(* Pairs a constant with the list of its type instantiations. *)
+fun pseudoconst_for thy (c, T) =
+ (c, map pseudotype_for (Sign.const_typargs thy (c, T)))
+ handle TYPE _ => (c, []) (* Variable (locale constant): monomorphic *)
-(*Pairs a constant with the list of its type instantiations (using const_typ)*)
-fun const_with_typ thy (c,typ) =
- let val tvars = Sign.const_typargs thy (c,typ) in
- (c, map const_typ_of tvars) end
- handle TYPE _ => (c, []) (*Variable (locale constant): monomorphic*)
+fun string_for_pseudoconst (s, []) = s
+ | string_for_pseudoconst (s, Ts) = s ^ string_for_pseudotypes Ts
+fun string_for_super_pseudoconst (s, [[]]) = s
+ | string_for_super_pseudoconst (s, Tss) =
+ s ^ "{" ^ commas (map string_for_pseudotypes Tss) ^ "}"
(*Add a const/type pair to the table, but a [] entry means a standard connective,
which we ignore.*)
@@ -86,7 +109,7 @@
fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
-val fresh_prefix = "Sledgehammer.FRESH."
+val fresh_prefix = "Sledgehammer.skolem."
val flip = Option.map not
(* These are typically simplified away by "Meson.presimplify". *)
val boring_consts =
@@ -99,7 +122,7 @@
introduce a fresh constant to simulate the effect of Skolemization. *)
fun do_term t =
case t of
- Const x => add_const_to_table (const_with_typ thy x)
+ Const x => add_const_to_table (pseudoconst_for thy x)
| Free (s, _) => add_const_to_table (s, [])
| t1 $ t2 => fold do_term [t1, t2]
| Abs (_, _, t') => do_term t'
@@ -166,23 +189,23 @@
(* A two-dimensional symbol table counts frequencies of constants. It's keyed
first by constant name and second by its list of type instantiations. For the
- latter, we need a linear ordering on "const_typ list". *)
+ latter, we need a linear ordering on "pseudotype list". *)
-fun const_typ_ord p =
+fun pseudotype_ord p =
case p of
- (CTVar, CTVar) => EQUAL
- | (CTVar, CType _) => LESS
- | (CType _, CTVar) => GREATER
- | (CType q1, CType q2) =>
- prod_ord fast_string_ord (dict_ord const_typ_ord) (q1, q2)
+ (PVar, PVar) => EQUAL
+ | (PVar, PType _) => LESS
+ | (PType _, PVar) => GREATER
+ | (PType q1, PType q2) =>
+ prod_ord fast_string_ord (dict_ord pseudotype_ord) (q1, q2)
structure CTtab =
- Table(type key = const_typ list val ord = dict_ord const_typ_ord)
+ Table(type key = pseudotype list val ord = dict_ord pseudotype_ord)
fun count_axiom_consts theory_relevant thy (_, th) =
let
fun do_const (a, T) =
- let val (c, cts) = const_with_typ thy (a, T) in
+ let val (c, cts) = pseudoconst_for thy (a, T) in
(* Two-dimensional table update. Constant maps to types maps to
count. *)
CTtab.map_default (cts, 0) (Integer.add 1)
@@ -199,8 +222,8 @@
(**** Actual Filtering Code ****)
(*The frequency of a constant is the sum of those of all instances of its type.*)
-fun const_frequency const_tab (c, cts) =
- CTtab.fold (fn (cts', m) => match_types cts cts' ? Integer.add m)
+fun pseudoconst_freq match const_tab (c, cts) =
+ CTtab.fold (fn (cts', m) => match (cts, cts') ? Integer.add m)
(the (Symtab.lookup const_tab c)) 0
handle Option.Option => 0
@@ -214,29 +237,22 @@
fun irrel_log (x : real) = Math.ln (x + 19.0) / 6.4
(* Computes a constant's weight, as determined by its frequency. *)
-val rel_const_weight = rel_log o real oo const_frequency
-val irrel_const_weight = irrel_log o real oo const_frequency
-(* fun irrel_const_weight _ _ = 1.0 FIXME: OLD CODE *)
+val rel_weight = rel_log o real oo pseudoconst_freq match_pseudotypes
+val irrel_weight =
+ irrel_log o real oo pseudoconst_freq (match_pseudotypes o swap)
+(* fun irrel_weight _ _ = 1.0 FIXME: OLD CODE *)
fun axiom_weight const_tab relevant_consts axiom_consts =
- let
- val (rel, irrel) = List.partition (const_mem relevant_consts) axiom_consts
- val rel_weight = fold (curry Real.+ o rel_const_weight const_tab) rel 0.0
- val irrel_weight = fold (curry Real.+ o irrel_const_weight const_tab) irrel 0.0
- val res = rel_weight / (rel_weight + irrel_weight)
- in if Real.isFinite res then res else 0.0 end
-
-(* OLD CODE:
-(*Relevant constants are weighted according to frequency,
- but irrelevant constants are simply counted. Otherwise, Skolem functions,
- which are rare, would harm a formula's chances of being picked.*)
-fun axiom_weight const_tab relevant_consts axiom_consts =
- let
- val rel = filter (const_mem relevant_consts) axiom_consts
- val rel_weight = fold (curry Real.+ o rel_const_weight const_tab) rel 0.0
- val res = rel_weight / (rel_weight + real (length axiom_consts - length rel))
- in if Real.isFinite res then res else 0.0 end
-*)
+ case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
+ ||> filter_out (pseudoconst_mem swap relevant_consts) of
+ ([], []) => 0.0
+ | (_, []) => 1.0
+ | (rel, irrel) =>
+ let
+ val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
+ val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
+ val res = rel_weight / (rel_weight + irrel_weight)
+ in if Real.isFinite res then res else 0.0 end
fun consts_of_term thy t =
Symtab.fold (fn (x, ys) => fold (fn y => cons (x, y)) ys)
@@ -247,83 +263,82 @@
|> consts_of_term thy)
type annotated_thm =
- ((unit -> string * bool) * thm) * (string * const_typ list) list
+ ((unit -> string * bool) * thm) * (string * pseudotype list) list
-(*For a reverse sort, putting the largest values first.*)
-fun compare_pairs ((_, w1), (_, w2)) = Real.compare (w2, w1)
+fun rev_compare_pairs ((_, w1), (_, w2)) = Real.compare (w2, w1)
-(* Limit the number of new facts, to prevent runaway acceptance. *)
-fun take_best max_relevant_per_iter (new_pairs : (annotated_thm * real) list) =
- let val nnew = length new_pairs in
- if nnew <= max_relevant_per_iter then
- (map #1 new_pairs, [])
- else
- let
- val new_pairs = sort compare_pairs new_pairs
- val accepted = List.take (new_pairs, max_relevant_per_iter)
- in
- trace_msg (fn () => ("Number of candidates, " ^ Int.toString nnew ^
- ", exceeds the limit of " ^ Int.toString max_relevant_per_iter));
- trace_msg (fn () => ("Effective pass mark: " ^ Real.toString (#2 (List.last accepted))));
- trace_msg (fn () => "Actually passed: " ^
- space_implode ", " (map (fst o (fn f => f ()) o fst o fst o fst) accepted));
- (map #1 accepted, List.drop (new_pairs, max_relevant_per_iter))
- end
- end;
+fun take_best max (new_pairs : (annotated_thm * real) list) =
+ let
+ val ((perfect, more_perfect), imperfect) =
+ new_pairs |> List.partition (fn (_, w) => w > 0.99999)
+ |>> chop (max - 1) ||> sort rev_compare_pairs
+ val (accepted, rejected) =
+ case more_perfect @ imperfect of
+ [] => (perfect, [])
+ | (q :: qs) => (q :: perfect, qs)
+ in
+ trace_msg (fn () => "Number of candidates: " ^
+ string_of_int (length new_pairs));
+ trace_msg (fn () => "Effective threshold: " ^
+ Real.toString (#2 (hd accepted)));
+ trace_msg (fn () => "Actually passed: " ^
+ (accepted |> map (fn (((name, _), _), weight) =>
+ fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
+ |> commas));
+ (map #1 accepted, rejected)
+ end
val threshold_divisor = 2.0
val ridiculous_threshold = 0.1
-fun relevance_filter ctxt relevance_threshold relevance_decay
- max_relevant_per_iter theory_relevant
- ({add, del, ...} : relevance_override) axioms goal_ts =
+fun relevance_filter ctxt relevance_threshold relevance_decay max_relevant
+ theory_relevant ({add, del, ...} : relevance_override)
+ axioms goal_ts =
let
val thy = ProofContext.theory_of ctxt
val const_tab = fold (count_axiom_consts theory_relevant thy) axioms
Symtab.empty
- val goal_const_tab = get_consts thy (SOME false) goal_ts
- val _ =
- trace_msg (fn () => "Initial constants: " ^
- commas (goal_const_tab |> Symtab.dest
- |> filter (curry (op <>) [] o snd)
- |> map fst))
val add_thms = maps (ProofContext.get_fact ctxt) add
val del_thms = maps (ProofContext.get_fact ctxt) del
- fun iter j threshold rel_const_tab =
+ fun iter j max threshold rel_const_tab rest =
let
+ fun game_over rejects =
+ if j = 0 andalso threshold >= ridiculous_threshold then
+ (* First iteration? Try again. *)
+ iter 0 max (threshold / threshold_divisor) rel_const_tab rejects
+ else
+ (* Add "add:" facts. *)
+ if null add_thms then
+ []
+ else
+ map_filter (fn ((p as (_, th), _), _) =>
+ if member Thm.eq_thm add_thms th then SOME p
+ else NONE) rejects
fun relevant ([], rejects) [] =
- (* Nothing was added this iteration. *)
- if j = 0 andalso threshold >= ridiculous_threshold then
- (* First iteration? Try again. *)
- iter 0 (threshold / threshold_divisor) rel_const_tab
- (map (apsnd SOME) rejects)
- else
- (* Add "add:" facts. *)
- if null add_thms then
- []
- else
- map_filter (fn ((p as (_, th), _), _) =>
- if member Thm.eq_thm add_thms th then SOME p
- else NONE) rejects
+ (* Nothing has been added this iteration. *)
+ game_over (map (apsnd SOME) rejects)
| relevant (new_pairs, rejects) [] =
let
- val (new_rels, more_rejects) =
- take_best max_relevant_per_iter new_pairs
+ val (new_rels, more_rejects) = take_best max new_pairs
val rel_const_tab' =
rel_const_tab |> fold add_const_to_table (maps snd new_rels)
- fun is_dirty c =
- const_mem rel_const_tab' c andalso
- not (const_mem rel_const_tab c)
+ fun is_dirty (c, _) =
+ Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c
val rejects =
more_rejects @ rejects
|> map (fn (ax as (_, consts), old_weight) =>
(ax, if exists is_dirty consts then NONE
else SOME old_weight))
val threshold = threshold + (1.0 - threshold) * relevance_decay
+ val max = max - length new_rels
in
- trace_msg (fn () => "relevant this iteration: " ^
- Int.toString (length new_rels));
- map #1 new_rels @ iter (j + 1) threshold rel_const_tab' rejects
+ trace_msg (fn () => "New or updated constants: " ^
+ commas (rel_const_tab' |> Symtab.dest
+ |> subtract (op =) (Symtab.dest rel_const_tab)
+ |> map string_for_super_pseudoconst));
+ map #1 new_rels @
+ (if max = 0 then game_over rejects
+ else iter (j + 1) max threshold rel_const_tab' rejects)
end
| relevant (new_rels, rejects)
(((ax as ((name, th), axiom_consts)), cached_weight)
@@ -335,26 +350,29 @@
| NONE => axiom_weight const_tab rel_const_tab axiom_consts
in
if weight >= threshold then
- (trace_msg (fn () =>
- fst (name ()) ^ " passes: " ^ Real.toString weight
- ^ " consts: " ^ commas (map fst axiom_consts));
- relevant ((ax, weight) :: new_rels, rejects) rest)
+ relevant ((ax, weight) :: new_rels, rejects) rest
else
relevant (new_rels, (ax, weight) :: rejects) rest
end
in
- trace_msg (fn () => "relevant_facts, current threshold: " ^
- Real.toString threshold);
- relevant ([], [])
+ trace_msg (fn () =>
+ "ITERATION " ^ string_of_int j ^ ": current threshold: " ^
+ Real.toString threshold ^ ", constants: " ^
+ commas (rel_const_tab |> Symtab.dest
+ |> filter (curry (op <>) [] o snd)
+ |> map string_for_super_pseudoconst));
+ relevant ([], []) rest
end
in
axioms |> filter_out (member Thm.eq_thm del_thms o snd)
|> map (rpair NONE o pair_consts_axiom theory_relevant thy)
- |> iter 0 relevance_threshold goal_const_tab
+ |> iter 0 max_relevant relevance_threshold
+ (get_consts thy (SOME false) goal_ts)
|> tap (fn res => trace_msg (fn () =>
"Total relevant: " ^ Int.toString (length res)))
end
+
(***************************************************************)
(* Retrieving and filtering lemmas *)
(***************************************************************)
@@ -547,14 +565,7 @@
val name2 = Name_Space.extern full_space name0
in
case find_first check_thms [name1, name2, name0] of
- SOME name =>
- let
- val name =
- name |> Symtab.defined reserved name ? quote
- in
- if multi then name ^ "(" ^ Int.toString j ^ ")"
- else name
- end
+ SOME name => repair_name reserved multi j name
| NONE => ""
end, is_chained th), (multi, th)) :: rest)) ths
#> snd
@@ -567,25 +578,26 @@
(* The single-name theorems go after the multiple-name ones, so that single
names are preferred when both are available. *)
fun name_thm_pairs ctxt respect_no_atp =
- List.partition (fst o snd) #> op @
- #> map (apsnd snd)
+ List.partition (fst o snd) #> op @ #> map (apsnd snd)
#> respect_no_atp ? filter_out (No_ATPs.member ctxt o snd)
(***************************************************************)
(* ATP invocation methods setup *)
(***************************************************************)
-fun relevant_facts full_types relevance_threshold relevance_decay
- max_relevant_per_iter theory_relevant
- (relevance_override as {add, del, only})
+fun relevant_facts full_types relevance_threshold relevance_decay max_relevant
+ theory_relevant (relevance_override as {add, del, only})
(ctxt, (chained_ths, _)) hyp_ts concl_t =
let
+ val relevance_decay =
+ case relevance_decay of
+ SOME x => x
+ | NONE => 0.35 / Math.ln (Real.fromInt (max_relevant + 1))
val add_thms = maps (ProofContext.get_fact ctxt) add
val reserved = reserved_isar_keyword_table ()
val axioms =
(if only then
- maps ((fn (n, ths) => map (pair n o pair false) ths)
- o name_thms_pair_from_ref ctxt reserved chained_ths) add
+ maps (name_thm_pairs_from_ref ctxt reserved chained_ths) add
else
all_name_thms_pairs ctxt reserved full_types add_thms chained_ths)
|> name_thm_pairs ctxt (respect_no_atp andalso not only)
@@ -598,11 +610,10 @@
else if relevance_threshold < 0.0 then
axioms
else
- relevance_filter ctxt relevance_threshold relevance_decay
- max_relevant_per_iter theory_relevant relevance_override
- axioms (concl_t :: hyp_ts))
- |> map (apfst (fn f => f ()))
- |> sort_wrt (fst o fst)
+ relevance_filter ctxt relevance_threshold relevance_decay max_relevant
+ theory_relevant relevance_override axioms
+ (concl_t :: hyp_ts))
+ |> map (apfst (fn f => f ())) |> sort_wrt (fst o fst)
end
end;