src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50858 42c5fcc6f28f
parent 50857 80768e28c9ee
child 50860 e32a283b8ce0
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jan 13 12:15:43 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jan 13 12:15:44 2013 +0100
@@ -40,7 +40,7 @@
     val relearn :
       Proof.context -> bool -> (string * string list) list -> unit
     val suggest :
-      Proof.context -> bool -> int
+      Proof.context -> bool -> bool -> int
       -> string list * (string * real) list * string list
       -> (string * real) list
   end
@@ -196,6 +196,10 @@
 val unencode_strs =
   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
 
+fun freshish_name () =
+  Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
+  serial_string ()
+
 fun encode_feature (name, weight) =
   encode_str name ^
   (if Real.== (weight, 1.0) then "" else "=" ^ Real.toString weight)
@@ -209,9 +213,12 @@
 fun str_of_relearn (name, deps) =
   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
 
-fun str_of_query (parents, feats, hints) =
+fun str_of_query learn_hints (parents, feats, hints) =
+  (if not learn_hints orelse null hints then ""
+   else str_of_learn (freshish_name (), parents, feats, hints)) ^
   "? " ^ encode_strs parents ^ "; " ^ encode_features feats ^
-  (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
+  (if learn_hints orelse null hints then "" else "; " ^ encode_strs hints) ^
+  "\n"
 
 fun extract_suggestion sugg =
   case space_explode "=" sugg of
@@ -250,10 +257,10 @@
          elide_string 1000 (space_implode " " (map #1 relearns)));
      run_mash_tool ctxt overlord true 0 (relearns, str_of_relearn) (K ()))
 
-fun suggest ctxt overlord max_suggs (query as (_, feats, hints)) =
+fun suggest ctxt overlord learn_hints max_suggs (query as (_, feats, hints)) =
   (trace_msg ctxt (fn () => "MaSh suggest " ^ encode_features feats);
-   run_mash_tool ctxt overlord (not (null hints)) max_suggs
-       ([query], str_of_query)
+   run_mash_tool ctxt overlord (learn_hints andalso not (null hints))
+       max_suggs ([query], str_of_query learn_hints)
        (fn suggs =>
            case suggs () of
              [] => []
@@ -778,16 +785,14 @@
       facts |> sort (thm_ord o pairself snd o swap)
             |> take max_proximity_facts
     val mess =
-      [(0.80 (* FUDGE *), (map (rpair 1.0) chained, [])),
-       (0.16 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
-       (0.04 (* FUDGE *), (weight_proximity_facts proximity, []))]
+      [(0.8 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
+       (0.2 (* FUDGE *), (weight_proximity_facts proximity, []))]
     val unknown =
-      raw_unknown
-      |> fold (subtract (Thm.eq_thm_prop o pairself snd)) [chained, proximity]
+      raw_unknown |> subtract (Thm.eq_thm_prop o pairself snd) proximity
   in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
 
-fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
-                         concl_t facts =
+fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
+                         hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
@@ -803,7 +808,8 @@
               val hints = map (nickname_of_thm o snd) chained
             in
               (access_G,
-               MaSh.suggest ctxt overlord max_facts (parents, feats, hints))
+               MaSh.suggest ctxt overlord learn max_facts
+                            (parents, feats, hints))
             end)
     val unknown = facts |> filter_out (is_fact_in_graph access_G)
   in find_mash_suggestions max_facts suggs facts chained unknown end
@@ -840,10 +846,6 @@
     val desc = ("Machine learner for Sledgehammer", "")
   in Async_Manager.launch MaShN birth_time death_time desc task end
 
-fun freshish_name () =
-  Date.fmt ".%Y_%m_%d_%H_%M_%S__" (Date.fromTimeLocal (Time.now ())) ^
-  serial_string ()
-
 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
                      used_ths =
   if is_smt_prover ctxt prover then