use correct weights in MeSh driver
authorblanchet
Thu, 17 Jan 2013 23:29:22 +0100
changeset 50965 7a7d1418301e
parent 50964 2a990baa09af
child 50966 b85cb3049df9
use correct weights in MeSh driver
src/HOL/TPTP/mash_eval.ML
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/TPTP/mash_eval.ML	Thu Jan 17 23:29:17 2013 +0100
+++ b/src/HOL/TPTP/mash_eval.ML	Thu Jan 17 23:29:22 2013 +0100
@@ -93,11 +93,12 @@
                          mesh_isar_line), mesh_prover_line)) =
       if in_range range j then
         let
-          val (name1, mepo_suggs) = extract_suggestions mepo_line
-          val (name2, mash_isar_suggs) = extract_suggestions mash_isar_line
-          val (name3, mash_prover_suggs) = extract_suggestions mash_prover_line
-          val (name4, mesh_isar_suggs) = extract_suggestions mesh_isar_line
-          val (name5, mesh_prover_suggs) = extract_suggestions mesh_prover_line
+          val get_suggs = extract_suggestions ##> take slack_max_facts
+          val (name1, mepo_suggs) = get_suggs mepo_line
+          val (name2, mash_isar_suggs) = get_suggs mash_isar_line
+          val (name3, mash_prover_suggs) = get_suggs mash_prover_line
+          val (name4, mesh_isar_suggs) = get_suggs mesh_isar_line
+          val (name5, mesh_prover_suggs) = get_suggs mesh_prover_line
           val [name] =
             [name1, name2, name3, name4, name5]
             |> filter (curry (op <>) "") |> distinct (op =)
@@ -115,12 +116,12 @@
           val mepo_facts =
             get_facts mepo_suggs (fn _ =>
                 mepo_suggested_facts ctxt params prover slack_max_facts NONE
-                                     hyp_ts concl_t facts
-                |> weight_mepo_facts)
+                                     hyp_ts concl_t facts)
+            |> weight_mepo_facts
           fun mash_of suggs =
             get_facts suggs (fn _ =>
-                find_mash_suggestions slack_max_facts suggs facts [] []
-                |> fst |> weight_mash_facts)
+                find_mash_suggestions slack_max_facts suggs facts [] [] |> fst)
+            |> weight_mash_facts
           val mash_isar_facts = mash_of mash_isar_suggs
           val mash_prover_facts = mash_of mash_prover_suggs
           fun mess_of mash_facts =
@@ -129,12 +130,10 @@
           fun mesh_of suggs mash_facts =
             get_facts suggs (fn _ =>
                 mesh_facts (Thm.eq_thm_prop o pairself snd) slack_max_facts
-                           (mess_of mash_facts)
-                |> map (rpair 1.0))
+                           (mess_of mash_facts))
           val mesh_isar_facts = mesh_of mesh_isar_suggs mash_isar_facts
           val mesh_prover_facts = mesh_of mesh_prover_suggs mash_prover_facts
-          val isar_facts =
-            find_suggested_facts (map (rpair 1.0) isar_deps) facts
+          val isar_facts = find_suggested_facts isar_deps facts
           (* adapted from "mirabelle_sledgehammer.ML" *)
           fun set_file_name method (SOME dir) =
               let
@@ -147,7 +146,7 @@
                 #> Config.put SMT_Config.debug_files (dir ^ "/" ^ prob_prefix)
               end
             | set_file_name _ NONE = I
-          fun prove method facts =
+          fun prove method get facts =
             if not (member (op =) methods method) orelse
                (null facts andalso method <> IsarN) then
               (str_of_method method ^ "Skipped", 0)
@@ -157,7 +156,7 @@
                   ((K (encode_str (nickname_of_thm th)), stature), th)
                 val facts =
                   facts
