consider "locality" when assigning weights to facts
authorblanchet
Thu, 26 Aug 2010 10:42:06 +0200
changeset 38752 6628adcae4a7
parent 38751 01c4d14b2a61
child 38753 3913f58d0fcc
consider "locality" when assigning weights to facts
src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
src/HOL/Tools/Sledgehammer/sledgehammer_fact_minimize.ML
src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML
src/HOL/Tools/Sledgehammer/sledgehammer_translate.ML
src/HOL/Tools/Sledgehammer/sledgehammer_util.ML
--- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -290,10 +290,12 @@
     | NONE => get_prover (default_atp_name ()))
   end
 
+type locality = Sledgehammer_Fact_Filter.locality
+
 local
 
 datatype sh_result =
-  SH_OK of int * int * (string * bool) list |
+  SH_OK of int * int * (string * locality) list |
   SH_FAIL of int * int |
   SH_ERROR
 
@@ -355,8 +357,8 @@
     case result of
       SH_OK (time_isa, time_atp, names) =>
         let
-          fun get_thms (name, chained) =
-            ((name, chained), thms_of_name (Proof.context_of st) name)
+          fun get_thms (name, loc) =
+            ((name, loc), thms_of_name (Proof.context_of st) name)
         in
           change_data id inc_sh_success;
           change_data id (inc_sh_lemmas (length names));
@@ -445,7 +447,7 @@
     then () else
     let
       val named_thms =
-        Unsynchronized.ref (NONE : ((string * bool) * thm list) list option)
+        Unsynchronized.ref (NONE : ((string * locality) * thm list) list option)
       val minimize = AList.defined (op =) args minimizeK
       val metis_ft = AList.defined (op =) args metis_ftK
   
--- a/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -9,6 +9,7 @@
 signature SLEDGEHAMMER =
 sig
   type failure = ATP_Systems.failure
+  type locality = Sledgehammer_Fact_Filter.locality
   type relevance_override = Sledgehammer_Fact_Filter.relevance_override
   type minimize_command = Sledgehammer_Proof_Reconstruct.minimize_command
   type params =
@@ -28,16 +29,16 @@
     {subgoal: int,
      goal: Proof.context * (thm list * thm),
      relevance_override: relevance_override,
-     axioms: ((string * bool) * thm) list option}
+     axioms: ((string * locality) * thm) list option}
   type prover_result =
     {outcome: failure option,
      message: string,
      pool: string Symtab.table,
-     used_thm_names: (string * bool) list,
+     used_thm_names: (string * locality) list,
      atp_run_time_in_msecs: int,
      output: string,
      proof: string,
-     axiom_names: (string * bool) vector,
+     axiom_names: (string * locality) vector,
      conjecture_shape: int list list}
   type prover = params -> minimize_command -> problem -> prover_result
 
@@ -96,17 +97,17 @@
   {subgoal: int,
    goal: Proof.context * (thm list * thm),
    relevance_override: relevance_override,
-   axioms: ((string * bool) * thm) list option}
+   axioms: ((string * locality) * thm) list option}
 
 type prover_result =
   {outcome: failure option,
    message: string,
    pool: string Symtab.table,
-   used_thm_names: (string * bool) list,
+   used_thm_names: (string * locality) list,
    atp_run_time_in_msecs: int,
    output: string,
    proof: string,
-   axiom_names: (string * bool) vector,
+   axiom_names: (string * locality) vector,
    conjecture_shape: int list list}
 
 type prover = params -> minimize_command -> problem -> prover_result
@@ -193,8 +194,11 @@
           val axioms =
             j |> AList.lookup (op =) name_map |> these
               |> map_filter (try (unprefix axiom_prefix)) |> map unascii_of
-          val chained = forall (is_true_for axiom_names) axioms
-        in (axioms |> space_implode " ", chained) end
+          val loc =
+            case axioms of
+              [axiom] => find_first_in_vector axiom_names axiom General
+            | _ => General
+        in (axioms |> space_implode " ", loc) end
     in
       (conjecture_shape |> map (maps renumber_conjecture),
        seq |> map name_for_number |> Vector.fromList)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -5,6 +5,8 @@
 
 signature SLEDGEHAMMER_FACT_FILTER =
 sig
