started work on MaSh/SML
authorblanchet
Mon May 19 23:43:53 2014 +0200 (2014-05-19)
changeset 57007d3eed0518882
parent 57006 20e5b110d19b
child 57008 10f68b83b474
started work on MaSh/SML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.3 @@ -14,6 +14,7 @@
     1.4    type prover_result = Sledgehammer_Prover.prover_result
     1.5  
     1.6    val trace : bool Config.T
     1.7 +  val sml : bool Config.T
     1.8    val MePoN : string
     1.9    val MaShN : string
    1.10    val MeShN : string
    1.11 @@ -34,18 +35,6 @@
    1.12    val encode_features : (string list * real) list -> string
    1.13    val extract_suggestions : string -> string * string list
    1.14  
    1.15 -  structure MaSh:
    1.16 -  sig
    1.17 -    val unlearn : Proof.context -> bool -> unit
    1.18 -    val learn : Proof.context -> bool -> bool ->
    1.19 -      (string * string list * string list list * string list) list -> unit
    1.20 -    val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
    1.21 -    val query : Proof.context -> bool -> int ->
    1.22 -      (string * string list * string list list * string list) list * string list * string list *
    1.23 -        (string list * real) list ->
    1.24 -      string list
    1.25 -  end
    1.26 -
    1.27    val mash_unlearn : Proof.context -> params -> unit
    1.28    val is_mash_enabled : unit -> bool
    1.29    val nickname_of_thm : thm -> string
    1.30 @@ -98,6 +87,7 @@
    1.31  open Sledgehammer_MePo
    1.32  
    1.33  val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    1.34 +val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
    1.35  
    1.36  fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    1.37  
    1.38 @@ -117,12 +107,18 @@
    1.39  val relearn_isarN = "relearn_isar"
    1.40  val relearn_proverN = "relearn_prover"
    1.41  
    1.42 -fun mash_model_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
    1.43 -val mash_state_dir = mash_model_dir
    1.44 +fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
    1.45  fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
    1.46  
    1.47 +fun wipe_out_mash_state_dir () =
    1.48 +  let val path = mash_state_dir () in
    1.49 +    try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
    1.50 +      NONE;
    1.51 +    ()
    1.52 +  end
    1.53  
    1.54 -(*** Low-level communication with MaSh ***)
    1.55 +
    1.56 +(*** Low-level communication with Python version of MaSh ***)
    1.57  
    1.58  val save_models_arg = "--saveModels"
    1.59  val shutdown_server_arg = "--shutdownServer"
    1.60 @@ -145,7 +141,7 @@
    1.61      val sugg_path = Path.explode sugg_file
    1.62      val cmd_file = temp_dir ^ "/mash_commands" ^ serial
    1.63      val cmd_path = Path.explode cmd_file
    1.64 -    val model_dir = File.shell_path (mash_model_dir ())
    1.65 +    val model_dir = File.shell_path (mash_state_dir ())
    1.66      val command =
    1.67        "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
    1.68        \PYTHONDONTWRITEBYTECODE=y ./mash.py\
    1.69 @@ -238,52 +234,65 @@
    1.70      [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
    1.71    | _ => ("", []))
    1.72  
    1.73 -structure MaSh =
    1.74 +structure MaSh_Py =
    1.75  struct
    1.76  
    1.77  fun shutdown ctxt overlord =
    1.78 -  (trace_msg ctxt (K "MaSh shutdown");
    1.79 +  (trace_msg ctxt (K "MaSh_Py shutdown");
    1.80     run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
    1.81  
    1.82  fun save ctxt overlord =
    1.83 -  (trace_msg ctxt (K "MaSh save");
    1.84 +  (trace_msg ctxt (K "MaSh_Py save");
    1.85     run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
    1.86  
    1.87  fun unlearn ctxt overlord =
    1.88 -  let val path = mash_model_dir () in
    1.89 -    trace_msg ctxt (K "MaSh unlearn");
    1.90 -    shutdown ctxt overlord;
    1.91 -    try (File.fold_dir (fn file => fn _ => try File.rm (Path.append path (Path.basic file))) path)
    1.92 -      NONE;
    1.93 -    ()
    1.94 -  end
    1.95 +  (trace_msg ctxt (K "MaSh_Py unlearn");
    1.96 +   shutdown ctxt overlord;
    1.97 +   wipe_out_mash_state_dir ())
    1.98  
    1.99  fun learn _ _ _ [] = ()
   1.100    | learn ctxt overlord save learns =
   1.101 -    let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
   1.102 -      (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
   1.103 -       run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
   1.104 -         (K ()))
   1.105 -    end
   1.106 +    (trace_msg ctxt (fn () =>
   1.107 +       let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
   1.108 +         "MaSh_Py learn" ^ (if names = "" then "" else " " ^ names)
   1.109 +       end);
   1.110 +     run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
   1.111 +       (K ()))
   1.112  
   1.113  fun relearn _ _ _ [] = ()
   1.114    | relearn ctxt overlord save relearns =
   1.115 -    (trace_msg ctxt (fn () => "MaSh relearn " ^
   1.116 +    (trace_msg ctxt (fn () => "MaSh_Py relearn " ^
   1.117         elide_string 1000 (space_implode " " (map #1 relearns)));
   1.118       run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
   1.119         (relearns, str_of_relearn) (K ()))
   1.120  
   1.121  fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
   1.122 -  (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
   1.123 +  (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats);
   1.124     run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs =>
   1.125 -     (case suggs () of
   1.126 -       [] => []
   1.127 -     | suggs => snd (extract_suggestions (List.last suggs))))
   1.128 +     (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs))))
   1.129     handle List.Empty => [])
   1.130  
   1.131  end;
   1.132  
   1.133  
   1.134 +(*** Standard ML version of MaSh ***)
   1.135 +
   1.136 +structure MaSh_SML =
   1.137 +struct
   1.138 +
   1.139 +fun learn_and_query ctxt (learns : (string * string list * string list list * string list) list)
   1.140 +    max_suggs (query as (_, _, _, feats)) =
   1.141 +  (trace_msg ctxt (fn () =>
   1.142 +     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
   1.143 +       "MaSh_SML learn" ^ (if names = "" then "" else " " ^ names) ^ "\n" ^
   1.144 +       "MaSh_SML query " ^ encode_features feats
   1.145 +     end);
   1.146 +   (* Implementation missing *)
   1.147 +   [])
   1.148 +
   1.149 +end;
   1.150 +
   1.151 +
   1.152  (*** Middle-level communication with MaSh ***)
   1.153  
   1.154  datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
   1.155 @@ -364,7 +373,9 @@
   1.156                 (try_graph ctxt "loading state" Graph.empty (fn () =>
   1.157                    fold add_node node_lines Graph.empty),
   1.158                  length node_lines)
   1.159 -             | LESS => (MaSh.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
   1.160 +             | LESS =>
   1.161 +               (if Config.get ctxt sml then wipe_out_mash_state_dir ()
   1.162 +                else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
   1.163               | GREATER => raise FILE_VERSION_TOO_NEW ())
   1.164           in
   1.165             trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
   1.166 @@ -411,7 +422,9 @@
   1.167  
   1.168  fun clear_state ctxt overlord =
   1.169    (* "unlearn" also removes the state file *)
   1.170 -  Synchronized.change global_state (fn _ => (MaSh.unlearn ctxt overlord; (false, empty_state)))
   1.171 +  Synchronized.change global_state (fn _ =>
   1.172 +    (if Config.get ctxt sml then wipe_out_mash_state_dir ()
   1.173 +     else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
   1.174  
   1.175  end
   1.176  
   1.177 @@ -953,10 +966,8 @@
   1.178        [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   1.179         (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
   1.180         (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
   1.181 -    val unknown =
   1.182 -      raw_unknown
   1.183 -      |> fold (subtract (eq_snd Thm.eq_thm_prop))
   1.184 -              [unknown_chained, unknown_proximate]
   1.185 +    val unknown = raw_unknown
   1.186 +      |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate]
   1.187    in
   1.188      (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown)
   1.189    end
   1.190 @@ -964,6 +975,13 @@
   1.191  fun add_const_counts t =
   1.192    fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
   1.193  
   1.194 +fun learn_of_graph graph =
   1.195 +  let
   1.196 +    fun sched parents (name, (kind, feats, deps)) = (name, map fst parents, feats, deps)
   1.197 +  in
   1.198 +    Graph.schedule sched graph
   1.199 +  end
   1.200 +
   1.201  fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   1.202    let
   1.203      val thy = Proof_Context.theory_of ctxt
   1.204 @@ -983,31 +1001,34 @@
   1.205  
   1.206      val (access_G, suggs) =
   1.207        peek_state ctxt overlord (fn {access_G, ...} =>
   1.208 -          if Graph.is_empty access_G then
   1.209 -            (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   1.210 -          else
   1.211 -            let
   1.212 -              val parents = maximal_wrt_access_graph access_G facts
   1.213 -              val goal_feats =
   1.214 -                features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
   1.215 -              val chained_feats = chained
   1.216 -                |> map (rpair 1.0)
   1.217 -                |> map (chained_or_extra_features_of chained_feature_factor)
   1.218 -                |> rpair [] |-> fold (union (eq_fst (op =)))
   1.219 -              val extra_feats = facts
   1.220 -                |> take (Int.max (0, num_extra_feature_facts - length chained))
   1.221 -                |> filter fact_has_right_theory
   1.222 -                |> weight_facts_steeply
   1.223 -                |> map (chained_or_extra_features_of extra_feature_factor)
   1.224 -                |> rpair [] |-> fold (union (eq_fst (op =)))
   1.225 -              val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
   1.226 -                |> debug ? sort (Real.compare o swap o pairself snd)
   1.227 -              val hints = chained
   1.228 -                |> filter (is_fact_in_graph access_G o snd)
   1.229 -                |> map (nickname_of_thm o snd)
   1.230 -            in
   1.231 -              (access_G, MaSh.query ctxt overlord max_facts ([], hints, parents, feats))
   1.232 -            end)
   1.233 +        if Graph.is_empty access_G then
   1.234 +          (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   1.235 +        else
   1.236 +          let
   1.237 +            val parents = maximal_wrt_access_graph access_G facts
   1.238 +            val goal_feats =
   1.239 +              features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
   1.240 +            val chained_feats = chained
   1.241 +              |> map (rpair 1.0)
   1.242 +              |> map (chained_or_extra_features_of chained_feature_factor)
   1.243 +              |> rpair [] |-> fold (union (eq_fst (op =)))
   1.244 +            val extra_feats = facts
   1.245 +              |> take (Int.max (0, num_extra_feature_facts - length chained))
   1.246 +              |> filter fact_has_right_theory
   1.247 +              |> weight_facts_steeply
   1.248 +              |> map (chained_or_extra_features_of extra_feature_factor)
   1.249 +              |> rpair [] |-> fold (union (eq_fst (op =)))
   1.250 +            val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
   1.251 +              |> debug ? sort (Real.compare o swap o pairself snd)
   1.252 +            val hints = chained
   1.253 +              |> filter (is_fact_in_graph access_G o snd)
   1.254 +              |> map (nickname_of_thm o snd)
   1.255 +          in
   1.256 +            (access_G,
   1.257 +             (if Config.get ctxt sml then MaSh_SML.learn_and_query ctxt (learn_of_graph access_G)
   1.258 +              else MaSh_Py.query ctxt overlord)
   1.259 +               max_facts ([], hints, parents, feats))
   1.260 +          end)
   1.261      val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   1.262    in
   1.263      find_mash_suggestions ctxt max_facts suggs facts chained unknown
   1.264 @@ -1060,14 +1081,15 @@
   1.265          val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
   1.266        in
   1.267          peek_state ctxt overlord (fn {access_G, ...} =>
   1.268 -            let
   1.269 -              val parents = maximal_wrt_access_graph access_G facts
   1.270 -              val deps =
   1.271 -                used_ths |> filter (is_fact_in_graph access_G)
   1.272 -                         |> map nickname_of_thm
   1.273 -            in
   1.274 -              MaSh.learn ctxt overlord true [("", parents, feats, deps)]
   1.275 -            end);
   1.276 +          let
   1.277 +            val parents = maximal_wrt_access_graph access_G facts
   1.278 +            val deps =
   1.279 +              used_ths |> filter (is_fact_in_graph access_G)
   1.280 +                       |> map nickname_of_thm
   1.281 +          in
   1.282 +            if Config.get ctxt sml then () (* TODO: implement *)
   1.283 +            else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
   1.284 +          end);
   1.285          (true, "")
   1.286        end)
   1.287    else
   1.288 @@ -1126,8 +1148,11 @@
   1.289                    (false, SOME names, []) => SOME (map #1 learns @ names)
   1.290                  | _ => NONE)
   1.291              in
   1.292 -              MaSh.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.293 -              MaSh.relearn ctxt overlord save relearns;
   1.294 +              if Config.get ctxt sml then
   1.295 +                ()
   1.296 +              else
   1.297 +                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.298 +                 MaSh_Py.relearn ctxt overlord save relearns);
   1.299                {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
   1.300              end
   1.301  
   1.302 @@ -1363,7 +1388,7 @@
   1.303             |> Par_List.map (apsnd (fn f => f ()))
   1.304        val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
   1.305      in
   1.306 -      if save then MaSh.save ctxt overlord else ();
   1.307 +      if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
   1.308        (case (fact_filter, mess) of
   1.309          (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
   1.310          [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
   1.311 @@ -1372,7 +1397,8 @@
   1.312      end
   1.313  
   1.314  fun kill_learners ctxt ({overlord, ...} : params) =
   1.315 -  (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
   1.316 +  (Async_Manager.kill_threads MaShN "learner";
   1.317 +   if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
   1.318  
   1.319  fun running_learners () = Async_Manager.running_threads MaShN "learner"
   1.320