src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
changeset 48309 42c05a6c6c1e
parent 48308 89674e5a4d35
child 48311 3c4e10606567
equal deleted inserted replaced
48308:89674e5a4d35 48309:42c05a6c6c1e
    42   val mash_suggest_facts :
    42   val mash_suggest_facts :
    43     Proof.context -> params -> string -> int -> term list -> term -> fact list
    43     Proof.context -> params -> string -> int -> term list -> term -> fact list
    44     -> fact list * fact list
    44     -> fact list * fact list
    45   val mash_can_learn_thy : Proof.context -> theory -> bool
    45   val mash_can_learn_thy : Proof.context -> theory -> bool
    46   val mash_learn_thy : Proof.context -> theory -> real -> unit
    46   val mash_learn_thy : Proof.context -> theory -> real -> unit
    47   val mash_learn_proof : Proof.context -> term -> thm list -> unit
    47   val mash_learn_proof : Proof.context -> theory -> term -> thm list -> unit
    48   val relevant_facts :
    48   val relevant_facts :
    49     Proof.context -> params -> string -> int -> fact_override -> term list
    49     Proof.context -> params -> string -> int -> fact_override -> term list
    50     -> term -> fact list -> fact list
    50     -> term -> fact list -> fact list
    51 end;
    51 end;
    52 
    52 
    62 
    62 
    63 val trace =
    63 val trace =
    64   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
    64   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
    65 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    65 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    66 
    66 
    67 val mash_dir = "mash"
    67 fun mash_dir () =
    68 val model_file = "model"
    68   getenv "ISABELLE_HOME_USER" ^ "/mash"
    69 val state_file = "state"
    69   |> tap (fn dir => Isabelle_System.mkdir (Path.explode dir))
    70 
    70 fun mash_state_path () = mash_dir () ^ "/state" |> Path.explode
    71 fun mk_path file =
       
    72   getenv "ISABELLE_HOME_USER" ^ "/" ^ mash_dir ^ "/" ^ file
       
    73   |> Path.explode
       
    74 
       
    75 
    71 
    76 (*** Isabelle helpers ***)
    72 (*** Isabelle helpers ***)
    77 
    73 
    78 fun meta_char c =
    74 fun meta_char c =
    79   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
    75   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   281 
   277 
   282 
   278 
   283 (*** Low-level communication with MaSh ***)
   279 (*** Low-level communication with MaSh ***)
   284 
   280 
   285 fun mash_RESET ctxt =
   281 fun mash_RESET ctxt =
   286   (trace_msg ctxt (K "MaSh RESET"); File.write (mk_path model_file) "")
   282   let val path = mash_dir () |> Path.explode in
   287 
   283     trace_msg ctxt (K "MaSh RESET");
   288 fun mash_ADD ctxt =
   284     File.fold_dir (fn file => fn () =>
   289   let
   285                       File.rm (Path.append path (Path.basic file)))
   290     fun add_record (fact, access, feats, deps) =
   286                   path ()
   291       let
   287   end
   292         val s =
   288 
   293           escape_meta fact ^ ": " ^ escape_metas access ^ "; " ^
   289 fun mash_ADD _ [] = ()
   294           escape_metas feats ^ "; " ^ escape_metas deps
   290   | mash_ADD ctxt records =
   295       in trace_msg ctxt (fn () => "MaSh ADD " ^ s) end
   291     let
   296   in List.app add_record end
   292       val temp_dir = getenv "ISABELLE_TMP"
       
   293       val serial = serial_string ()
       
   294       val cmd_file = temp_dir ^ "/mash_commands." ^ serial
       
   295       val cmd_path = Path.explode cmd_file
       
   296       val pred_file = temp_dir ^ "/mash_preds." ^ serial
       
   297       val log_file = temp_dir ^ "/mash_log." ^ serial
       
   298       val _ = File.write cmd_path ""
       
   299       val _ =
       
   300         trace_msg ctxt (fn () =>
       
   301             "MaSh ADD " ^ space_implode " " (map #1 records))
       
   302       fun append_record (fact, access, feats, deps) =
       
   303         "! " ^ escape_meta fact ^ ": " ^ escape_metas access ^ "; " ^
       
   304         escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
       
   305         |> File.append cmd_path
       
   306       val command =
       
   307         getenv "MASH_HOME" ^ "/mash.py --inputFile " ^ cmd_file ^
       
   308         " --outputDir " ^ mash_dir () ^ " --predictions " ^ pred_file ^
       
   309         " --log " ^ log_file ^ " --saveModel > /dev/null"
       
   310       val _ = trace_msg ctxt (fn () => "Run: " ^ command)
       
   311       val _ = List.app append_record records
       
   312       val _ = Isabelle_System.bash command
       
   313     in () end
   297 
   314 
   298 fun mash_DEL ctxt facts feats =
   315 fun mash_DEL ctxt facts feats =
   299   trace_msg ctxt (fn () =>
   316   trace_msg ctxt (fn () =>
   300       "MaSh DEL " ^ escape_metas facts ^ "; " ^ escape_metas feats)
   317       "MaSh DEL " ^ escape_metas facts ^ "; " ^ escape_metas feats)
   301 
   318 
   317 
   334 
   318 local
   335 local
   319 
   336 
   320 fun mash_load (state as (true, _)) = state
   337 fun mash_load (state as (true, _)) = state
   321   | mash_load _ =
   338   | mash_load _ =
   322     let
   339     let val path = mash_state_path () in
   323       val path = mk_path state_file
       
   324       val _ = Isabelle_System.mkdir (path |> Path.dir)
       
   325     in
       
   326       (true,
   340       (true,
   327        case try File.read_lines path of
   341        case try File.read_lines path of
   328          SOME (dirty_line :: facts_lines) =>
   342          SOME (dirty_line :: facts_lines) =>
   329          let
   343          let
   330            fun dirty_thys_of_line line =
   344            fun dirty_thys_of_line line =
   340        | _ => empty_state)
   354        | _ => empty_state)
   341     end
   355     end
   342 
   356 
   343 fun mash_save ({dirty_thys, thy_facts} : mash_state) =
   357 fun mash_save ({dirty_thys, thy_facts} : mash_state) =
   344   let
   358   let
   345     val path = mk_path state_file
   359     val path = mash_state_path ()
   346     val dirty_line = (escape_metas (Symtab.keys dirty_thys)) ^ "\n"
   360     val dirty_line = (escape_metas (Symtab.keys dirty_thys)) ^ "\n"
   347     fun fact_line_for (thy, facts) = escape_metas (thy :: facts) ^ "\n"
   361     fun fact_line_for (thy, facts) = escape_metas (thy :: facts) ^ "\n"
   348   in
   362   in
   349     File.write path dirty_line;
   363     File.write path dirty_line;
   350     Symtab.fold (fn thy_fact => fn () =>
   364     Symtab.fold (fn thy_fact => fn () =>
   361 
   375 
   362 fun mash_get () = Synchronized.change_result global_state (mash_load #> `snd)
   376 fun mash_get () = Synchronized.change_result global_state (mash_load #> `snd)
   363 
   377 
   364 fun mash_reset ctxt =
   378 fun mash_reset ctxt =
   365   Synchronized.change global_state (fn _ =>
   379   Synchronized.change global_state (fn _ =>
   366       (mash_RESET ctxt; File.write (mk_path state_file) "";
   380       (mash_RESET ctxt; File.write (mash_state_path ()) "";
   367        (true, empty_state)))
   381        (true, empty_state)))
   368 
   382 
   369 end
   383 end
   370 
   384 
   371 fun mash_can_suggest_facts (_ : Proof.context) =
   385 fun mash_can_suggest_facts (_ : Proof.context) =
   428             val feats = features_of thy status [prop_of th]
   442             val feats = features_of thy status [prop_of th]
   429             val deps = isabelle_dependencies_of all_names th
   443             val deps = isabelle_dependencies_of all_names th
   430             val record = (name, prevs, feats, deps)
   444             val record = (name, prevs, feats, deps)
   431           in ([name], record :: records) end
   445           in ([name], record :: records) end
   432         val parents = parent_facts thy thy_facts
   446         val parents = parent_facts thy thy_facts
   433         val (_, records) = (parents, []) |> fold_rev do_fact new_facts
   447         val (_, records) = (parents, []) |> fold do_fact new_facts
   434         val new_thy_facts = new_facts |> thy_facts_from_thms
   448         val new_thy_facts = new_facts |> thy_facts_from_thms
   435         fun trans {dirty_thys, thy_facts} =
   449         fun trans {dirty_thys, thy_facts} =
   436           (mash_ADD ctxt records;
   450           (mash_ADD ctxt (rev records);
   437            {dirty_thys = dirty_thys,
   451            {dirty_thys = dirty_thys,
   438             thy_facts = thy_facts |> add_thy_facts_from_thys new_thy_facts})
   452             thy_facts = thy_facts |> add_thy_facts_from_thys new_thy_facts})
   439       in mash_map trans end
   453       in mash_map trans end
   440   end
   454   end
   441 
   455 
   442 fun mash_learn_proof ctxt t ths =
   456 fun mash_learn_proof ctxt thy t ths =
   443   let val thy = Proof_Context.theory_of ctxt in
   457   mash_map (fn state as {dirty_thys, thy_facts} =>
   444     mash_map (fn state as {dirty_thys, thy_facts} =>
   458     let val deps = ths |> map Thm.get_name_hint in
   445       let val deps = ths |> map Thm.get_name_hint in
   459       if forall (is_fact_in_thy_facts thy_facts) deps then
   446         if forall (is_fact_in_thy_facts thy_facts) deps then
   460         let
   447           let
   461           val fact = ATP_Util.timestamp () (* should be fairly fresh *)
   448             val fact = ATP_Util.timestamp () (* should be fairly fresh *)
   462           val access = accessibility_of thy thy_facts
   449             val access = accessibility_of thy thy_facts
   463           val feats = features_of thy General [t]
   450             val feats = features_of thy General [t]
   464         in
   451           in
   465           mash_ADD ctxt [(fact, access, feats, deps)];
   452             mash_ADD ctxt [(fact, access, feats, deps)];
   466           {dirty_thys = dirty_thys, thy_facts = thy_facts}
   453             {dirty_thys = dirty_thys, thy_facts = thy_facts}
   467         end
   454           end
   468       else
   455         else
   469         state
   456           state
   470     end)
   457       end)
       
   458   end
       
   459 
   471 
   460 fun relevant_facts ctxt params prover max_facts
   472 fun relevant_facts ctxt params prover max_facts
   461         ({add, only, ...} : fact_override) hyp_ts concl_t facts =
   473         ({add, only, ...} : fact_override) hyp_ts concl_t facts =
   462   if only then
   474   if only then
   463     facts
   475     facts