src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
changeset 38747 b264ae66cede
parent 38745 ad577fd62ee4
child 38749 0d2f7f0614d1
equal deleted inserted replaced
38746:9b465a288c62 38747:b264ae66cede
    99   | string_for_pseudoconst (s, Ts) = s ^ string_for_pseudotypes Ts
    99   | string_for_pseudoconst (s, Ts) = s ^ string_for_pseudotypes Ts
   100 fun string_for_super_pseudoconst (s, [[]]) = s
   100 fun string_for_super_pseudoconst (s, [[]]) = s
   101   | string_for_super_pseudoconst (s, Tss) =
   101   | string_for_super_pseudoconst (s, Tss) =
   102     s ^ "{" ^ commas (map string_for_pseudotypes Tss) ^ "}"
   102     s ^ "{" ^ commas (map string_for_pseudotypes Tss) ^ "}"
   103 
   103 
   104 (*Add a const/type pair to the table, but a [] entry means a standard connective,
   104 val skolem_prefix = "Sledgehammer."
   105   which we ignore.*)
   105 
   106 fun add_const_to_table (c, ctyps) =
   106 (* Add a pseudoconstant to the table, but a [] entry means a standard
   107   Symtab.map_default (c, [ctyps])
   107    connective, which we ignore.*)
   108                      (fn [] => [] | ctypss => insert (op =) ctyps ctypss)
   108 fun add_pseudoconst_to_table also_skolem (c, ctyps) =
       
   109   if also_skolem orelse not (String.isPrefix skolem_prefix c) then
       
   110     Symtab.map_default (c, [ctyps])
       
   111                        (fn [] => [] | ctypss => insert (op =) ctyps ctypss)
       
   112   else
       
   113     I
   109 
   114 
   110 fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
   115 fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
   111 
   116 
   112 val fresh_prefix = "Sledgehammer.skolem."
       
   113 val flip = Option.map not
   117 val flip = Option.map not
   114 (* These are typically simplified away by "Meson.presimplify". *)
   118 (* These are typically simplified away by "Meson.presimplify". *)
   115 val boring_consts =
   119 val boring_consts =
   116   [@{const_name False}, @{const_name True}, @{const_name If}, @{const_name Let}]
   120   [@{const_name False}, @{const_name True}, @{const_name If}, @{const_name Let}]
   117 
   121 
   118 fun get_consts thy pos ts =
   122 fun get_pseudoconsts thy also_skolems pos ts =
   119   let
   123   let
   120     (* We include free variables, as well as constants, to handle locales. For
   124     (* We include free variables, as well as constants, to handle locales. For
   121        each quantifiers that must necessarily be skolemized by the ATP, we
   125        each quantifiers that must necessarily be skolemized by the ATP, we
   122        introduce a fresh constant to simulate the effect of Skolemization. *)
   126        introduce a fresh constant to simulate the effect of Skolemization. *)
   123     fun do_term t =
   127     fun do_term t =
   124       case t of
   128       case t of
   125         Const x => add_const_to_table (pseudoconst_for thy x)
   129         Const x => add_pseudoconst_to_table also_skolems (pseudoconst_for thy x)
   126       | Free (s, _) => add_const_to_table (s, [])
   130       | Free (s, _) => add_pseudoconst_to_table also_skolems (s, [])
   127       | t1 $ t2 => fold do_term [t1, t2]
   131       | t1 $ t2 => fold do_term [t1, t2]
   128       | Abs (_, _, t') => do_term t'
   132       | Abs (_, _, t') => do_term t'  (* FIXME: add penalty? *)
   129       | _ => I
   133       | _ => I
   130     fun do_quantifier will_surely_be_skolemized body_t =
   134     fun do_quantifier will_surely_be_skolemized body_t =
   131       do_formula pos body_t
   135       do_formula pos body_t
   132       #> (if will_surely_be_skolemized then
   136       #> (if also_skolems andalso will_surely_be_skolemized then
   133             add_const_to_table (gensym fresh_prefix, [])
   137             add_pseudoconst_to_table true (gensym skolem_prefix, [])
   134           else
   138           else
   135             I)
   139             I)
   136     and do_term_or_formula T =
   140     and do_term_or_formula T =
   137       if is_formula_type T then do_formula NONE else do_term
   141       if is_formula_type T then do_formula NONE else do_term
   138     and do_formula pos t =
   142     and do_formula pos t =
   231 (* A surprising number of theorems contain only a few significant constants.
   235 (* A surprising number of theorems contain only a few significant constants.
   232    These include all induction rules, and other general theorems. *)
   236    These include all induction rules, and other general theorems. *)
   233 
   237 
   234 (* "log" seems best in practice. A constant function of one ignores the constant
   238 (* "log" seems best in practice. A constant function of one ignores the constant
   235    frequencies. *)
   239    frequencies. *)
   236 fun rel_log (x : real) = 1.0 + 2.0 / Math.ln (x + 1.0)
   240 fun rel_log n = 1.0 + 2.0 / Math.ln (Real.fromInt n + 1.0)
   237 fun irrel_log (x : real) = Math.ln (x + 19.0) / 6.4
   241 (* TODO: experiment
       
   242 fun irrel_log n = 0.5 + 1.0 / Math.ln (Real.fromInt n + 1.0)
       
   243 *)
       
   244 fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
   238 
   245 
   239 (* Computes a constant's weight, as determined by its frequency. *)
   246 (* Computes a constant's weight, as determined by its frequency. *)
   240 val rel_weight = rel_log o real oo pseudoconst_freq match_pseudotypes
   247 val rel_weight = rel_log oo pseudoconst_freq match_pseudotypes
   241 val irrel_weight =
   248 fun irrel_weight const_tab (c as (s, _)) =
   242   irrel_log o real oo pseudoconst_freq (match_pseudotypes o swap)
   249   if String.isPrefix skolem_prefix s then 1.0
   243 (* fun irrel_weight _ _ = 1.0  FIXME: OLD CODE *)
   250   else irrel_log (pseudoconst_freq (match_pseudotypes o swap) const_tab c)
       
   251 (* TODO: experiment
       
   252 fun irrel_weight _ _ = 1.0
       
   253 *)
   244 
   254 
   245 fun axiom_weight const_tab relevant_consts axiom_consts =
   255 fun axiom_weight const_tab relevant_consts axiom_consts =
   246   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
   256   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
   247                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
   257                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
   248     ([], []) => 0.0
   258     ([], []) => 0.0
   252       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
   262       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
   253       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
   263       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
   254       val res = rel_weight / (rel_weight + irrel_weight)
   264       val res = rel_weight / (rel_weight + irrel_weight)
   255     in if Real.isFinite res then res else 0.0 end
   265     in if Real.isFinite res then res else 0.0 end
   256 
   266 
   257 fun consts_of_term thy t =
   267 (* TODO: experiment
       
   268 fun debug_axiom_weight const_tab relevant_consts axiom_consts =
       
   269   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
       
   270                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
       
   271     ([], []) => 0.0
       
   272   | (_, []) => 1.0
       
   273   | (rel, irrel) =>
       
   274     let
       
   275 val _ = tracing (PolyML.makestring ("REL: ", rel))
       
   276 val _ = tracing (PolyML.makestring ("IRREL: ", irrel))
       
   277       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
       
   278       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
       
   279       val res = rel_weight / (rel_weight + irrel_weight)
       
   280     in if Real.isFinite res then res else 0.0 end
       
   281 *)
       
   282 
       
   283 fun pseudoconsts_of_term thy t =
   258   Symtab.fold (fn (x, ys) => fold (fn y => cons (x, y)) ys)
   284   Symtab.fold (fn (x, ys) => fold (fn y => cons (x, y)) ys)
   259               (get_consts thy (SOME true) [t]) []
   285               (get_pseudoconsts thy true (SOME true) [t]) []
   260 
       
   261 fun pair_consts_axiom theory_relevant thy axiom =
   286 fun pair_consts_axiom theory_relevant thy axiom =
   262   (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
   287   (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
   263                 |> consts_of_term thy)
   288                 |> pseudoconsts_of_term thy)
   264 
   289 
   265 type annotated_thm =
   290 type annotated_thm =
   266   ((unit -> string * bool) * thm) * (string * pseudotype list) list
   291   ((unit -> string * bool) * thm) * (string * pseudotype list) list
   267 
   292 
   268 fun take_best max (candidates : (annotated_thm * real) list) =
   293 fun take_most_relevant max_max_imperfect max_relevant remaining_max
   269   let
   294                        (candidates : (annotated_thm * real) list) =
   270     val ((perfect, more_perfect), imperfect) =
   295   let
   271       candidates |> List.partition (fn (_, w) => w > 0.99999) |>> chop (max - 1)
   296     val max_imperfect =
       
   297       Real.ceil (Math.pow (max_max_imperfect,
       
   298                            Real.fromInt remaining_max
       
   299                            / Real.fromInt max_relevant))
       
   300     val (perfect, imperfect) =
       
   301       candidates |> List.partition (fn (_, w) => w > 0.99999)
   272                  ||> sort (Real.compare o swap o pairself snd)
   302                  ||> sort (Real.compare o swap o pairself snd)
   273     val (accepts, rejects) =
   303     val ((accepts, more_rejects), rejects) =
   274       case more_perfect @ imperfect of
   304       chop max_imperfect imperfect |>> append perfect |>> chop remaining_max
   275         [] => (perfect, [])
       
   276       | (q :: qs) => (q :: perfect, qs)
       
   277   in
   305   in
   278     trace_msg (fn () => "Number of candidates: " ^
   306     trace_msg (fn () => "Number of candidates: " ^
   279                         string_of_int (length candidates));
   307                         string_of_int (length candidates));
   280     trace_msg (fn () => "Effective threshold: " ^
   308     trace_msg (fn () => "Effective threshold: " ^
   281                         Real.toString (#2 (hd accepts)));
   309                         Real.toString (#2 (hd accepts)));
   282     trace_msg (fn () => "Actually passed: " ^
   310     trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
   283         (accepts |> map (fn (((name, _), _), weight) =>
   311         "): " ^ (accepts
       
   312                  |> map (fn (((name, _), _), weight) =>
   284                             fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
   313                             fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
   285                  |> commas));
   314                  |> commas));
   286     (accepts, rejects)
   315     (accepts, more_rejects @ rejects)
   287   end
   316   end
   288 
   317 
   289 val threshold_divisor = 2.0
   318 val threshold_divisor = 2.0
   290 val ridiculous_threshold = 0.1
   319 val ridiculous_threshold = 0.1
       
   320 val max_max_imperfect_fudge_factor = 0.66
   291 
   321 
   292 fun relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   322 fun relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   293                      ({add, del, ...} : relevance_override) axioms goal_ts =
   323                      ({add, del, ...} : relevance_override) axioms goal_ts =
   294   let
   324   let
   295     val thy = ProofContext.theory_of ctxt
   325     val thy = ProofContext.theory_of ctxt
   296     val const_tab = fold (count_axiom_consts theory_relevant thy) axioms
   326     val const_tab = fold (count_axiom_consts theory_relevant thy) axioms
   297                          Symtab.empty
   327                          Symtab.empty
   298     val add_thms = maps (ProofContext.get_fact ctxt) add
   328     val add_thms = maps (ProofContext.get_fact ctxt) add
   299     val del_thms = maps (ProofContext.get_fact ctxt) del
   329     val del_thms = maps (ProofContext.get_fact ctxt) del
   300     fun iter j max threshold rel_const_tab hopeless hopeful =
   330     val max_max_imperfect =
       
   331       Math.sqrt (Real.fromInt max_relevant * max_max_imperfect_fudge_factor)
       
   332     fun iter j remaining_max threshold rel_const_tab hopeless hopeful =
   301       let
   333       let
   302         fun game_over rejects =
   334         fun game_over rejects =
   303           if j = 0 andalso threshold >= ridiculous_threshold then
   335           (* Add "add:" facts. *)
   304             (* First iteration? Try again. *)
   336           if null add_thms then
   305             iter 0 max (threshold / threshold_divisor) rel_const_tab hopeless
   337             []
   306                  hopeful
       
   307           else
   338           else
   308             (* Add "add:" facts. *)
   339             map_filter (fn ((p as (_, th), _), _) =>
   309             if null add_thms then
   340                            if member Thm.eq_thm add_thms th then SOME p
   310               []
   341                            else NONE) rejects
       
   342         fun relevant [] rejects hopeless [] =
       
   343             (* Nothing has been added this iteration. *)
       
   344             if j = 0 andalso threshold >= ridiculous_threshold then
       
   345               (* First iteration? Try again. *)
       
   346               iter 0 max_relevant (threshold / threshold_divisor) rel_const_tab
       
   347                    hopeless hopeful
   311             else
   348             else
   312               map_filter (fn ((p as (_, th), _), _) =>
   349               game_over (rejects @ hopeless)
   313                              if member Thm.eq_thm add_thms th then SOME p
   350           | relevant candidates rejects hopeless [] =
   314                              else NONE) rejects
       
   315         fun relevant [] rejects [] hopeless =
       
   316             (* Nothing has been added this iteration. *)
       
   317             game_over (map (apsnd SOME) (rejects @ hopeless))
       
   318           | relevant candidates rejects [] hopeless =
       
   319             let
   351             let
   320               val (accepts, more_rejects) = take_best max candidates
   352               val (accepts, more_rejects) =
       
   353                 take_most_relevant max_max_imperfect max_relevant remaining_max
       
   354                                    candidates
   321               val rel_const_tab' =
   355               val rel_const_tab' =
   322                 rel_const_tab
   356                 rel_const_tab
   323                 |> fold add_const_to_table (maps (snd o fst) accepts)
   357                 |> fold (add_pseudoconst_to_table false)
       
   358                         (maps (snd o fst) accepts)
   324               fun is_dirty (c, _) =
   359               fun is_dirty (c, _) =
   325                 Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c
   360                 Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c
   326               val (hopeful_rejects, hopeless_rejects) =
   361               val (hopeful_rejects, hopeless_rejects) =
   327                  (rejects @ hopeless, ([], []))
   362                  (rejects @ hopeless, ([], []))
   328                  |-> fold (fn (ax as (_, consts), old_weight) =>
   363                  |-> fold (fn (ax as (_, consts), old_weight) =>
   332                                 apsnd (cons (ax, old_weight)))
   367                                 apsnd (cons (ax, old_weight)))
   333                  |>> append (more_rejects
   368                  |>> append (more_rejects
   334                              |> map (fn (ax as (_, consts), old_weight) =>
   369                              |> map (fn (ax as (_, consts), old_weight) =>
   335                                         (ax, if exists is_dirty consts then NONE
   370                                         (ax, if exists is_dirty consts then NONE
   336                                              else SOME old_weight)))
   371                                              else SOME old_weight)))
   337               val threshold = threshold + (1.0 - threshold) * decay
   372               val threshold =
   338               val max = max - length accepts
   373                 threshold + (1.0 - threshold)
       
   374                 * Math.pow (decay, Real.fromInt (length accepts))
       
   375               val remaining_max = remaining_max - length accepts
   339             in
   376             in
   340               trace_msg (fn () => "New or updated constants: " ^
   377               trace_msg (fn () => "New or updated constants: " ^
   341                   commas (rel_const_tab' |> Symtab.dest
   378                   commas (rel_const_tab' |> Symtab.dest
   342                           |> subtract (op =) (Symtab.dest rel_const_tab)
   379                           |> subtract (op =) (Symtab.dest rel_const_tab)
   343                           |> map string_for_super_pseudoconst));
   380                           |> map string_for_super_pseudoconst));
   344               map (fst o fst) accepts @
   381               map (fst o fst) accepts @
   345               (if max = 0 then
   382               (if remaining_max = 0 then
   346                  game_over (hopeful_rejects @ map (apsnd SOME) hopeless_rejects)
   383                  game_over (hopeful_rejects @ map (apsnd SOME) hopeless_rejects)
   347                else
   384                else
   348                  iter (j + 1) max threshold rel_const_tab' hopeless_rejects
   385                  iter (j + 1) remaining_max threshold rel_const_tab'
   349                       hopeful_rejects)
   386                       hopeless_rejects hopeful_rejects)
   350             end
   387             end
   351           | relevant candidates rejects
   388           | relevant candidates rejects hopeless
   352                      (((ax as ((name, th), axiom_consts)), cached_weight)
   389                      (((ax as ((name, th), axiom_consts)), cached_weight)
   353                       :: hopeful) hopeless =
   390                       :: hopeful) =
   354             let
   391             let
   355               val weight =
   392               val weight =
   356                 case cached_weight of
   393                 case cached_weight of
   357                   SOME w => w
   394                   SOME w => w
   358                 | NONE => axiom_weight const_tab rel_const_tab axiom_consts
   395                 | NONE => axiom_weight const_tab rel_const_tab axiom_consts
       
   396 (* TODO: experiment
       
   397 val _ = if String.isPrefix "lift.simps(3" (fst (name ())) then
       
   398 tracing ("*** " ^ (fst (name ())) ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
       
   399 else
       
   400 ()
       
   401 *)
   359             in
   402             in
   360               if weight >= threshold then
   403               if weight >= threshold then
   361                 relevant ((ax, weight) :: candidates) rejects hopeful hopeless
   404                 relevant ((ax, weight) :: candidates) rejects hopeless hopeful
   362               else
   405               else
   363                 relevant candidates ((ax, weight) :: rejects) hopeful hopeless
   406                 relevant candidates ((ax, weight) :: rejects) hopeless hopeful
   364             end
   407             end
   365         in
   408         in
   366           trace_msg (fn () =>
   409           trace_msg (fn () =>
   367               "ITERATION " ^ string_of_int j ^ ": current threshold: " ^
   410               "ITERATION " ^ string_of_int j ^ ": current threshold: " ^
   368               Real.toString threshold ^ ", constants: " ^
   411               Real.toString threshold ^ ", constants: " ^
   369               commas (rel_const_tab |> Symtab.dest
   412               commas (rel_const_tab |> Symtab.dest
   370                       |> filter (curry (op <>) [] o snd)
   413                       |> filter (curry (op <>) [] o snd)
   371                       |> map string_for_super_pseudoconst));
   414                       |> map string_for_super_pseudoconst));
   372           relevant [] [] hopeful hopeless
   415           relevant [] [] hopeless hopeful
   373         end
   416         end
   374   in
   417   in
   375     axioms |> filter_out (member Thm.eq_thm del_thms o snd)
   418     axioms |> filter_out (member Thm.eq_thm del_thms o snd)
   376            |> map (rpair NONE o pair_consts_axiom theory_relevant thy)
   419            |> map (rpair NONE o pair_consts_axiom theory_relevant thy)
   377            |> iter 0 max_relevant threshold0
   420            |> iter 0 max_relevant threshold0
   378                    (get_consts thy (SOME false) goal_ts) []
   421                    (get_pseudoconsts thy false (SOME false) goal_ts) []
   379            |> tap (fn res => trace_msg (fn () =>
   422            |> tap (fn res => trace_msg (fn () =>
   380                                 "Total relevant: " ^ Int.toString (length res)))
   423                                 "Total relevant: " ^ Int.toString (length res)))
   381   end
   424   end
   382 
   425 
   383 
   426