src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57007 d3eed0518882
parent 57006 20e5b110d19b
child 57009 8cb6a5f1ae84
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
@@ -14,6 +14,7 @@
   type prover_result = Sledgehammer_Prover.prover_result
 
   val trace : bool Config.T
+  val sml : bool Config.T
   val MePoN : string
   val MaShN : string
   val MeShN : string
@@ -34,18 +35,6 @@
   val encode_features : (string list * real) list -> string
   val extract_suggestions : string -> string * string list
 
-  structure MaSh:
-  sig
-    val unlearn : Proof.context -> bool -> unit
-    val learn : Proof.context -> bool -> bool ->
-      (string * string list * string list list * string list) list -> unit
-    val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
-    val query : Proof.context -> bool -> int ->
-      (string * string list * string list list * string list) list * string list * string list *
-        (string list * real) list ->
-      string list
-  end
-
   val mash_unlearn : Proof.context -> params -> unit
   val is_mash_enabled : unit -> bool
   val nickname_of_thm : thm -> string
@@ -98,6 +87,7 @@
 open Sledgehammer_MePo
 
 val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
+val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
 
 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
 
@@ -117,12 +107,18 @@
 val relearn_isarN = "relearn_isar"
 val relearn_proverN = "relearn_prover"
 
-fun mash_model_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
-val mash_state_dir = mash_model_dir
+fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
 
+fun wipe_out_mash_state_dir () =
+  let val path = mash_state_dir () in
+    try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
+      NONE;
+    ()
+  end
 
-(*** Low-level communication with MaSh ***)
+
+(*** Low-level communication with Python version of MaSh ***)
 
 val save_models_arg = "--saveModels"
 val shutdown_server_arg = "--shutdownServer"
@@ -145,7 +141,7 @@
     val sugg_path = Path.explode sugg_file
     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
     val cmd_path = Path.explode cmd_file
-    val model_dir = File.shell_path (mash_model_dir ())
+    val model_dir = File.shell_path (mash_state_dir ())
     val command =
       "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
       \PYTHONDONTWRITEBYTECODE=y ./mash.py\
@@ -238,52 +234,65 @@
     [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
   | _ => ("", []))
 
-structure MaSh =
+structure MaSh_Py =
 struct
 
 fun shutdown ctxt overlord =
