src/HOL/Tools/SMT/z3_proof_reconstruction.ML
author boehmes
Wed Nov 24 15:33:35 2010 +0100 (2010-11-24)
changeset 40686 4725ed462387
parent 40681 872b08416fb4
child 41127 2ea84c8535c6
permissions -rw-r--r--
swap names for built-in tester functions (to better reflect the intuition of what they do);
eta-expand all built-in functions (even those which are only partially supported)
     1 (*  Title:      HOL/Tools/SMT/z3_proof_reconstruction.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Proof reconstruction for proofs found by Z3.
     5 *)
     6 
     7 signature Z3_PROOF_RECONSTRUCTION =
     8 sig
     9   val add_z3_rule: thm -> Context.generic -> Context.generic
    10   val reconstruct: Proof.context -> SMT_Translate.recon -> string list ->
    11     (int list * thm) * Proof.context
    12   val setup: theory -> theory
    13 end
    14 
    15 structure Z3_Proof_Reconstruction: Z3_PROOF_RECONSTRUCTION =
    16 struct
    17 
    18 structure P = Z3_Proof_Parser
    19 structure T = Z3_Proof_Tools
    20 structure L = Z3_Proof_Literals
    21 structure M = Z3_Proof_Methods
    22 
    23 fun z3_exn msg = raise SMT_Failure.SMT (SMT_Failure.Other_Failure
    24   ("Z3 proof reconstruction: " ^ msg))
    25 
    26 
    27 
    28 (** net of schematic rules **)
    29 
    30 val z3_ruleN = "z3_rule"
    31 
    32 local
    33   val description = "declaration of Z3 proof rules"
    34 
    35   val eq = Thm.eq_thm
    36 
    37   structure Z3_Rules = Generic_Data
    38   (
    39     type T = thm Net.net
    40     val empty = Net.empty
    41     val extend = I
    42     val merge = Net.merge eq
    43   )
    44 
    45   val prep = `Thm.prop_of o Simplifier.rewrite_rule [L.rewrite_true]
    46 
    47   fun ins thm net = Net.insert_term eq (prep thm) net handle Net.INSERT => net
    48   fun del thm net = Net.delete_term eq (prep thm) net handle Net.DELETE => net
    49 
    50   val add = Thm.declaration_attribute (Z3_Rules.map o ins)
    51   val del = Thm.declaration_attribute (Z3_Rules.map o del)
    52 in
    53 
    54 val add_z3_rule = Z3_Rules.map o ins
    55 
    56 fun by_schematic_rule ctxt ct =
    57   the (T.net_instance (Z3_Rules.get (Context.Proof ctxt)) ct)
    58 
    59 val z3_rules_setup =
    60   Attrib.setup (Binding.name z3_ruleN) (Attrib.add_del add del) description #>
    61   Global_Theory.add_thms_dynamic (Binding.name z3_ruleN, Net.content o Z3_Rules.get)
    62 
    63 end
    64 
    65 
    66 
    67 (** proof tools **)
    68 
    69 fun named ctxt name prover ct =
    70   let val _ = SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
    71   in prover ct end
    72 
    73 fun NAMED ctxt name tac i st =
    74   let val _ = SMT_Config.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
    75   in tac i st end
    76 
    77 fun pretty_goal ctxt thms t =
    78   [Pretty.block [Pretty.str "proposition: ", Syntax.pretty_term ctxt t]]
    79   |> not (null thms) ? cons (Pretty.big_list "assumptions:"
    80        (map (Display.pretty_thm ctxt) thms))
    81 
    82 fun try_apply ctxt thms =
    83   let
    84     fun try_apply_err ct = Pretty.string_of (Pretty.chunks [
    85       Pretty.big_list ("Z3 found a proof," ^
    86         " but proof reconstruction failed at the following subgoal:")
    87         (pretty_goal ctxt thms (Thm.term_of ct)),
    88       Pretty.str ("Adding a rule to the lemma group " ^ quote z3_ruleN ^
    89         " might solve this problem.")])
    90 
    91     fun apply [] ct = error (try_apply_err ct)
    92       | apply (prover :: provers) ct =
    93           (case try prover ct of
    94             SOME thm => (SMT_Config.trace_msg ctxt I "Z3: succeeded"; thm)
    95           | NONE => apply provers ct)
    96 
    97   in apply o cons (named ctxt "schematic rules" (by_schematic_rule ctxt)) end
    98 
    99 local
   100   val rewr_if =
   101     @{lemma "(if P then Q1 else Q2) = ((P --> Q1) & (~P --> Q2))" by simp}
   102 in
   103 val simp_fast_tac =
   104   Simplifier.simp_tac (HOL_ss addsimps [rewr_if])
   105   THEN_ALL_NEW Classical.fast_tac HOL_cs
   106 end
   107 
   108 
   109 
   110 (** theorems and proofs **)
   111 
   112 (* theorem incarnations *)
   113 
   114 datatype theorem =
   115   Thm of thm | (* theorem without special features *)
   116   MetaEq of thm | (* meta equality "t == s" *)
   117   Literals of thm * L.littab
   118     (* "P1 & ... & Pn" and table of all literals P1, ..., Pn *)
   119 
   120 fun thm_of (Thm thm) = thm
   121   | thm_of (MetaEq thm) = thm COMP @{thm meta_eq_to_obj_eq}
   122   | thm_of (Literals (thm, _)) = thm
   123 
   124 fun meta_eq_of (MetaEq thm) = thm
   125   | meta_eq_of p = mk_meta_eq (thm_of p)
   126 
   127 fun literals_of (Literals (_, lits)) = lits
   128   | literals_of p = L.make_littab [thm_of p]
   129 
   130 
   131 (* proof representation *)
   132 
   133 datatype proof = Unproved of P.proof_step | Proved of theorem
   134 
   135 
   136 
   137 (** core proof rules **)
   138 
   139 (* assumption *)
   140 
   141 local
   142   val remove_trigger = @{lemma "trigger t p == p"
   143     by (rule eq_reflection, rule trigger_def)}
   144 
   145   val remove_weight = @{lemma "weight w p == p"
   146     by (rule eq_reflection, rule weight_def)}
   147 
   148   val prep_rules = [@{thm Let_def}, remove_trigger, remove_weight,
   149     L.rewrite_true]
   150 
   151   fun rewrite_conv ctxt eqs = Simplifier.full_rewrite
   152     (Simplifier.context ctxt Simplifier.empty_ss addsimps eqs)
   153 
   154   fun rewrites f ctxt eqs = map (f (Conv.fconv_rule (rewrite_conv ctxt eqs)))
   155 
   156   fun burrow_snd_option f (i, thm) = Option.map (pair i) (f thm)
   157   fun lookup_assm ctxt assms ct =
   158     (case T.net_instance' burrow_snd_option assms ct of
   159       SOME ithm => ithm
   160     | _ => z3_exn ("not asserted: " ^
   161         quote (Syntax.string_of_term ctxt (Thm.term_of ct))))
   162 in
   163 fun prepare_assms ctxt unfolds assms =
   164   let
   165     val unfolds' = rewrites I ctxt [L.rewrite_true] unfolds
   166     val assms' =
   167       rewrites apsnd ctxt (union Thm.eq_thm unfolds' prep_rules) assms
   168   in (unfolds', T.thm_net_of snd assms') end
   169 
   170 fun asserted ctxt (unfolds, assms) ct =
   171   let val revert_conv = rewrite_conv ctxt unfolds
   172   in Thm (T.with_conv revert_conv (snd o lookup_assm ctxt assms) ct) end
   173 
   174 fun find_assm ctxt (unfolds, assms) ct =
   175   fst (lookup_assm ctxt assms (Thm.rhs_of (rewrite_conv ctxt unfolds ct)))
   176 end
   177 
   178 
   179 
   180 (* P = Q ==> P ==> Q   or   P --> Q ==> P ==> Q *)
   181 local
   182   val meta_iffD1 = @{lemma "P == Q ==> P ==> (Q::bool)" by simp}
   183   val meta_iffD1_c = T.precompose2 Thm.dest_binop meta_iffD1
   184 
   185   val iffD1_c = T.precompose2 (Thm.dest_binop o Thm.dest_arg) @{thm iffD1}
   186   val mp_c = T.precompose2 (Thm.dest_binop o Thm.dest_arg) @{thm mp}
   187 in
   188 fun mp (MetaEq thm) p = Thm (Thm.implies_elim (T.compose meta_iffD1_c thm) p)
   189   | mp p_q p = 
   190       let
   191         val pq = thm_of p_q
   192         val thm = T.compose iffD1_c pq handle THM _ => T.compose mp_c pq
   193       in Thm (Thm.implies_elim thm p) end
   194 end
   195 
   196 
   197 
   198 (* and_elim:     P1 & ... & Pn ==> Pi *)
   199 (* not_or_elim:  ~(P1 | ... | Pn) ==> ~Pi *)
   200 local
   201   fun is_sublit conj t = L.exists_lit conj (fn u => u aconv t)
   202 
   203   fun derive conj t lits idx ptab =
   204     let
   205       val lit = the (L.get_first_lit (is_sublit conj t) lits)
   206       val ls = L.explode conj false false [t] lit
   207       val lits' = fold L.insert_lit ls (L.delete_lit lit lits)
   208 
   209       fun upd (Proved thm) = Proved (Literals (thm_of thm, lits'))
   210         | upd p = p
   211     in (the (L.lookup_lit lits' t), Inttab.map_entry idx upd ptab) end
   212 
   213   fun lit_elim conj (p, idx) ct ptab =
   214     let val lits = literals_of p
   215     in
   216       (case L.lookup_lit lits (T.term_of ct) of
   217         SOME lit => (Thm lit, ptab)
   218       | NONE => apfst Thm (derive conj (T.term_of ct) lits idx ptab))
   219     end
   220 in
   221 val and_elim = lit_elim true
   222 val not_or_elim = lit_elim false
   223 end
   224 
   225 
   226 
   227 (* P1, ..., Pn |- False ==> |- ~P1 | ... | ~Pn *)
   228 local
   229   fun step lit thm =
   230     Thm.implies_elim (Thm.implies_intr (Thm.cprop_of lit) thm) lit
   231   val explode_disj = L.explode false false false
   232   fun intro hyps thm th = fold step (explode_disj hyps th) thm
   233 
   234   fun dest_ccontr ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg1 ct))]
   235   val ccontr = T.precompose dest_ccontr @{thm ccontr}
   236 in
   237 fun lemma thm ct =
   238   let
   239     val cu = L.negate (Thm.dest_arg ct)
   240     val hyps = map_filter (try HOLogic.dest_Trueprop) (#hyps (Thm.rep_thm thm))
   241   in Thm (T.compose ccontr (T.under_assumption (intro hyps thm) cu)) end
   242 end
   243 
   244 
   245 
   246 (* \/{P1, ..., Pn, Q1, ..., Qn}, ~P1, ..., ~Pn ==> \/{Q1, ..., Qn} *)
   247 local
   248   val explode_disj = L.explode false true false
   249   val join_disj = L.join false
   250   fun unit thm thms th =
   251     let val t = @{const Not} $ T.prop_of thm and ts = map T.prop_of thms
   252     in join_disj (L.make_littab (thms @ explode_disj ts th)) t end
   253 
   254   fun dest_arg2 ct = Thm.dest_arg (Thm.dest_arg ct)
   255   fun dest ct = pairself dest_arg2 (Thm.dest_binop ct)
   256   val contrapos = T.precompose2 dest @{lemma "(~P ==> ~Q) ==> Q ==> P" by fast}
   257 in
   258 fun unit_resolution thm thms ct =
   259   L.negate (Thm.dest_arg ct)
   260   |> T.under_assumption (unit thm thms)
   261   |> Thm o T.discharge thm o T.compose contrapos
   262 end
   263 
   264 
   265 
   266 (* P ==> P == True   or   P ==> P == False *)
   267 local
   268   val iff1 = @{lemma "P ==> P == (~ False)" by simp}
   269   val iff2 = @{lemma "~P ==> P == False" by simp}
   270 in
   271 fun iff_true thm = MetaEq (thm COMP iff1)
   272 fun iff_false thm = MetaEq (thm COMP iff2)
   273 end
   274 
   275 
   276 
   277 (* distributivity of | over & *)
   278 fun distributivity ctxt = Thm o try_apply ctxt [] [
   279   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
   280     (* FIXME: not very well tested *)
   281 
   282 
   283 
   284 (* Tseitin-like axioms *)
   285 
   286 local
   287   val disjI1 = @{lemma "(P ==> Q) ==> ~P | Q" by fast}
   288   val disjI2 = @{lemma "(~P ==> Q) ==> P | Q" by fast}
   289   val disjI3 = @{lemma "(~Q ==> P) ==> P | Q" by fast}
   290   val disjI4 = @{lemma "(Q ==> P) ==> P | ~Q" by fast}
   291 
   292   fun prove' conj1 conj2 ct2 thm =
   293     let val lits = L.true_thm :: L.explode conj1 true (conj1 <> conj2) [] thm
   294     in L.join conj2 (L.make_littab lits) (Thm.term_of ct2) end
   295 
   296   fun prove rule (ct1, conj1) (ct2, conj2) =
   297     T.under_assumption (prove' conj1 conj2 ct2) ct1 COMP rule
   298 
   299   fun prove_def_axiom ct =
   300     let val (ct1, ct2) = Thm.dest_binop (Thm.dest_arg ct)
   301     in
   302       (case Thm.term_of ct1 of
   303         @{const Not} $ (@{const HOL.conj} $ _ $ _) =>
   304           prove disjI1 (Thm.dest_arg ct1, true) (ct2, true)
   305       | @{const HOL.conj} $ _ $ _ =>
   306           prove disjI3 (L.negate ct2, false) (ct1, true)
   307       | @{const Not} $ (@{const HOL.disj} $ _ $ _) =>
   308           prove disjI3 (L.negate ct2, false) (ct1, false)
   309       | @{const HOL.disj} $ _ $ _ =>
   310           prove disjI2 (L.negate ct1, false) (ct2, true)
   311       | Const (@{const_name distinct}, _) $ _ =>
   312           let
   313             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv cv)
   314             fun prv cu =
   315               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
   316               in prove disjI4 (Thm.dest_arg cu2, true) (cu1, true) end
   317           in T.with_conv (dis_conv T.unfold_distinct_conv) prv ct end
   318       | @{const Not} $ (Const (@{const_name distinct}, _) $ _) =>
   319           let
   320             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv (Conv.arg_conv cv))
   321             fun prv cu =
   322               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
   323               in prove disjI1 (Thm.dest_arg cu1, true) (cu2, true) end
   324           in T.with_conv (dis_conv T.unfold_distinct_conv) prv ct end
   325       | _ => raise CTERM ("prove_def_axiom", [ct]))
   326     end
   327 in
   328 fun def_axiom ctxt = Thm o try_apply ctxt [] [
   329   named ctxt "conj/disj/distinct" prove_def_axiom,
   330   T.by_abstraction (true, false) ctxt [] (fn ctxt' =>
   331     named ctxt' "simp+fast" (T.by_tac simp_fast_tac))]
   332 end
   333 
   334 
   335 
   336 (* local definitions *)
   337 local
   338   val intro_rules = [
   339     @{lemma "n == P ==> (~n | P) & (n | ~P)" by simp},
   340     @{lemma "n == (if P then s else t) ==> (~P | n = s) & (P | n = t)"
   341       by simp},
   342     @{lemma "n == P ==> n = P" by (rule meta_eq_to_obj_eq)} ]
   343 
   344   val apply_rules = [
   345     @{lemma "(~n | P) & (n | ~P) ==> P == n" by (atomize(full)) fast},
   346     @{lemma "(~P | n = s) & (P | n = t) ==> (if P then s else t) == n"
   347       by (atomize(full)) fastsimp} ]
   348 
   349   val inst_rule = T.match_instantiate Thm.dest_arg
   350 
   351   fun apply_rule ct =
   352     (case get_first (try (inst_rule ct)) intro_rules of
   353       SOME thm => thm
   354     | NONE => raise CTERM ("intro_def", [ct]))
   355 in
   356 fun intro_def ct = T.make_hyp_def (apply_rule ct) #>> Thm
   357 
   358 fun apply_def thm =
   359   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
   360   |> the_default (Thm thm)
   361 end
   362 
   363 
   364 
   365 (* negation normal form *)
   366 
   367 local
   368   val quant_rules1 = ([
   369     @{lemma "(!!x. P x == Q) ==> ALL x. P x == Q" by simp},
   370     @{lemma "(!!x. P x == Q) ==> EX x. P x == Q" by simp}], [
   371     @{lemma "(!!x. P x == Q x) ==> ALL x. P x == ALL x. Q x" by simp},
   372     @{lemma "(!!x. P x == Q x) ==> EX x. P x == EX x. Q x" by simp}])
   373 
   374   val quant_rules2 = ([
   375     @{lemma "(!!x. ~P x == Q) ==> ~(ALL x. P x) == Q" by simp},
   376     @{lemma "(!!x. ~P x == Q) ==> ~(EX x. P x) == Q" by simp}], [
   377     @{lemma "(!!x. ~P x == Q x) ==> ~(ALL x. P x) == EX x. Q x" by simp},
   378     @{lemma "(!!x. ~P x == Q x) ==> ~(EX x. P x) == ALL x. Q x" by simp}])
   379 
   380   fun nnf_quant_tac thm (qs as (qs1, qs2)) i st = (
   381     Tactic.rtac thm ORELSE'
   382     (Tactic.match_tac qs1 THEN' nnf_quant_tac thm qs) ORELSE'
   383     (Tactic.match_tac qs2 THEN' nnf_quant_tac thm qs)) i st
   384 
   385   fun nnf_quant vars qs p ct =
   386     T.as_meta_eq ct
   387     |> T.by_tac (nnf_quant_tac (T.varify vars (meta_eq_of p)) qs)
   388 
   389   fun prove_nnf ctxt = try_apply ctxt [] [
   390     named ctxt "conj/disj" L.prove_conj_disj_eq,
   391     T.by_abstraction (true, false) ctxt [] (fn ctxt' =>
   392       named ctxt' "simp+fast" (T.by_tac simp_fast_tac))]
   393 in
   394 fun nnf ctxt vars ps ct =
   395   (case T.term_of ct of
   396     _ $ (l as Const _ $ Abs _) $ (r as Const _ $ Abs _) =>
   397       if l aconv r
   398       then MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
   399       else MetaEq (nnf_quant vars quant_rules1 (hd ps) ct)
   400   | _ $ (@{const Not} $ (Const _ $ Abs _)) $ (Const _ $ Abs _) =>
   401       MetaEq (nnf_quant vars quant_rules2 (hd ps) ct)
   402   | _ =>
   403       let
   404         val nnf_rewr_conv = Conv.arg_conv (Conv.arg_conv
   405           (T.unfold_eqs ctxt (map (Thm.symmetric o meta_eq_of) ps)))
   406       in Thm (T.with_conv nnf_rewr_conv (prove_nnf ctxt) ct) end)
   407 end
   408 
   409 
   410 
   411 (** equality proof rules **)
   412 
   413 (* |- t = t *)
   414 fun refl ct = MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
   415 
   416 
   417 
   418 (* s = t ==> t = s *)
   419 local
   420   val symm_rule = @{lemma "s = t ==> t == s" by simp}
   421 in
   422 fun symm (MetaEq thm) = MetaEq (Thm.symmetric thm)
   423   | symm p = MetaEq (thm_of p COMP symm_rule)
   424 end
   425 
   426 
   427 
   428 (* s = t ==> t = u ==> s = u *)
   429 local
   430   val trans1 = @{lemma "s == t ==> t =  u ==> s == u" by simp}
   431   val trans2 = @{lemma "s =  t ==> t == u ==> s == u" by simp}
   432   val trans3 = @{lemma "s =  t ==> t =  u ==> s == u" by simp}
   433 in
   434 fun trans (MetaEq thm1) (MetaEq thm2) = MetaEq (Thm.transitive thm1 thm2)
   435   | trans (MetaEq thm) q = MetaEq (thm_of q COMP (thm COMP trans1))
   436   | trans p (MetaEq thm) = MetaEq (thm COMP (thm_of p COMP trans2))
   437   | trans p q = MetaEq (thm_of q COMP (thm_of p COMP trans3))
   438 end
   439 
   440 
   441 
   442 (* t1 = s1 ==> ... ==> tn = sn ==> f t1 ... tn = f s1 .. sn
   443    (reflexive antecendents are droppped) *)
   444 local
   445   exception MONO
   446 
   447   fun prove_refl (ct, _) = Thm.reflexive ct
   448   fun prove_comb f g cp =
   449     let val ((ct1, ct2), (cu1, cu2)) = pairself Thm.dest_comb cp
   450     in Thm.combination (f (ct1, cu1)) (g (ct2, cu2)) end
   451   fun prove_arg f = prove_comb prove_refl f
   452 
   453   fun prove f cp = prove_comb (prove f) f cp handle CTERM _ => prove_refl cp
   454 
   455   fun prove_nary is_comb f =
   456     let
   457       fun prove (cp as (ct, _)) = f cp handle MONO =>
   458         if is_comb (Thm.term_of ct)
   459         then prove_comb (prove_arg prove) prove cp
   460         else prove_refl cp
   461     in prove end
   462 
   463   fun prove_list f n cp =
   464     if n = 0 then prove_refl cp
   465     else prove_comb (prove_arg f) (prove_list f (n-1)) cp
   466 
   467   fun with_length f (cp as (cl, _)) =
   468     f (length (HOLogic.dest_list (Thm.term_of cl))) cp
   469 
   470   fun prove_distinct f = prove_arg (with_length (prove_list f))
   471 
   472   fun prove_eq exn lookup cp =
   473     (case lookup (Logic.mk_equals (pairself Thm.term_of cp)) of
   474       SOME eq => eq
   475     | NONE => if exn then raise MONO else prove_refl cp)
   476   
   477   val prove_eq_exn = prove_eq true
   478   and prove_eq_safe = prove_eq false
   479 
   480   fun mono f (cp as (cl, _)) =
   481     (case Term.head_of (Thm.term_of cl) of
   482       @{const HOL.conj} => prove_nary L.is_conj (prove_eq_exn f)
   483     | @{const HOL.disj} => prove_nary L.is_disj (prove_eq_exn f)
   484     | Const (@{const_name distinct}, _) => prove_distinct (prove_eq_safe f)
   485     | _ => prove (prove_eq_safe f)) cp
   486 in
   487 fun monotonicity eqs ct =
   488   let
   489     fun and_symmetric (t, thm) = [(t, thm), (t, Thm.symmetric thm)]
   490     val teqs = maps (and_symmetric o `Thm.prop_of o meta_eq_of) eqs
   491     val lookup = AList.lookup (op aconv) teqs
   492     val cp = Thm.dest_binop (Thm.dest_arg ct)
   493   in MetaEq (prove_eq_exn lookup cp handle MONO => mono lookup cp) end
   494 end
   495 
   496 
   497 
   498 (* |- f a b = f b a (where f is equality) *)
   499 local
   500   val rule = @{lemma "a = b == b = a" by (atomize(full)) (rule eq_commute)}
   501 in
   502 fun commutativity ct = MetaEq (T.match_instantiate I (T.as_meta_eq ct) rule)
   503 end
   504 
   505 
   506 
   507 (** quantifier proof rules **)
   508 
   509 (* P ?x = Q ?x ==> (ALL x. P x) = (ALL x. Q x)
   510    P ?x = Q ?x ==> (EX x. P x) = (EX x. Q x)    *)
   511 local
   512   val rules = [
   513     @{lemma "(!!x. P x == Q x) ==> (ALL x. P x) == (ALL x. Q x)" by simp},
   514     @{lemma "(!!x. P x == Q x) ==> (EX x. P x) == (EX x. Q x)" by simp}]
   515 in
   516 fun quant_intro vars p ct =
   517   let
   518     val thm = meta_eq_of p
   519     val rules' = T.varify vars thm :: rules
   520     val cu = T.as_meta_eq ct
   521   in MetaEq (T.by_tac (REPEAT_ALL_NEW (Tactic.match_tac rules')) cu) end
   522 end
   523 
   524 
   525 
   526 (* |- ((ALL x. P x) | Q) = (ALL x. P x | Q) *)
   527 fun pull_quant ctxt = Thm o try_apply ctxt [] [
   528   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
   529     (* FIXME: not very well tested *)
   530 
   531 
   532 
   533 (* |- (ALL x. P x & Q x) = ((ALL x. P x) & (ALL x. Q x)) *)
   534 fun push_quant ctxt = Thm o try_apply ctxt [] [
   535   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
   536     (* FIXME: not very well tested *)
   537 
   538 
   539 
   540 (* |- (ALL x1 ... xn y1 ... yn. P x1 ... xn) = (ALL x1 ... xn. P x1 ... xn) *)
   541 local
   542   val elim_all = @{lemma "(ALL x. P) == P" by simp}
   543   val elim_ex = @{lemma "(EX x. P) == P" by simp}
   544 
   545   fun elim_unused_conv ctxt =
   546     Conv.params_conv ~1 (K (Conv.arg_conv (Conv.arg1_conv
   547       (Conv.rewrs_conv [elim_all, elim_ex])))) ctxt
   548 
   549   fun elim_unused_tac ctxt =
   550     REPEAT_ALL_NEW (
   551       Tactic.match_tac [@{thm refl}, @{thm iff_allI}, @{thm iff_exI}]
   552       ORELSE' CONVERSION (elim_unused_conv ctxt))
   553 in
   554 fun elim_unused_vars ctxt = Thm o T.by_tac (elim_unused_tac ctxt)
   555 end
   556 
   557 
   558 
   559 (* |- (ALL x1 ... xn. ~(x1 = t1 & ... xn = tn) | P x1 ... xn) = P t1 ... tn *)
   560 fun dest_eq_res ctxt = Thm o try_apply ctxt [] [
   561   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
   562     (* FIXME: not very well tested *)
   563 
   564 
   565 
   566 (* |- ~(ALL x1...xn. P x1...xn) | P a1...an *)
   567 local
   568   val rule = @{lemma "~ P x | Q ==> ~(ALL x. P x) | Q" by fast}
   569 in
   570 val quant_inst = Thm o T.by_tac (
   571   REPEAT_ALL_NEW (Tactic.match_tac [rule])
   572   THEN' Tactic.rtac @{thm excluded_middle})
   573 end
   574 
   575 
   576 
   577 (* c = SOME x. P x |- (EX x. P x) = P c
   578    c = SOME x. ~ P x |- ~(ALL x. P x) = ~ P c *)
   579 local
   580   val elim_ex = @{lemma "EX x. P == P" by simp}
   581   val elim_all = @{lemma "~ (ALL x. P) == ~P" by simp}
   582   val sk_ex = @{lemma "c == SOME x. P x ==> EX x. P x == P c"
   583     by simp (intro eq_reflection some_eq_ex[symmetric])}
   584   val sk_all = @{lemma "c == SOME x. ~ P x ==> ~(ALL x. P x) == ~ P c"
   585     by (simp only: not_all) (intro eq_reflection some_eq_ex[symmetric])}
   586   val sk_ex_rule = ((sk_ex, I), elim_ex)
   587   and sk_all_rule = ((sk_all, Thm.dest_arg), elim_all)
   588 
   589   fun dest f sk_rule = 
   590     Thm.dest_comb (f (Thm.dest_arg (Thm.dest_arg (Thm.cprop_of sk_rule))))
   591   fun type_of f sk_rule = Thm.ctyp_of_term (snd (dest f sk_rule))
   592   fun pair2 (a, b) (c, d) = [(a, c), (b, d)]
   593   fun inst_sk (sk_rule, f) p c =
   594     Thm.instantiate ([(type_of f sk_rule, Thm.ctyp_of_term c)], []) sk_rule
   595     |> (fn sk' => Thm.instantiate ([], (pair2 (dest f sk') (p, c))) sk')
   596     |> Conv.fconv_rule (Thm.beta_conversion true)
   597 
   598   fun kind (Const (@{const_name Ex}, _) $ _) = (sk_ex_rule, I, I)
   599     | kind (@{const Not} $ (Const (@{const_name All}, _) $ _)) =
   600         (sk_all_rule, Thm.dest_arg, L.negate)
   601     | kind t = raise TERM ("skolemize", [t])
   602 
   603   fun dest_abs_type (Abs (_, T, _)) = T
   604     | dest_abs_type t = raise TERM ("dest_abs_type", [t])
   605 
   606   fun bodies_of thy lhs rhs =
   607     let
   608       val (rule, dest, make) = kind (Thm.term_of lhs)
   609 
   610       fun dest_body idx cbs ct =
   611         let
   612           val cb = Thm.dest_arg (dest ct)
   613           val T = dest_abs_type (Thm.term_of cb)
   614           val cv = Thm.cterm_of thy (Var (("x", idx), T))
   615           val cu = make (Drule.beta_conv cb cv)
   616           val cbs' = (cv, cb) :: cbs
   617         in
   618           (snd (Thm.first_order_match (cu, rhs)), rev cbs')
   619           handle Pattern.MATCH => dest_body (idx+1) cbs' cu
   620         end
   621     in (rule, dest_body 1 [] lhs) end
   622 
   623   fun transitive f thm = Thm.transitive thm (f (Thm.rhs_of thm))
   624 
   625   fun sk_step (rule, elim) (cv, mct, cb) ((is, thm), ctxt) =
   626     (case mct of
   627       SOME ct =>
   628         ctxt
   629         |> T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
   630         |>> pair ((cv, ct) :: is) o Thm.transitive thm
   631     | NONE => ((is, transitive (Conv.rewr_conv elim) thm), ctxt))
   632 in
   633 fun skolemize ct ctxt =
   634   let
   635     val (lhs, rhs) = Thm.dest_binop (Thm.dest_arg ct)
   636     val (rule, (ctab, cbs)) = bodies_of (ProofContext.theory_of ctxt) lhs rhs
   637     fun lookup_var (cv, cb) = (cv, AList.lookup (op aconvc) ctab cv, cb)
   638   in
   639     (([], Thm.reflexive lhs), ctxt)
   640     |> fold (sk_step rule) (map lookup_var cbs)
   641     |>> MetaEq o snd
   642   end
   643 end
   644 
   645 
   646 
   647 (** theory proof rules **)
   648 
   649 (* theory lemmas: linear arithmetic, arrays *)
   650 
   651 fun th_lemma ctxt simpset thms = Thm o try_apply ctxt thms [
   652   T.by_abstraction (false, true) ctxt thms (fn ctxt' => T.by_tac (
   653     NAMED ctxt' "arith" (Arith_Data.arith_tac ctxt')
   654     ORELSE' NAMED ctxt' "simp+arith" (Simplifier.simp_tac simpset THEN_ALL_NEW
   655       Arith_Data.arith_tac ctxt')))]
   656 
   657 
   658 
   659 (* rewriting: prove equalities:
   660      * ACI of conjunction/disjunction
   661      * contradiction, excluded middle
   662      * logical rewriting rules (for negation, implication, equivalence,
   663          distinct)
   664      * normal forms for polynoms (integer/real arithmetic)
   665      * quantifier elimination over linear arithmetic
   666      * ... ? **)
   667 structure Z3_Simps = Named_Thms
   668 (
   669   val name = "z3_simp"
   670   val description = "simplification rules for Z3 proof reconstruction"
   671 )
   672 
   673 local
   674   fun spec_meta_eq_of thm =
   675     (case try (fn th => th RS @{thm spec}) thm of
   676       SOME thm' => spec_meta_eq_of thm'
   677     | NONE => mk_meta_eq thm)
   678 
   679   fun prep (Thm thm) = spec_meta_eq_of thm
   680     | prep (MetaEq thm) = thm
   681     | prep (Literals (thm, _)) = spec_meta_eq_of thm
   682 
   683   fun unfold_conv ctxt ths =
   684     Conv.arg_conv (Conv.binop_conv (T.unfold_eqs ctxt (map prep ths)))
   685 
   686   fun with_conv _ [] prv = prv
   687     | with_conv ctxt ths prv = T.with_conv (unfold_conv ctxt ths) prv
   688 
   689   val unfold_conv =
   690     Conv.arg_conv (Conv.binop_conv (Conv.try_conv T.unfold_distinct_conv))
   691   val prove_conj_disj_eq = T.with_conv unfold_conv L.prove_conj_disj_eq
   692 
   693   fun assume_prems ctxt thm =
   694     Assumption.add_assumes (Drule.cprems_of thm) ctxt
   695     |>> (fn thms => fold Thm.elim_implies thms thm)
   696 in
   697 
   698 fun rewrite simpset ths ct ctxt =
   699   apfst Thm (assume_prems ctxt (with_conv ctxt ths (try_apply ctxt [] [
   700     named ctxt "conj/disj/distinct" prove_conj_disj_eq,
   701     T.by_abstraction (true, false) ctxt [] (fn ctxt' => T.by_tac (
   702       NAMED ctxt' "simp (logic)" (Simplifier.simp_tac simpset)
   703       THEN_ALL_NEW NAMED ctxt' "fast (logic)" (Classical.fast_tac HOL_cs))),
   704     T.by_abstraction (false, true) ctxt [] (fn ctxt' => T.by_tac (
   705       NAMED ctxt' "simp (theory)" (Simplifier.simp_tac simpset)
   706       THEN_ALL_NEW (
   707         NAMED ctxt' "fast (theory)" (Classical.fast_tac HOL_cs)
   708         ORELSE' NAMED ctxt' "arith (theory)" (Arith_Data.arith_tac ctxt')))),
   709     T.by_abstraction (true, true) ctxt [] (fn ctxt' => T.by_tac (
   710       NAMED ctxt' "simp (full)" (Simplifier.simp_tac simpset)
   711       THEN_ALL_NEW (
   712         NAMED ctxt' "fast (full)" (Classical.fast_tac HOL_cs)
   713         ORELSE' NAMED ctxt' "arith (full)" (Arith_Data.arith_tac ctxt')))),
   714     named ctxt "injectivity" (M.prove_injectivity ctxt)]) ct))
   715 
   716 end
   717 
   718 
   719 
   720 (** proof reconstruction **)
   721 
   722 (* tracing and checking *)
   723 
   724 local
   725   fun count_rules ptab =
   726     let
   727       fun count (_, Unproved _) (solved, total) = (solved, total + 1)
   728         | count (_, Proved _) (solved, total) = (solved + 1, total + 1)
   729     in Inttab.fold count ptab (0, 0) end
   730 
   731   fun header idx r (solved, total) = 
   732     "Z3: #" ^ string_of_int idx ^ ": " ^ P.string_of_rule r ^ " (goal " ^
   733     string_of_int (solved + 1) ^ " of " ^ string_of_int total ^ ")"
   734 
   735   fun check ctxt idx r ps ct p =
   736     let val thm = thm_of p |> tap (Thm.join_proofs o single)
   737     in
   738       if (Thm.cprop_of thm) aconvc ct then ()
   739       else z3_exn (Pretty.string_of (Pretty.big_list ("proof step failed: " ^
   740         quote (P.string_of_rule r) ^ " (#" ^ string_of_int idx ^ ")")
   741           (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
   742            [Pretty.block [Pretty.str "expected: ",
   743             Syntax.pretty_term ctxt (Thm.term_of ct)]])))
   744     end
   745 in
   746 fun trace_rule idx prove r ps ct (cxp as (ctxt, ptab)) =
   747   let
   748     val _ = SMT_Config.trace_msg ctxt (header idx r o count_rules) ptab
   749     val result as (p, (ctxt', _)) = prove r ps ct cxp
   750     val _ = if not (Config.get ctxt' SMT_Config.trace) then ()
   751       else check ctxt' idx r ps ct p
   752   in result end
   753 end
   754 
   755 
   756 (* overall reconstruction procedure *)
   757 
   758 local
   759   fun not_supported r = raise Fail ("Z3: proof rule not implemented: " ^
   760     quote (P.string_of_rule r))
   761 
   762   fun step assms simpset vars r ps ct (cxp as (cx, ptab)) =
   763     (case (r, ps) of
   764       (* core rules *)
   765       (P.TrueAxiom, _) => (Thm L.true_thm, cxp)
   766     | (P.Asserted, _) => (asserted cx assms ct, cxp)
   767     | (P.Goal, _) => (asserted cx assms ct, cxp)
   768     | (P.ModusPonens, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
   769     | (P.ModusPonensOeq, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
   770     | (P.AndElim, [(p, i)]) => and_elim (p, i) ct ptab ||> pair cx
   771     | (P.NotOrElim, [(p, i)]) => not_or_elim (p, i) ct ptab ||> pair cx
   772     | (P.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
   773     | (P.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
   774     | (P.UnitResolution, (p, _) :: ps) =>
   775         (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
   776     | (P.IffTrue, [(p, _)]) => (iff_true (thm_of p), cxp)
   777     | (P.IffFalse, [(p, _)]) => (iff_false (thm_of p), cxp)
   778     | (P.Distributivity, _) => (distributivity cx ct, cxp)
   779     | (P.DefAxiom, _) => (def_axiom cx ct, cxp)
   780     | (P.IntroDef, _) => intro_def ct cx ||> rpair ptab
   781     | (P.ApplyDef, [(p, _)]) => (apply_def (thm_of p), cxp)
   782     | (P.IffOeq, [(p, _)]) => (p, cxp)
   783     | (P.NnfPos, _) => (nnf cx vars (map fst ps) ct, cxp)
   784     | (P.NnfNeg, _) => (nnf cx vars (map fst ps) ct, cxp)
   785 
   786       (* equality rules *)
   787     | (P.Reflexivity, _) => (refl ct, cxp)
   788     | (P.Symmetry, [(p, _)]) => (symm p, cxp)
   789     | (P.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
   790     | (P.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
   791     | (P.Commutativity, _) => (commutativity ct, cxp)
   792 
   793       (* quantifier rules *)
   794     | (P.QuantIntro, [(p, _)]) => (quant_intro vars p ct, cxp)
   795     | (P.PullQuant, _) => (pull_quant cx ct, cxp)
   796     | (P.PushQuant, _) => (push_quant cx ct, cxp)
   797     | (P.ElimUnusedVars, _) => (elim_unused_vars cx ct, cxp)
   798     | (P.DestEqRes, _) => (dest_eq_res cx ct, cxp)
   799     | (P.QuantInst, _) => (quant_inst ct, cxp)
   800     | (P.Skolemize, _) => skolemize ct cx ||> rpair ptab
   801 
   802       (* theory rules *)
   803     | (P.ThLemma _, _) =>  (* FIXME: use arguments *)
   804         (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
   805     | (P.Rewrite, _) => rewrite simpset [] ct cx ||> rpair ptab
   806     | (P.RewriteStar, ps) => rewrite simpset (map fst ps) ct cx ||> rpair ptab
   807 
   808     | (P.NnfStar, _) => not_supported r
   809     | (P.CnfStar, _) => not_supported r
   810     | (P.TransitivityStar, _) => not_supported r
   811     | (P.PullQuantStar, _) => not_supported r
   812 
   813     | _ => raise Fail ("Z3: proof rule " ^ quote (P.string_of_rule r) ^
   814        " has an unexpected number of arguments."))
   815 
   816   fun prove ctxt assms vars =
   817     let
   818       val simpset = T.make_simpset ctxt (Z3_Simps.get ctxt)
   819  
   820       fun conclude idx rule prop (ps, cxp) =
   821         trace_rule idx (step assms simpset vars) rule ps prop cxp
   822         |-> (fn p => apsnd (Inttab.update (idx, Proved p)) #> pair p)
   823  
   824       fun lookup idx (cxp as (_, ptab)) =
   825         (case Inttab.lookup ptab idx of
   826           SOME (Unproved (P.Proof_Step {rule, prems, prop})) =>
   827             fold_map lookup prems cxp
   828             |>> map2 rpair prems
   829             |> conclude idx rule prop
   830         | SOME (Proved p) => (p, cxp)
   831         | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
   832  
   833       fun result (p, (cx, _)) = (thm_of p, cx)
   834     in
   835       (fn idx => result o lookup idx o pair ctxt o Inttab.map (K Unproved))
   836     end
   837 
   838   fun filter_assms ctxt assms ptab =
   839     let
   840       fun step r ct =
   841         (case r of
   842           P.Asserted => insert (op =) (find_assm ctxt assms ct)
   843         | P.Goal => insert (op =) (find_assm ctxt assms ct)
   844         | _ => I)
   845 
   846       fun lookup idx =
   847         (case Inttab.lookup ptab idx of
   848           SOME (P.Proof_Step {rule, prems, prop}) =>
   849             fold lookup prems #> step rule prop
   850         | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
   851     in lookup end
   852 in
   853 
   854 fun reconstruct ctxt {typs, terms, unfolds, assms} output =
   855   let
   856     val (idx, (ptab, vars, cx)) = P.parse ctxt typs terms output
   857     val assms' = prepare_assms cx unfolds assms
   858   in
   859     if Config.get cx SMT_Config.filter_only_facts
   860     then ((filter_assms cx assms' ptab idx [], @{thm TrueI}), cx)
   861     else apfst (pair []) (prove cx assms' vars idx ptab)
   862   end
   863 
   864 end
   865 
   866 val setup = z3_rules_setup #> Z3_Simps.setup
   867 
   868 end