src/HOL/Tools/SMT/z3_proof_literals.ML
author haftmann
Sat Aug 28 16:14:32 2010 +0200 (2010-08-28)
changeset 38864 4abe644fcea5
parent 38795 848be46708dc
child 40579 98ebd2300823
permissions -rw-r--r--
formerly unnamed infix equality now named HOL.eq
     1 (*  Title:      HOL/Tools/SMT/z3_proof_literals.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Proof tools related to conjunctions and disjunctions.
     5 *)
     6 
     7 signature Z3_PROOF_LITERALS =
     8 sig
     9   (* literal table *)
    10   type littab = thm Termtab.table
    11   val make_littab: thm list -> littab
    12   val insert_lit: thm -> littab -> littab
    13   val delete_lit: thm -> littab -> littab
    14   val lookup_lit: littab -> term -> thm option
    15   val get_first_lit: (term -> bool) -> littab -> thm option
    16 
    17   (* rules *)
    18   val true_thm: thm
    19   val rewrite_true: thm
    20 
    21   (* properties *)
    22   val is_conj: term -> bool
    23   val is_disj: term -> bool
    24   val exists_lit: bool -> (term -> bool) -> term -> bool
    25 
    26   (* proof tools *)
    27   val explode: bool -> bool -> bool -> term list -> thm -> thm list
    28   val join: bool -> littab -> term -> thm
    29   val prove_conj_disj_eq: cterm -> thm
    30 end
    31 
    32 structure Z3_Proof_Literals: Z3_PROOF_LITERALS =
    33 struct
    34 
    35 structure T = Z3_Proof_Tools
    36 
    37 
    38 
    39 (** literal table **)
    40 
    41 type littab = thm Termtab.table
    42 
    43 fun make_littab thms = fold (Termtab.update o `T.prop_of) thms Termtab.empty
    44 
    45 fun insert_lit thm = Termtab.update (`T.prop_of thm)
    46 fun delete_lit thm = Termtab.delete (T.prop_of thm)
    47 fun lookup_lit lits = Termtab.lookup lits
    48 fun get_first_lit f =
    49   Termtab.get_first (fn (t, thm) => if f t then SOME thm else NONE)
    50 
    51 
    52 
    53 (** rules **)
    54 
    55 val true_thm = @{lemma "~False" by simp}
    56 val rewrite_true = @{lemma "True == ~ False" by simp}
    57 
    58 
    59 
    60 (** properties and term operations **)
    61 
    62 val is_neg = (fn @{term Not} $ _ => true | _ => false)
    63 fun is_neg' f = (fn @{term Not} $ t => f t | _ => false)
    64 val is_dneg = is_neg' is_neg
    65 val is_conj = (fn @{term HOL.conj} $ _ $ _ => true | _ => false)
    66 val is_disj = (fn @{term HOL.disj} $ _ $ _ => true | _ => false)
    67 
    68 fun dest_disj_term' f = (fn
    69     @{term Not} $ (@{term HOL.disj} $ t $ u) => SOME (f t, f u)
    70   | _ => NONE)
    71 
    72 val dest_conj_term = (fn @{term HOL.conj} $ t $ u => SOME (t, u) | _ => NONE)
    73 val dest_disj_term =
    74   dest_disj_term' (fn @{term Not} $ t => t | t => @{term Not} $ t)
    75 
    76 fun exists_lit is_conj P =
    77   let
    78     val dest = if is_conj then dest_conj_term else dest_disj_term
    79     fun exists t = P t orelse
    80       (case dest t of
    81         SOME (t1, t2) => exists t1 orelse exists t2
    82       | NONE => false)
    83   in exists end
    84 
    85 
    86 
    87 (** proof tools **)
    88 
    89 (* explosion of conjunctions and disjunctions *)
    90 
    91 local
    92   fun destc ct = Thm.dest_binop (Thm.dest_arg ct)
    93   val dest_conj1 = T.precompose2 destc @{thm conjunct1}
    94   val dest_conj2 = T.precompose2 destc @{thm conjunct2}
    95   fun dest_conj_rules t =
    96     dest_conj_term t |> Option.map (K (dest_conj1, dest_conj2))
    97     
    98   fun destd f ct = f (Thm.dest_binop (Thm.dest_arg (Thm.dest_arg ct)))
    99   val dn1 = apfst Thm.dest_arg and dn2 = apsnd Thm.dest_arg
   100   val dest_disj1 = T.precompose2 (destd I) @{lemma "~(P | Q) ==> ~P" by fast}
   101   val dest_disj2 = T.precompose2 (destd dn1) @{lemma "~(~P | Q) ==> P" by fast}
   102   val dest_disj3 = T.precompose2 (destd I) @{lemma "~(P | Q) ==> ~Q" by fast}
   103   val dest_disj4 = T.precompose2 (destd dn2) @{lemma "~(P | ~Q) ==> Q" by fast}
   104 
   105   fun dest_disj_rules t =
   106     (case dest_disj_term' is_neg t of
   107       SOME (true, true) => SOME (dest_disj2, dest_disj4)
   108     | SOME (true, false) => SOME (dest_disj2, dest_disj3)
   109     | SOME (false, true) => SOME (dest_disj1, dest_disj4)
   110     | SOME (false, false) => SOME (dest_disj1, dest_disj3)
   111     | NONE => NONE)
   112 
   113   fun destn ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg ct))]
   114   val dneg_rule = T.precompose destn @{thm notnotD}
   115 in
   116 
   117 (* explode a term into literals and collect all rules to be able to deduce
   118    particular literals afterwards *)
   119 fun explode_term is_conj =
   120   let
   121     val dest = if is_conj then dest_conj_term else dest_disj_term
   122     val dest_rules = if is_conj then dest_conj_rules else dest_disj_rules
   123 
   124     fun add (t, rs) = Termtab.map_default (t, rs)
   125       (fn rs' => if length rs' < length rs then rs' else rs)
   126 
   127     fun explode1 rules t =
   128       (case dest t of
   129         SOME (t1, t2) =>
   130           let val (rule1, rule2) = the (dest_rules t)
   131           in
   132             explode1 (rule1 :: rules) t1 #>
   133             explode1 (rule2 :: rules) t2 #>
   134             add (t, rev rules)
   135           end
   136       | NONE => add (t, rev rules))
   137 
   138     fun explode0 (@{term Not} $ (@{term Not} $ t)) =
   139           Termtab.make [(t, [dneg_rule])]
   140       | explode0 t = explode1 [] t Termtab.empty
   141 
   142   in explode0 end
   143 
   144 (* extract a literal by applying previously collected rules *)
   145 fun extract_lit thm rules = fold T.compose rules thm
   146 
   147 
   148 (* explode a theorem into its literals *)
   149 fun explode is_conj full keep_intermediate stop_lits =
   150   let
   151     val dest_rules = if is_conj then dest_conj_rules else dest_disj_rules
   152     val tab = fold (Termtab.update o rpair ()) stop_lits Termtab.empty
   153 
   154     fun explode1 thm =
   155       if Termtab.defined tab (T.prop_of thm) then cons thm
   156       else
   157         (case dest_rules (T.prop_of thm) of
   158           SOME (rule1, rule2) =>
   159             explode2 rule1 thm #>
   160             explode2 rule2 thm #>
   161             keep_intermediate ? cons thm
   162         | NONE => cons thm)
   163 
   164     and explode2 dest_rule thm =
   165       if full orelse exists_lit is_conj (Termtab.defined tab) (T.prop_of thm)
   166       then explode1 (T.compose dest_rule thm)
   167       else cons (T.compose dest_rule thm)
   168 
   169     fun explode0 thm =
   170       if not is_conj andalso is_dneg (T.prop_of thm)
   171       then [T.compose dneg_rule thm]
   172       else explode1 thm []
   173 
   174   in explode0 end
   175 
   176 end
   177 
   178 
   179 
   180 (* joining of literals to conjunctions or disjunctions *)
   181 
   182 local
   183   fun on_cprem i f thm = f (Thm.cprem_of thm i)
   184   fun on_cprop f thm = f (Thm.cprop_of thm)
   185   fun precomp2 f g thm = (on_cprem 1 f thm, on_cprem 2 g thm, f, g, thm)
   186   fun comp2 (cv1, cv2, f, g, rule) thm1 thm2 =
   187     Thm.instantiate ([], [(cv1, on_cprop f thm1), (cv2, on_cprop g thm2)]) rule
   188     |> T.discharge thm1 |> T.discharge thm2
   189 
   190   fun d1 ct = Thm.dest_arg ct and d2 ct = Thm.dest_arg (Thm.dest_arg ct)
   191 
   192   val conj_rule = precomp2 d1 d1 @{thm conjI}
   193   fun comp_conj ((_, thm1), (_, thm2)) = comp2 conj_rule thm1 thm2
   194 
   195   val disj1 = precomp2 d2 d2 @{lemma "~P ==> ~Q ==> ~(P | Q)" by fast}
   196   val disj2 = precomp2 d2 d1 @{lemma "~P ==> Q ==> ~(P | ~Q)" by fast}
   197   val disj3 = precomp2 d1 d2 @{lemma "P ==> ~Q ==> ~(~P | Q)" by fast}
   198   val disj4 = precomp2 d1 d1 @{lemma "P ==> Q ==> ~(~P | ~Q)" by fast}
   199 
   200   fun comp_disj ((false, thm1), (false, thm2)) = comp2 disj1 thm1 thm2
   201     | comp_disj ((false, thm1), (true, thm2)) = comp2 disj2 thm1 thm2
   202     | comp_disj ((true, thm1), (false, thm2)) = comp2 disj3 thm1 thm2
   203     | comp_disj ((true, thm1), (true, thm2)) = comp2 disj4 thm1 thm2
   204 
   205   fun dest_conj (@{term HOL.conj} $ t $ u) = ((false, t), (false, u))
   206     | dest_conj t = raise TERM ("dest_conj", [t])
   207 
   208   val neg = (fn @{term Not} $ t => (true, t) | t => (false, @{term Not} $ t))
   209   fun dest_disj (@{term Not} $ (@{term HOL.disj} $ t $ u)) = (neg t, neg u)
   210     | dest_disj t = raise TERM ("dest_disj", [t])
   211 
   212   val dnegE = T.precompose (single o d2 o d1) @{thm notnotD}
   213   val dnegI = T.precompose (single o d1) @{lemma "P ==> ~~P" by fast}
   214   fun as_dneg f t = f (@{term Not} $ (@{term Not} $ t))
   215 
   216   fun dni f = apsnd f o Thm.dest_binop o f o d1
   217   val negIffE = T.precompose2 (dni d1) @{lemma "~(P = (~Q)) ==> Q = P" by fast}
   218   val negIffI = T.precompose2 (dni I) @{lemma "P = Q ==> ~(Q = (~P))" by fast}
   219   val iff_const = @{term "op = :: bool => _"}
   220   fun as_negIff f (@{term "op = :: bool => _"} $ t $ u) =
   221         f (@{term Not} $ (iff_const $ u $ (@{term Not} $ t)))
   222     | as_negIff _ _ = NONE
   223 in
   224 
   225 fun join is_conj littab t =
   226   let
   227     val comp = if is_conj then comp_conj else comp_disj
   228     val dest = if is_conj then dest_conj else dest_disj
   229 
   230     val lookup = lookup_lit littab
   231 
   232     fun lookup_rule t =
   233       (case t of
   234         @{term Not} $ (@{term Not} $ t) => (T.compose dnegI, lookup t)
   235       | @{term Not} $ (@{term "op = :: bool => _"} $ t $ (@{term Not} $ u)) =>
   236           (T.compose negIffI, lookup (iff_const $ u $ t))
   237       | @{term Not} $ ((eq as Const (@{const_name HOL.eq}, _)) $ t $ u) =>
   238           let fun rewr lit = lit COMP @{thm not_sym}
   239           in (rewr, lookup (@{term Not} $ (eq $ u $ t))) end
   240       | _ =>
   241           (case as_dneg lookup t of
   242             NONE => (T.compose negIffE, as_negIff lookup t)
   243           | x => (T.compose dnegE, x)))
   244 
   245     fun join1 (s, t) =
   246       (case lookup t of
   247         SOME lit => (s, lit)
   248       | NONE => 
   249           (case lookup_rule t of
   250             (rewrite, SOME lit) => (s, rewrite lit)
   251           | (_, NONE) => (s, comp (pairself join1 (dest t)))))
   252 
   253   in snd (join1 (if is_conj then (false, t) else (true, t))) end
   254 
   255 end
   256 
   257 
   258 
   259 (* proving equality of conjunctions or disjunctions *)
   260 
   261 fun iff_intro thm1 thm2 = thm2 COMP (thm1 COMP @{thm iffI})
   262 
   263 local
   264   val cp1 = @{lemma "(~P) = (~Q) ==> P = Q" by simp}
   265   val cp2 = @{lemma "(~P) = Q ==> P = (~Q)" by fastsimp}
   266   val cp3 = @{lemma "P = (~Q) ==> (~P) = Q" by simp}
   267   val neg = Thm.capply @{cterm Not}
   268 in
   269 fun contrapos1 prove (ct, cu) = prove (neg ct, neg cu) COMP cp1
   270 fun contrapos2 prove (ct, cu) = prove (neg ct, Thm.dest_arg cu) COMP cp2
   271 fun contrapos3 prove (ct, cu) = prove (Thm.dest_arg ct, neg cu) COMP cp3
   272 end
   273 
   274 
   275 local
   276   val contra_rule = @{lemma "P ==> ~P ==> False" by (rule notE)}
   277   fun contra_left conj thm =
   278     let
   279       val rules = explode_term conj (T.prop_of thm)
   280       fun contra_lits (t, rs) =
   281         (case t of
   282           @{term Not} $ u => Termtab.lookup rules u |> Option.map (pair rs)
   283         | _ => NONE)
   284     in
   285       (case Termtab.lookup rules @{term False} of
   286         SOME rs => extract_lit thm rs
   287       | NONE =>
   288           the (Termtab.get_first contra_lits rules)
   289           |> pairself (extract_lit thm)
   290           |> (fn (nlit, plit) => nlit COMP (plit COMP contra_rule)))
   291     end
   292 
   293   val falseE_v = Thm.dest_arg (Thm.dest_arg (Thm.cprop_of @{thm FalseE}))
   294   fun contra_right ct = Thm.instantiate ([], [(falseE_v, ct)]) @{thm FalseE}
   295 in
   296 fun contradict conj ct =
   297   iff_intro (T.under_assumption (contra_left conj) ct) (contra_right ct)
   298 end
   299 
   300 
   301 local
   302   fun prove_eq l r (cl, cr) =
   303     let
   304       fun explode' is_conj = explode is_conj true (l <> r) []
   305       fun make_tab is_conj thm = make_littab (true_thm :: explode' is_conj thm)
   306       fun prove is_conj ct tab = join is_conj tab (Thm.term_of ct)
   307 
   308       val thm1 = T.under_assumption (prove r cr o make_tab l) cl
   309       val thm2 = T.under_assumption (prove l cl o make_tab r) cr
   310     in iff_intro thm1 thm2 end
   311 
   312   datatype conj_disj = CONJ | DISJ | NCON | NDIS
   313   fun kind_of t =
   314     if is_conj t then SOME CONJ
   315     else if is_disj t then SOME DISJ
   316     else if is_neg' is_conj t then SOME NCON
   317     else if is_neg' is_disj t then SOME NDIS
   318     else NONE
   319 in
   320 
   321 fun prove_conj_disj_eq ct =
   322   let val cp as (cl, cr) = Thm.dest_binop (Thm.dest_arg ct)
   323   in
   324     (case (kind_of (Thm.term_of cl), Thm.term_of cr) of
   325       (SOME CONJ, @{term False}) => contradict true cl
   326     | (SOME DISJ, @{term "~False"}) => contrapos2 (contradict false o fst) cp
   327     | (kl, _) =>
   328         (case (kl, kind_of (Thm.term_of cr)) of
   329           (SOME CONJ, SOME CONJ) => prove_eq true true cp
   330         | (SOME CONJ, SOME NDIS) => prove_eq true false cp
   331         | (SOME CONJ, _) => prove_eq true true cp
   332         | (SOME DISJ, SOME DISJ) => contrapos1 (prove_eq false false) cp
   333         | (SOME DISJ, SOME NCON) => contrapos2 (prove_eq false true) cp
   334         | (SOME DISJ, _) => contrapos1 (prove_eq false false) cp
   335         | (SOME NCON, SOME NCON) => contrapos1 (prove_eq true true) cp
   336         | (SOME NCON, SOME DISJ) => contrapos3 (prove_eq true false) cp
   337         | (SOME NCON, NONE) => contrapos3 (prove_eq true false) cp
   338         | (SOME NDIS, SOME NDIS) => prove_eq false false cp
   339         | (SOME NDIS, SOME CONJ) => prove_eq false true cp
   340         | (SOME NDIS, NONE) => prove_eq false true cp
   341         | _ => raise CTERM ("prove_conj_disj_eq", [ct])))
   342   end
   343 
   344 end
   345 
   346 end