take proximity into account for MaSh + fix a debilitating bug in feature generation
authorblanchet
Wed Dec 05 13:25:06 2012 +0100 (2012-12-05 ago)
changeset 503834274b25ff4e7
parent 50382 cb564ff43c28
child 50384 b9b967da28e9
take proximity into account for MaSh + fix a debilitating bug in feature generation
src/HOL/TPTP/mash_eval.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML
     1.1 --- a/src/HOL/TPTP/mash_eval.ML	Wed Dec 05 13:25:06 2012 +0100
     1.2 +++ b/src/HOL/TPTP/mash_eval.ML	Wed Dec 05 13:25:06 2012 +0100
     1.3 @@ -79,7 +79,7 @@
     1.4                slack_max_facts NONE hyp_ts concl_t facts
     1.5            |> Sledgehammer_MePo.weight_mepo_facts
     1.6          val mash_facts = suggested_facts suggs facts
     1.7 -        val mess = [(mepo_facts, []), (mash_facts, [])]
     1.8 +        val mess = [(0.5, (mepo_facts, [])), (0.5, (mash_facts, []))]
     1.9          val mesh_facts = mesh_facts slack_max_facts mess
    1.10          val isar_facts = suggested_facts (map (rpair 1.0) isar_deps) facts
    1.11          fun prove ok heading get facts =
     2.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 05 13:25:06 2012 +0100
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 05 13:25:06 2012 +0100
     2.3 @@ -44,7 +44,8 @@
     2.4    val suggested_facts :
     2.5      (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
     2.6    val mesh_facts :
     2.7 -    int -> ((('a * thm) * real) list * ('a * thm) list) list -> ('a * thm) list
     2.8 +    int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
     2.9 +    -> ('a * thm) list
    2.10    val theory_ord : theory * theory -> order
    2.11    val thm_ord : thm * thm -> order
    2.12    val goal_of_thm : theory -> thm -> thm
    2.13 @@ -59,8 +60,8 @@
    2.14      -> thm -> bool * string list option
    2.15    val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    2.16    val mash_suggested_facts :
    2.17 -    Proof.context -> params -> string -> int -> term list -> term
    2.18 -    -> fact list -> fact list * fact list
    2.19 +    Proof.context -> params -> string -> int -> term list -> term -> fact list
    2.20 +    -> fact list
    2.21    val mash_learn_proof :
    2.22      Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    2.23      -> unit
    2.24 @@ -298,7 +299,7 @@
    2.25  
    2.26  local
    2.27  
    2.28 -val version = "*** MaSh version 20121204a ***"
    2.29 +val version = "*** MaSh version 20121205a ***"
    2.30  
    2.31  exception Too_New of unit
    2.32  
    2.33 @@ -425,30 +426,43 @@
    2.34        Symtab.lookup tab name |> Option.map (rpair weight)
    2.35    in map_filter find_sugg suggs end
    2.36  
    2.37 -fun sum_avg [] = 0
    2.38 -  | sum_avg xs =
    2.39 +fun scaled_avg [] = 0
    2.40 +  | scaled_avg xs =
    2.41      Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
    2.42  
    2.43 -fun normalize_scores [] = []
    2.44 -  | normalize_scores ((fact, score) :: tail) =
    2.45 -    (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail
    2.46 +fun avg [] = 0.0
    2.47 +  | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
    2.48  
    2.49 -fun mesh_facts max_facts [(sels, unks)] =
    2.50 +fun normalize_scores _ [] = []
    2.51 +  | normalize_scores max_facts xs =
    2.52 +    let val avg = avg (map snd (take max_facts xs)) in
    2.53 +      map (apsnd (curry Real.* (1.0 / avg))) xs
    2.54 +    end
    2.55 +
    2.56 +fun mesh_facts max_facts [(_, (sels, unks))] =
    2.57      map fst (take max_facts sels) @ take (max_facts - length sels) unks
    2.58    | mesh_facts max_facts mess =
    2.59      let
    2.60 -      val mess = mess |> map (apfst (normalize_scores #> `length))
    2.61 +      val mess =
    2.62 +        mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
    2.63        val fact_eq = Thm.eq_thm o pairself snd
    2.64 -      fun score_at sels = try (nth sels) #> Option.map snd
    2.65 -      fun score_in fact ((sel_len, sels), unks) =
    2.66 -        case find_index (curry fact_eq fact o fst) sels of
    2.67 -          ~1 => (case find_index (curry fact_eq fact) unks of
    2.68 -                   ~1 => score_at sels sel_len
    2.69 -                 | _ => NONE)
    2.70 -        | rank => score_at sels rank
    2.71 -      fun weight_of fact = mess |> map_filter (score_in fact) |> sum_avg
    2.72 +      fun score_in fact (global_weight, ((sel_len, sels), unks)) =
    2.73 +        let
    2.74 +          fun score_at j =
    2.75 +            case try (nth sels) j of
    2.76 +              SOME (_, score) => SOME (global_weight * score)
    2.77 +            | NONE => NONE
    2.78 +        in
    2.79 +          case find_index (curry fact_eq fact o fst) sels of
    2.80 +            ~1 => (case find_index (curry fact_eq fact) unks of
    2.81 +                     ~1 => score_at sel_len
    2.82 +                   | _ => NONE)
    2.83 +          | rank => score_at rank
    2.84 +        end
    2.85 +      fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
    2.86        val facts =
    2.87 -        fold (union fact_eq o map fst o take max_facts o snd o fst) mess []
    2.88 +        fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess
    2.89 +             []
    2.90      in
    2.91        facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
    2.92              |> map snd |> take max_facts
    2.93 @@ -459,7 +473,7 @@
    2.94  fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *))
    2.95  fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
    2.96  fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *))
    2.97 -val local_feature = ("local", 20.0 (* FUDGE *))
    2.98 +val local_feature = ("local", 1.0 (* FUDGE *))
    2.99  val lams_feature = ("lams", 1.0 (* FUDGE *))
   2.100  val skos_feature = ("skos", 1.0 (* FUDGE *))
   2.101  
   2.102 @@ -531,7 +545,7 @@
   2.103          let
   2.104            val ps = patternify_term (u :: args) depth t
   2.105            val qs = "" :: patternify_term [] (depth - 1) u
   2.106 -        in map_product (fn p => fn "" => p | q => "(" ^ q ^ ")") ps qs end
   2.107 +        in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   2.108        | patternify_term _ _ _ = []
   2.109      val add_term_pattern =
   2.110        union (op = o pairself fst) o map term_feature_of oo patternify_term []
   2.111 @@ -692,24 +706,22 @@
   2.112                                        (Graph.imm_preds fact_G new) news))
   2.113    in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   2.114  
   2.115 -(* Generate more suggestions than requested, because some might be thrown out
   2.116 -   later for various reasons and "meshing" gives better results with some
   2.117 -   slack. *)
   2.118 -fun max_suggs_of max_facts = max_facts + Int.min (50, max_facts)
   2.119 -
   2.120  fun is_fact_in_graph fact_G (_, th) =
   2.121    can (Graph.get_node fact_G) (nickname_of th)
   2.122  
   2.123 -fun interleave 0 _ _ = []
   2.124 -  | interleave n [] ys = take n ys
   2.125 -  | interleave n xs [] = take n xs
   2.126 -  | interleave 1 (x :: _) _ = [x]
   2.127 -  | interleave n (x :: xs) (y :: ys) = x :: y :: interleave (n - 2) xs ys
   2.128 -
   2.129  (* factor that controls whether unknown global facts should be included *)
   2.130  val include_unk_global_factor = 15
   2.131  
   2.132 -val weight_mash_facts = weight_mepo_facts (* use MePo weights for now *)
   2.133 +(* use MePo weights for now *)
   2.134 +val weight_raw_mash_facts = weight_mepo_facts
   2.135 +val weight_mash_facts = weight_raw_mash_facts
   2.136 +
   2.137 +(* FUDGE *)
   2.138 +fun weight_of_proximity_fact rank =
   2.139 +  Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0
   2.140 +
   2.141 +fun weight_proximity_facts facts =
   2.142 +  facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   2.143  
   2.144  fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   2.145                           concl_t facts =
   2.146 @@ -725,27 +737,23 @@
   2.147                val feats =
   2.148                  features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   2.149              in
   2.150 -              (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts)
   2.151 -                                  (parents, feats))
   2.152 +              (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   2.153              end)
   2.154      val (chained, unchained) =
   2.155        List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
   2.156 -    val sels =
   2.157 +    val raw_mash =
   2.158        facts |> suggested_facts suggs
   2.159              (* The weights currently returned by "mash.py" are too spaced out to
   2.160                 make any sense. *)
   2.161              |> map fst
   2.162 -            |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained)
   2.163 -    val (unk_global, unk_local) =
   2.164 -      unchained |> filter_out (is_fact_in_graph fact_G)
   2.165 -                |> List.partition (fn ((_, (scope, _)), _) => scope = Global)
   2.166 -    val (small_unk_global, big_unk_global) =
   2.167 -      ([], unk_global)
   2.168 -      |> include_unk_global_factor * length unk_global <= max_facts ? swap
   2.169 -  in
   2.170 -    (interleave max_facts (chained @ unk_local @ small_unk_global) sels,
   2.171 -     big_unk_global)
   2.172 -  end
   2.173 +    val proximity =
   2.174 +      chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained
   2.175 +                       |> sort (thm_ord o pairself snd o swap))
   2.176 +    val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   2.177 +    val mess =
   2.178 +      [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
   2.179 +       (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))]
   2.180 +  in mesh_facts max_facts mess end
   2.181  
   2.182  fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   2.183    let
   2.184 @@ -995,6 +1003,10 @@
   2.185  fun is_mash_enabled () = (getenv "MASH" = "yes")
   2.186  fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
   2.187  
   2.188 +(* Generate more suggestions than requested, because some might be thrown out
   2.189 +   later for various reasons. *)
   2.190 +fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
   2.191 +
   2.192  (* The threshold should be large enough so that MaSh doesn't kick in for Auto
   2.193     Sledgehammer and Try. *)
   2.194  val min_secs_for_learning = 15
   2.195 @@ -1040,11 +1052,12 @@
   2.196                               facts
   2.197          |> weight_mepo_facts
   2.198        fun mash () =
   2.199 -        mash_suggested_facts ctxt params prover max_facts hyp_ts concl_t facts
   2.200 -        |>> weight_mash_facts
   2.201 +        mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
   2.202 +            hyp_ts concl_t facts
   2.203 +        |> weight_mash_facts
   2.204        val mess =
   2.205 -        [] |> (if fact_filter <> mashN then cons (mepo (), []) else I)
   2.206 -           |> (if fact_filter <> mepoN then cons (mash ()) else I)
   2.207 +        [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
   2.208 +           |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I)
   2.209      in
   2.210        mesh_facts max_facts mess
   2.211        |> not (null add_ths) ? prepend_facts add_ths
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML	Wed Dec 05 13:25:06 2012 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML	Wed Dec 05 13:25:06 2012 +0100
     3.3 @@ -510,10 +510,12 @@
     3.4    end
     3.5  
     3.6  (* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *)
     3.7 -fun weight_of_fact rank = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0
     3.8 +(* FUDGE *)
     3.9 +fun weight_of_mepo_fact rank =
    3.10 +  Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0
    3.11  
    3.12  fun weight_mepo_facts facts =
    3.13 -  facts ~~ map weight_of_fact (0 upto length facts - 1)
    3.14 +  facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
    3.15  
    3.16  fun mepo_suggested_facts ctxt
    3.17          ({fact_thresholds = (thres0, thres1), ...} : params) prover