# HG changeset patch # User blanchet # Date 1354710306 -3600 # Node ID 4274b25ff4e7243273029db49bc4fdc0e69c3d2c # Parent cb564ff43c28bfb5275b16f68fb29451384e15cc take proximity into account for MaSh + fix a debilitating bug in feature generation diff -r cb564ff43c28 -r 4274b25ff4e7 src/HOL/TPTP/mash_eval.ML --- a/src/HOL/TPTP/mash_eval.ML Wed Dec 05 13:25:06 2012 +0100 +++ b/src/HOL/TPTP/mash_eval.ML Wed Dec 05 13:25:06 2012 +0100 @@ -79,7 +79,7 @@ slack_max_facts NONE hyp_ts concl_t facts |> Sledgehammer_MePo.weight_mepo_facts val mash_facts = suggested_facts suggs facts - val mess = [(mepo_facts, []), (mash_facts, [])] + val mess = [(0.5, (mepo_facts, [])), (0.5, (mash_facts, []))] val mesh_facts = mesh_facts slack_max_facts mess val isar_facts = suggested_facts (map (rpair 1.0) isar_deps) facts fun prove ok heading get facts = diff -r cb564ff43c28 -r 4274b25ff4e7 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Dec 05 13:25:06 2012 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Dec 05 13:25:06 2012 +0100 @@ -44,7 +44,8 @@ val suggested_facts : (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list val mesh_facts : - int -> ((('a * thm) * real) list * ('a * thm) list) list -> ('a * thm) list + int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list + -> ('a * thm) list val theory_ord : theory * theory -> order val thm_ord : thm * thm -> order val goal_of_thm : theory -> thm -> thm @@ -59,8 +60,8 @@ -> thm -> bool * string list option val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list val mash_suggested_facts : - Proof.context -> params -> string -> int -> term list -> term - -> fact list -> fact list * fact list + Proof.context -> params -> string -> int -> term list -> term -> fact list + -> fact list val mash_learn_proof : Proof.context -> params -> string -> term -> ('a * thm) list -> thm list -> unit @@ -298,7 +299,7 @@ local -val version = "*** MaSh version 20121204a ***" +val version = "*** MaSh version 20121205a ***" exception Too_New of unit @@ -425,30 +426,43 @@ Symtab.lookup tab name |> Option.map (rpair weight) in map_filter find_sugg suggs end -fun sum_avg [] = 0 - | sum_avg xs = +fun scaled_avg [] = 0 + | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs -fun normalize_scores [] = [] - | normalize_scores ((fact, score) :: tail) = - (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail +fun avg [] = 0.0 + | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) -fun mesh_facts max_facts [(sels, unks)] = +fun normalize_scores _ [] = [] + | normalize_scores max_facts xs = + let val avg = avg (map snd (take max_facts xs)) in + map (apsnd (curry Real.* (1.0 / avg))) xs + end + +fun mesh_facts max_facts [(_, (sels, unks))] = map fst (take max_facts sels) @ take (max_facts - length sels) unks | mesh_facts max_facts mess = let - val mess = mess |> map (apfst (normalize_scores #> `length)) + val mess = + mess |> map (apsnd (apfst (normalize_scores max_facts #> `length))) val fact_eq = Thm.eq_thm o pairself snd - fun score_at sels = try (nth sels) #> Option.map snd - fun score_in fact ((sel_len, sels), unks) = - case find_index (curry fact_eq fact o fst) sels of - ~1 => (case find_index (curry fact_eq fact) unks of - ~1 => score_at sels sel_len - | _ => NONE) - | rank => score_at sels rank - fun weight_of fact = mess |> map_filter (score_in fact) |> sum_avg + fun score_in fact (global_weight, ((sel_len, sels), unks)) = + let + fun score_at j = + case try (nth sels) j of + SOME (_, score) => SOME (global_weight * score) + | NONE => NONE + in + case find_index (curry fact_eq fact o fst) sels of + ~1 => (case find_index (curry fact_eq fact) unks of + ~1 => score_at sel_len + | _ => NONE) + | rank => score_at rank + end + fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg val facts = - fold (union fact_eq o map fst o take max_facts o snd o fst) mess [] + fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess + [] in facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst) |> map snd |> take max_facts @@ -459,7 +473,7 @@ fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *)) fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *)) fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *)) -val local_feature = ("local", 20.0 (* FUDGE *)) +val local_feature = ("local", 1.0 (* FUDGE *)) val lams_feature = ("lams", 1.0 (* FUDGE *)) val skos_feature = ("skos", 1.0 (* FUDGE *)) @@ -531,7 +545,7 @@ let val ps = patternify_term (u :: args) depth t val qs = "" :: patternify_term [] (depth - 1) u - in map_product (fn p => fn "" => p | q => "(" ^ q ^ ")") ps qs end + in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end | patternify_term _ _ _ = [] val add_term_pattern = union (op = o pairself fst) o map term_feature_of oo patternify_term [] @@ -692,24 +706,22 @@ (Graph.imm_preds fact_G new) news)) in find_maxes Symtab.empty ([], Graph.maximals fact_G) end -(* Generate more suggestions than requested, because some might be thrown out - later for various reasons and "meshing" gives better results with some - slack. *) -fun max_suggs_of max_facts = max_facts + Int.min (50, max_facts) - fun is_fact_in_graph fact_G (_, th) = can (Graph.get_node fact_G) (nickname_of th) -fun interleave 0 _ _ = [] - | interleave n [] ys = take n ys - | interleave n xs [] = take n xs - | interleave 1 (x :: _) _ = [x] - | interleave n (x :: xs) (y :: ys) = x :: y :: interleave (n - 2) xs ys - (* factor that controls whether unknown global facts should be included *) val include_unk_global_factor = 15 -val weight_mash_facts = weight_mepo_facts (* use MePo weights for now *) +(* use MePo weights for now *) +val weight_raw_mash_facts = weight_mepo_facts +val weight_mash_facts = weight_raw_mash_facts + +(* FUDGE *) +fun weight_of_proximity_fact rank = + Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0 + +fun weight_proximity_facts facts = + facts ~~ map weight_of_proximity_fact (0 upto length facts - 1) fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts concl_t facts = @@ -725,27 +737,23 @@ val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) in - (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts) - (parents, feats)) + (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats)) end) val (chained, unchained) = List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts - val sels = + val raw_mash = facts |> suggested_facts suggs (* The weights currently returned by "mash.py" are too spaced out to make any sense. *) |> map fst - |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained) - val (unk_global, unk_local) = - unchained |> filter_out (is_fact_in_graph fact_G) - |> List.partition (fn ((_, (scope, _)), _) => scope = Global) - val (small_unk_global, big_unk_global) = - ([], unk_global) - |> include_unk_global_factor * length unk_global <= max_facts ? swap - in - (interleave max_facts (chained @ unk_local @ small_unk_global) sels, - big_unk_global) - end + val proximity = + chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained + |> sort (thm_ord o pairself snd o swap)) + val unknown = facts |> filter_out (is_fact_in_graph fact_G) + val mess = + [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)), + (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))] + in mesh_facts max_facts mess end fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) = let @@ -995,6 +1003,10 @@ fun is_mash_enabled () = (getenv "MASH" = "yes") fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt))) +(* Generate more suggestions than requested, because some might be thrown out + later for various reasons. *) +fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts) + (* The threshold should be large enough so that MaSh doesn't kick in for Auto Sledgehammer and Try. *) val min_secs_for_learning = 15 @@ -1040,11 +1052,12 @@ facts |> weight_mepo_facts fun mash () = - mash_suggested_facts ctxt params prover max_facts hyp_ts concl_t facts - |>> weight_mash_facts + mash_suggested_facts ctxt params prover (generous_max_facts max_facts) + hyp_ts concl_t facts + |> weight_mash_facts val mess = - [] |> (if fact_filter <> mashN then cons (mepo (), []) else I) - |> (if fact_filter <> mepoN then cons (mash ()) else I) + [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I) + |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I) in mesh_facts max_facts mess |> not (null add_ths) ? prepend_facts add_ths diff -r cb564ff43c28 -r 4274b25ff4e7 src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML Wed Dec 05 13:25:06 2012 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML Wed Dec 05 13:25:06 2012 +0100 @@ -510,10 +510,12 @@ end (* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *) -fun weight_of_fact rank = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0 +(* FUDGE *) +fun weight_of_mepo_fact rank = + Math.pow (1.5, 15.5 - 0.05 * Real.fromInt rank) + 15.0 fun weight_mepo_facts facts = - facts ~~ map weight_of_fact (0 upto length facts - 1) + facts ~~ map weight_of_mepo_fact (0 upto length facts - 1) fun mepo_suggested_facts ctxt ({fact_thresholds = (thres0, thres1), ...} : params) prover