-                  |> map (fst #> nickify)
+                  |> map (get #> nickify)
                   |> maybe_instantiate_inducts ctxt hyp_ts concl_t
                   |> take (the max_facts)
                 val ctxt = ctxt |> set_file_name method prob_dir_name
@@ -166,12 +165,12 @@
                 val ok = if is_none outcome then 1 else 0
               in (str_of_result method facts res, ok) end
           val ress =
-            [fn () => prove MePoN mepo_facts,
-             fn () => prove MaSh_IsarN mash_isar_facts,
-             fn () => prove MaSh_ProverN mash_prover_facts,
-             fn () => prove MeSh_IsarN mesh_isar_facts,
-             fn () => prove MeSh_ProverN mesh_prover_facts,
-             fn () => prove IsarN isar_facts]
+            [fn () => prove MePoN fst mepo_facts,
+             fn () => prove MaSh_IsarN fst mash_isar_facts,
+             fn () => prove MaSh_ProverN fst mash_prover_facts,
+             fn () => prove MeSh_IsarN I mesh_isar_facts,
+             fn () => prove MeSh_ProverN I mesh_prover_facts,
+             fn () => prove IsarN I isar_facts]
             |> (* Par_List. *) map (fn f => f ())
         in
           "Goal " ^ string_of_int j ^ ": " ^ name :: map fst ress
--- a/src/HOL/TPTP/mash_export.ML	Thu Jan 17 23:29:17 2013 +0100
+++ b/src/HOL/TPTP/mash_export.ML	Thu Jan 17 23:29:22 2013 +0100
@@ -213,10 +213,10 @@
       let
         val (name, mash_suggs) =
           extract_suggestions mash_line
-          ||> (map fst #> weight_mash_facts)
+          ||> weight_mash_facts
         val (name', mepo_suggs) =
           extract_suggestions mepo_line
-          ||> (map fst #> weight_mash_facts)
+          ||> weight_mepo_facts
         val _ = if name = name' then () else error "Input files out of sync."
         val mess =
           [(mepo_weight, (mepo_suggs, [])),
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jan 17 23:29:17 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jan 17 23:29:22 2013 +0100
@@ -29,7 +29,7 @@
   val unencode_str : string -> string
   val unencode_strs : string -> string list
   val encode_features : (string * real) list -> string
-  val extract_suggestions : string -> string * (string * real) list
+  val extract_suggestions : string -> string * string list
 
   structure MaSh:
   sig
@@ -41,14 +41,12 @@
       Proof.context -> bool -> (string * string list) list -> unit
     val suggest :
       Proof.context -> bool -> bool -> int
-      -> string list * (string * real) list * string list
-      -> (string * real) list
+      -> string list * (string * real) list * string list -> string list
   end
 
   val mash_unlearn : Proof.context -> unit
   val nickname_of_thm : thm -> string
-  val find_suggested_facts :
-    (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
+  val find_suggested_facts : string list -> ('b * thm) list -> ('b * thm) list
   val mesh_facts :
     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
     -> 'a list
@@ -69,8 +67,8 @@
     -> bool * string list
   val weight_mash_facts : 'a list -> ('a * real) list
   val find_mash_suggestions :
-    int -> (Symtab.key * 'a) list -> ('b * thm) list -> ('b * thm) list
-    -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
+    int -> string list -> ('b * thm) list -> ('b * thm) list -> ('b * thm) list
+    -> ('b * thm) list * ('b * thm) list
   val mash_suggested_facts :
     Proof.context -> params -> string -> int -> term list -> term -> fact list
     -> fact list * fact list
@@ -219,11 +217,13 @@
   (if learn_hints orelse null hints then "" else "; " ^ encode_strs hints) ^
   "\n"
 
+(* The weights currently returned by "mash.py" are too spaced out to make any
+   sense. *)
 fun extract_suggestion sugg =
   case space_explode "=" sugg of
     [name, weight] =>
-    SOME (unencode_str name, Real.fromString weight |> the_default 1.0)
-  | [name] => SOME (unencode_str name, 1.0)
+    SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
+  | [name] => SOME (unencode_str name (* , 1.0 *))
   | _ => NONE
 
 fun extract_suggestions line =
@@ -436,10 +436,8 @@
 fun find_suggested_facts suggs facts =
   let
     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
-    val tab = Symtab.empty |> fold add_fact facts
-    fun find_sugg (name, weight) =
-      Symtab.lookup tab name |> Option.map (rpair weight)
-  in map_filter find_sugg suggs end
+    val tab = fold add_fact facts Symtab.empty
+  in map_filter (Symtab.lookup tab) suggs end
 
 fun scaled_avg [] = 0
   | scaled_avg xs =
@@ -776,11 +774,7 @@
 fun find_mash_suggestions _ [] _ _ raw_unknown = ([], raw_unknown)
   | find_mash_suggestions max_facts suggs facts chained raw_unknown =
     let
-      val raw_mash =
-        facts |> find_suggested_facts suggs
-              (* The weights currently returned by "mash.py" are too spaced out
-                 to make any sense. *)
-              |> map fst
+      val raw_mash = find_suggested_facts suggs facts
       val unknown_chained =
         inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
       val proximity =
@@ -814,9 +808,8 @@
                 chained |> filter (is_fact_in_graph access_G snd)
                         |> map (nickname_of_thm o snd)
             in
-              (access_G,
-               MaSh.suggest ctxt overlord learn max_facts
-                            (parents, feats, hints))
+              (access_G, MaSh.suggest ctxt overlord learn max_facts
+                                      (parents, feats, hints))
             end)
     val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
   in find_mash_suggestions max_facts suggs facts chained unknown end
@@ -1079,7 +1072,7 @@
 
 (* Generate more suggestions than requested, because some might be thrown out
    later for various reasons. *)
-fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts div 2)
+fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
 
 val mepo_weight = 0.5
 val mash_weight = 0.5