take proximity into account for MaSh + fix a debilitating bug in feature generation
authorblanchet
Wed, 05 Dec 2012 13:25:06 +0100
changeset 50383 4274b25ff4e7
parent 50382 cb564ff43c28
child 50384 b9b967da28e9
take proximity into account for MaSh + fix a debilitating bug in feature generation
src/HOL/TPTP/mash_eval.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML
--- a/src/HOL/TPTP/mash_eval.ML	Wed Dec 05 13:25:06 2012 +0100
+++ b/src/HOL/TPTP/mash_eval.ML	Wed Dec 05 13:25:06 2012 +0100
@@ -79,7 +79,7 @@
               slack_max_facts NONE hyp_ts concl_t facts
           |> Sledgehammer_MePo.weight_mepo_facts
         val mash_facts = suggested_facts suggs facts
-        val mess = [(mepo_facts, []), (mash_facts, [])]
+        val mess = [(0.5, (mepo_facts, [])), (0.5, (mash_facts, []))]
         val mesh_facts = mesh_facts slack_max_facts mess
         val isar_facts = suggested_facts (map (rpair 1.0) isar_deps) facts
         fun prove ok heading get facts =
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 05 13:25:06 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 05 13:25:06 2012 +0100
@@ -44,7 +44,8 @@
   val suggested_facts :
     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
   val mesh_facts :
-    int -> ((('a * thm) * real) list * ('a * thm) list) list -> ('a * thm) list
+    int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
+    -> ('a * thm) list
   val theory_ord : theory * theory -> order
   val thm_ord : thm * thm -> order
   val goal_of_thm : theory -> thm -> thm
@@ -59,8 +60,8 @@
     -> thm -> bool * string list option
   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
   val mash_suggested_facts :
-    Proof.context -> params -> string -> int -> term list -> term
-    -> fact list -> fact list * fact list
+    Proof.context -> params -> string -> int -> term list -> term -> fact list
+    -> fact list
   val mash_learn_proof :
     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
     -> unit
@@ -298,7 +299,7 @@
 
 local
 
-val version = "*** MaSh version 20121204a ***"
+val version = "*** MaSh version 20121205a ***"
 
 exception Too_New of unit
 
@@ -425,30 +426,43 @@
       Symtab.lookup tab name |> Option.map (rpair weight)
   in map_filter find_sugg suggs end
 
-fun sum_avg [] = 0
-  | sum_avg xs =
+fun scaled_avg [] = 0
+  | scaled_avg xs =
     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
 
-fun normalize_scores [] = []
-  | normalize_scores ((fact, score) :: tail) =
-    (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail
+fun avg [] = 0.0
+  | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
 
-fun mesh_facts max_facts [(sels, unks)] =
+fun normalize_scores _ [] = []
+  | normalize_scores max_facts xs =
+    let val avg = avg (map snd (take max_facts xs)) in
+      map (apsnd (curry Real.* (1.0 / avg))) xs
+    end
+
+fun mesh_facts max_facts [(_, (sels, unks))] =
     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   | mesh_facts max_facts mess =
     let
-      val mess = mess |> map (apfst (normalize_scores #> `length))
+      val mess =
+        mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
       val fact_eq = Thm.eq_thm o pairself snd
-      fun score_at sels = try (nth sels) #> Option.map snd
-      fun score_in fact ((sel_len, sels), unks) =
-        case find_index (curry fact_eq fact o fst) sels of
-          ~1 => (case find_index (curry fact_eq fact) unks of
-                   ~1 => score_at sels sel_len
-                 | _ => NONE)
-        | rank => score_at sels rank
-      fun weight_of fact = mess |> map_filter (score_in fact) |> sum_avg
+      fun score_in fact (global_weight, ((sel_len, sels), unks)) =
+        let
+          fun score_at j =
+            case try (nth sels) j of
+              SOME (_, score) => SOME (global_weight * score)
+            | NONE => NONE
+        in
+          case find_index (curry fact_eq fact o fst) sels of
+            ~1 => (case find_index (curry fact_eq fact) unks of
+                     ~1 => score_at sel_len
+                   | _ => NONE)
+          | rank => score_at rank
+        end
+      fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
       val facts =
-        fold (union fact_eq o map fst o take max_facts o snd o fst) mess []
+        fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess
+             []
     in
       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
             |> map snd |> take max_facts
@@ -459,7 +473,7 @@
 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *))
 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *))
-val local_feature = ("local", 20.0 (* FUDGE *))
+val local_feature = ("local", 1.0 (* FUDGE *))
 val lams_feature = ("lams", 1.0 (* FUDGE *))
 val skos_feature = ("skos", 1.0 (* FUDGE *))
 
@@ -531,7 +545,7 @@
         let
           val ps = patternify_term (u :: args) depth t
           val qs = "" :: patternify_term [] (depth - 1) u
-        in map_product (fn p => fn "" => p | q => "(" ^ q ^ ")") ps qs end
+        in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
       | patternify_term _ _ _ = []
     val add_term_pattern =
       union (op = o pairself fst) o map term_feature_of oo patternify_term []
@@ -692,24 +706,22 @@
                                       (Graph.imm_preds fact_G new) news))
   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
 
