src/HOL/Tools/SMT/z3_replay.ML
author wenzelm
Sun Nov 26 21:08:32 2017 +0100 (18 months ago)
changeset 67091 1393c2340eec
parent 62519 a564458f94db
child 69204 d5ab1636660b
permissions -rw-r--r--
more symbols;
blanchet@58061
     1
(*  Title:      HOL/Tools/SMT/z3_replay.ML
blanchet@56078
     2
    Author:     Sascha Boehme, TU Muenchen
blanchet@56078
     3
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@56078
     4
blanchet@57219
     5
Z3 proof parsing and replay.
blanchet@56078
     6
*)
blanchet@56078
     7
blanchet@58061
     8
signature Z3_REPLAY =
blanchet@56078
     9
sig
blanchet@58061
    10
  val parse_proof: Proof.context -> SMT_Translate.replay_data ->
blanchet@57164
    11
    ((string * ATP_Problem_Generate.stature) * thm) list -> term list -> term -> string list ->
blanchet@58061
    12
    SMT_Solver.parsed_proof
blanchet@58061
    13
  val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm
blanchet@57229
    14
end;
blanchet@56078
    15
blanchet@58061
    16
structure Z3_Replay: Z3_REPLAY =
blanchet@56078
    17
struct
blanchet@56078
    18
wenzelm@56245
    19
fun params_of t = Term.strip_qnt_vars @{const_name Pure.all} t
blanchet@56078
    20
blanchet@56078
    21
fun varify ctxt thm =
blanchet@56078
    22
  let
blanchet@56078
    23
    val maxidx = Thm.maxidx_of thm + 1
blanchet@56078
    24
    val vs = params_of (Thm.prop_of thm)
blanchet@56078
    25
    val vars = map_index (fn (i, (n, T)) => Var ((n, i + maxidx), T)) vs
wenzelm@59621
    26
  in Drule.forall_elim_list (map (Thm.cterm_of ctxt) vars) thm end
blanchet@56078
    27
blanchet@56078
    28
fun add_paramTs names t =
blanchet@56078
    29
  fold2 (fn n => fn (_, T) => AList.update (op =) (n, T)) names (params_of t)
blanchet@56078
    30
blanchet@56078
    31
fun new_fixes ctxt nTs =
blanchet@56078
    32
  let