-  (trace_msg ctxt (K "MaSh shutdown");
+  (trace_msg ctxt (K "MaSh_Py shutdown");
    run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
 
 fun save ctxt overlord =
-  (trace_msg ctxt (K "MaSh save");
+  (trace_msg ctxt (K "MaSh_Py save");
    run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
 
 fun unlearn ctxt overlord =
-  let val path = mash_model_dir () in
-    trace_msg ctxt (K "MaSh unlearn");
-    shutdown ctxt overlord;
-    try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
-      NONE;
-    ()
-  end
+  (trace_msg ctxt (K "MaSh_Py unlearn");
+   shutdown ctxt overlord;
+   wipe_out_mash_state_dir ())
 
 fun learn _ _ _ [] = ()
   | learn ctxt overlord save learns =
-    let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
-      (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
-       run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
-         (K ()))
-    end
+    (trace_msg ctxt (fn () =>
+       let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
+         "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
+       end);
+     run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
+       (K ()))
 
 fun relearn _ _ _ [] = ()
   | relearn ctxt overlord save relearns =
-    (trace_msg ctxt (fn () => "MaSh relearn " ^
+    (trace_msg ctxt (fn () => "MaSh_Py relearn " ^
        elide_string 1000 (space_implode " " (map #1 relearns)));
      run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
        (relearns, str_of_relearn) (K ()))
 
 fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
-  (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
+  (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats);
    run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs =>
-     (case suggs () of
-       [] => []
-     | suggs => snd (extract_suggestions (List.last suggs))))
+     (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs))))
    handle List.Empty => [])
 
 end;
 
 
+(*** Standard ML version of MaSh ***)
+
+structure MaSh_SML =
+struct
+
+fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
+    max_suggs (query as (_, _, _, feats)) =
+  (trace_msg ctxt (fn () =>
+     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
+       "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
+       "MaSh_SML query " ^ encode_features feats
+     end);
+   (* Implementation missing *)
+   [])
+
+end;
+
+
 (*** Middle-level communication with MaSh ***)
 
 datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
@@ -364,7 +373,9 @@
                (try_graph ctxt "loading state" Graph.empty (fn () =>
                   fold add_node node_lines Graph.empty),
                 length node_lines)
-             | LESS => (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
+             | LESS =>
+               (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+                else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
              | GREATER => raise FILE_VERSION_TOO_NEW ())
          in
            trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
@@ -411,7 +422,9 @@
 
 fun clear_state ctxt overlord =
   (* "unlearn" also removes the state file *)
-  Synchronized.change global_state (fn _ => (MaSh.unlearn ctxt overlord; (false, empty_state)))
+  Synchronized.change global_state (fn _ =>
+    (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+     else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
 
 end
 
@@ -953,10 +966,8 @@
       [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
        (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
        (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
-    val unknown =
-      raw_unknown
-      |> fold (subtract (eq_snd Thm.eq_thm_prop))
-              [unknown_chained, unknown_proximate]
+    val unknown = raw_unknown
+      |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate]
   in
     (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown)
   end
@@ -964,6 +975,13 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
+fun learn_of_graph graph =
+  let
+    fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
+  in
+    Graph.schedule sched graph
+  end
+
 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
@@ -983,31 +1001,34 @@
 
     val (access_G, suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
-          if Graph.is_empty access_G then
-            (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
-          else
-            let
-              val parents = maximal_wrt_access_graph access_G facts
-              val goal_feats =
-                features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
-              val chained_feats = chained
-                |> map (rpair 1.0)
-                |> map (chained_or_extra_features_of chained_feature_factor)
-                |> rpair [] |-> fold (union (eq_fst (op =)))
-              val extra_feats = facts
-                |> take (Int.max (0, num_extra_feature_facts - length chained))
-                |> filter fact_has_right_theory
-                |> weight_facts_steeply
-                |> map (chained_or_extra_features_of extra_feature_factor)
-                |> rpair [] |-> fold (union (eq_fst (op =)))
-              val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
-                |> debug ? sort (Real.compare o swap o pairself snd)
-              val hints = chained
-                |> filter (is_fact_in_graph access_G o snd)
-                |> map (nickname_of_thm o snd)
-            in
-              (access_G, MaSh.query ctxt overlord max_facts ([], hints, parents, feats))
-            end)
+        if Graph.is_empty access_G then
+          (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
+        else
+          let
+            val parents = maximal_wrt_access_graph access_G facts
+            val goal_feats =
+              features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+            val chained_feats = chained
+              |> map (rpair 1.0)
+              |> map (chained_or_extra_features_of chained_feature_factor)
+              |> rpair [] |-> fold (union (eq_fst (op =)))
+            val extra_feats = facts
+              |> take (Int.max (0, num_extra_feature_facts - length chained))
+              |> filter fact_has_right_theory
+              |> weight_facts_steeply
+              |> map (chained_or_extra_features_of extra_feature_factor)
+              |> rpair [] |-> fold (union (eq_fst (op =)))
+            val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+              |> debug ? sort (Real.compare o swap o pairself snd)
+            val hints = chained
+              |> filter (is_fact_in_graph access_G o snd)
+              |> map (nickname_of_thm o snd)
+          in
+            (access_G,
+             (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
+              else MaSh_Py.query ctxt overlord)
+               max_facts ([], hints, parents, feats))
+          end)
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
     find_mash_suggestions ctxt max_facts suggs facts chained unknown
@@ -1060,14 +1081,15 @@
         val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
       in
         peek_state ctxt overlord (fn {access_G, ...} =>
-            let
-              val parents = maximal_wrt_access_graph access_G facts
-              val deps =
-                used_ths |> filter (is_fact_in_graph access_G)
-                         |> map nickname_of_thm
-            in
-              MaSh.learn ctxt overlord true [("", parents, feats, deps)]
-            end);
+          let
+            val parents = maximal_wrt_access_graph access_G facts
+            val deps =
+              used_ths |> filter (is_fact_in_graph access_G)
+                       |> map nickname_of_thm
+          in
+            if Config.get ctxt sml then () (* TODO: implement *)
+            else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
+          end);
         (true, "")
       end)
   else
@@ -1126,8 +1148,11 @@
                   (false, SOME names, []) => SOME (map #1 learns @ names)
                 | _ => NONE)
             in
-              MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
-              MaSh.relearn ctxt overlord save relearns;
+              if Config.get ctxt sml then
+                ()
+              else
+                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
+                 MaSh_Py.relearn ctxt overlord save relearns);
               {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
             end
 
@@ -1363,7 +1388,7 @@
            |> Par_List.map (apsnd (fn f => f ()))
       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
     in
-      if save then MaSh.save ctxt overlord else ();
+      if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
       (case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1372,7 +1397,8 @@
     end
 
 fun kill_learners ctxt ({overlord, ...} : params) =
-  (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
+  (Async_Manager.kill_threads MaShN "learner";
+   if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
 
 fun running_learners () = Async_Manager.running_threads MaShN "learner"