-(* Generate more suggestions than requested, because some might be thrown out
-   later for various reasons and "meshing" gives better results with some
-   slack. *)
-fun max_suggs_of max_facts = max_facts + Int.min (50, max_facts)
-
 fun is_fact_in_graph fact_G (_, th) =
   can (Graph.get_node fact_G) (nickname_of th)
 
-fun interleave 0 _ _ = []
-  | interleave n [] ys = take n ys
-  | interleave n xs [] = take n xs
-  | interleave 1 (x :: _) _ = [x]
-  | interleave n (x :: xs) (y :: ys) = x :: y :: interleave (n - 2) xs ys
-
 (* factor that controls whether unknown global facts should be included *)
 val include_unk_global_factor = 15
 
-val weight_mash_facts = weight_mepo_facts (* use MePo weights for now *)
+(* use MePo weights for now *)
+val weight_raw_mash_facts = weight_mepo_facts
+val weight_mash_facts = weight_raw_mash_facts
+
+(* FUDGE *)
+fun weight_of_proximity_fact rank =
+  Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0
+
+fun weight_proximity_facts facts =
+  facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
 
 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
                          concl_t facts =
@@ -725,27 +737,23 @@
               val feats =
                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
             in
-              (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts)
-                                  (parents, feats))
+              (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
             end)
     val (chained, unchained) =
       List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
-    val sels =
+    val raw_mash =
       facts |> suggested_facts suggs
             (* The weights currently returned by "mash.py" are too spaced out to
                make any sense. *)
             |> map fst
-            |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained)
-    val (unk_global, unk_local) =
-      unchained |> filter_out (is_fact_in_graph fact_G)
-                |> List.partition (fn ((_, (scope, _)), _) => scope = Global)
-    val (small_unk_global, big_unk_global) =
-      ([], unk_global)
-      |> include_unk_global_factor * length unk_global <= max_facts ? swap
-  in
-    (interleave max_facts (chained @ unk_local @ small_unk_global) sels,
-     big_unk_global)
-  end
+    val proximity =
+      chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained
+                       |> sort (thm_ord o pairself snd o swap))
+    val unknown = facts |> filter_out (is_fact_in_graph fact_G)
+    val mess =
+      [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
+       (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))]
+  in mesh_facts max_facts mess end
 
 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   let
@@ -995,6 +1003,10 @@
 fun is_mash_enabled () = (getenv "MASH" = "yes")
 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
 
+(* Generate more suggestions than requested, because some might be thrown out
+   later for various reasons. *)
+fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
+
 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
    Sledgehammer and Try. *)
 val min_secs_for_learning = 15
@@ -1040,11 +1052,12 @@
                              facts
         |> weight_mepo_facts
       fun mash () =
-        mash_suggested_facts ctxt params prover max_facts hyp_ts concl_t facts
-        |>> weight_mash_facts
+        mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
+            hyp_ts concl_t facts
+        |> weight_mash_facts
       val mess =
-        [] |> (if fact_filter <> mashN then cons (mepo (), []) else I)
-           |> (if fact_filter <> mepoN then cons (mash ()) else I)
+        [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
+           |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I)
     in
       mesh_facts max_facts mess
       |> not (null add_ths) ? prepend_facts add_ths
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML	Wed Dec 05 13:25:06 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML	Wed Dec 05 13:25:06 2012 +0100
@@ -510,10 +510,12 @@
   end
 
 (* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *)
-fun weight_of_fact rank = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0
+(* FUDGE *)
+fun weight_of_mepo_fact rank =
+  Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0
 
 fun weight_mepo_facts facts =
-  facts ~~ map weight_of_fact (0 upto length facts - 1)
+  facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
 
 fun mepo_suggested_facts ctxt
         ({fact_thresholds = (thres0, thres1), ...} : params) prover