src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
changeset 48311 3c4e10606567
parent 48309 42c05a6c6c1e
child 48312 b40722a81ac9
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -19,6 +19,8 @@
   val escape_metas : string list -> string
   val unescape_meta : string -> string
   val unescape_metas : string -> string list
+  val extract_query : string -> string * string list
+  val suggested_facts : string list -> ('a * thm) list -> ('a * thm) list
   val all_non_tautological_facts_of :
     theory -> status Termtab.table -> fact list
   val theory_ord : theory * theory -> order
@@ -36,12 +38,12 @@
     Proof.context -> (string * string list * string list * string list) list
     -> unit
   val mash_DEL : Proof.context -> string list -> string list -> unit
-  val mash_SUGGEST : Proof.context -> string list -> string list -> string list
+  val mash_QUERY : Proof.context -> string list * string list -> string list
   val mash_reset : Proof.context -> unit
   val mash_can_suggest_facts : Proof.context -> bool
   val mash_suggest_facts :
-    Proof.context -> params -> string -> int -> term list -> term -> fact list
-    -> fact list * fact list
+    Proof.context -> params -> string -> term list -> term -> fact list
+    -> fact list
   val mash_can_learn_thy : Proof.context -> theory -> bool
   val mash_learn_thy : Proof.context -> theory -> real -> unit
   val mash_learn_proof : Proof.context -> theory -> term -> thm list -> unit
@@ -92,6 +94,17 @@
 val unescape_meta = unmeta_chars [] o String.explode
 val unescape_metas = map unescape_meta o space_explode " "
 
+val explode_suggs =
+  space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
+fun extract_query line =
+  case space_explode ":" line of
+    [goal_name, suggs] => (unescape_meta goal_name, explode_suggs suggs)
+  | _ => ("", explode_suggs line)
+
+fun find_suggested facts sugg =
+  find_first (fn (_, th) => Thm.get_name_hint th = sugg) facts
+fun suggested_facts suggs facts = map_filter (find_suggested facts) suggs
+
 val thy_feature_prefix = "y_"
 
 val thy_feature_name_of = prefix thy_feature_prefix
@@ -278,6 +291,32 @@
 
 (*** Low-level communication with MaSh ***)
 
+fun run_mash ctxt save write_cmds read_preds =
+  let
+    val temp_dir = getenv "ISABELLE_TMP"
+    val serial = serial_string ()
+    val cmd_file = temp_dir ^ "/mash_commands." ^ serial
+    val cmd_path = Path.explode cmd_file
+    val pred_file = temp_dir ^ "/mash_preds." ^ serial
+    val log_file = temp_dir ^ "/mash_log." ^ serial
+    val command =
+      getenv "MASH_HOME" ^ "/mash.py --inputFile " ^ cmd_file ^
+      " --outputDir " ^ mash_dir () ^ " --predictions " ^ pred_file ^
+      " --log " ^ log_file ^ (if save then " --saveModel" else "") ^
+      " > /dev/null"
+    val _ = File.write cmd_path ""
+    val _ = write_cmds (File.append cmd_path)
+    val _ = trace_msg ctxt (fn () => "  running " ^ command)
+    val _ = Isabelle_System.bash command
+  in read_preds (fn () => File.read_lines (Path.explode pred_file)) end
+
+fun str_of_update (fact, access, feats, deps) =
+  "! " ^ escape_meta fact ^ ": " ^ escape_metas access ^ "; " ^
+  escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
+
+fun str_of_query (access, feats) =
+  "? " ^ escape_metas access ^ "; " ^ escape_metas feats
+
 fun mash_RESET ctxt =
   let val path = mash_dir () |> Path.explode in
     trace_msg ctxt (K "MaSh RESET");
@@ -287,39 +326,20 @@
   end
 
 fun mash_ADD _ [] = ()
