src/HOL/Tools/SMT/z3_proof_reconstruction.ML
author wenzelm
Fri Mar 21 20:33:56 2014 +0100 (2014-03-21)
changeset 56245 84fc7dfa3cd4
parent 54883 dd04a8b654fc
child 57957 e6ee35b8f4b5
permissions -rw-r--r--
more qualified names;
     1 (*  Title:      HOL/Tools/SMT/z3_proof_reconstruction.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Proof reconstruction for proofs found by Z3.
     5 *)
     6 
     7 signature Z3_PROOF_RECONSTRUCTION =
     8 sig
     9   val add_z3_rule: thm -> Context.generic -> Context.generic
    10   val reconstruct: Proof.context -> SMT_Translate.recon -> string list ->
    11     int list * thm
    12   val setup: theory -> theory
    13 end
    14 
    15 structure Z3_Proof_Reconstruction: Z3_PROOF_RECONSTRUCTION =
    16 struct
    17 
    18 
    19 fun z3_exn msg = raise SMT_Failure.SMT (SMT_Failure.Other_Failure
    20   ("Z3 proof reconstruction: " ^ msg))
    21 
    22 
    23 
    24 (* net of schematic rules *)
    25 
    26 val z3_ruleN = "z3_rule"
    27 
    28 local
    29   val description = "declaration of Z3 proof rules"
    30 
    31   val eq = Thm.eq_thm
    32 
    33   structure Z3_Rules = Generic_Data
    34   (
    35     type T = thm Net.net
    36     val empty = Net.empty
    37     val extend = I
    38     val merge = Net.merge eq
    39   )
    40 
    41   fun prep context =
    42     `Thm.prop_of o rewrite_rule (Context.proof_of context) [Z3_Proof_Literals.rewrite_true]
    43 
    44   fun ins thm context =
    45     context |> Z3_Rules.map (fn net => Net.insert_term eq (prep context thm) net handle Net.INSERT => net)
    46   fun rem thm context =
    47     context |> Z3_Rules.map (fn net => Net.delete_term eq (prep context thm) net handle Net.DELETE => net)
    48 
    49   val add = Thm.declaration_attribute ins
    50   val del = Thm.declaration_attribute rem
    51 in
    52 
    53 val add_z3_rule = ins
    54 
    55 fun by_schematic_rule ctxt ct =
    56   the (Z3_Proof_Tools.net_instance (Z3_Rules.get (Context.Proof ctxt)) ct)
    57 
    58 val z3_rules_setup =
    59   Attrib.setup (Binding.name z3_ruleN) (Attrib.add_del add del) description #>
    60   Global_Theory.add_thms_dynamic (Binding.name z3_ruleN, Net.content o Z3_Rules.get)
    61 
    62 end
    63 
    64 
    65 
    66 (* proof tools *)
    67 
    68 fun named ctxt name prover ct =
    69   let val _ = SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
    70   in prover ct end
    71 
    72 fun NAMED ctxt name tac i st =
    73   let val _ = SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
    74   in tac i st end
    75 
    76 fun pretty_goal ctxt thms t =
    77   [Pretty.block [Pretty.str "proposition: ", Syntax.pretty_term ctxt t]]
    78   |> not (null thms) ? cons (Pretty.big_list "assumptions:"
    79        (map (Display.pretty_thm ctxt) thms))
    80 
    81 fun try_apply ctxt thms =
    82   let
    83     fun try_apply_err ct = Pretty.string_of (Pretty.chunks [
    84       Pretty.big_list ("Z3 found a proof," ^
    85         " but proof reconstruction failed at the following subgoal:")
    86         (pretty_goal ctxt thms (Thm.term_of ct)),
    87       Pretty.str ("Adding a rule to the lemma group " ^ quote z3_ruleN ^
    88         " might solve this problem.")])
    89 
    90     fun apply [] ct = error (try_apply_err ct)
    91       | apply (prover :: provers) ct =
    92           (case try prover ct of
    93             SOME thm => (SMT_Config.trace_msg ctxt I "Z3: succeeded"; thm)
    94           | NONE => apply provers ct)
    95 
    96     fun schematic_label full = "schematic rules" |> full ? suffix " (full)"
    97     fun schematic ctxt full ct =
    98       ct
    99       |> full ? fold_rev (curry Drule.mk_implies o Thm.cprop_of) thms
   100       |> named ctxt (schematic_label full) (by_schematic_rule ctxt)
   101       |> fold Thm.elim_implies thms
   102 
   103   in apply o cons (schematic ctxt false) o cons (schematic ctxt true) end
   104 
   105 local
   106   val rewr_if =
   107     @{lemma "(if P then Q1 else Q2) = ((P --> Q1) & (~P --> Q2))" by simp}
   108 in
   109 
   110 fun HOL_fast_tac ctxt =
   111   Classical.fast_tac (put_claset HOL_cs ctxt)
   112 
   113 fun simp_fast_tac ctxt =
   114   Simplifier.simp_tac (put_simpset HOL_ss ctxt addsimps [rewr_if])
   115   THEN_ALL_NEW HOL_fast_tac ctxt
   116 
   117 end
   118 
   119 
   120 
   121 (* theorems and proofs *)
   122 
   123 (** theorem incarnations **)
   124 
   125 datatype theorem =
   126   Thm of thm | (* theorem without special features *)
   127   MetaEq of thm | (* meta equality "t == s" *)
   128   Literals of thm * Z3_Proof_Literals.littab
   129     (* "P1 & ... & Pn" and table of all literals P1, ..., Pn *)
   130 
   131 fun thm_of (Thm thm) = thm
   132   | thm_of (MetaEq thm) = thm COMP @{thm meta_eq_to_obj_eq}
   133   | thm_of (Literals (thm, _)) = thm
   134 
   135 fun meta_eq_of (MetaEq thm) = thm
   136   | meta_eq_of p = mk_meta_eq (thm_of p)
   137 
   138 fun literals_of (Literals (_, lits)) = lits
   139   | literals_of p = Z3_Proof_Literals.make_littab [thm_of p]
   140 
   141 
   142 
   143 (** core proof rules **)
   144 
   145 (* assumption *)
   146 
   147 local
   148   val remove_trigger = mk_meta_eq @{thm SMT.trigger_def}
   149   val remove_weight = mk_meta_eq @{thm SMT.weight_def}
   150   val remove_fun_app = mk_meta_eq @{thm SMT.fun_app_def}
   151 
   152   fun rewrite_conv _ [] = Conv.all_conv
   153     | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)
   154 
   155   val prep_rules = [@{thm Let_def}, remove_trigger, remove_weight,
   156     remove_fun_app, Z3_Proof_Literals.rewrite_true]
   157 
   158   fun rewrite _ [] = I
   159     | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
   160 
   161   fun lookup_assm assms_net ct =
   162     Z3_Proof_Tools.net_instances assms_net ct
   163     |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
   164 in
   165 
   166 fun add_asserted outer_ctxt rewrite_rules assms asserted ctxt =
   167   let
   168     val eqs = map (rewrite ctxt [Z3_Proof_Literals.rewrite_true]) rewrite_rules
   169     val eqs' = union Thm.eq_thm eqs prep_rules
   170 
   171     val assms_net =
   172       assms
   173       |> map (apsnd (rewrite ctxt eqs'))
   174       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
   175       |> Z3_Proof_Tools.thm_net_of snd 
   176 
   177     fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
   178 
   179     fun assume thm ctxt =
   180       let
   181         val ct = Thm.cprem_of thm 1
   182         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
   183       in (Thm.implies_elim thm thm', ctxt') end
   184 
   185     fun add1 idx thm1 ((i, th), exact) ((is, thms), (ctxt, ptab)) =
   186       let
   187         val (thm, ctxt') =
   188           if exact then (Thm.implies_elim thm1 th, ctxt)
   189           else assume thm1 ctxt
   190         val thms' = if exact then thms else th :: thms
   191       in 
   192         ((insert (op =) i is, thms'),
   193           (ctxt', Inttab.update (idx, Thm thm) ptab))
   194       end
   195 
   196     fun add (idx, ct) (cx as ((is, thms), (ctxt, ptab))) =
   197       let
   198         val thm1 = 
   199           Thm.trivial ct
   200           |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
   201         val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
   202       in
   203         (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
   204           [] =>
   205             let val (thm, ctxt') = assume thm1 ctxt
   206             in ((is, thms), (ctxt', Inttab.update (idx, Thm thm) ptab)) end
   207         | ithms => fold (add1 idx thm1) ithms cx)
   208       end
   209   in fold add asserted (([], []), (ctxt, Inttab.empty)) end
   210 
   211 end
   212 
   213 
   214 (* P = Q ==> P ==> Q   or   P --> Q ==> P ==> Q *)
   215 local
   216   val precomp = Z3_Proof_Tools.precompose2
   217   val comp = Z3_Proof_Tools.compose
   218 
   219   val meta_iffD1 = @{lemma "P == Q ==> P ==> (Q::bool)" by simp}
   220   val meta_iffD1_c = precomp Thm.dest_binop meta_iffD1
   221 
   222   val iffD1_c = precomp (Thm.dest_binop o Thm.dest_arg) @{thm iffD1}
   223   val mp_c = precomp (Thm.dest_binop o Thm.dest_arg) @{thm mp}
   224 in
   225 fun mp (MetaEq thm) p = Thm (Thm.implies_elim (comp meta_iffD1_c thm) p)
   226   | mp p_q p = 
   227       let
   228         val pq = thm_of p_q
   229         val thm = comp iffD1_c pq handle THM _ => comp mp_c pq
   230       in Thm (Thm.implies_elim thm p) end
   231 end
   232 
   233 
   234 (* and_elim:     P1 & ... & Pn ==> Pi *)
   235 (* not_or_elim:  ~(P1 | ... | Pn) ==> ~Pi *)
   236 local
   237   fun is_sublit conj t = Z3_Proof_Literals.exists_lit conj (fn u => u aconv t)
   238 
   239   fun derive conj t lits idx ptab =
   240     let
   241       val lit = the (Z3_Proof_Literals.get_first_lit (is_sublit conj t) lits)
   242       val ls = Z3_Proof_Literals.explode conj false false [t] lit
   243       val lits' = fold Z3_Proof_Literals.insert_lit ls
   244         (Z3_Proof_Literals.delete_lit lit lits)
   245 
   246       fun upd thm = Literals (thm_of thm, lits')
   247       val ptab' = Inttab.map_entry idx upd ptab
   248     in (the (Z3_Proof_Literals.lookup_lit lits' t), ptab') end
   249 
   250   fun lit_elim conj (p, idx) ct ptab =
   251     let val lits = literals_of p
   252     in
   253       (case Z3_Proof_Literals.lookup_lit lits (SMT_Utils.term_of ct) of
   254         SOME lit => (Thm lit, ptab)
   255       | NONE => apfst Thm (derive conj (SMT_Utils.term_of ct) lits idx ptab))
   256     end
   257 in
   258 val and_elim = lit_elim true
   259 val not_or_elim = lit_elim false
   260 end
   261 
   262 
   263 (* P1, ..., Pn |- False ==> |- ~P1 | ... | ~Pn *)
   264 local
   265   fun step lit thm =
   266     Thm.implies_elim (Thm.implies_intr (Thm.cprop_of lit) thm) lit
   267   val explode_disj = Z3_Proof_Literals.explode false false false
   268   fun intro hyps thm th = fold step (explode_disj hyps th) thm
   269 
   270   fun dest_ccontr ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg1 ct))]
   271   val ccontr = Z3_Proof_Tools.precompose dest_ccontr @{thm ccontr}
   272 in
   273 fun lemma thm ct =
   274   let
   275     val cu = Z3_Proof_Literals.negate (Thm.dest_arg ct)
   276     val hyps = map_filter (try HOLogic.dest_Trueprop) (Thm.hyps_of thm)
   277     val th = Z3_Proof_Tools.under_assumption (intro hyps thm) cu
   278   in Thm (Z3_Proof_Tools.compose ccontr th) end
   279 end
   280 
   281 
   282 (* \/{P1, ..., Pn, Q1, ..., Qn}, ~P1, ..., ~Pn ==> \/{Q1, ..., Qn} *)
   283 local
   284   val explode_disj = Z3_Proof_Literals.explode false true false
   285   val join_disj = Z3_Proof_Literals.join false
   286   fun unit thm thms th =
   287     let
   288       val t = @{const Not} $ SMT_Utils.prop_of thm
   289       val ts = map SMT_Utils.prop_of thms
   290     in
   291       join_disj (Z3_Proof_Literals.make_littab (thms @ explode_disj ts th)) t
   292     end
   293 
   294   fun dest_arg2 ct = Thm.dest_arg (Thm.dest_arg ct)
   295   fun dest ct = pairself dest_arg2 (Thm.dest_binop ct)
   296   val contrapos =
   297     Z3_Proof_Tools.precompose2 dest @{lemma "(~P ==> ~Q) ==> Q ==> P" by fast}
   298 in
   299 fun unit_resolution thm thms ct =
   300   Z3_Proof_Literals.negate (Thm.dest_arg ct)
   301   |> Z3_Proof_Tools.under_assumption (unit thm thms)
   302   |> Thm o Z3_Proof_Tools.discharge thm o Z3_Proof_Tools.compose contrapos
   303 end
   304 
   305 
   306 (* P ==> P == True   or   P ==> P == False *)
   307 local
   308   val iff1 = @{lemma "P ==> P == (~ False)" by simp}
   309   val iff2 = @{lemma "~P ==> P == False" by simp}
   310 in
   311 fun iff_true thm = MetaEq (thm COMP iff1)
   312 fun iff_false thm = MetaEq (thm COMP iff2)
   313 end
   314 
   315 
   316 (* distributivity of | over & *)
   317 fun distributivity ctxt = Thm o try_apply ctxt [] [
   318   named ctxt "fast" (Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
   319     (* FIXME: not very well tested *)
   320 
   321 
   322 (* Tseitin-like axioms *)
   323 local
   324   val disjI1 = @{lemma "(P ==> Q) ==> ~P | Q" by fast}
   325   val disjI2 = @{lemma "(~P ==> Q) ==> P | Q" by fast}
   326   val disjI3 = @{lemma "(~Q ==> P) ==> P | Q" by fast}
   327   val disjI4 = @{lemma "(Q ==> P) ==> P | ~Q" by fast}
   328 
   329   fun prove' conj1 conj2 ct2 thm =
   330     let
   331       val littab =
   332         Z3_Proof_Literals.explode conj1 true (conj1 <> conj2) [] thm
   333         |> cons Z3_Proof_Literals.true_thm
   334         |> Z3_Proof_Literals.make_littab
   335     in Z3_Proof_Literals.join conj2 littab (Thm.term_of ct2) end
   336 
   337   fun prove rule (ct1, conj1) (ct2, conj2) =
   338     Z3_Proof_Tools.under_assumption (prove' conj1 conj2 ct2) ct1 COMP rule
   339 
   340   fun prove_def_axiom ct =
   341     let val (ct1, ct2) = Thm.dest_binop (Thm.dest_arg ct)
   342     in
   343       (case Thm.term_of ct1 of
   344         @{const Not} $ (@{const HOL.conj} $ _ $ _) =>
   345           prove disjI1 (Thm.dest_arg ct1, true) (ct2, true)
   346       | @{const HOL.conj} $ _ $ _ =>
   347           prove disjI3 (Z3_Proof_Literals.negate ct2, false) (ct1, true)
   348       | @{const Not} $ (@{const HOL.disj} $ _ $ _) =>
   349           prove disjI3 (Z3_Proof_Literals.negate ct2, false) (ct1, false)
   350       | @{const HOL.disj} $ _ $ _ =>
   351           prove disjI2 (Z3_Proof_Literals.negate ct1, false) (ct2, true)
   352       | Const (@{const_name distinct}, _) $ _ =>
   353           let
   354             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv cv)
   355             val unfold_dis_conv = dis_conv Z3_Proof_Tools.unfold_distinct_conv
   356             fun prv cu =
   357               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
   358               in prove disjI4 (Thm.dest_arg cu2, true) (cu1, true) end
   359           in Z3_Proof_Tools.with_conv unfold_dis_conv prv ct end
   360       | @{const Not} $ (Const (@{const_name distinct}, _) $ _) =>
   361           let
   362             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv (Conv.arg_conv cv))
   363             val unfold_dis_conv = dis_conv Z3_Proof_Tools.unfold_distinct_conv
   364             fun prv cu =
   365               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
   366               in prove disjI1 (Thm.dest_arg cu1, true) (cu2, true) end
   367           in Z3_Proof_Tools.with_conv unfold_dis_conv prv ct end
   368       | _ => raise CTERM ("prove_def_axiom", [ct]))
   369     end
   370 in
   371 fun def_axiom ctxt = Thm o try_apply ctxt [] [
   372   named ctxt "conj/disj/distinct" prove_def_axiom,
   373   Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
   374     named ctxt' "simp+fast" (Z3_Proof_Tools.by_tac ctxt (simp_fast_tac ctxt')))]
   375 end
   376 
   377 
   378 (* local definitions *)
   379 local
   380   val intro_rules = [
   381     @{lemma "n == P ==> (~n | P) & (n | ~P)" by simp},
   382     @{lemma "n == (if P then s else t) ==> (~P | n = s) & (P | n = t)"
   383       by simp},
   384     @{lemma "n == P ==> n = P" by (rule meta_eq_to_obj_eq)} ]
   385 
   386   val apply_rules = [
   387     @{lemma "(~n | P) & (n | ~P) ==> P == n" by (atomize(full)) fast},
   388     @{lemma "(~P | n = s) & (P | n = t) ==> (if P then s else t) == n"
   389       by (atomize(full)) fastforce} ]
   390 
   391   val inst_rule = Z3_Proof_Tools.match_instantiate Thm.dest_arg
   392 
   393   fun apply_rule ct =
   394     (case get_first (try (inst_rule ct)) intro_rules of
   395       SOME thm => thm
   396     | NONE => raise CTERM ("intro_def", [ct]))
   397 in
   398 fun intro_def ct = Z3_Proof_Tools.make_hyp_def (apply_rule ct) #>> Thm
   399 
   400 fun apply_def thm =
   401   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
   402   |> the_default (Thm thm)
   403 end
   404 
   405 
   406 (* negation normal form *)
   407 local
   408   val quant_rules1 = ([
   409     @{lemma "(!!x. P x == Q) ==> ALL x. P x == Q" by simp},
   410     @{lemma "(!!x. P x == Q) ==> EX x. P x == Q" by simp}], [
   411     @{lemma "(!!x. P x == Q x) ==> ALL x. P x == ALL x. Q x" by simp},
   412     @{lemma "(!!x. P x == Q x) ==> EX x. P x == EX x. Q x" by simp}])
   413 
   414   val quant_rules2 = ([
   415     @{lemma "(!!x. ~P x == Q) ==> ~(ALL x. P x) == Q" by simp},
   416     @{lemma "(!!x. ~P x == Q) ==> ~(EX x. P x) == Q" by simp}], [
   417     @{lemma "(!!x. ~P x == Q x) ==> ~(ALL x. P x) == EX x. Q x" by simp},
   418     @{lemma "(!!x. ~P x == Q x) ==> ~(EX x. P x) == ALL x. Q x" by simp}])
   419 
   420   fun nnf_quant_tac thm (qs as (qs1, qs2)) i st = (
   421     rtac thm ORELSE'
   422     (match_tac qs1 THEN' nnf_quant_tac thm qs) ORELSE'
   423     (match_tac qs2 THEN' nnf_quant_tac thm qs)) i st
   424 
   425   fun nnf_quant_tac_varified vars eq =
   426     nnf_quant_tac (Z3_Proof_Tools.varify vars eq)
   427 
   428   fun nnf_quant ctxt vars qs p ct =
   429     Z3_Proof_Tools.as_meta_eq ct
   430     |> Z3_Proof_Tools.by_tac ctxt (nnf_quant_tac_varified vars (meta_eq_of p) qs)
   431 
   432   fun prove_nnf ctxt = try_apply ctxt [] [
   433     named ctxt "conj/disj" Z3_Proof_Literals.prove_conj_disj_eq,
   434     Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
   435       named ctxt' "simp+fast" (Z3_Proof_Tools.by_tac ctxt' (simp_fast_tac ctxt')))]
   436 in
   437 fun nnf ctxt vars ps ct =
   438   (case SMT_Utils.term_of ct of
   439     _ $ (l as Const _ $ Abs _) $ (r as Const _ $ Abs _) =>
   440       if l aconv r
   441       then MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
   442       else MetaEq (nnf_quant ctxt vars quant_rules1 (hd ps) ct)
   443   | _ $ (@{const Not} $ (Const _ $ Abs _)) $ (Const _ $ Abs _) =>
   444       MetaEq (nnf_quant ctxt vars quant_rules2 (hd ps) ct)
   445   | _ =>
   446       let
   447         val nnf_rewr_conv = Conv.arg_conv (Conv.arg_conv
   448           (Z3_Proof_Tools.unfold_eqs ctxt
   449             (map (Thm.symmetric o meta_eq_of) ps)))
   450       in Thm (Z3_Proof_Tools.with_conv nnf_rewr_conv (prove_nnf ctxt) ct) end)
   451 end
   452 
   453 
   454 
   455 (** equality proof rules **)
   456 
   457 (* |- t = t *)
   458 fun refl ct = MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
   459 
   460 
   461 (* s = t ==> t = s *)
   462 local
   463   val symm_rule = @{lemma "s = t ==> t == s" by simp}
   464 in
   465 fun symm (MetaEq thm) = MetaEq (Thm.symmetric thm)
   466   | symm p = MetaEq (thm_of p COMP symm_rule)
   467 end
   468 
   469 
   470 (* s = t ==> t = u ==> s = u *)
   471 local
   472   val trans1 = @{lemma "s == t ==> t =  u ==> s == u" by simp}
   473   val trans2 = @{lemma "s =  t ==> t == u ==> s == u" by simp}
   474   val trans3 = @{lemma "s =  t ==> t =  u ==> s == u" by simp}
   475 in
   476 fun trans (MetaEq thm1) (MetaEq thm2) = MetaEq (Thm.transitive thm1 thm2)
   477   | trans (MetaEq thm) q = MetaEq (thm_of q COMP (thm COMP trans1))
   478   | trans p (MetaEq thm) = MetaEq (thm COMP (thm_of p COMP trans2))
   479   | trans p q = MetaEq (thm_of q COMP (thm_of p COMP trans3))
   480 end
   481 
   482 
   483 (* t1 = s1 ==> ... ==> tn = sn ==> f t1 ... tn = f s1 .. sn
   484    (reflexive antecendents are droppped) *)
   485 local
   486   exception MONO
   487 
   488   fun prove_refl (ct, _) = Thm.reflexive ct
   489   fun prove_comb f g cp =
   490     let val ((ct1, ct2), (cu1, cu2)) = pairself Thm.dest_comb cp
   491     in Thm.combination (f (ct1, cu1)) (g (ct2, cu2)) end
   492   fun prove_arg f = prove_comb prove_refl f
   493 
   494   fun prove f cp = prove_comb (prove f) f cp handle CTERM _ => prove_refl cp
   495 
   496   fun prove_nary is_comb f =
   497     let
   498       fun prove (cp as (ct, _)) = f cp handle MONO =>
   499         if is_comb (Thm.term_of ct)
   500         then prove_comb (prove_arg prove) prove cp
   501         else prove_refl cp
   502     in prove end
   503 
   504   fun prove_list f n cp =
   505     if n = 0 then prove_refl cp
   506     else prove_comb (prove_arg f) (prove_list f (n-1)) cp
   507 
   508   fun with_length f (cp as (cl, _)) =
   509     f (length (HOLogic.dest_list (Thm.term_of cl))) cp
   510 
   511   fun prove_distinct f = prove_arg (with_length (prove_list f))
   512 
   513   fun prove_eq exn lookup cp =
   514     (case lookup (Logic.mk_equals (pairself Thm.term_of cp)) of
   515       SOME eq => eq
   516     | NONE => if exn then raise MONO else prove_refl cp)
   517   
   518   val prove_exn = prove_eq true
   519   and prove_safe = prove_eq false
   520 
   521   fun mono f (cp as (cl, _)) =
   522     (case Term.head_of (Thm.term_of cl) of
   523       @{const HOL.conj} => prove_nary Z3_Proof_Literals.is_conj (prove_exn f)
   524     | @{const HOL.disj} => prove_nary Z3_Proof_Literals.is_disj (prove_exn f)
   525     | Const (@{const_name distinct}, _) => prove_distinct (prove_safe f)
   526     | _ => prove (prove_safe f)) cp
   527 in
   528 fun monotonicity eqs ct =
   529   let
   530     fun and_symmetric (t, thm) = [(t, thm), (t, Thm.symmetric thm)]
   531     val teqs = maps (and_symmetric o `Thm.prop_of o meta_eq_of) eqs
   532     val lookup = AList.lookup (op aconv) teqs
   533     val cp = Thm.dest_binop (Thm.dest_arg ct)
   534   in MetaEq (prove_exn lookup cp handle MONO => mono lookup cp) end
   535 end
   536 
   537 
   538 (* |- f a b = f b a (where f is equality) *)
   539 local
   540   val rule = @{lemma "a = b == b = a" by (atomize(full)) (rule eq_commute)}
   541 in
   542 fun commutativity ct =
   543   MetaEq (Z3_Proof_Tools.match_instantiate I
   544     (Z3_Proof_Tools.as_meta_eq ct) rule)
   545 end
   546 
   547 
   548 
   549 (** quantifier proof rules **)
   550 
   551 (* P ?x = Q ?x ==> (ALL x. P x) = (ALL x. Q x)
   552    P ?x = Q ?x ==> (EX x. P x) = (EX x. Q x)    *)
   553 local
   554   val rules = [
   555     @{lemma "(!!x. P x == Q x) ==> (ALL x. P x) == (ALL x. Q x)" by simp},
   556     @{lemma "(!!x. P x == Q x) ==> (EX x. P x) == (EX x. Q x)" by simp}]
   557 in
   558 fun quant_intro ctxt vars p ct =
   559   let
   560     val thm = meta_eq_of p
   561     val rules' = Z3_Proof_Tools.varify vars thm :: rules
   562     val cu = Z3_Proof_Tools.as_meta_eq ct
   563     val tac = REPEAT_ALL_NEW (match_tac rules')
   564   in MetaEq (Z3_Proof_Tools.by_tac ctxt tac cu) end
   565 end
   566 
   567 
   568 (* |- ((ALL x. P x) | Q) = (ALL x. P x | Q) *)
   569 fun pull_quant ctxt = Thm o try_apply ctxt [] [
   570   named ctxt "fast" (Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
   571     (* FIXME: not very well tested *)
   572 
   573 
   574 (* |- (ALL x. P x & Q x) = ((ALL x. P x) & (ALL x. Q x)) *)
   575 fun push_quant ctxt = Thm o try_apply ctxt [] [
   576   named ctxt "fast" (Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
   577     (* FIXME: not very well tested *)
   578 
   579 
   580 (* |- (ALL x1 ... xn y1 ... yn. P x1 ... xn) = (ALL x1 ... xn. P x1 ... xn) *)
   581 local
   582   val elim_all = @{lemma "P = Q ==> (ALL x. P) = Q" by fast}
   583   val elim_ex = @{lemma "P = Q ==> (EX x. P) = Q" by fast}
   584 
   585   fun elim_unused_tac i st = (
   586     match_tac [@{thm refl}]
   587     ORELSE' (match_tac [elim_all, elim_ex] THEN' elim_unused_tac)
   588     ORELSE' (
   589       match_tac [@{thm iff_allI}, @{thm iff_exI}]
   590       THEN' elim_unused_tac)) i st
   591 in
   592 
   593 fun elim_unused_vars ctxt = Thm o Z3_Proof_Tools.by_tac ctxt elim_unused_tac
   594 
   595 end
   596 
   597 
   598 (* |- (ALL x1 ... xn. ~(x1 = t1 & ... xn = tn) | P x1 ... xn) = P t1 ... tn *)
   599 fun dest_eq_res ctxt = Thm o try_apply ctxt [] [
   600   named ctxt "fast" (Z3_Proof_Tools.by_tac ctxt (HOL_fast_tac ctxt))]
   601     (* FIXME: not very well tested *)
   602 
   603 
   604 (* |- ~(ALL x1...xn. P x1...xn) | P a1...an *)
   605 local
   606   val rule = @{lemma "~ P x | Q ==> ~(ALL x. P x) | Q" by fast}
   607 in
   608 fun quant_inst ctxt = Thm o Z3_Proof_Tools.by_tac ctxt (
   609   REPEAT_ALL_NEW (match_tac [rule])
   610   THEN' rtac @{thm excluded_middle})
   611 end
   612 
   613 
   614 (* |- (EX x. P x) = P c     |- ~(ALL x. P x) = ~ P c *)
   615 local
   616   val forall =
   617     SMT_Utils.mk_const_pat @{theory} @{const_name Pure.all}
   618       (SMT_Utils.destT1 o SMT_Utils.destT1)
   619   fun mk_forall cv ct =
   620     Thm.apply (SMT_Utils.instT' cv forall) (Thm.lambda cv ct)
   621 
   622   fun get_vars f mk pred ctxt t =
   623     Term.fold_aterms f t []
   624     |> map_filter (fn v =>
   625          if pred v then SOME (SMT_Utils.certify ctxt (mk v)) else NONE)
   626 
   627   fun close vars f ct ctxt =
   628     let
   629       val frees_of = get_vars Term.add_frees Free (member (op =) vars o fst)
   630       val vs = frees_of ctxt (Thm.term_of ct)
   631       val (thm, ctxt') = f (fold_rev mk_forall vs ct) ctxt
   632       val vars_of = get_vars Term.add_vars Var (K true) ctxt'
   633     in (Thm.instantiate ([], vars_of (Thm.prop_of thm) ~~ vs) thm, ctxt') end
   634 
   635   val sk_rules = @{lemma
   636     "c = (SOME x. P x) ==> (EX x. P x) = P c"
   637     "c = (SOME x. ~P x) ==> (~(ALL x. P x)) = (~P c)"
   638     by (metis someI_ex)+}
   639 in
   640 
   641 fun skolemize vars =
   642   apfst Thm oo close vars (yield_singleton Assumption.add_assumes)
   643 
   644 fun discharge_sk_tac i st = (
   645   rtac @{thm trans} i
   646   THEN resolve_tac sk_rules i
   647   THEN (rtac @{thm refl} ORELSE' discharge_sk_tac) (i+1)
   648   THEN rtac @{thm refl} i) st
   649 
   650 end
   651 
   652 
   653 
   654 (** theory proof rules **)
   655 
   656 (* theory lemmas: linear arithmetic, arrays *)
   657 fun th_lemma ctxt simpset thms = Thm o try_apply ctxt thms [
   658   Z3_Proof_Tools.by_abstraction 0 (false, true) ctxt thms (fn ctxt' =>
   659     Z3_Proof_Tools.by_tac ctxt' (
   660       NAMED ctxt' "arith" (Arith_Data.arith_tac ctxt')
   661       ORELSE' NAMED ctxt' "simp+arith" (
   662         Simplifier.asm_full_simp_tac (put_simpset simpset ctxt')
   663         THEN_ALL_NEW Arith_Data.arith_tac ctxt')))]
   664 
   665 
   666 (* rewriting: prove equalities:
   667      * ACI of conjunction/disjunction
   668      * contradiction, excluded middle
   669      * logical rewriting rules (for negation, implication, equivalence,
   670          distinct)
   671      * normal forms for polynoms (integer/real arithmetic)
   672      * quantifier elimination over linear arithmetic
   673      * ... ? **)
   674 structure Z3_Simps = Named_Thms
   675 (
   676   val name = @{binding z3_simp}
   677   val description = "simplification rules for Z3 proof reconstruction"
   678 )
   679 
   680 local
   681   fun spec_meta_eq_of thm =
   682     (case try (fn th => th RS @{thm spec}) thm of
   683       SOME thm' => spec_meta_eq_of thm'
   684     | NONE => mk_meta_eq thm)
   685 
   686   fun prep (Thm thm) = spec_meta_eq_of thm
   687     | prep (MetaEq thm) = thm
   688     | prep (Literals (thm, _)) = spec_meta_eq_of thm
   689 
   690   fun unfold_conv ctxt ths =
   691     Conv.arg_conv (Conv.binop_conv (Z3_Proof_Tools.unfold_eqs ctxt
   692       (map prep ths)))
   693 
   694   fun with_conv _ [] prv = prv
   695     | with_conv ctxt ths prv =
   696         Z3_Proof_Tools.with_conv (unfold_conv ctxt ths) prv
   697 
   698   val unfold_conv =
   699     Conv.arg_conv (Conv.binop_conv
   700       (Conv.try_conv Z3_Proof_Tools.unfold_distinct_conv))
   701   val prove_conj_disj_eq =
   702     Z3_Proof_Tools.with_conv unfold_conv Z3_Proof_Literals.prove_conj_disj_eq
   703 
   704   fun declare_hyps ctxt thm =
   705     (thm, snd (Assumption.add_assumes (#hyps (Thm.crep_thm thm)) ctxt))
   706 in
   707 
   708 val abstraction_depth = 3
   709   (*
   710     This value was chosen large enough to potentially catch exceptions,
   711     yet small enough to not cause too much harm.  The value might be
   712     increased in the future, if reconstructing 'rewrite' fails on problems
   713     that get too much abstracted to be reconstructable.
   714   *)
   715 
   716 fun rewrite simpset ths ct ctxt =
   717   apfst Thm (declare_hyps ctxt (with_conv ctxt ths (try_apply ctxt [] [
   718     named ctxt "conj/disj/distinct" prove_conj_disj_eq,
   719     named ctxt "pull-ite" Z3_Proof_Methods.prove_ite ctxt,
   720     Z3_Proof_Tools.by_abstraction 0 (true, false) ctxt [] (fn ctxt' =>
   721       Z3_Proof_Tools.by_tac ctxt' (
   722         NAMED ctxt' "simp (logic)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
   723         THEN_ALL_NEW NAMED ctxt' "fast (logic)" (fast_tac ctxt'))),
   724     Z3_Proof_Tools.by_abstraction 0 (false, true) ctxt [] (fn ctxt' =>
   725       Z3_Proof_Tools.by_tac ctxt' (
   726         (rtac @{thm iff_allI} ORELSE' K all_tac)
   727         THEN' NAMED ctxt' "simp (theory)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
   728         THEN_ALL_NEW (
   729           NAMED ctxt' "fast (theory)" (HOL_fast_tac ctxt')
   730           ORELSE' NAMED ctxt' "arith (theory)" (Arith_Data.arith_tac ctxt')))),
   731     Z3_Proof_Tools.by_abstraction 0 (true, true) ctxt [] (fn ctxt' =>
   732       Z3_Proof_Tools.by_tac ctxt' (
   733         (rtac @{thm iff_allI} ORELSE' K all_tac)
   734         THEN' NAMED ctxt' "simp (full)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
   735         THEN_ALL_NEW (
   736           NAMED ctxt' "fast (full)" (HOL_fast_tac ctxt')
   737           ORELSE' NAMED ctxt' "arith (full)" (Arith_Data.arith_tac ctxt')))),
   738     named ctxt "injectivity" (Z3_Proof_Methods.prove_injectivity ctxt),
   739     Z3_Proof_Tools.by_abstraction abstraction_depth (true, true) ctxt []
   740       (fn ctxt' =>
   741         Z3_Proof_Tools.by_tac ctxt' (
   742           (rtac @{thm iff_allI} ORELSE' K all_tac)
   743           THEN' NAMED ctxt' "simp (deepen)" (Simplifier.simp_tac (put_simpset simpset ctxt'))
   744           THEN_ALL_NEW (
   745             NAMED ctxt' "fast (deepen)" (HOL_fast_tac ctxt')
   746             ORELSE' NAMED ctxt' "arith (deepen)" (Arith_Data.arith_tac
   747               ctxt'))))]) ct))
   748 
   749 end
   750 
   751 
   752 
   753 (* proof reconstruction *)
   754 
   755 (** tracing and checking **)
   756 
   757 fun trace_before ctxt idx = SMT_Config.trace_msg ctxt (fn r =>
   758   "Z3: #" ^ string_of_int idx ^ ": " ^ Z3_Proof_Parser.string_of_rule r)
   759 
   760 fun check_after idx r ps ct (p, (ctxt, _)) =
   761   if not (Config.get ctxt SMT_Config.trace) then ()
   762   else
   763     let val thm = thm_of p |> tap (Thm.join_proofs o single)
   764     in
   765       if (Thm.cprop_of thm) aconvc ct then ()
   766       else
   767         z3_exn (Pretty.string_of (Pretty.big_list
   768           ("proof step failed: " ^ quote (Z3_Proof_Parser.string_of_rule r) ^
   769             " (#" ^ string_of_int idx ^ ")")
   770           (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
   771             [Pretty.block [Pretty.str "expected: ",
   772               Syntax.pretty_term ctxt (Thm.term_of ct)]])))
   773     end
   774 
   775 
   776 (** overall reconstruction procedure **)
   777 
   778 local
   779   fun not_supported r = raise Fail ("Z3: proof rule not implemented: " ^
   780     quote (Z3_Proof_Parser.string_of_rule r))
   781 
   782   fun prove_step simpset vars r ps ct (cxp as (cx, ptab)) =
   783     (case (r, ps) of
   784       (* core rules *)
   785       (Z3_Proof_Parser.True_Axiom, _) => (Thm Z3_Proof_Literals.true_thm, cxp)
   786     | (Z3_Proof_Parser.Asserted, _) => raise Fail "bad assertion"
   787     | (Z3_Proof_Parser.Goal, _) => raise Fail "bad assertion"
   788     | (Z3_Proof_Parser.Modus_Ponens, [(p, _), (q, _)]) =>
   789         (mp q (thm_of p), cxp)
   790     | (Z3_Proof_Parser.Modus_Ponens_Oeq, [(p, _), (q, _)]) =>
   791         (mp q (thm_of p), cxp)
   792     | (Z3_Proof_Parser.And_Elim, [(p, i)]) =>
   793         and_elim (p, i) ct ptab ||> pair cx
   794     | (Z3_Proof_Parser.Not_Or_Elim, [(p, i)]) =>
   795         not_or_elim (p, i) ct ptab ||> pair cx
   796     | (Z3_Proof_Parser.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
   797     | (Z3_Proof_Parser.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
   798     | (Z3_Proof_Parser.Unit_Resolution, (p, _) :: ps) =>
   799         (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
   800     | (Z3_Proof_Parser.Iff_True, [(p, _)]) => (iff_true (thm_of p), cxp)
   801     | (Z3_Proof_Parser.Iff_False, [(p, _)]) => (iff_false (thm_of p), cxp)
   802     | (Z3_Proof_Parser.Distributivity, _) => (distributivity cx ct, cxp)
   803     | (Z3_Proof_Parser.Def_Axiom, _) => (def_axiom cx ct, cxp)
   804     | (Z3_Proof_Parser.Intro_Def, _) => intro_def ct cx ||> rpair ptab
   805     | (Z3_Proof_Parser.Apply_Def, [(p, _)]) => (apply_def (thm_of p), cxp)
   806     | (Z3_Proof_Parser.Iff_Oeq, [(p, _)]) => (p, cxp)
   807     | (Z3_Proof_Parser.Nnf_Pos, _) => (nnf cx vars (map fst ps) ct, cxp)
   808     | (Z3_Proof_Parser.Nnf_Neg, _) => (nnf cx vars (map fst ps) ct, cxp)
   809 
   810       (* equality rules *)
   811     | (Z3_Proof_Parser.Reflexivity, _) => (refl ct, cxp)
   812     | (Z3_Proof_Parser.Symmetry, [(p, _)]) => (symm p, cxp)
   813     | (Z3_Proof_Parser.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
   814     | (Z3_Proof_Parser.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
   815     | (Z3_Proof_Parser.Commutativity, _) => (commutativity ct, cxp)
   816 
   817       (* quantifier rules *)
   818     | (Z3_Proof_Parser.Quant_Intro, [(p, _)]) => (quant_intro cx vars p ct, cxp)
   819     | (Z3_Proof_Parser.Pull_Quant, _) => (pull_quant cx ct, cxp)
   820     | (Z3_Proof_Parser.Push_Quant, _) => (push_quant cx ct, cxp)
   821     | (Z3_Proof_Parser.Elim_Unused_Vars, _) => (elim_unused_vars cx ct, cxp)
   822     | (Z3_Proof_Parser.Dest_Eq_Res, _) => (dest_eq_res cx ct, cxp)
   823     | (Z3_Proof_Parser.Quant_Inst, _) => (quant_inst cx ct, cxp)
   824     | (Z3_Proof_Parser.Skolemize, _) => skolemize vars ct cx ||> rpair ptab
   825 
   826       (* theory rules *)
   827     | (Z3_Proof_Parser.Th_Lemma _, _) =>  (* FIXME: use arguments *)
   828         (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
   829     | (Z3_Proof_Parser.Rewrite, _) => rewrite simpset [] ct cx ||> rpair ptab
   830     | (Z3_Proof_Parser.Rewrite_Star, ps) =>
   831         rewrite simpset (map fst ps) ct cx ||> rpair ptab
   832 
   833     | (Z3_Proof_Parser.Nnf_Star, _) => not_supported r
   834     | (Z3_Proof_Parser.Cnf_Star, _) => not_supported r
   835     | (Z3_Proof_Parser.Transitivity_Star, _) => not_supported r
   836     | (Z3_Proof_Parser.Pull_Quant_Star, _) => not_supported r
   837 
   838     | _ => raise Fail ("Z3: proof rule " ^
   839         quote (Z3_Proof_Parser.string_of_rule r) ^
   840         " has an unexpected number of arguments."))
   841 
   842   fun lookup_proof ptab idx =
   843     (case Inttab.lookup ptab idx of
   844       SOME p => (p, idx)
   845     | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
   846 
   847   fun prove simpset vars (idx, step) (_, cxp as (ctxt, ptab)) =
   848     let
   849       val Z3_Proof_Parser.Proof_Step {rule=r, prems, prop, ...} = step
   850       val ps = map (lookup_proof ptab) prems
   851       val _ = trace_before ctxt idx r
   852       val (thm, (ctxt', ptab')) =
   853         cxp
   854         |> prove_step simpset vars r ps prop
   855         |> tap (check_after idx r ps prop)
   856     in (thm, (ctxt', Inttab.update (idx, thm) ptab')) end
   857 
   858   fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl},
   859     @{thm reflexive}, Z3_Proof_Literals.true_thm]
   860 
   861   fun discharge_assms_tac rules =
   862     REPEAT (HEADGOAL (resolve_tac rules ORELSE' SOLVED' discharge_sk_tac))
   863     
   864   fun discharge_assms ctxt rules thm =
   865     if Thm.nprems_of thm = 0 then Goal.norm_result ctxt thm
   866     else
   867       (case Seq.pull (discharge_assms_tac rules thm) of
   868         SOME (thm', _) => Goal.norm_result ctxt thm'
   869       | NONE => raise THM ("failed to discharge premise", 1, [thm]))
   870 
   871   fun discharge rules outer_ctxt (p, (inner_ctxt, _)) =
   872     thm_of p
   873     |> singleton (Proof_Context.export inner_ctxt outer_ctxt)
   874     |> discharge_assms outer_ctxt (make_discharge_rules rules)
   875 in
   876 
   877 fun reconstruct outer_ctxt recon output =
   878   let
   879     val {context=ctxt, typs, terms, rewrite_rules, assms} = recon
   880     val (asserted, steps, vars, ctxt1) =
   881       Z3_Proof_Parser.parse ctxt typs terms output
   882 
   883     val simpset = Z3_Proof_Tools.make_simpset ctxt1 (Z3_Simps.get ctxt1)
   884 
   885     val ((is, rules), cxp as (ctxt2, _)) =
   886       add_asserted outer_ctxt rewrite_rules assms asserted ctxt1
   887   in
   888     if Config.get ctxt2 SMT_Config.filter_only_facts then (is, @{thm TrueI})
   889     else
   890       (Thm @{thm TrueI}, cxp)
   891       |> fold (prove simpset vars) steps 
   892       |> discharge rules outer_ctxt
   893       |> pair []
   894   end
   895 
   896 end
   897 
   898 val setup = z3_rules_setup #> Z3_Simps.setup
   899 
   900 end