src/HOL/Library/Old_SMT/old_z3_proof_reconstruction.ML
author wenzelm
Wed, 08 Mar 2017 10:50:59 +0100
changeset 65151 a7394aa4d21c
parent 64574 1134e4d5e5b7
permissions -rw-r--r--
tuned proofs;

(*  Title:      HOL/Library/Old_SMT/old_z3_proof_reconstruction.ML
    Author:     Sascha Boehme, TU Muenchen

Proof reconstruction for proofs found by Z3.
*)

signature OLD_Z3_PROOF_RECONSTRUCTION =
sig
  val add_z3_rule: thm -> Context.generic -> Context.generic
  val reconstruct: Proof.context -> Old_SMT_Translate.recon -> string list -> int list * thm
end

structure Old_Z3_Proof_Reconstruction: OLD_Z3_PROOF_RECONSTRUCTION =
struct


fun z3_exn msg = raise Old_SMT_Failure.SMT (Old_SMT_Failure.Other_Failure
  ("Z3 proof reconstruction: " ^ msg))



(* net of schematic rules *)

local
  val description = "declaration of Z3 proof rules"

  val eq = Thm.eq_thm

  structure Old_Z3_Rules = Generic_Data
  (
    type T = thm Net.net
    val empty = Net.empty
    val extend = I
    val merge = Net.merge eq
  )

  fun prep context =
    `Thm.prop_of o rewrite_rule (Context.proof_of context) [Old_Z3_Proof_Literals.rewrite_true]

  fun ins thm context =
    context |> Old_Z3_Rules.map (fn net => Net.insert_term eq (prep context thm) net handle Net.INSERT => net)
  fun rem thm context =
    context |> Old_Z3_Rules.map (fn net => Net.delete_term eq (prep context thm) net handle Net.DELETE => net)

  val add = Thm.declaration_attribute ins
  val del = Thm.declaration_attribute rem
in

val add_z3_rule = ins

fun by_schematic_rule ctxt ct =
  the (Old_Z3_Proof_Tools.net_instance (Old_Z3_Rules.get (Context.Proof ctxt)) ct)

val _ = Theory.setup
 (Attrib.setup @{binding old_z3_rule} (Attrib.add_del add del) description #>
  Global_Theory.add_thms_dynamic (@{binding old_z3_rule}, Net.content o Old_Z3_Rules.get))

end



(* proof tools *)

fun named ctxt name prover ct =
  let val _ = Old_SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
  in prover ct end

fun NAMED ctxt name tac i st =
  let val _ = Old_SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
  in tac i st end

fun pretty_goal ctxt thms t =
  [Pretty.block [Pretty.str "proposition: ", Syntax.pretty_term ctxt t]]
  |> not (null thms) ? cons (Pretty.big_list "assumptions:"
       (map (Thm.pretty_thm ctxt) thms))

fun try_apply ctxt thms =
  let
    fun try_apply_err ct = Pretty.string_of (Pretty.chunks [
      Pretty.big_list ("Z3 found a proof," ^
        " but proof reconstruction failed at the following subgoal:")
        (pretty_goal ctxt thms (Thm.term_of ct)),
      Pretty.str ("Declaring a rule as [old_z3_rule] might solve this problem.")])

    fun apply [] ct = error (try_apply_err ct)
      | apply (prover :: provers) ct =
          (case try prover ct of
            SOME thm => (Old_SMT_Config.trace_msg ctxt I "Z3: succeeded"; thm)
          | NONE => apply provers ct)

    fun schematic_label full = "schematic rules" |> full ? suffix " (full)"
    fun schematic ctxt full ct =
      ct
      |> full ? fold_rev (curry Drule.mk_implies o Thm.cprop_of) thms
      |> named ctxt (schematic_label full) (by_schematic_rule ctxt)
      |> fold Thm.elim_implies thms

  in apply o cons (schematic ctxt false) o cons (schematic ctxt true) end

local
  val rewr_if =
    @{lemma "(if P then Q1 else Q2) = ((P --> Q1) & (~P --> Q2))" by simp}
in

fun HOL_fast_tac ctxt =
  Classical.fast_tac (put_claset HOL_cs ctxt)

fun simp_fast_tac ctxt =
  Simplifier.simp_tac (put_simpset HOL_ss ctxt addsimps [rewr_if])
  THEN_ALL_NEW HOL_fast_tac ctxt

end



(* theorems and proofs *)

(** theorem incarnations **)

datatype theorem =
  Thm of thm | (* theorem without special features *)
  MetaEq of thm | (* meta equality "t == s" *)
  Literals of thm * Old_Z3_Proof_Literals.littab
    (* "P1 & ... & Pn" and table of all literals P1, ..., Pn *)

fun thm_of (Thm thm) = thm
  | thm_of (MetaEq thm) = thm COMP @{thm meta_eq_to_obj_eq}
  | thm_of (Literals (thm, _)) = thm

fun meta_eq_of (MetaEq thm) = thm
  | meta_eq_of p = mk_meta_eq (thm_of p)

fun literals_of (Literals (_, lits)) = lits
  | literals_of p = Old_Z3_Proof_Literals.make_littab [thm_of p]



(** core proof rules **)

(* assumption *)

local
  val remove_trigger = mk_meta_eq @{thm trigger_def}
  val remove_weight = mk_meta_eq @{thm weight_def}
  val remove_fun_app = mk_meta_eq @{thm fun_app_def}

  fun rewrite_conv _ [] = Conv.all_conv
    | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)

  val prep_rules = [@{thm Let_def}, remove_trigger, remove_weight,
    remove_fun_app, Old_Z3_Proof_Literals.rewrite_true]

  fun rewrite _ [] = I
    | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)

  fun lookup_assm assms_net ct =
    Old_Z3_Proof_Tools.net_instances assms_net ct
    |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
in

fun add_asserted outer_ctxt rewrite_rules assms asserted ctxt =
  let
    val eqs = map (rewrite ctxt [Old_Z3_Proof_Literals.rewrite_true]) rewrite_rules
    val eqs' = union Thm.eq_thm eqs prep_rules

    val assms_net =
      assms
      |> map (apsnd (rewrite ctxt eqs'))
      |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
      |> Old_Z3_Proof_Tools.thm_net_of snd

    fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion

    fun assume thm ctxt =
      let
        val ct = Thm.cprem_of thm 1
        val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
      in (Thm.implies_elim thm thm', ctxt') end

    fun add1 idx thm1 ((i, th), exact) ((is, thms), (ctxt, ptab)) =
      let
        val (thm, ctxt') =
          if exact then (Thm.implies_elim thm1 th, ctxt)
          else assume thm1 ctxt
        val thms' = if exact then thms else th :: thms
      in
        ((insert (op =) i is, thms'),
          (ctxt', Inttab.update (idx, Thm thm) ptab))
      end

    fun add (idx, ct) (cx as ((is, thms), (ctxt, ptab))) =
      let
        val thm1 =
          Thm.trivial ct
          |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
        val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
      in
        (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
          [] =>
            let val (thm, ctxt') = assume thm1 ctxt
            in ((is, thms), (ctxt', Inttab.update (idx, Thm thm) ptab)) end
        | ithms => fold (add1 idx thm1) ithms cx)
      end
  in fold add asserted (([], []), (ctxt, Inttab.empty)) end

end


(* P = Q ==> P ==> Q   or   P --> Q ==> P ==> Q *)
local
  val precomp = Old_Z3_Proof_Tools.precompose2
  val comp = Old_Z3_Proof_Tools.compose

  val meta_iffD1 = @{lemma "P == Q ==> P ==> (Q::bool)" by simp}
  val meta_iffD1_c = precomp Thm.dest_binop meta_iffD1

  val iffD1_c = precomp (Thm.dest_binop o Thm.dest_arg) @{thm iffD1}
  val mp_c = precomp (Thm.dest_binop o Thm.dest_arg) @{thm mp}
in
fun mp (MetaEq thm) p = Thm (Thm.implies_elim (comp meta_iffD1_c thm) p)
  | mp p_q p =
      let
        val pq = thm_of p_q
        val thm = comp iffD1_c pq handle THM _ => comp mp_c pq
      in Thm (Thm.implies_elim thm p) end
end


(* and_elim:     P1 & ... & Pn ==> Pi *)
(* not_or_elim:  ~(P1 | ... | Pn) ==> ~Pi *)
local
  fun is_sublit conj t = Old_Z3_Proof_Literals.exists_lit conj (fn u => u aconv t)

  fun derive conj t lits idx ptab =
    let
      val lit = the (Old_Z3_Proof_Literals.get_first_lit (is_sublit conj t) lits)
      val ls = Old_Z3_Proof_Literals.explode conj false false [t] lit
      val lits' = fold Old_Z3_Proof_Literals.insert_lit ls
        (Old_Z3_Proof_Literals.delete_lit lit lits)

      fun upd thm = Literals (thm_of thm, lits')
      val ptab' = Inttab.map_entry idx upd ptab
    in (the (Old_Z3_Proof_Literals.lookup_lit lits' t), ptab') end

  fun lit_elim conj (p, idx) ct ptab =
    let val lits = literals_of p
    in
      (case Old_Z3_Proof_Literals.lookup_lit lits (Old_SMT_Utils.term_of ct) of
        SOME lit => (Thm lit, ptab)
      | NONE => apfst Thm (derive conj (Old_SMT_Utils.term_of ct) lits idx ptab))
    end
in
val and_elim = lit_elim true
val not_or_elim = lit_elim false
end


(* P1, ..., Pn |- False ==> |- ~P1 | ... | ~Pn *)
local
  fun step lit thm =
    Thm.implies_elim (Thm.implies_intr (Thm.cprop_of lit) thm) lit
  val explode_disj = Old_Z3_Proof_Literals.explode false false false
  fun intro hyps thm th = fold step (explode_disj hyps th) thm

  fun dest_ccontr ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg1 ct))]
  val ccontr = Old_Z3_Proof_Tools.precompose dest_ccontr @{thm ccontr}
in
fun lemma thm ct =
  let
    val cu = Old_Z3_Proof_Literals.negate (Thm.dest_arg ct)
    val hyps = map_filter (try HOLogic.dest_Trueprop) (Thm.hyps_of thm)
    val th = Old_Z3_Proof_Tools.under_assumption (intro hyps thm) cu
  in Thm (Old_Z3_Proof_Tools.compose ccontr th) end
end


(* \/{P1, ..., Pn, Q1, ..., Qn}, ~P1, ..., ~Pn ==> \/{Q1, ..., Qn} *)
local
  val explode_disj = Old_Z3_Proof_Literals.explode false true false
  val join_disj = Old_Z3_Proof_Literals.join false
  fun unit thm thms th =
    let
      val t = @{const Not} $ Old_SMT_Utils.prop_of thm
      val ts = map Old_SMT_Utils.prop_of thms
    in
      join_disj (Old_Z3_Proof_Literals.make_littab (thms @ explode_disj ts th)) t
    end

  fun dest_arg2 ct = Thm.dest_arg (Thm.dest_arg ct)
  fun dest ct = apply2 dest_arg2 (Thm.dest_binop ct)
  val contrapos =
    Old_Z3_Proof_Tools.precompose2 dest @{lemma "(~P ==> ~Q) ==> Q ==> P" by fast}
in
fun unit_resolution thm thms ct =
  Old_Z3_Proof_Literals.negate (Thm.dest_arg ct)
  |> Old_Z3_Proof_Tools.under_assumption (unit thm thms)
  |> Thm o Old_Z3_Proof_Tools.discharge thm o Old_Z3_Proof_Tools.compose contrapos
end


(* P ==> P == True   or   P ==> P == False *)
local
  val iff1 = @{lemma "P ==> P == (~ False)" by simp}
  val iff2 = @{lemma "~P ==> P == False" by simp}
in
fun iff_true thm = MetaEq (thm COMP iff1)
fun iff_false thm = MetaEq (thm COMP iff2)
end


(* distributivity of | over & *)
fun distributivity ctxt = Thm o try_apply ctxt [] [
  named ctxt "fast" (Old_Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
    (* FIXME: not very well tested *)


(* Tseitin-like axioms *)
local
  val disjI1 = @{lemma "(P ==> Q) ==> ~P | Q" by fast}
  val disjI2 = @{lemma "(~P ==> Q) ==> P | Q" by fast}
  val disjI3 = @{lemma "(~Q ==> P) ==> P | Q" by fast}
  val disjI4 = @{lemma "(Q ==> P) ==> P | ~Q" by fast}

  fun prove' conj1 conj2 ct2 thm =
    let
      val littab =
        Old_Z3_Proof_Literals.explode conj1 true (conj1 <> conj2) [] thm
        |> cons Old_Z3_Proof_Literals.true_thm
        |> Old_Z3_Proof_Literals.make_littab
    in Old_Z3_Proof_Literals.join conj2 littab (Thm.term_of ct2) end

  fun prove rule (ct1, conj1) (ct2, conj2) =
    Old_Z3_Proof_Tools.under_assumption (prove' conj1 conj2 ct2) ct1 COMP rule

  fun prove_def_axiom ct =
    let val (ct1, ct2) = Thm.dest_binop (Thm.dest_arg ct)
    in
      (case Thm.term_of ct1 of
        @{const Not} $ (@{const HOL.conj} $ _ $ _) =>
          prove disjI1 (Thm.dest_arg ct1, true) (ct2, true)
      | @{const HOL.conj} $ _ $ _ =>
          prove disjI3 (Old_Z3_Proof_Literals.negate ct2, false) (ct1, true)
      | @{const Not} $ (@{const HOL.disj} $ _ $ _) =>
          prove disjI3 (Old_Z3_Proof_Literals.negate ct2, false) (ct1, false)
      | @{const HOL.disj} $ _ $ _ =>
          prove disjI2 (Old_Z3_Proof_Literals.negate ct1, false) (ct2, true)
      | Const (@{const_name distinct}, _) $ _ =>
          let
            fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv cv)
            val unfold_dis_conv = dis_conv Old_Z3_Proof_Tools.unfold_distinct_conv
            fun prv cu =
              let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
              in prove disjI4 (Thm.dest_arg cu2, true) (cu1, true) end
          in Old_Z3_Proof_Tools.with_conv unfold_dis_conv prv ct end
      | @{const Not} $ (Const (@{const_name distinct}, _) $ _) =>
          let
            fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv (Conv.arg_conv cv))
            val unfold_dis_conv = dis_conv Old_Z3_Proof_Tools.unfold_distinct_conv
            fun prv cu =
              let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
              in prove disjI1 (Thm.dest_arg cu1, true) (cu2, true) end
          in Old_Z3_Proof_Tools.with_conv unfold_dis_conv prv ct end
      | _ => raise CTERM ("prove_def_axiom", [ct]))
    end
in
fun def_axiom ctxt = Thm o try_apply ctxt [] [
  named ctxt "conj/disj/distinct" prove_def_axiom,
  Old_Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
    named ctxt' "simp+fast" (Old_Z3_Proof_Tools.by_tac ctxt (simp_fast_tac ctxt')))]
end


(* local definitions *)
local
  val intro_rules = [
    @{lemma "n == P ==> (~n | P) & (n | ~P)" by simp},
    @{lemma "n == (if P then s else t) ==> (~P | n = s) & (P | n = t)"
      by simp},
    @{lemma "n == P ==> n = P" by (rule meta_eq_to_obj_eq)} ]

  val apply_rules = [
    @{lemma "(~n | P) & (n | ~P) ==> P == n" by (atomize(full)) fast},
    @{lemma "(~P | n = s) & (P | n = t) ==> (if P then s else t) == n"
      by (atomize(full)) fastforce} ]

  val inst_rule = Old_Z3_Proof_Tools.match_instantiate Thm.dest_arg

  fun apply_rule ct =
    (case get_first (try (inst_rule ct)) intro_rules of
      SOME thm => thm
    | NONE => raise CTERM ("intro_def", [ct]))
in
fun intro_def ct = Old_Z3_Proof_Tools.make_hyp_def (apply_rule ct) #>> Thm

fun apply_def thm =
  get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
  |> the_default (Thm thm)
end


(* negation normal form *)
local
  val quant_rules1 = ([
    @{lemma "(!!x. P x == Q) ==> ALL x. P x == Q" by simp},
    @{lemma "(!!x. P x == Q) ==> EX x. P x == Q" by simp}], [
    @{lemma "(!!x. P x == Q x) ==> ALL x. P x == ALL x. Q x" by simp},
    @{lemma "(!!x. P x == Q x) ==> EX x. P x == EX x. Q x" by simp}])

  val quant_rules2 = ([
    @{lemma "(!!x. ~P x == Q) ==> ~(ALL x. P x) == Q" by simp},
    @{lemma "(!!x. ~P x == Q) ==> ~(EX x. P x) == Q" by simp}], [
    @{lemma "(!!x. ~P x == Q x) ==> ~(ALL x. P x) == EX x. Q x" by simp},
    @{lemma "(!!x. ~P x == Q x) ==> ~(EX x. P x) == ALL x. Q x" by simp}])

  fun nnf_quant_tac ctxt thm (qs as (qs1, qs2)) i st = (
    resolve_tac ctxt [thm] ORELSE'
    (match_tac ctxt qs1 THEN' nnf_quant_tac ctxt thm qs) ORELSE'
    (match_tac ctxt qs2 THEN' nnf_quant_tac ctxt thm qs)) i st

  fun nnf_quant_tac_varified ctxt vars eq =
    nnf_quant_tac ctxt (Old_Z3_Proof_Tools.varify vars eq)

  fun nnf_quant ctxt vars qs p ct =
    Old_Z3_Proof_Tools.as_meta_eq ct
    |> Old_Z3_Proof_Tools.by_tac ctxt (nnf_quant_tac_varified ctxt vars (meta_eq_of p) qs)

  fun prove_nnf ctxt = try_apply ctxt [] [
    named ctxt "conj/disj" Old_Z3_Proof_Literals.prove_conj_disj_eq,
    Old_Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
      named ctxt' "simp+fast" (Old_Z3_Proof_Tools.by_tac ctxt' (simp_fast_tac ctxt')))]
in
fun nnf ctxt vars ps ct =
  (case Old_SMT_Utils.term_of ct of
    _ $ (l as Const _ $ Abs _) $ (r as Const _ $ Abs _) =>
      if l aconv r
      then MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
      else MetaEq (nnf_quant ctxt vars quant_rules1 (hd ps) ct)
  | _ $ (@{const Not} $ (Const _ $ Abs _)) $ (Const _ $ Abs _) =>
      MetaEq (nnf_quant ctxt vars quant_rules2 (hd ps) ct)
  | _ =>
      let
        val nnf_rewr_conv = Conv.arg_conv (Conv.arg_conv
          (Old_Z3_Proof_Tools.unfold_eqs ctxt
            (map (Thm.symmetric o meta_eq_of) ps)))
      in Thm (Old_Z3_Proof_Tools.with_conv nnf_rewr_conv (prove_nnf ctxt) ct) end)
end



(** equality proof rules **)

(* |- t = t *)
fun refl ct = MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))


(* s = t ==> t = s *)
local
  val symm_rule = @{lemma "s = t ==> t == s" by simp}
in
fun symm (MetaEq thm) = MetaEq (Thm.symmetric thm)
  | symm p = MetaEq (thm_of p COMP symm_rule)
end


(* s = t ==> t = u ==> s = u *)
local
  val trans1 = @{lemma "s == t ==> t =  u ==> s == u" by simp}
  val trans2 = @{lemma "s =  t ==> t == u ==> s == u" by simp}
  val trans3 = @{lemma "s =  t ==> t =  u ==> s == u" by simp}
in
fun trans (MetaEq thm1) (MetaEq thm2) = MetaEq (Thm.transitive thm1 thm2)
  | trans (MetaEq thm) q = MetaEq (thm_of q COMP (thm COMP trans1))
  | trans p (MetaEq thm) = MetaEq (thm COMP (thm_of p COMP trans2))
  | trans p q = MetaEq (thm_of q COMP (thm_of p COMP trans3))
end


(* t1 = s1 ==> ... ==> tn = sn ==> f t1 ... tn = f s1 .. sn
   (reflexive antecendents are droppped) *)
local
  exception MONO

  fun prove_refl (ct, _) = Thm.reflexive ct
  fun prove_comb f g cp =
    let val ((ct1, ct2), (cu1, cu2)) = apply2 Thm.dest_comb cp
    in Thm.combination (f (ct1, cu1)) (g (ct2, cu2)) end
  fun prove_arg f = prove_comb prove_refl f

  fun prove f cp = prove_comb (prove f) f cp handle CTERM _ => prove_refl cp

  fun prove_nary is_comb f =
    let
      fun prove (cp as (ct, _)) = f cp handle MONO =>
        if is_comb (Thm.term_of ct)
        then prove_comb (prove_arg prove) prove cp
        else prove_refl cp
    in prove end

  fun prove_list f n cp =
    if n = 0 then prove_refl cp
    else prove_comb (prove_arg f) (prove_list f (n-1)) cp

  fun with_length f (cp as (cl, _)) =
    f (length (HOLogic.dest_list (Thm.term_of cl))) cp

  fun prove_distinct f = prove_arg (with_length (prove_list f))

  fun prove_eq exn lookup cp =
    (case lookup (Logic.mk_equals (apply2 Thm.term_of cp)) of
      SOME eq => eq
    | NONE => if exn then raise MONO else prove_refl cp)

  val prove_exn = prove_eq true
  and prove_safe = prove_eq false

  fun mono f (cp as (cl, _)) =
    (case Term.head_of (Thm.term_of cl) of
      @{const HOL.conj} => prove_nary Old_Z3_Proof_Literals.is_conj (prove_exn f)
    | @{const HOL.disj} => prove_nary Old_Z3_Proof_Literals.is_disj (prove_exn f)
    | Const (@{const_name distinct}, _) => prove_distinct (prove_safe f)
    | _ => prove (prove_safe f)) cp
in
fun monotonicity eqs ct =
  let
    fun and_symmetric (t, thm) = [(t, thm), (t, Thm.symmetric thm)]
    val teqs = maps (and_symmetric o `Thm.prop_of o meta_eq_of) eqs
    val lookup = AList.lookup (op aconv) teqs
    val cp = Thm.dest_binop (Thm.dest_arg ct)
  in MetaEq (prove_exn lookup cp handle MONO => mono lookup cp) end
