src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
changeset 32864 a226f29d4bdc
parent 32819 004b251ac927
child 32868 5f1805c6ef2a
equal deleted inserted replaced
32852:7c8bc41ce810 32864:a226f29d4bdc
   289   |> the_default (hd (space_explode " " (AtpManager.get_atps ())))
   289   |> the_default (hd (space_explode " " (AtpManager.get_atps ())))
   290   |> (fn name => (name, the (AtpManager.get_prover name thy)))
   290   |> (fn name => (name, the (AtpManager.get_prover name thy)))
   291 
   291 
   292 local
   292 local
   293 
   293 
   294 fun safe init done f x =
       
   295   let
       
   296     val y = init x
       
   297     val z = Exn.capture f y
       
   298     val _ = done y
       
   299   in Exn.release z end
       
   300 
       
   301 fun init_sh NONE = !AtpWrapper.destdir
       
   302   | init_sh (SOME path) =
       
   303       let
       
   304         (* Warning: we implicitly assume single-threaded execution here! *)
       
   305         val old = !AtpWrapper.destdir
       
   306         val _ = AtpWrapper.destdir := path
       
   307       in old end
       
   308 
       
   309 fun done_sh path = AtpWrapper.destdir := path
       
   310 
       
   311 datatype sh_result =
   294 datatype sh_result =
   312   SH_OK of int * int * string list |
   295   SH_OK of int * int * string list |
   313   SH_FAIL of int * int |
   296   SH_FAIL of int * int |
   314   SH_ERROR
   297   SH_ERROR
   315 
   298 
   316 fun run_sh (prover_name, prover) hard_timeout timeout st _ =
   299 fun run_sh prover hard_timeout timeout dir st =
   317   let
   300   let
   318     val atp = prover timeout NONE NONE prover_name 1
   301     val (ctxt, goal) = Proof.get_goal st
       
   302     val ctxt' = ctxt |> is_some dir ? Config.put AtpWrapper.destdir (the dir)
       
   303     val atp = prover (AtpWrapper.atp_problem_of_goal
       
   304       (AtpManager.get_full_types ()) 1 (ctxt', goal))
       
   305 
   319     val time_limit =
   306     val time_limit =
   320       (case hard_timeout of
   307       (case hard_timeout of
   321         NONE => I
   308         NONE => I
   322       | SOME secs => TimeLimit.timeLimit (Time.fromSeconds secs))
   309       | SOME secs => TimeLimit.timeLimit (Time.fromSeconds secs))
   323     val ((success, (message, thm_names), time_atp, _, _, _), time_isa) =
   310     val (AtpWrapper.Prover_Result {success, message, theorem_names,
   324       time_limit (Mirabelle.cpu_time atp) (Proof.get_goal st)
   311       runtime=time_atp, ...}, time_isa) =
       
   312       time_limit (Mirabelle.cpu_time atp) timeout
   325   in
   313   in
   326     if success then (message, SH_OK (time_isa, time_atp, thm_names))
   314     if success then (message, SH_OK (time_isa, time_atp, theorem_names))
   327     else (message, SH_FAIL(time_isa, time_atp))
   315     else (message, SH_FAIL(time_isa, time_atp))
   328   end
   316   end
   329   handle ResHolClause.TOO_TRIVIAL => ("trivial", SH_OK (0, 0, []))
   317   handle ResHolClause.TOO_TRIVIAL => ("trivial", SH_OK (0, 0, []))
   330        | ERROR msg => ("error: " ^ msg, SH_ERROR)
   318        | ERROR msg => ("error: " ^ msg, SH_ERROR)
   331        | TimeLimit.TimeOut => ("timeout", SH_ERROR)
   319        | TimeLimit.TimeOut => ("timeout", SH_ERROR)
   346 in
   334 in
   347 
   335 
   348 fun run_sledgehammer args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   336 fun run_sledgehammer args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   349   let
   337   let
   350     val _ = change_data id inc_sh_calls
   338     val _ = change_data id inc_sh_calls
   351     val atp as (prover_name, _) = get_atp (Proof.theory_of st) args
   339     val (prover_name, prover) = get_atp (Proof.theory_of st) args
   352     val dir = AList.lookup (op =) args keepK
   340     val dir = AList.lookup (op =) args keepK
   353     val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
   341     val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
   354     val hard_timeout = AList.lookup (op =) args prover_hard_timeoutK
   342     val hard_timeout = AList.lookup (op =) args prover_hard_timeoutK
   355       |> Option.map (fst o read_int o explode)
   343       |> Option.map (fst o read_int o explode)
   356     val (msg, result) = safe init_sh done_sh
   344     val (msg, result) = run_sh prover hard_timeout timeout dir st
   357       (run_sh atp hard_timeout timeout st) dir
       
   358   in
   345   in
   359     case result of
   346     case result of
   360       SH_OK (time_isa, time_atp, names) =>
   347       SH_OK (time_isa, time_atp, names) =>
   361         let fun get_thms name = (name, thms_of_name (Proof.context_of st) name)
   348         let fun get_thms name = (name, thms_of_name (Proof.context_of st) name)
   362         in
   349         in