blanchet@56078
    33
    val (ns, ctxt') = Variable.variant_fixes (replicate (length nTs) "") ctxt
wenzelm@59621
    34
    fun mk (n, T) n' = (n, Thm.cterm_of ctxt' (Free (n', T)))
blanchet@56078
    35
  in (ctxt', Symtab.make (map2 mk nTs ns)) end
blanchet@56078
    36
wenzelm@56245
    37
fun forall_elim_term ct (Const (@{const_name Pure.all}, _) $ (a as Abs _)) =
blanchet@56078
    38
      Term.betapply (a, Thm.term_of ct)
blanchet@56078
    39
  | forall_elim_term _ qt = raise TERM ("forall_elim'", [qt])
blanchet@56078
    40
blanchet@56078
    41
fun apply_fixes elim env = fold (elim o the o Symtab.lookup env)
blanchet@56078
    42
blanchet@56078
    43
val apply_fixes_prem = uncurry o apply_fixes Thm.forall_elim
blanchet@56078
    44
val apply_fixes_concl = apply_fixes forall_elim_term
blanchet@56078
    45
blanchet@56078
    46
fun export_fixes env names = Drule.forall_intr_list (map (the o Symtab.lookup env) names)
blanchet@56078
    47
blanchet@56078
    48
fun under_fixes f ctxt (prems, nthms) names concl =
blanchet@56078
    49
  let
blanchet@56078
    50
    val thms1 = map (varify ctxt) prems
blanchet@56078
    51
    val (ctxt', env) =
blanchet@56078
    52
      add_paramTs names concl []
blanchet@56078
    53
      |> fold (uncurry add_paramTs o apsnd Thm.prop_of) nthms
blanchet@56078
    54
      |> new_fixes ctxt
blanchet@56078
    55
    val thms2 = map (apply_fixes_prem env) nthms
blanchet@56078
    56
    val t = apply_fixes_concl env names concl
blanchet@56078
    57
  in export_fixes env names (f ctxt' (thms1 @ thms2) t) end
blanchet@56078
    58
blanchet@58061
    59
fun replay_thm ctxt assumed nthms (Z3_Proof.Z3_Step {id, rule, concl, fixes, is_fix_step, ...}) =
blanchet@58061
    60
  if Z3_Proof.is_assumption rule then
blanchet@56078
    61
    (case Inttab.lookup assumed id of
blanchet@56078
    62
      SOME (_, thm) => thm
wenzelm@59621
    63
    | NONE => Thm.assume (Thm.cterm_of ctxt concl))
blanchet@56078
    64
  else
blanchet@58061
    65
    under_fixes (Z3_Replay_Methods.method_for rule) ctxt
blanchet@56078
    66
      (if is_fix_step then (map snd nthms, []) else ([], nthms)) fixes concl
blanchet@56078
    67
boehmes@59215
    68
fun replay_step ctxt assumed (step as Z3_Proof.Z3_Step {id, rule, prems, fixes, ...}) state =
boehmes@59213
    69
  let
boehmes@59215
    70
    val (proofs, stats) = state
boehmes@59213
    71
    val nthms = map (the o Inttab.lookup proofs) prems
boehmes@59215
    72
    val replay = Timing.timing (replay_thm ctxt assumed nthms)
boehmes@59215
    73
    val ({elapsed, ...}, thm) =
boehmes@59215
    74
      SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step
wenzelm@62519
    75
        handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
boehmes@59215
    76
    val stats' = Symtab.cons_list (Z3_Proof.string_of_rule rule, Time.toMilliseconds elapsed) stats
boehmes@59215
    77
  in (Inttab.update (id, (fixes, thm)) proofs, stats') end
blanchet@56078
    78
blanchet@56078
    79
local
blanchet@57230
    80
  val remove_trigger = mk_meta_eq @{thm trigger_def}
blanchet@57230
    81
  val remove_fun_app = mk_meta_eq @{thm fun_app_def}
blanchet@56078
    82
blanchet@56078
    83
  fun rewrite_conv _ [] = Conv.all_conv
blanchet@56078
    84
    | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)
blanchet@56078
    85
wenzelm@67091
    86
  val rewrite_true_rule = @{lemma "True \<equiv> \<not> False" by simp}
boehmes@59381
    87
  val prep_rules = [@{thm Let_def}, remove_trigger, remove_fun_app, rewrite_true_rule]
blanchet@56078
    88
blanchet@56078
    89
  fun rewrite _ [] = I
blanchet@56078
    90
    | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
blanchet@56078
    91
blanchet@56078
    92
  fun lookup_assm assms_net ct =
blanchet@58061
    93
    Z3_Replay_Util.net_instances assms_net ct
blanchet@56078
    94
    |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
blanchet@56078
    95
in
blanchet@56078
    96
blanchet@56078
    97
fun add_asserted outer_ctxt rewrite_rules assms steps ctxt =
blanchet@56078
    98
  let
boehmes@59381
    99
    val eqs = map (rewrite ctxt [rewrite_true_rule]) rewrite_rules
blanchet@56078
   100
    val eqs' = union Thm.eq_thm eqs prep_rules
blanchet@56078
   101
blanchet@56078
   102
    val assms_net =
blanchet@56078
   103
      assms
blanchet@56078
   104
      |> map (apsnd (rewrite ctxt eqs'))
blanchet@56078
   105
      |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
blanchet@58061
   106
      |> Z3_Replay_Util.thm_net_of snd
blanchet@56078
   107
blanchet@56078
   108
    fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
blanchet@56078
   109
blanchet@56078
   110
    fun assume thm ctxt =
blanchet@56078
   111
      let
blanchet@56078
   112
        val ct = Thm.cprem_of thm 1
blanchet@56078
   113
        val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
blanchet@56078
   114
      in (thm' RS thm, ctxt') end
blanchet@56078
   115
blanchet@56104
   116
    fun add1 id fixes thm1 ((i, th), exact) ((iidths, thms), (ctxt, ptab)) =
blanchet@56078
   117
      let
blanchet@56078
   118
        val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
blanchet@56078
   119
        val thms' = if exact then thms else th :: thms
blanchet@56104
   120
      in (((i, (id, th)) :: iidths, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
blanchet@56078
   121
blanchet@58061
   122
    fun add (Z3_Proof.Z3_Step {id, rule, concl, fixes, ...})
blanchet@56104
   123
        (cx as ((iidths, thms), (ctxt, ptab))) =
blanchet@58061
   124
      if Z3_Proof.is_assumption rule andalso rule <> Z3_Proof.Hypothesis then
blanchet@56078
   125
        let
wenzelm@59621
   126
          val ct = Thm.cterm_of ctxt concl
blanchet@56099
   127
          val thm1 = Thm.trivial ct |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
blanchet@56078
   128
          val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
blanchet@56078
   129
        in
blanchet@56078
   130
          (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
blanchet@56078
   131
            [] =>
blanchet@56078
   132
              let val (thm, ctxt') = assume thm1 ctxt
blanchet@56104
   133
              in ((iidths, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
blanchet@56078
   134
          | ithms => fold (add1 id fixes thm1) ithms cx)
blanchet@56078
   135
        end
blanchet@56078
   136
      else
blanchet@56078
   137
        cx
blanchet@56078
   138
  in fold add steps (([], []), (ctxt, Inttab.empty)) end
blanchet@56078
   139
blanchet@56078
   140
end
blanchet@56078
   141
blanchet@56078
   142
(* |- (EX x. P x) = P c     |- ~ (ALL x. P x) = ~ P c *)
blanchet@56078
   143
local
blanchet@56078
   144
  val sk_rules = @{lemma
wenzelm@67091
   145
    "c = (SOME x. P x) \<Longrightarrow> (\<exists>x. P x) = P c"
wenzelm@67091
   146
    "c = (SOME x. \<not> P x) \<Longrightarrow> (\<not> (\<forall>x. P x)) = (\<not> P c)"
blanchet@56078
   147
    by (metis someI_ex)+}
blanchet@56078
   148
in
blanchet@56078
   149
wenzelm@59498
   150
fun discharge_sk_tac ctxt i st =
wenzelm@60752
   151
  (resolve_tac ctxt @{thms trans} i
wenzelm@59498
   152
   THEN resolve_tac ctxt sk_rules i
wenzelm@60752
   153
   THEN (resolve_tac ctxt @{thms refl} ORELSE' discharge_sk_tac ctxt) (i+1)
wenzelm@60752
   154
   THEN resolve_tac ctxt @{thms refl} i) st
blanchet@56078
   155
blanchet@56078
   156
end
blanchet@56078
   157
wenzelm@67091
   158
val true_thm = @{lemma "\<not>False" by simp}
boehmes@59381
   159
fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl}, @{thm reflexive}, true_thm]
blanchet@56078
   160
blanchet@56078
   161
val intro_def_rules = @{lemma
wenzelm@67091
   162
  "(\<not> P \<or> P) \<and> (P \<or> \<not> P)"
wenzelm@67091
   163
  "(P \<or> \<not> P) \<and> (\<not> P \<or> P)"
blanchet@56078
   164
  by fast+}
blanchet@56078
   165
wenzelm@59498
   166
fun discharge_assms_tac ctxt rules =
wenzelm@59498
   167
  REPEAT
wenzelm@59498
   168
    (HEADGOAL (resolve_tac ctxt (intro_def_rules @ rules) ORELSE'
wenzelm@59498
   169
      SOLVED' (discharge_sk_tac ctxt)))
blanchet@57230
   170
blanchet@56078
   171
fun discharge_assms ctxt rules thm =
blanchet@56078
   172
  (if Thm.nprems_of thm = 0 then
blanchet@56078
   173
     thm
blanchet@56078
   174
   else
wenzelm@59498
   175
     (case Seq.pull (discharge_assms_tac ctxt rules thm) of
blanchet@56078
   176
       SOME (thm', _) => thm'
blanchet@56078
   177
     | NONE => raise THM ("failed to discharge premise", 1, [thm])))
blanchet@56078
   178
  |> Goal.norm_result ctxt
blanchet@56078
   179
blanchet@56078
   180
fun discharge rules outer_ctxt inner_ctxt =
blanchet@56078
   181
  singleton (Proof_Context.export inner_ctxt outer_ctxt)
blanchet@56078
   182
  #> discharge_assms outer_ctxt (make_discharge_rules rules)
blanchet@56078
   183
blanchet@57157
   184
fun parse_proof outer_ctxt
blanchet@58061
   185
    ({context = ctxt, typs, terms, ll_defs, rewrite_rules, assms} : SMT_Translate.replay_data)
blanchet@57541
   186
    xfacts prems concl output =
blanchet@57157
   187
  let
blanchet@58061
   188
    val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
blanchet@57157
   189
    val ((iidths, _), _) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
blanchet@57164
   190
blanchet@57164
   191
    fun id_of_index i = the_default ~1 (Option.map fst (AList.lookup (op =) iidths i))
blanchet@57164
   192
blanchet@57164
   193
    val conjecture_i = 0
blanchet@57164
   194
    val prems_i = 1
blanchet@57164
   195
    val facts_i = prems_i + length prems
blanchet@57164
   196
blanchet@57164
   197
    val conjecture_id = id_of_index conjecture_i
blanchet@57164
   198
    val prem_ids = map id_of_index (prems_i upto facts_i - 1)
blanchet@58482
   199
    val fact_ids' =
blanchet@58482
   200
      map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i - facts_i)) iidths
blanchet@58482
   201
    val helper_ids' = map_filter (try (fn (~1, idth) => idth)) iidths
blanchet@58482
   202
blanchet@57164
   203
    val fact_helper_ts =
wenzelm@59582
   204
      map (fn (_, th) => (ATP_Util.short_thm_name ctxt th, Thm.prop_of th)) helper_ids' @
wenzelm@59582
   205
      map (fn (_, ((s, _), th)) => (s, Thm.prop_of th)) fact_ids'
blanchet@58482
   206
    val fact_helper_ids' =
blanchet@58482
   207
      map (apsnd (ATP_Util.short_thm_name ctxt)) helper_ids' @ map (apsnd (fst o fst)) fact_ids'
blanchet@57157
   208
  in
blanchet@60201
   209
    {outcome = NONE, fact_ids = SOME fact_ids',
blanchet@58061
   210
     atp_proof = fn () => Z3_Isar.atp_proof_of_z3_proof ctxt ll_defs rewrite_rules prems concl
blanchet@58482
   211
       fact_helper_ts prem_ids conjecture_id fact_helper_ids' steps}
blanchet@57157
   212
  end
blanchet@57157
   213
boehmes@59374
   214
fun intermediate_statistics ctxt start total =
boehmes@59374
   215
  SMT_Config.statistics_msg ctxt (fn current =>
boehmes@59374
   216
    "Reconstructed " ^ string_of_int current ^ " of " ^ string_of_int total ^ " steps in " ^
boehmes@59374
   217
    string_of_int (Time.toMilliseconds (#elapsed (Timing.result start))) ^ " ms")
boehmes@59374
   218
boehmes@59215
   219
fun pretty_statistics total stats =
boehmes@59215
   220
  let
boehmes@59215
   221
    fun mean_of is =
boehmes@59215
   222
      let
boehmes@59215
   223
        val len = length is
boehmes@59215
   224
        val mid = len div 2
boehmes@59215
   225
      in if len mod 2 = 0 then (nth is (mid - 1) + nth is mid) div 2 else nth is mid end
boehmes@59215
   226
    fun pretty_item name p = Pretty.item (Pretty.separate ":" [Pretty.str name, p])
boehmes@59215
   227
    fun pretty (name, milliseconds) = pretty_item name (Pretty.block (Pretty.separate "," [
boehmes@59215
   228
      Pretty.str (string_of_int (length milliseconds) ^ " occurrences") ,
boehmes@59215
   229
      Pretty.str (string_of_int (mean_of milliseconds) ^ " ms mean time"),
boehmes@59215
   230
      Pretty.str (string_of_int (fold Integer.max milliseconds 0) ^ " ms maximum time"),
boehmes@59215
   231
      Pretty.str (string_of_int (fold Integer.add milliseconds 0) ^ " ms total time")]))
boehmes@59215
   232
  in
boehmes@59215
   233
    Pretty.big_list "Z3 proof reconstruction statistics:" (
boehmes@59215
   234
      pretty_item "total time" (Pretty.str (string_of_int total ^ " ms")) ::
boehmes@59215
   235
      map pretty (Symtab.dest stats))
boehmes@59215
   236
  end
boehmes@59215
   237
blanchet@56078
   238
fun replay outer_ctxt
blanchet@58061
   239
    ({context = ctxt, typs, terms, rewrite_rules, assms, ...} : SMT_Translate.replay_data) output =
blanchet@56078
   240
  let
blanchet@58061
   241
    val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
blanchet@57157
   242
    val ((_, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
blanchet@57157
   243
    val ctxt4 =
blanchet@57157
   244
      ctxt3
blanchet@58061
   245
      |> put_simpset (Z3_Replay_Util.make_simpset ctxt3 [])
blanchet@58061
   246
      |> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver)
boehmes@59374
   247
    val len = length steps
boehmes@59374
   248
    val start = Timing.start ()
boehmes@59374
   249
    val print_runtime_statistics = intermediate_statistics ctxt4 start len
boehmes@59374
   250
    fun blockwise f (i, x) y =
boehmes@59374
   251
      (if i > 0 andalso i mod 100 = 0 then print_runtime_statistics i else (); f x y)
boehmes@59374
   252
    val (proofs, stats) =
boehmes@59374
   253
      fold_index (blockwise (replay_step ctxt4 assumed)) steps (assumed, Symtab.empty)
boehmes@59374
   254
    val _ = print_runtime_statistics len
boehmes@59374
   255
    val total = Time.toMilliseconds (#elapsed (Timing.result start))
blanchet@58061
   256
    val (_, Z3_Proof.Z3_Step {id, ...}) = split_last steps
boehmes@59215
   257
    val _ = SMT_Config.statistics_msg ctxt4 (Pretty.string_of o pretty_statistics total) stats
blanchet@56078
   258
  in
blanchet@57157
   259
    Inttab.lookup proofs id |> the |> snd |> discharge rules outer_ctxt ctxt4
blanchet@56078
   260
  end
blanchet@56078
   261
blanchet@57229
   262
end;