src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
changeset 40062 cfaebaa8588f
parent 40061 71cc5aac8b76
child 40063 d086e3699e78
equal deleted inserted replaced
40061:71cc5aac8b76 40062:cfaebaa8588f
    27   nontriv_calls: int,
    27   nontriv_calls: int,
    28   nontriv_success: int,
    28   nontriv_success: int,
    29   lemmas: int,
    29   lemmas: int,
    30   max_lems: int,
    30   max_lems: int,
    31   time_isa: int,
    31   time_isa: int,
    32   time_atp: int,
    32   time_prover: int,
    33   time_atp_fail: int}
    33   time_prover_fail: int}
    34 
    34 
    35 datatype me_data = MeData of {
    35 datatype me_data = MeData of {
    36   calls: int,
    36   calls: int,
    37   success: int,
    37   success: int,
    38   nontriv_calls: int,
    38   nontriv_calls: int,
    49   ab_ratios: int
    49   ab_ratios: int
    50   }
    50   }
    51 
    51 
    52 fun make_sh_data
    52 fun make_sh_data
    53       (calls,success,nontriv_calls,nontriv_success,lemmas,max_lems,time_isa,
    53       (calls,success,nontriv_calls,nontriv_success,lemmas,max_lems,time_isa,
    54        time_atp,time_atp_fail) =
    54        time_prover,time_prover_fail) =
    55   ShData{calls=calls, success=success, nontriv_calls=nontriv_calls,
    55   ShData{calls=calls, success=success, nontriv_calls=nontriv_calls,
    56          nontriv_success=nontriv_success, lemmas=lemmas, max_lems=max_lems,
    56          nontriv_success=nontriv_success, lemmas=lemmas, max_lems=max_lems,
    57          time_isa=time_isa, time_atp=time_atp, time_atp_fail=time_atp_fail}
    57          time_isa=time_isa, time_prover=time_prover,
       
    58          time_prover_fail=time_prover_fail}
    58 
    59 
    59 fun make_min_data (succs, ab_ratios) =
    60 fun make_min_data (succs, ab_ratios) =
    60   MinData{succs=succs, ab_ratios=ab_ratios}
    61   MinData{succs=succs, ab_ratios=ab_ratios}
    61 
    62 
    62 fun make_me_data (calls,success,nontriv_calls,nontriv_success,proofs,time,
    63 fun make_me_data (calls,success,nontriv_calls,nontriv_success,proofs,time,
    69 val empty_min_data = make_min_data (0, 0)
    70 val empty_min_data = make_min_data (0, 0)
    70 val empty_me_data = make_me_data (0, 0, 0, 0, 0, 0, 0, (0,0,0), [])
    71 val empty_me_data = make_me_data (0, 0, 0, 0, 0, 0, 0, (0,0,0), [])
    71 
    72 
    72 fun tuple_of_sh_data (ShData {calls, success, nontriv_calls, nontriv_success,
    73 fun tuple_of_sh_data (ShData {calls, success, nontriv_calls, nontriv_success,
    73                               lemmas, max_lems, time_isa,
    74                               lemmas, max_lems, time_isa,
    74   time_atp, time_atp_fail}) = (calls, success, nontriv_calls, nontriv_success,
    75   time_prover, time_prover_fail}) = (calls, success, nontriv_calls,
    75   lemmas, max_lems, time_isa, time_atp, time_atp_fail)
    76   nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail)
    76 
    77 
    77 fun tuple_of_min_data (MinData {succs, ab_ratios}) = (succs, ab_ratios)
    78 fun tuple_of_min_data (MinData {succs, ab_ratios}) = (succs, ab_ratios)
    78 
    79 
    79 fun tuple_of_me_data (MeData {calls, success, nontriv_calls, nontriv_success,
    80 fun tuple_of_me_data (MeData {calls, success, nontriv_calls, nontriv_success,
    80   proofs, time, timeout, lemmas, posns}) = (calls, success, nontriv_calls,
    81   proofs, time, timeout, lemmas, posns}) = (calls, success, nontriv_calls,
   125   make_data (sh, min, me_u, me_m, me_uft, me_mft, mini)
   126   make_data (sh, min, me_u, me_m, me_uft, me_mft, mini)
   126 
   127 
   127 fun inc_max (n:int) (s,sos,m) = (s+n, sos + n*n, Int.max(m,n));
   128 fun inc_max (n:int) (s,sos,m) = (s+n, sos + n*n, Int.max(m,n));
   128 
   129 
   129 val inc_sh_calls =  map_sh_data
   130 val inc_sh_calls =  map_sh_data
   130   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_atp, time_atp_fail)
   131   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_prover, time_prover_fail)
   131     => (calls + 1, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_atp, time_atp_fail))
   132     => (calls + 1, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail))
   132 
   133 
   133 val inc_sh_success = map_sh_data
   134 val inc_sh_success = map_sh_data
   134   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_atp, time_atp_fail)
   135   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_prover, time_prover_fail)
   135     => (calls, success + 1, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_atp, time_atp_fail))
   136     => (calls, success + 1, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_prover, time_prover_fail))
   136 
   137 
   137 val inc_sh_nontriv_calls =  map_sh_data
   138 val inc_sh_nontriv_calls =  map_sh_data
   138   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_atp, time_atp_fail)
   139   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_prover, time_prover_fail)
   139     => (calls, success, nontriv_calls + 1, nontriv_success, lemmas, max_lems, time_isa, time_atp, time_atp_fail))
   140     => (calls, success, nontriv_calls + 1, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail))
   140 
   141 
   141 val inc_sh_nontriv_success = map_sh_data
   142 val inc_sh_nontriv_success = map_sh_data
   142   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_atp, time_atp_fail)
   143   (fn (calls, success, nontriv_calls, nontriv_success, lemmas,max_lems, time_isa, time_prover, time_prover_fail)
   143     => (calls, success, nontriv_calls, nontriv_success + 1, lemmas,max_lems, time_isa, time_atp, time_atp_fail))
   144     => (calls, success, nontriv_calls, nontriv_success + 1, lemmas,max_lems, time_isa, time_prover, time_prover_fail))
   144 
   145 
   145 fun inc_sh_lemmas n = map_sh_data
   146 fun inc_sh_lemmas n = map_sh_data
   146   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail)
   147   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail)
   147     => (calls,success,nontriv_calls, nontriv_success, lemmas+n,max_lems,time_isa,time_atp,time_atp_fail))
   148     => (calls,success,nontriv_calls, nontriv_success, lemmas+n,max_lems,time_isa,time_prover,time_prover_fail))
   148 
   149 
   149 fun inc_sh_max_lems n = map_sh_data
   150 fun inc_sh_max_lems n = map_sh_data
   150   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail)
   151   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail)
   151     => (calls,success,nontriv_calls, nontriv_success, lemmas,Int.max(max_lems,n),time_isa,time_atp,time_atp_fail))
   152     => (calls,success,nontriv_calls, nontriv_success, lemmas,Int.max(max_lems,n),time_isa,time_prover,time_prover_fail))
   152 
   153 
   153 fun inc_sh_time_isa t = map_sh_data
   154 fun inc_sh_time_isa t = map_sh_data
   154   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail)
   155   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail)
   155     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa + t,time_atp,time_atp_fail))
   156     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa + t,time_prover,time_prover_fail))
   156 
   157 
   157 fun inc_sh_time_atp t = map_sh_data
   158 fun inc_sh_time_prover t = map_sh_data
   158   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail)
   159   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail)
   159     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp + t,time_atp_fail))
   160     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover + t,time_prover_fail))
   160 
   161 
   161 fun inc_sh_time_atp_fail t = map_sh_data
   162 fun inc_sh_time_prover_fail t = map_sh_data
   162   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail)
   163   (fn (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail)
   163     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_atp,time_atp_fail + t))
   164     => (calls,success,nontriv_calls, nontriv_success, lemmas,max_lems,time_isa,time_prover,time_prover_fail + t))
   164 
   165 
   165 val inc_min_succs = map_min_data
   166 val inc_min_succs = map_min_data
   166   (fn (succs,ab_ratios) => (succs+1, ab_ratios))
   167   (fn (succs,ab_ratios) => (succs+1, ab_ratios))
   167 
   168 
   168 fun inc_min_ab_ratios r = map_min_data
   169 fun inc_min_ab_ratios r = map_min_data
   212 fun time t = Real.fromInt t / 1000.0
   213 fun time t = Real.fromInt t / 1000.0
   213 fun avg_time t n =
   214 fun avg_time t n =
   214   if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
   215   if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
   215 
   216 
   216 fun log_sh_data log
   217 fun log_sh_data log
   217     (calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_atp, time_atp_fail) =
   218     (calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail) =
   218  (log ("Total number of sledgehammer calls: " ^ str calls);
   219  (log ("Total number of sledgehammer calls: " ^ str calls);
   219   log ("Number of successful sledgehammer calls: " ^ str success);
   220   log ("Number of successful sledgehammer calls: " ^ str success);
   220   log ("Number of sledgehammer lemmas: " ^ str lemmas);
   221   log ("Number of sledgehammer lemmas: " ^ str lemmas);
   221   log ("Max number of sledgehammer lemmas: " ^ str max_lems);
   222   log ("Max number of sledgehammer lemmas: " ^ str max_lems);
   222   log ("Success rate: " ^ percentage success calls ^ "%");
   223   log ("Success rate: " ^ percentage success calls ^ "%");
   223   log ("Total number of nontrivial sledgehammer calls: " ^ str nontriv_calls);
   224   log ("Total number of nontrivial sledgehammer calls: " ^ str nontriv_calls);
   224   log ("Number of successful nontrivial sledgehammer calls: " ^ str nontriv_success);
   225   log ("Number of successful nontrivial sledgehammer calls: " ^ str nontriv_success);
   225   log ("Total time for sledgehammer calls (Isabelle): " ^ str3 (time time_isa));
   226   log ("Total time for sledgehammer calls (Isabelle): " ^ str3 (time time_isa));
   226   log ("Total time for successful sledgehammer calls (ATP): " ^ str3 (time time_atp));
   227   log ("Total time for successful sledgehammer calls (ATP): " ^ str3 (time time_prover));
   227   log ("Total time for failed sledgehammer calls (ATP): " ^ str3 (time time_atp_fail));
   228   log ("Total time for failed sledgehammer calls (ATP): " ^ str3 (time time_prover_fail));
   228   log ("Average time for sledgehammer calls (Isabelle): " ^
   229   log ("Average time for sledgehammer calls (Isabelle): " ^
   229     str3 (avg_time time_isa calls));
   230     str3 (avg_time time_isa calls));
   230   log ("Average time for successful sledgehammer calls (ATP): " ^
   231   log ("Average time for successful sledgehammer calls (ATP): " ^
   231     str3 (avg_time time_atp success));
   232     str3 (avg_time time_prover success));
   232   log ("Average time for failed sledgehammer calls (ATP): " ^
   233   log ("Average time for failed sledgehammer calls (ATP): " ^
   233     str3 (avg_time time_atp_fail (calls - success)))
   234     str3 (avg_time time_prover_fail (calls - success)))
   234   )
   235   )
   235 
   236 
   236 
   237 
   237 fun str_of_pos (pos, triv) =
   238 fun str_of_pos (pos, triv) =
   238   let val str0 = string_of_int o the_default 0
   239   let val str0 = string_of_int o the_default 0
   311   |> K ()
   312   |> K ()
   312 
   313 
   313 fun change_data id f = (Unsynchronized.change data (AList.map_entry (op =) id f); ())
   314 fun change_data id f = (Unsynchronized.change data (AList.map_entry (op =) id f); ())
   314 
   315 
   315 
   316 
   316 fun get_atp thy args =
   317 fun get_prover thy args =
   317   let
   318   let
   318     fun default_atp_name () =
   319     fun default_prover_name () =
   319       hd (#provers (Sledgehammer_Isar.default_params thy []))
   320       hd (#provers (Sledgehammer_Isar.default_params thy []))
   320       handle Empty => error "No ATP available."
   321       handle Empty => error "No ATP available."
   321     fun get_atp name = (name, Sledgehammer.run_atp thy name)
   322     fun get_prover name = (name, Sledgehammer.get_prover thy name)
   322   in
   323   in
   323     (case AList.lookup (op =) args proverK of
   324     (case AList.lookup (op =) args proverK of
   324       SOME name => get_atp name
   325       SOME name => get_prover name
   325     | NONE => get_atp (default_atp_name ()))
   326     | NONE => get_prover (default_prover_name ()))
   326   end
   327   end
   327 
   328 
   328 type locality = Sledgehammer_Filter.locality
   329 type locality = Sledgehammer_Filter.locality
   329 
   330 
   330 local
   331 local
   347                  #> Config.put Sledgehammer.measure_run_time true)
   348                  #> Config.put Sledgehammer.measure_run_time true)
   348     val params as {full_types, relevance_thresholds, max_relevant, ...} =
   349     val params as {full_types, relevance_thresholds, max_relevant, ...} =
   349       Sledgehammer_Isar.default_params thy
   350       Sledgehammer_Isar.default_params thy
   350           [("timeout", Int.toString timeout ^ " s")]
   351           [("timeout", Int.toString timeout ^ " s")]
   351     val relevance_override = {add = [], del = [], only = false}
   352     val relevance_override = {add = [], del = [], only = false}
   352     val {default_max_relevant, ...} = ATP_Systems.get_atp thy prover_name
   353     val default_max_relevant =
       
   354       if member (op =) Sledgehammer.smt_prover_names prover_name then
       
   355         Sledgehammer.smt_default_max_relevant
       
   356       else
       
   357         #default_max_relevant (ATP_Systems.get_atp thy prover_name)
   353     val (_, hyp_ts, concl_t) = Sledgehammer_Util.strip_subgoal goal i
   358     val (_, hyp_ts, concl_t) = Sledgehammer_Util.strip_subgoal goal i
   354     val axioms =
   359     val axioms =
   355       Sledgehammer_Filter.relevant_facts ctxt full_types relevance_thresholds
   360       Sledgehammer_Filter.relevant_facts ctxt full_types relevance_thresholds
   356           (the_default default_max_relevant max_relevant)
   361           (the_default default_max_relevant max_relevant)
   357           relevance_override chained_ths hyp_ts concl_t
   362           relevance_override chained_ths hyp_ts concl_t
   360        axioms = axioms |> map Sledgehammer.Unprepared, only = true}
   365        axioms = axioms |> map Sledgehammer.Unprepared, only = true}
   361     val time_limit =
   366     val time_limit =
   362       (case hard_timeout of
   367       (case hard_timeout of
   363         NONE => I
   368         NONE => I
   364       | SOME secs => TimeLimit.timeLimit (Time.fromSeconds secs))
   369       | SOME secs => TimeLimit.timeLimit (Time.fromSeconds secs))
   365     val ({outcome, message, used_axioms, run_time_in_msecs = time_atp, ...}
   370     val ({outcome, message, used_axioms, run_time_in_msecs = SOME time_prover, ...}
   366          : Sledgehammer.prover_result,
   371          : Sledgehammer.prover_result,
   367         time_isa) = time_limit (Mirabelle.cpu_time
   372         time_isa) = time_limit (Mirabelle.cpu_time
   368                       (prover params (K ""))) problem
   373                       (prover params (K ""))) problem
   369   in
   374   in
   370     case outcome of
   375     case outcome of
   371       NONE => (message, SH_OK (time_isa, time_atp, used_axioms))
   376       NONE => (message, SH_OK (time_isa, time_prover, used_axioms))
   372     | SOME _ => (message, SH_FAIL (time_isa, time_atp))
   377     | SOME _ => (message, SH_FAIL (time_isa, time_prover))
   373   end
   378   end
   374   handle ERROR msg => ("error: " ^ msg, SH_ERROR)
   379   handle ERROR msg => ("error: " ^ msg, SH_ERROR)
   375        | TimeLimit.TimeOut => ("timeout", SH_ERROR)
   380        | TimeLimit.TimeOut => ("timeout", SH_ERROR)
   376 
   381 
   377 fun thms_of_name ctxt name =
   382 fun thms_of_name ctxt name =
   392 fun run_sledgehammer trivial args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   397 fun run_sledgehammer trivial args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   393   let
   398   let
   394     val triv_str = if trivial then "[T] " else ""
   399     val triv_str = if trivial then "[T] " else ""
   395     val _ = change_data id inc_sh_calls
   400     val _ = change_data id inc_sh_calls
   396     val _ = if trivial then () else change_data id inc_sh_nontriv_calls
   401     val _ = if trivial then () else change_data id inc_sh_nontriv_calls
   397     val (prover_name, prover) = get_atp (Proof.theory_of st) args
   402     val (prover_name, prover) = get_prover (Proof.theory_of st) args
   398     val dir = AList.lookup (op =) args keepK
   403     val dir = AList.lookup (op =) args keepK
   399     val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
   404     val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
   400     val hard_timeout = AList.lookup (op =) args prover_hard_timeoutK
   405     val hard_timeout = AList.lookup (op =) args prover_hard_timeoutK
   401       |> Option.map (fst o read_int o explode)
   406       |> Option.map (fst o read_int o explode)
   402     val (msg, result) = run_sh prover_name prover hard_timeout timeout dir st
   407     val (msg, result) = run_sh prover_name prover hard_timeout timeout dir st
   403   in
   408   in
   404     case result of
   409     case result of
   405       SH_OK (time_isa, time_atp, names) =>
   410       SH_OK (time_isa, time_prover, names) =>
   406         let
   411         let
   407           fun get_thms (_, Sledgehammer_Filter.Chained) = NONE
   412           fun get_thms (_, Sledgehammer_Filter.Chained) = NONE
   408             | get_thms (name, loc) =
   413             | get_thms (name, loc) =
   409               SOME ((name, loc), thms_of_name (Proof.context_of st) name)
   414               SOME ((name, loc), thms_of_name (Proof.context_of st) name)
   410         in
   415         in
   411           change_data id inc_sh_success;
   416           change_data id inc_sh_success;
   412           if trivial then () else change_data id inc_sh_nontriv_success;
   417           if trivial then () else change_data id inc_sh_nontriv_success;
   413           change_data id (inc_sh_lemmas (length names));
   418           change_data id (inc_sh_lemmas (length names));
   414           change_data id (inc_sh_max_lems (length names));
   419           change_data id (inc_sh_max_lems (length names));
   415           change_data id (inc_sh_time_isa time_isa);
   420           change_data id (inc_sh_time_isa time_isa);
   416           change_data id (inc_sh_time_atp time_atp);
   421           change_data id (inc_sh_time_prover time_prover);
   417           named_thms := SOME (map_filter get_thms names);
   422           named_thms := SOME (map_filter get_thms names);
   418           log (sh_tag id ^ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
   423           log (sh_tag id ^ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
   419             string_of_int time_atp ^ ") [" ^ prover_name ^ "]:\n" ^ msg)
   424             string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg)
   420         end
   425         end
   421     | SH_FAIL (time_isa, time_atp) =>
   426     | SH_FAIL (time_isa, time_prover) =>
   422         let
   427         let
   423           val _ = change_data id (inc_sh_time_isa time_isa)
   428           val _ = change_data id (inc_sh_time_isa time_isa)
   424           val _ = change_data id (inc_sh_time_atp_fail time_atp)
   429           val _ = change_data id (inc_sh_time_prover_fail time_prover)
   425         in log (sh_tag id ^ triv_str ^ "failed: " ^ msg) end
   430         in log (sh_tag id ^ triv_str ^ "failed: " ^ msg) end
   426     | SH_ERROR => log (sh_tag id ^ "failed: " ^ msg)
   431     | SH_ERROR => log (sh_tag id ^ "failed: " ^ msg)
   427   end
   432   end
   428 
   433 
   429 end
   434 end
   434 fun run_minimize args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   439 fun run_minimize args named_thms id ({pre=st, log, ...}: Mirabelle.run_args) =
   435   let
   440   let
   436     open Metis_Translate
   441     open Metis_Translate
   437     val thy = Proof.theory_of st
   442     val thy = Proof.theory_of st
   438     val n0 = length (these (!named_thms))
   443     val n0 = length (these (!named_thms))
   439     val (prover_name, _) = get_atp thy args
   444     val (prover_name, _) = get_prover thy args
   440     val timeout =
   445     val timeout =
   441       AList.lookup (op =) args minimize_timeoutK
   446       AList.lookup (op =) args minimize_timeoutK
   442       |> Option.map (fst o read_int o explode)
   447       |> Option.map (fst o read_int o explode)
   443       |> the_default 5
   448       |> the_default 5
   444     val params = Sledgehammer_Isar.default_params thy
   449     val params = Sledgehammer_Isar.default_params thy