end


(* |- f a b = f b a (where f is equality) *)
local
  val rule = @{lemma "a = b == b = a" by (atomize(full)) (rule eq_commute)}
in
fun commutativity ct =
  MetaEq (Old_Z3_Proof_Tools.match_instantiate I
    (Old_Z3_Proof_Tools.as_meta_eq ct) rule)
end



(** quantifier proof rules **)

(* P ?x = Q ?x ==> (ALL x. P x) = (ALL x. Q x)
   P ?x = Q ?x ==> (EX x. P x) = (EX x. Q x)    *)
local
  val rules = [
    @{lemma "(!!x. P x == Q x) ==> (ALL x. P x) == (ALL x. Q x)" by simp},
    @{lemma "(!!x. P x == Q x) ==> (EX x. P x) == (EX x. Q x)" by simp}]
in
fun quant_intro ctxt vars p ct =
  let
    val thm = meta_eq_of p
    val rules' = Old_Z3_Proof_Tools.varify vars thm :: rules
    val cu = Old_Z3_Proof_Tools.as_meta_eq ct
    val tac = REPEAT_ALL_NEW (match_tac ctxt rules')
  in MetaEq (Old_Z3_Proof_Tools.by_tac ctxt tac cu) end
end


(* |- ((ALL x. P x) | Q) = (ALL x. P x | Q) *)
fun pull_quant ctxt = Thm o try_apply ctxt [] [
  named ctxt "fast" (Old_Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
    (* FIXME: not very well tested *)


(* |- (ALL x. P x & Q x) = ((ALL x. P x) & (ALL x. Q x)) *)
fun push_quant ctxt = Thm o try_apply ctxt [] [
  named ctxt "fast" (Old_Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
    (* FIXME: not very well tested *)


(* |- (ALL x1 ... xn y1 ... yn. P x1 ... xn) = (ALL x1 ... xn. P x1 ... xn) *)
local
  val elim_all = @{lemma "P = Q ==> (ALL x. P) = Q" by fast}
  val elim_ex = @{lemma "P = Q ==> (EX x. P) = Q" by fast}

  fun elim_unused_tac ctxt i st = (
    match_tac ctxt [@{thm refl}]
    ORELSE' (match_tac ctxt [elim_all, elim_ex] THEN' elim_unused_tac ctxt)
    ORELSE' (
      match_tac ctxt [@{thm iff_allI}, @{thm iff_exI}]
      THEN' elim_unused_tac ctxt)) i st
in

fun elim_unused_vars ctxt = Thm o Old_Z3_Proof_Tools.by_tac ctxt (elim_unused_tac ctxt)

end


(* |- (ALL x1 ... xn. ~(x1 = t1 & ... xn = tn) | P x1 ... xn) = P t1 ... tn *)
fun dest_eq_res ctxt = Thm o try_apply ctxt [] [
  named ctxt "fast" (Old_Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
    (* FIXME: not very well tested *)


(* |- ~(ALL x1...xn. P x1...xn) | P a1...an *)
local
  val rule = @{lemma "~ P x | Q ==> ~(ALL x. P x) | Q" by fast}
in
fun quant_inst ctxt = Thm o Old_Z3_Proof_Tools.by_tac ctxt (
  REPEAT_ALL_NEW (match_tac ctxt [rule])
  THEN' resolve_tac ctxt @{thms excluded_middle})
end


(* |- (EX x. P x) = P c     |- ~(ALL x. P x) = ~ P c *)
local
  val forall =
    Old_SMT_Utils.mk_const_pat @{theory} @{const_name Pure.all}
      (Old_SMT_Utils.destT1 o Old_SMT_Utils.destT1)
  fun mk_forall cv ct =
    Thm.apply (Old_SMT_Utils.instT' cv forall) (Thm.lambda cv ct)

  fun get_vars f mk pred ctxt t =
    Term.fold_aterms f t []
    |> map_filter (fn v =>
         if pred v then SOME (Thm.cterm_of ctxt (mk v)) else NONE)

  fun close vars f ct ctxt =
    let
      val frees_of = get_vars Term.add_frees Free (member (op =) vars o fst)
      val vs = frees_of ctxt (Thm.term_of ct)
      val (thm, ctxt') = f (fold_rev mk_forall vs ct) ctxt
      val vars_of = get_vars Term.add_vars Var (K true) ctxt'
    in
      (Thm.instantiate ([], map (dest_Var o Thm.term_of) (vars_of (Thm.prop_of thm)) ~~ vs) thm,
        ctxt')
    end

  val sk_rules = @{lemma
    "c = (SOME x. P x) ==> (EX x. P x) = P c"
    "c = (SOME x. ~P x) ==> (~(ALL x. P x)) = (~P c)"
    by (metis someI_ex)+}
in

fun skolemize vars =
  apfst Thm oo close vars (yield_singleton Assumption.add_assumes)

fun discharge_sk_tac ctxt i st = (
  resolve_tac ctxt @{thms trans} i
  THEN resolve_tac ctxt sk_rules i
  THEN (resolve_tac ctxt @{thms refl} ORELSE' discharge_sk_tac ctxt) (i+1)
  THEN resolve_tac ctxt @{thms refl} i) st

end



(** theory proof rules **)

(* theory lemmas: linear arithmetic, arrays *)
fun th_lemma ctxt simpset thms = Thm o try_apply ctxt thms [
  Old_Z3_Proof_Tools.by_abstraction 0 (false, true) ctxt thms (fn ctxt' =>
    Old_Z3_Proof_Tools.by_tac ctxt' (
      NAMED ctxt' "arith" (Arith_Data.arith_tac ctxt')
      ORELSE' NAMED ctxt' "simp+arith" (
        Simplifier.asm_full_simp_tac (put_simpset simpset ctxt')
        THEN_ALL_NEW Arith_Data.arith_tac ctxt')))]


(* rewriting: prove equalities:
     * ACI of conjunction/disjunction
     * contradiction, excluded middle
     * logical rewriting rules (for negation, implication, equivalence,
         distinct)
     * normal forms for polynoms (integer/real arithmetic)
     * quantifier elimination over linear arithmetic
     * ... ? **)
local
  fun spec_meta_eq_of thm =
    (case try (fn th => th RS @{thm spec}) thm of
      SOME thm' => spec_meta_eq_of thm'
    | NONE => mk_meta_eq thm)

  fun prep (Thm thm) = spec_meta_eq_of thm
    | prep (MetaEq thm) = thm
    | prep (Literals (thm, _)) = spec_meta_eq_of thm

  fun unfold_conv ctxt ths =
    Conv.arg_conv (Conv.binop_conv (Old_Z3_Proof_Tools.unfold_eqs ctxt
      (map prep ths)))

  fun with_conv _ [] prv = prv
    | with_conv ctxt ths prv =
        Old_Z3_Proof_Tools.with_conv (unfold_conv ctxt ths) prv

  val unfold_conv =
    Conv.arg_conv (Conv.binop_conv
      (Conv.try_conv Old_Z3_Proof_Tools.unfold_distinct_conv))
  val prove_conj_disj_eq =
    Old_Z3_Proof_Tools.with_conv unfold_conv Old_Z3_Proof_Literals.prove_conj_disj_eq

  fun declare_hyps ctxt thm =
    (thm, snd (Assumption.add_assumes (Thm.chyps_of thm) ctxt))
in

val abstraction_depth = 3
  (*
    This value was chosen large enough to potentially catch exceptions,
    yet small enough to not cause too much harm.  The value might be
    increased in the future, if reconstructing 'rewrite' fails on problems
    that get too much abstracted to be reconstructable.
  *)

fun rewrite simpset ths ct ctxt =
  apfst Thm (declare_hyps ctxt (with_conv ctxt ths (try_apply ctxt [] [
    named ctxt "conj/disj/distinct" prove_conj_disj_eq,
    named ctxt "pull-ite" Old_Z3_Proof_Methods.prove_ite ctxt,
    Old_Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
      Old_Z3_Proof_Tools.by_tac ctxt' (
        NAMED ctxt' "simp (logic)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
        THEN_ALL_NEW NAMED ctxt' "fast (logic)" (fast_tac ctxt'))),
    Old_Z3_Proof_Tools.by_abstraction 0 (false, true) ctxt [] (fn ctxt' =>
      Old_Z3_Proof_Tools.by_tac ctxt' (
        (resolve_tac ctxt' @{thms iff_allI} ORELSE' K all_tac)
        THEN' NAMED ctxt' "simp (theory)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
        THEN_ALL_NEW (
          NAMED ctxt' "fast (theory)" (HOL_fast_tac ctxt')
          ORELSE' NAMED ctxt' "arith (theory)" (Arith_Data.arith_tac ctxt')))),
    Old_Z3_Proof_Tools.by_abstraction 0 (true, true) ctxt [] (fn ctxt' =>
      Old_Z3_Proof_Tools.by_tac ctxt' (
        (resolve_tac ctxt' @{thms iff_allI} ORELSE' K all_tac)
        THEN' NAMED ctxt' "simp (full)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
        THEN_ALL_NEW (
          NAMED ctxt' "fast (full)" (HOL_fast_tac ctxt')
          ORELSE' NAMED ctxt' "arith (full)" (Arith_Data.arith_tac ctxt')))),
    named ctxt "injectivity" (Old_Z3_Proof_Methods.prove_injectivity ctxt),
    Old_Z3_Proof_Tools.by_abstraction abstraction_depth (true, true) ctxt []
      (fn ctxt' =>
        Old_Z3_Proof_Tools.by_tac ctxt' (
          (resolve_tac ctxt' @{thms iff_allI} ORELSE' K all_tac)
          THEN' NAMED ctxt' "simp (deepen)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
          THEN_ALL_NEW (
            NAMED ctxt' "fast (deepen)" (HOL_fast_tac ctxt')
            ORELSE' NAMED ctxt' "arith (deepen)" (Arith_Data.arith_tac
              ctxt'))))]) ct))

end



(* proof reconstruction *)

(** tracing and checking **)

fun trace_before ctxt idx = Old_SMT_Config.trace_msg ctxt (fn r =>
  "Z3: #" ^ string_of_int idx ^ ": " ^ Old_Z3_Proof_Parser.string_of_rule r)

fun check_after idx r ps ct (p, (ctxt, _)) =
  if not (Config.get ctxt Old_SMT_Config.trace) then ()
  else
    let
      val thm = thm_of p
      val _ = Thm.consolidate thm
    in
      if (Thm.cprop_of thm) aconvc ct then ()
      else
        z3_exn (Pretty.string_of (Pretty.big_list
          ("proof step failed: " ^ quote (Old_Z3_Proof_Parser.string_of_rule r) ^
            " (#" ^ string_of_int idx ^ ")")
          (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
            [Pretty.block [Pretty.str "expected: ",
              Syntax.pretty_term ctxt (Thm.term_of ct)]])))
    end


(** overall reconstruction procedure **)

local
  fun not_supported r = raise Fail ("Z3: proof rule not implemented: " ^
    quote (Old_Z3_Proof_Parser.string_of_rule r))

  fun prove_step simpset vars r ps ct (cxp as (cx, ptab)) =
    (case (r, ps) of
      (* core rules *)
      (Old_Z3_Proof_Parser.True_Axiom, _) => (Thm Old_Z3_Proof_Literals.true_thm, cxp)
    | (Old_Z3_Proof_Parser.Asserted, _) => raise Fail "bad assertion"
    | (Old_Z3_Proof_Parser.Goal, _) => raise Fail "bad assertion"
    | (Old_Z3_Proof_Parser.Modus_Ponens, [(p, _), (q, _)]) =>
        (mp q (thm_of p), cxp)
    | (Old_Z3_Proof_Parser.Modus_Ponens_Oeq, [(p, _), (q, _)]) =>
        (mp q (thm_of p), cxp)
    | (Old_Z3_Proof_Parser.And_Elim, [(p, i)]) =>
        and_elim (p, i) ct ptab ||> pair cx
    | (Old_Z3_Proof_Parser.Not_Or_Elim, [(p, i)]) =>
        not_or_elim (p, i) ct ptab ||> pair cx
    | (Old_Z3_Proof_Parser.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
    | (Old_Z3_Proof_Parser.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
    | (Old_Z3_Proof_Parser.Unit_Resolution, (p, _) :: ps) =>
        (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
    | (Old_Z3_Proof_Parser.Iff_True, [(p, _)]) => (iff_true (thm_of p), cxp)
    | (Old_Z3_Proof_Parser.Iff_False, [(p, _)]) => (iff_false (thm_of p), cxp)
    | (Old_Z3_Proof_Parser.Distributivity, _) => (distributivity cx ct, cxp)
    | (Old_Z3_Proof_Parser.Def_Axiom, _) => (def_axiom cx ct, cxp)
    | (Old_Z3_Proof_Parser.Intro_Def, _) => intro_def ct cx ||> rpair ptab
    | (Old_Z3_Proof_Parser.Apply_Def, [(p, _)]) => (apply_def (thm_of p), cxp)
    | (Old_Z3_Proof_Parser.Iff_Oeq, [(p, _)]) => (p, cxp)
    | (Old_Z3_Proof_Parser.Nnf_Pos, _) => (nnf cx vars (map fst ps) ct, cxp)
    | (Old_Z3_Proof_Parser.Nnf_Neg, _) => (nnf cx vars (map fst ps) ct, cxp)

      (* equality rules *)
    | (Old_Z3_Proof_Parser.Reflexivity, _) => (refl ct, cxp)
    | (Old_Z3_Proof_Parser.Symmetry, [(p, _)]) => (symm p, cxp)
    | (Old_Z3_Proof_Parser.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
    | (Old_Z3_Proof_Parser.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
    | (Old_Z3_Proof_Parser.Commutativity, _) => (commutativity ct, cxp)

      (* quantifier rules *)
    | (Old_Z3_Proof_Parser.Quant_Intro, [(p, _)]) => (quant_intro cx vars p ct, cxp)
    | (Old_Z3_Proof_Parser.Pull_Quant, _) => (pull_quant cx ct, cxp)
    | (Old_Z3_Proof_Parser.Push_Quant, _) => (push_quant cx ct, cxp)
    | (Old_Z3_Proof_Parser.Elim_Unused_Vars, _) => (elim_unused_vars cx ct, cxp)
    | (Old_Z3_Proof_Parser.Dest_Eq_Res, _) => (dest_eq_res cx ct, cxp)
    | (Old_Z3_Proof_Parser.Quant_Inst, _) => (quant_inst cx ct, cxp)
    | (Old_Z3_Proof_Parser.Skolemize, _) => skolemize vars ct cx ||> rpair ptab

      (* theory rules *)
    | (Old_Z3_Proof_Parser.Th_Lemma _, _) =>  (* FIXME: use arguments *)
        (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
    | (Old_Z3_Proof_Parser.Rewrite, _) => rewrite simpset [] ct cx ||> rpair ptab
    | (Old_Z3_Proof_Parser.Rewrite_Star, ps) =>
        rewrite simpset (map fst ps) ct cx ||> rpair ptab

    | (Old_Z3_Proof_Parser.Nnf_Star, _) => not_supported r
    | (Old_Z3_Proof_Parser.Cnf_Star, _) => not_supported r
    | (Old_Z3_Proof_Parser.Transitivity_Star, _) => not_supported r
    | (Old_Z3_Proof_Parser.Pull_Quant_Star, _) => not_supported r

    | _ => raise Fail ("Z3: proof rule " ^
        quote (Old_Z3_Proof_Parser.string_of_rule r) ^
        " has an unexpected number of arguments."))

  fun lookup_proof ptab idx =
    (case Inttab.lookup ptab idx of
      SOME p => (p, idx)
    | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))

  fun prove simpset vars (idx, step) (_, cxp as (ctxt, ptab)) =
    let
      val Old_Z3_Proof_Parser.Proof_Step {rule=r, prems, prop, ...} = step
      val ps = map (lookup_proof ptab) prems
      val _ = trace_before ctxt idx r
      val (thm, (ctxt', ptab')) =
        cxp
        |> prove_step simpset vars r ps prop
        |> tap (check_after idx r ps prop)
    in (thm, (ctxt', Inttab.update (idx, thm) ptab')) end

  fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl},
    @{thm reflexive}, Old_Z3_Proof_Literals.true_thm]

  fun discharge_assms_tac ctxt rules =
    REPEAT (HEADGOAL (resolve_tac ctxt rules ORELSE' SOLVED' (discharge_sk_tac ctxt)))

  fun discharge_assms ctxt rules thm =
    if Thm.nprems_of thm = 0 then Goal.norm_result ctxt thm
    else
      (case Seq.pull (discharge_assms_tac ctxt rules thm) of
        SOME (thm', _) => Goal.norm_result ctxt thm'
      | NONE => raise THM ("failed to discharge premise", 1, [thm]))

  fun discharge rules outer_ctxt (p, (inner_ctxt, _)) =
    thm_of p
    |> singleton (Proof_Context.export inner_ctxt outer_ctxt)
    |> discharge_assms outer_ctxt (make_discharge_rules rules)
in

fun reconstruct outer_ctxt recon output =
  let
    val {context=ctxt, typs, terms, rewrite_rules, assms} = recon
    val (asserted, steps, vars, ctxt1) =
      Old_Z3_Proof_Parser.parse ctxt typs terms output

    val simpset =
      Old_Z3_Proof_Tools.make_simpset ctxt1 (Named_Theorems.get ctxt1 @{named_theorems old_z3_simp})

    val ((is, rules), cxp as (ctxt2, _)) =
      add_asserted outer_ctxt rewrite_rules assms asserted ctxt1
  in
    if Config.get ctxt2 Old_SMT_Config.filter_only_facts then (is, @{thm TrueI})
    else
      (Thm @{thm TrueI}, cxp)
      |> fold (prove simpset vars) steps
      |> discharge rules outer_ctxt
      |> pair []
  end

end

end