+  datatype locality = General | Theory | Local | Chained
+
   type relevance_override =
     {add: Facts.ref list,
      del: Facts.ref list,
@@ -13,11 +15,11 @@
   val trace : bool Unsynchronized.ref
   val name_thm_pairs_from_ref :
     Proof.context -> unit Symtab.table -> thm list -> Facts.ref
-    -> ((unit -> string * bool) * (bool * thm)) list
+    -> ((string * locality) * thm) list
   val relevant_facts :
     bool -> real * real -> int -> bool -> relevance_override
     -> Proof.context * (thm list * 'a) -> term list -> term
-    -> ((string * bool) * thm) list
+    -> ((string * locality) * thm) list
 end;
 
 structure Sledgehammer_Fact_Filter : SLEDGEHAMMER_FACT_FILTER =
@@ -30,6 +32,8 @@
 
 val respect_no_atp = true
 
+datatype locality = General | Theory | Local | Chained
+
 type relevance_override =
   {add: Facts.ref list,
    del: Facts.ref list,
@@ -47,11 +51,11 @@
     val name = Facts.string_of_ref xref
     val multi = length ths > 1
   in
-    fold (fn th => fn (j, rest) =>
-             (j + 1, (fn () => (repair_name reserved multi j name,
-                                member Thm.eq_thm chained_ths th),
-                      (multi, th)) :: rest))
-         ths (1, [])
+    (ths, (1, []))
+    |-> fold (fn th => fn (j, rest) =>
+                 (j + 1, ((repair_name reserved multi j name,
+                          if member Thm.eq_thm chained_ths th then Chained
+                          else General), th) :: rest))
     |> snd
   end
 
@@ -245,19 +249,27 @@
 *)
 fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
 
+(* FUDGE *)
+val skolem_weight = 1.0
+val abs_weight = 2.0
+
 (* Computes a constant's weight, as determined by its frequency. *)
 val rel_weight = rel_log oo pseudoconst_freq match_pseudotypes
 fun irrel_weight const_tab (c as (s, _)) =
-  if String.isPrefix skolem_prefix s then 1.0
-  else if String.isPrefix abs_prefix s then 2.0
+  if String.isPrefix skolem_prefix s then skolem_weight
+  else if String.isPrefix abs_prefix s then abs_weight
   else irrel_log (pseudoconst_freq (match_pseudotypes o swap) const_tab c)
 (* TODO: experiment
 fun irrel_weight _ _ = 1.0
 *)
 
-val chained_bonus_factor = 2.0
+(* FUDGE *)
+fun locality_multiplier General = 1.0
+  | locality_multiplier Theory = 1.1
+  | locality_multiplier Local = 1.3
+  | locality_multiplier Chained = 2.0
 
-fun axiom_weight chained const_tab relevant_consts axiom_consts =
+fun axiom_weight loc const_tab relevant_consts axiom_consts =
   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
     ([], []) => 0.0
@@ -265,7 +277,7 @@
   | (rel, irrel) =>
     let
       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
-                       |> chained ? curry Real.* chained_bonus_factor
+                       |> curry Real.* (locality_multiplier loc)
       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
       val res = rel_weight / (rel_weight + irrel_weight)
     in if Real.isFinite res then res else 0.0 end
@@ -294,7 +306,7 @@
                 |> pseudoconsts_of_term thy)
 
 type annotated_thm =
-  ((unit -> string * bool) * thm) * (string * pseudotype list) list
+  (((unit -> string) * locality) * thm) * (string * pseudotype list) list
 
 fun take_most_relevant max_max_imperfect max_relevant remaining_max
                        (candidates : (annotated_thm * real) list) =
@@ -315,12 +327,13 @@
                         Real.toString (#2 (hd accepts)));
     trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
         "): " ^ (accepts
-                 |> map (fn (((name, _), _), weight) =>
-                            fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
+                 |> map (fn ((((name, _), _), _), weight) =>
+                            name () ^ " [" ^ Real.toString weight ^ "]")
                  |> commas));
     (accepts, more_rejects @ rejects)
   end
 
+(* FUDGE *)
 val threshold_divisor = 2.0
 val ridiculous_threshold = 0.1
 val max_max_imperfect_fudge_factor = 0.66
@@ -392,17 +405,17 @@
                       hopeless_rejects hopeful_rejects)
             end
           | relevant candidates rejects hopeless
-                     (((ax as ((name, th), axiom_consts)), cached_weight)
+                     (((ax as (((_, loc), th), axiom_consts)), cached_weight)
                       :: hopeful) =
             let
               val weight =
                 case cached_weight of
                   SOME w => w
-                | NONE => axiom_weight (snd (name ())) const_tab rel_const_tab
-                                       axiom_consts
+                | NONE => axiom_weight loc const_tab rel_const_tab axiom_consts
 (* TODO: experiment
-val _ = if String.isPrefix "lift.simps(3" (fst (name ())) then
-tracing ("*** " ^ (fst (name ())) ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
+val name = fst (fst (fst ax)) ()
+val _ = if String.isPrefix "lift.simps(3" name then
+tracing ("*** " ^ name ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
 else
 ()
 *)
@@ -570,10 +583,12 @@
 
 fun all_name_thms_pairs ctxt reserved full_types add_thms chained_ths =
   let
-    val is_chained = member Thm.eq_thm chained_ths
-    val global_facts = PureThy.facts_of (ProofContext.theory_of ctxt)
+    val thy = ProofContext.theory_of ctxt
+    val thy_prefix = Context.theory_name thy ^ Long_Name.separator
+    val global_facts = PureThy.facts_of thy
     val local_facts = ProofContext.facts_of ctxt
     val named_locals = local_facts |> Facts.dest_static []
+    val is_chained = member Thm.eq_thm chained_ths
     (* Unnamed, not chained formulas with schematic variables are omitted,
        because they are rejected by the backticks (`...`) parser for some
        reason. *)
@@ -585,7 +600,7 @@
                   |> map (pair "" o single)
     val full_space =
       Name_Space.merge (Facts.space_of global_facts, Facts.space_of local_facts)
-    fun add_valid_facts foldx facts =
+    fun add_facts global foldx facts =
       foldx (fn (name0, ths) =>
         if name0 <> "" andalso
            forall (not o member Thm.eq_thm add_thms) ths andalso
@@ -596,6 +611,10 @@
           I
         else
           let
+            val base_loc =
+              if not global then Local
+              else if String.isPrefix thy_prefix name0 then Theory
+              else General
             val multi = length ths > 1
             fun backquotify th =
               "`" ^ Print_Mode.setmp [Print_Mode.input]
@@ -614,23 +633,24 @@
                      not (member Thm.eq_thm add_thms th) then
                     rest
                   else
-                    (fn () =>
-                        (if name0 = "" then
-                           th |> backquotify
-                         else
-                           let
-                             val name1 = Facts.extern facts name0
-                             val name2 = Name_Space.extern full_space name0
-                           in
-                             case find_first check_thms [name1, name2, name0] of
-                               SOME name => repair_name reserved multi j name
-                             | NONE => ""
-                           end, is_chained th), (multi, th)) :: rest)) ths
+                    (((fn () =>
+                          if name0 = "" then
+                            th |> backquotify
+                          else
+                            let
+                              val name1 = Facts.extern facts name0
+                              val name2 = Name_Space.extern full_space name0
+                            in
+                              case find_first check_thms [name1, name2, name0] of
+                                SOME name => repair_name reserved multi j name
+                              | NONE => ""
+                            end), if is_chained th then Chained else base_loc),
+                      (multi, th)) :: rest)) ths
             #> snd
           end)
   in
-    [] |> add_valid_facts fold local_facts (unnamed_locals @ named_locals)
-       |> add_valid_facts Facts.fold_static global_facts global_facts
+    [] |> add_facts false fold local_facts (unnamed_locals @ named_locals)
+       |> add_facts true Facts.fold_static global_facts global_facts
   end
 
 (* The single-name theorems go after the multiple-name ones, so that single
@@ -653,7 +673,8 @@
     val reserved = reserved_isar_keyword_table ()
     val axioms =
       (if only then
-         maps (name_thm_pairs_from_ref ctxt reserved chained_ths) add
+         maps (map (fn ((name, loc), th) => ((K name, loc), (true, th)))
+               o name_thm_pairs_from_ref ctxt reserved chained_ths) add
        else
          all_name_thms_pairs ctxt reserved full_types add_thms chained_ths)
       |> name_thm_pairs ctxt (respect_no_atp andalso not only)
@@ -668,7 +689,7 @@
      else
        relevance_filter ctxt threshold0 decay max_relevant theory_relevant
                         relevance_override axioms (concl_t :: hyp_ts))
-    |> map (apfst (fn f => f ())) |> sort_wrt (fst o fst)
+    |> map (apfst (apfst (fn f => f ()))) |> sort_wrt (fst o fst)
   end
 
 end;
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_minimize.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_minimize.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -7,11 +7,12 @@
 
 signature SLEDGEHAMMER_FACT_MINIMIZE =
 sig
+  type locality = Sledgehammer_Fact_Filter.locality
   type params = Sledgehammer.params
 
   val minimize_theorems :
-    params -> int -> int -> Proof.state -> ((string * bool) * thm list) list
-    -> ((string * bool) * thm list) list option * string
+    params -> int -> int -> Proof.state -> ((string * locality) * thm list) list
+    -> ((string * locality) * thm list) list option * string
   val run_minimize : params -> int -> Facts.ref list -> Proof.state -> unit
 end;
 
@@ -120,7 +121,7 @@
          val n = length min_thms
          val _ = priority (cat_lines
            ["Minimized: " ^ string_of_int n ^ " theorem" ^ plural_s n] ^
-            (case length (filter (snd o fst) min_thms) of
+            (case length (filter (curry (op =) Chained o snd o fst) min_thms) of
                0 => ""
              | n => " (including " ^ Int.toString n ^ " chained)") ^ ".")
        in
@@ -149,7 +150,7 @@
     val reserved = reserved_isar_keyword_table ()
     val chained_ths = #facts (Proof.goal state)
     val axioms =
-      maps (map (fn (name, (_, th)) => (name (), [th]))
+      maps (map (apsnd single)
             o name_thm_pairs_from_ref ctxt reserved chained_ths) refs
   in
     case subgoal_count state of
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -8,19 +8,20 @@
 
 signature SLEDGEHAMMER_PROOF_RECONSTRUCT =
 sig
+  type locality = Sledgehammer_Fact_Filter.locality
   type minimize_command = string list -> string
 
   val metis_proof_text:
-    bool * minimize_command * string * (string * bool) vector * thm * int
-    -> string * (string * bool) list
+    bool * minimize_command * string * (string * locality) vector * thm * int
+    -> string * (string * locality) list
   val isar_proof_text:
     string Symtab.table * bool * int * Proof.context * int list list
-    -> bool * minimize_command * string * (string * bool) vector * thm * int
-    -> string * (string * bool) list
+    -> bool * minimize_command * string * (string * locality) vector * thm * int
+    -> string * (string * locality) list
   val proof_text:
     bool -> string Symtab.table * bool * int * Proof.context * int list list
-    -> bool * minimize_command * string * (string * bool) vector * thm * int
-    -> string * (string * bool) list
+    -> bool * minimize_command * string * (string * locality) vector * thm * int
+    -> string * (string * locality) list
 end;
 
 structure Sledgehammer_Proof_Reconstruct : SLEDGEHAMMER_PROOF_RECONSTRUCT =
@@ -578,7 +579,7 @@
           (case strip_prefix_and_unascii axiom_prefix (List.last rest) of
              SOME name =>
              if member (op =) rest "file" then
-               SOME (name, is_true_for axiom_names name)
+               SOME (name, find_first_in_vector axiom_names name General)
              else
                axiom_name_at_index num
            | NONE => axiom_name_at_index num)
@@ -624,8 +625,8 @@
 
 fun used_facts axiom_names =
   used_facts_in_atp_proof axiom_names
-  #> List.partition snd
-  #> pairself (sort_distinct (string_ord) o map fst)
+  #> List.partition (curry (op =) Chained o snd)
+  #> pairself (sort_distinct (string_ord o pairself fst))
 
 fun metis_proof_text (full_types, minimize_command, atp_proof, axiom_names,
                       goal, i) =
@@ -633,9 +634,9 @@
     val (chained_lemmas, other_lemmas) = used_facts axiom_names atp_proof
     val n = Logic.count_prems (prop_of goal)
   in
-    (metis_line full_types i n other_lemmas ^
-     minimize_line minimize_command (other_lemmas @ chained_lemmas),
-     map (rpair false) other_lemmas @ map (rpair true) chained_lemmas)
+    (metis_line full_types i n (map fst other_lemmas) ^
+     minimize_line minimize_command (map fst (other_lemmas @ chained_lemmas)),
+     other_lemmas @ chained_lemmas)
   end
 
 (** Isar proof construction and manipulation **)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_translate.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_translate.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -18,8 +18,8 @@
   val tfrees_name : string
   val prepare_problem :
     Proof.context -> bool -> bool -> bool -> bool -> term list -> term
-    -> ((string * bool) * thm) list
-    -> string problem * string Symtab.table * int * (string * bool) vector
+    -> ((string * 'a) * thm) list
+    -> string problem * string Symtab.table * int * (string * 'a) vector
 end;
 
 structure Sledgehammer_Translate : SLEDGEHAMMER_TRANSLATE =
@@ -39,11 +39,11 @@
 (* Freshness almost guaranteed! *)
 val sledgehammer_weak_prefix = "Sledgehammer:"
 
-datatype fol_formula =
-  FOLFormula of {name: string,
-                 kind: kind,
-                 combformula: (name, combterm) formula,
-                 ctypes_sorts: typ list}
+type fol_formula =
+  {name: string,
+   kind: kind,
+   combformula: (name, combterm) formula,
+   ctypes_sorts: typ list}
 
 fun mk_anot phi = AConn (ANot, [phi])
 fun mk_aconn c phi1 phi2 = AConn (c, [phi1, phi2])
@@ -190,15 +190,14 @@
               |> kind <> Axiom ? freeze_term
     val (combformula, ctypes_sorts) = combformula_for_prop thy t []
   in
-    FOLFormula {name = name, combformula = combformula, kind = kind,
-                ctypes_sorts = ctypes_sorts}
+    {name = name, combformula = combformula, kind = kind,
+     ctypes_sorts = ctypes_sorts}
   end
 
-fun make_axiom ctxt presimp ((name, chained), th) =
+fun make_axiom ctxt presimp ((name, loc), th) =
   case make_formula ctxt presimp name Axiom (prop_of th) of
-    FOLFormula {combformula = AAtom (CombConst (("c_True", _), _, _)), ...} =>
-    NONE
-  | formula => SOME ((name, chained), formula)
+    {combformula = AAtom (CombConst (("c_True", _), _, _)), ...} => NONE
+  | formula => SOME ((name, loc), formula)
 fun make_conjecture ctxt ts =
   let val last = length ts - 1 in
     map2 (fn j => make_formula ctxt true (Int.toString j)
@@ -215,7 +214,7 @@
 fun count_combformula (AQuant (_, _, phi)) = count_combformula phi
   | count_combformula (AConn (_, phis)) = fold count_combformula phis
   | count_combformula (AAtom tm) = count_combterm tm
-fun count_fol_formula (FOLFormula {combformula, ...}) =
+fun count_fol_formula ({combformula, ...} : fol_formula) =
   count_combformula combformula
 
 val optional_helpers =
@@ -326,13 +325,13 @@
       | aux (AAtom tm) = AAtom (fo_term_for_combterm full_types tm)
   in aux end
 
-fun formula_for_axiom full_types (FOLFormula {combformula, ctypes_sorts, ...}) =
+fun formula_for_axiom full_types
+                      ({combformula, ctypes_sorts, ...} : fol_formula) =
   mk_ahorn (map (formula_for_fo_literal o fo_literal_for_type_literal)
                 (type_literals_for_types ctypes_sorts))
            (formula_for_combformula full_types combformula)
 
-fun problem_line_for_fact prefix full_types
-                          (formula as FOLFormula {name, kind, ...}) =
+fun problem_line_for_fact prefix full_types (formula as {name, kind, ...}) =
   Fof (prefix ^ ascii_of name, kind, formula_for_axiom full_types formula)
 
 fun problem_line_for_class_rel_clause (ClassRelClause {name, subclass,
@@ -357,11 +356,11 @@
                      (fo_literal_for_arity_literal conclLit)))
 
 fun problem_line_for_conjecture full_types
-                                (FOLFormula {name, kind, combformula, ...}) =
+                                ({name, kind, combformula, ...} : fol_formula) =
   Fof (conjecture_prefix ^ name, kind,
        formula_for_combformula full_types combformula)
 
-fun free_type_literals_for_conjecture (FOLFormula {ctypes_sorts, ...}) =
+fun free_type_literals_for_conjecture ({ctypes_sorts, ...} : fol_formula) =
   map fo_literal_for_type_literal (type_literals_for_types ctypes_sorts)
 
 fun problem_line_for_free_type lit =
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Thu Aug 26 09:23:21 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Thu Aug 26 10:42:06 2010 +0200
@@ -6,7 +6,7 @@
 
 signature SLEDGEHAMMER_UTIL =
 sig
-  val is_true_for : (string * bool) vector -> string -> bool
+  val find_first_in_vector : (''a * 'b) vector -> ''a -> 'b -> 'b
   val plural_s : int -> string
   val serial_commas : string -> string list -> string list
   val simplify_spaces : string -> string
@@ -29,8 +29,9 @@
 structure Sledgehammer_Util : SLEDGEHAMMER_UTIL =
 struct
 
-fun is_true_for v s =
-  Vector.foldl (fn ((s', b'), b) => if s' = s then b' else b) false v
+fun find_first_in_vector vec key default =
+  Vector.foldl (fn ((key', value'), value) =>
+                   if key' = key then value' else value) default vec
 
 fun plural_s n = if n = 1 then "" else "s"