src/HOL/Tools/SMT/z3_replay.ML
author wenzelm
Sun Nov 26 21:08:32 2017 +0100 (17 months ago)
changeset 67091 1393c2340eec
parent 62519 a564458f94db
child 69204 d5ab1636660b
permissions -rw-r--r--
more symbols;
     1 (*  Title:      HOL/Tools/SMT/z3_replay.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4 
     5 Z3 proof parsing and replay.
     6 *)
     7 
     8 signature Z3_REPLAY =
     9 sig
    10   val parse_proof: Proof.context -> SMT_Translate.replay_data ->
    11     ((string * ATP_Problem_Generate.stature) * thm) list -> term list -> term -> string list ->
    12     SMT_Solver.parsed_proof
    13   val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm
    14 end;
    15 
    16 structure Z3_Replay: Z3_REPLAY =
    17 struct
    18 
    19 fun params_of t = Term.strip_qnt_vars @{const_name Pure.all} t
    20 
    21 fun varify ctxt thm =
    22   let
    23     val maxidx = Thm.maxidx_of thm + 1
    24     val vs = params_of (Thm.prop_of thm)
    25     val vars = map_index (fn (i, (n, T)) => Var ((n, i + maxidx), T)) vs
    26   in Drule.forall_elim_list (map (Thm.cterm_of ctxt) vars) thm end
    27 
    28 fun add_paramTs names t =
    29   fold2 (fn n => fn (_, T) => AList.update (op =) (n, T)) names (params_of t)
    30 
    31 fun new_fixes ctxt nTs =
    32   let
    33     val (ns, ctxt') = Variable.variant_fixes (replicate (length nTs) "") ctxt
    34     fun mk (n, T) n' = (n, Thm.cterm_of ctxt' (Free (n', T)))
    35   in (ctxt', Symtab.make (map2 mk nTs ns)) end
    36 
    37 fun forall_elim_term ct (Const (@{const_name Pure.all}, _) $ (a as Abs _)) =
    38       Term.betapply (a, Thm.term_of ct)
    39   | forall_elim_term _ qt = raise TERM ("forall_elim'", [qt])
    40 
    41 fun apply_fixes elim env = fold (elim o the o Symtab.lookup env)
    42 
    43 val apply_fixes_prem = uncurry o apply_fixes Thm.forall_elim
    44 val apply_fixes_concl = apply_fixes forall_elim_term
    45 
    46 fun export_fixes env names = Drule.forall_intr_list (map (the o Symtab.lookup env) names)
    47 
    48 fun under_fixes f ctxt (prems, nthms) names concl =
    49   let
    50     val thms1 = map (varify ctxt) prems
    51     val (ctxt', env) =
    52       add_paramTs names concl []
    53       |> fold (uncurry add_paramTs o apsnd Thm.prop_of) nthms
    54       |> new_fixes ctxt
    55     val thms2 = map (apply_fixes_prem env) nthms
    56     val t = apply_fixes_concl env names concl
    57   in export_fixes env names (f ctxt' (thms1 @ thms2) t) end
    58 
    59 fun replay_thm ctxt assumed nthms (Z3_Proof.Z3_Step {id, rule, concl, fixes, is_fix_step, ...}) =
    60   if Z3_Proof.is_assumption rule then
    61     (case Inttab.lookup assumed id of
    62       SOME (_, thm) => thm
    63     | NONE => Thm.assume (Thm.cterm_of ctxt concl))
    64   else
    65     under_fixes (Z3_Replay_Methods.method_for rule) ctxt
    66       (if is_fix_step then (map snd nthms, []) else ([], nthms)) fixes concl
    67 
    68 fun replay_step ctxt assumed (step as Z3_Proof.Z3_Step {id, rule, prems, fixes, ...}) state =
    69   let
    70     val (proofs, stats) = state
    71     val nthms = map (the o Inttab.lookup proofs) prems
    72     val replay = Timing.timing (replay_thm ctxt assumed nthms)
    73     val ({elapsed, ...}, thm) =
    74       SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step
    75         handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
    76     val stats' = Symtab.cons_list (Z3_Proof.string_of_rule rule, Time.toMilliseconds elapsed) stats
    77   in (Inttab.update (id, (fixes, thm)) proofs, stats') end
    78 
    79 local
    80   val remove_trigger = mk_meta_eq @{thm trigger_def}
    81   val remove_fun_app = mk_meta_eq @{thm fun_app_def}
    82 
    83   fun rewrite_conv _ [] = Conv.all_conv
    84     | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)
    85 
    86   val rewrite_true_rule = @{lemma "True \<equiv> \<not> False" by simp}
    87   val prep_rules = [@{thm Let_def}, remove_trigger, remove_fun_app, rewrite_true_rule]
    88 
    89   fun rewrite _ [] = I
    90     | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
    91 
    92   fun lookup_assm assms_net ct =
    93     Z3_Replay_Util.net_instances assms_net ct
    94     |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
    95 in
    96 
    97 fun add_asserted outer_ctxt rewrite_rules assms steps ctxt =
    98   let
    99     val eqs = map (rewrite ctxt [rewrite_true_rule]) rewrite_rules
   100     val eqs' = union Thm.eq_thm eqs prep_rules
   101 
   102     val assms_net =
   103       assms
   104       |> map (apsnd (rewrite ctxt eqs'))
   105       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
   106       |> Z3_Replay_Util.thm_net_of snd
   107 
   108     fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
   109 
   110     fun assume thm ctxt =
   111       let
   112         val ct = Thm.cprem_of thm 1
   113         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
   114       in (thm' RS thm, ctxt') end
   115 
   116     fun add1 id fixes thm1 ((i, th), exact) ((iidths, thms), (ctxt, ptab)) =
   117       let
   118         val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
   119         val thms' = if exact then thms else th :: thms
   120       in (((i, (id, th)) :: iidths, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
   121 
   122     fun add (Z3_Proof.Z3_Step {id, rule, concl, fixes, ...})
   123         (cx as ((iidths, thms), (ctxt, ptab))) =
   124       if Z3_Proof.is_assumption rule andalso rule <> Z3_Proof.Hypothesis then
   125         let
   126           val ct = Thm.cterm_of ctxt concl
   127           val thm1 = Thm.trivial ct |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
   128           val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
   129         in
   130           (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
   131             [] =>
   132               let val (thm, ctxt') = assume thm1 ctxt
   133               in ((iidths, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
   134           | ithms => fold (add1 id fixes thm1) ithms cx)
   135         end
   136       else
   137         cx
   138   in fold add steps (([], []), (ctxt, Inttab.empty)) end
   139 
   140 end
   141 
   142 (* |- (EX x. P x) = P c     |- ~ (ALL x. P x) = ~ P c *)
   143 local
   144   val sk_rules = @{lemma
   145     "c = (SOME x. P x) \<Longrightarrow> (\<exists>x. P x) = P c"
   146     "c = (SOME x. \<not> P x) \<Longrightarrow> (\<not> (\<forall>x. P x)) = (\<not> P c)"
   147     by (metis someI_ex)+}
   148 in
   149 
   150 fun discharge_sk_tac ctxt i st =
   151   (resolve_tac ctxt @{thms trans} i
   152    THEN resolve_tac ctxt sk_rules i
   153    THEN (resolve_tac ctxt @{thms refl} ORELSE' discharge_sk_tac ctxt) (i+1)
   154    THEN resolve_tac ctxt @{thms refl} i) st
   155 
   156 end
   157 
   158 val true_thm = @{lemma "\<not>False" by simp}
   159 fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl}, @{thm reflexive}, true_thm]
   160 
   161 val intro_def_rules = @{lemma
   162   "(\<not> P \<or> P) \<and> (P \<or> \<not> P)"
   163   "(P \<or> \<not> P) \<and> (\<not> P \<or> P)"
   164   by fast+}
   165 
   166 fun discharge_assms_tac ctxt rules =
   167   REPEAT
   168     (HEADGOAL (resolve_tac ctxt (intro_def_rules @ rules) ORELSE'
   169       SOLVED' (discharge_sk_tac ctxt)))
   170 
   171 fun discharge_assms ctxt rules thm =
   172   (if Thm.nprems_of thm = 0 then
   173      thm
   174    else
   175      (case Seq.pull (discharge_assms_tac ctxt rules thm) of
   176        SOME (thm', _) => thm'
   177      | NONE => raise THM ("failed to discharge premise", 1, [thm])))
   178   |> Goal.norm_result ctxt
   179 
   180 fun discharge rules outer_ctxt inner_ctxt =
   181   singleton (Proof_Context.export inner_ctxt outer_ctxt)
   182   #> discharge_assms outer_ctxt (make_discharge_rules rules)
   183 
   184 fun parse_proof outer_ctxt
   185     ({context = ctxt, typs, terms, ll_defs, rewrite_rules, assms} : SMT_Translate.replay_data)
   186     xfacts prems concl output =
   187   let
   188     val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
   189     val ((iidths, _), _) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
   190 
   191     fun id_of_index i = the_default ~1 (Option.map fst (AList.lookup (op =) iidths i))
   192 
   193     val conjecture_i = 0
   194     val prems_i = 1
   195     val facts_i = prems_i + length prems
   196 
   197     val conjecture_id = id_of_index conjecture_i
   198     val prem_ids = map id_of_index (prems_i upto facts_i - 1)
   199     val fact_ids' =
   200       map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i - facts_i)) iidths
   201     val helper_ids' = map_filter (try (fn (~1, idth) => idth)) iidths
   202 
   203     val fact_helper_ts =
   204       map (fn (_, th) => (ATP_Util.short_thm_name ctxt th, Thm.prop_of th)) helper_ids' @
   205       map (fn (_, ((s, _), th)) => (s, Thm.prop_of th)) fact_ids'
   206     val fact_helper_ids' =
   207       map (apsnd (ATP_Util.short_thm_name ctxt)) helper_ids' @ map (apsnd (fst o fst)) fact_ids'
   208   in
   209     {outcome = NONE, fact_ids = SOME fact_ids',
   210      atp_proof = fn () => Z3_Isar.atp_proof_of_z3_proof ctxt ll_defs rewrite_rules prems concl
   211        fact_helper_ts prem_ids conjecture_id fact_helper_ids' steps}
   212   end
   213 
   214 fun intermediate_statistics ctxt start total =
   215   SMT_Config.statistics_msg ctxt (fn current =>
   216     "Reconstructed " ^ string_of_int current ^ " of " ^ string_of_int total ^ " steps in " ^
   217     string_of_int (Time.toMilliseconds (#elapsed (Timing.result start))) ^ " ms")
   218 
   219 fun pretty_statistics total stats =
   220   let
   221     fun mean_of is =
   222       let
   223         val len = length is
   224         val mid = len div 2
   225       in if len mod 2 = 0 then (nth is (mid - 1) + nth is mid) div 2 else nth is mid end
   226     fun pretty_item name p = Pretty.item (Pretty.separate ":" [Pretty.str name, p])
   227     fun pretty (name, milliseconds) = pretty_item name (Pretty.block (Pretty.separate "," [
   228       Pretty.str (string_of_int (length milliseconds) ^ " occurrences") ,
   229       Pretty.str (string_of_int (mean_of milliseconds) ^ " ms mean time"),
   230       Pretty.str (string_of_int (fold Integer.max milliseconds 0) ^ " ms maximum time"),
   231       Pretty.str (string_of_int (fold Integer.add milliseconds 0) ^ " ms total time")]))
   232   in
   233     Pretty.big_list "Z3 proof reconstruction statistics:" (
   234       pretty_item "total time" (Pretty.str (string_of_int total ^ " ms")) ::
   235       map pretty (Symtab.dest stats))
   236   end
   237 
   238 fun replay outer_ctxt
   239     ({context = ctxt, typs, terms, rewrite_rules, assms, ...} : SMT_Translate.replay_data) output =
   240   let
   241     val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
   242     val ((_, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
   243     val ctxt4 =
   244       ctxt3
   245       |> put_simpset (Z3_Replay_Util.make_simpset ctxt3 [])
   246       |> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver)
   247     val len = length steps
   248     val start = Timing.start ()
   249     val print_runtime_statistics = intermediate_statistics ctxt4 start len
   250     fun blockwise f (i, x) y =
   251       (if i > 0 andalso i mod 100 = 0 then print_runtime_statistics i else (); f x y)
   252     val (proofs, stats) =
   253       fold_index (blockwise (replay_step ctxt4 assumed)) steps (assumed, Symtab.empty)
   254     val _ = print_runtime_statistics len
   255     val total = Time.toMilliseconds (#elapsed (Timing.result start))
   256     val (_, Z3_Proof.Z3_Step {id, ...}) = split_last steps
   257     val _ = SMT_Config.statistics_msg ctxt4 (Pretty.string_of o pretty_statistics total) stats
   258   in
   259     Inttab.lookup proofs id |> the |> snd |> discharge rules outer_ctxt ctxt4
   260   end
   261 
   262 end;