-  | mash_ADD ctxt records =
-    let
-      val temp_dir = getenv "ISABELLE_TMP"
-      val serial = serial_string ()
-      val cmd_file = temp_dir ^ "/mash_commands." ^ serial
-      val cmd_path = Path.explode cmd_file
-      val pred_file = temp_dir ^ "/mash_preds." ^ serial
-      val log_file = temp_dir ^ "/mash_log." ^ serial
-      val _ = File.write cmd_path ""
-      val _ =
-        trace_msg ctxt (fn () =>
-            "MaSh ADD " ^ space_implode " " (map #1 records))
-      fun append_record (fact, access, feats, deps) =
-        "! " ^ escape_meta fact ^ ": " ^ escape_metas access ^ "; " ^
-        escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
-        |> File.append cmd_path
-      val command =
-        getenv "MASH_HOME" ^ "/mash.py --inputFile " ^ cmd_file ^
-        " --outputDir " ^ mash_dir () ^ " --predictions " ^ pred_file ^
-        " --log " ^ log_file ^ " --saveModel > /dev/null"
-      val _ = trace_msg ctxt (fn () => "Run: " ^ command)
-      val _ = List.app append_record records
-      val _ = Isabelle_System.bash command
-    in () end
+  | mash_ADD ctxt upds =
+    (trace_msg ctxt (fn () => "MaSh ADD " ^ space_implode " " (map #1 upds));
+     run_mash ctxt true (fn append => List.app (append o str_of_update) upds)
+              (K ()))
 
 fun mash_DEL ctxt facts feats =
   trace_msg ctxt (fn () =>
       "MaSh DEL " ^ escape_metas facts ^ "; " ^ escape_metas feats)
 
-fun mash_SUGGEST ctxt access feats =
-  (trace_msg ctxt (fn () =>
-       "MaSh SUGGEST " ^ escape_metas access ^ "; " ^ escape_metas feats);
-   [])
+fun mash_QUERY ctxt (query as (_, feats)) =
+  (trace_msg ctxt (fn () => "MaSh SUGGEST " ^ space_implode " " feats);
+   run_mash ctxt false (fn append => append (str_of_query query))
+                 (fn preds => snd (extract_query (List.last (preds ()))))
+   handle List.Empty => [])
 
 
 (*** High-level communication with MaSh ***)
@@ -385,13 +405,14 @@
 fun mash_can_suggest_facts (_ : Proof.context) =
   not (Symtab.is_empty (#thy_facts (mash_get ())))
 
-fun mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts =
+fun mash_suggest_facts ctxt params prover hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
-    val access = accessibility_of thy (#thy_facts (mash_get ()))
+    val thy_facts = #thy_facts (mash_get ())
+    val access = accessibility_of thy thy_facts
     val feats = features_of thy General (concl_t :: hyp_ts)
-    val suggs = mash_SUGGEST ctxt access feats
-  in (facts, []) end
+    val suggs = mash_QUERY ctxt (access, feats)
+  in suggested_facts suggs facts end
 
 fun mash_can_learn_thy (_ : Proof.context) thy =
   not (Symtab.defined (#dirty_thys (mash_get ())) (Context.theory_name thy))
@@ -436,18 +457,18 @@
       let
         val ths = facts |> map snd
         val all_names = ths |> map Thm.get_name_hint
-        fun do_fact ((_, (_, status)), th) (prevs, records) =
+        fun do_fact ((_, (_, status)), th) (prevs, upds) =
           let
             val name = Thm.get_name_hint th
             val feats = features_of thy status [prop_of th]
             val deps = isabelle_dependencies_of all_names th
-            val record = (name, prevs, feats, deps)
-          in ([name], record :: records) end
+            val upd = (name, prevs, feats, deps)
+          in ([name], upd :: upds) end
         val parents = parent_facts thy thy_facts
-        val (_, records) = (parents, []) |> fold do_fact new_facts
+        val (_, upds) = (parents, []) |> fold do_fact new_facts
         val new_thy_facts = new_facts |> thy_facts_from_thms
         fun trans {dirty_thys, thy_facts} =
-          (mash_ADD ctxt (rev records);
+          (mash_ADD ctxt (rev upds);
            {dirty_thys = dirty_thys,
             thy_facts = thy_facts |> add_thy_facts_from_thys new_thy_facts})
       in mash_map trans end
@@ -485,8 +506,9 @@
       val iter_facts =
         iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
                                  concl_t facts
-      val (mash_facts, mash_rejected) =
-        mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts
+      val (mash_facts, mash_antifacts) =
+        facts |> mash_suggest_facts ctxt params prover hyp_ts concl_t
+              |> chop max_facts
       val mesh_facts = iter_facts (* ### *)
     in
       mesh_facts