src/HOL/Tools/SMT/z3_replay_literals.ML
changeset 59381 de4218223e00
parent 59379 c7f6f01ede15
child 59382 a78e71fcd146
equal deleted inserted replaced
59379:c7f6f01ede15 59381:de4218223e00
     1 (*  Title:      HOL/Tools/SMT/z3_replay_literals.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Proof tools related to conjunctions and disjunctions.
       
     5 *)
       
     6 
       
     7 signature Z3_REPLAY_LITERALS =
       
     8 sig
       
     9   (*literal table*)
       
    10   type littab = thm Termtab.table
       
    11   val make_littab: thm list -> littab
       
    12   val insert_lit: thm -> littab -> littab
       
    13   val delete_lit: thm -> littab -> littab
       
    14   val lookup_lit: littab -> term -> thm option
       
    15   val get_first_lit: (term -> bool) -> littab -> thm option
       
    16 
       
    17   (*rules*)
       
    18   val true_thm: thm
       
    19   val rewrite_true: thm
       
    20 
       
    21   (*properties*)
       
    22   val is_conj: term -> bool
       
    23   val is_disj: term -> bool
       
    24   val exists_lit: bool -> (term -> bool) -> term -> bool
       
    25   val negate: cterm -> cterm
       
    26 
       
    27   (*proof tools*)
       
    28   val explode: bool -> bool -> bool -> term list -> thm -> thm list
       
    29   val join: bool -> littab -> term -> thm
       
    30   val prove_conj_disj_eq: cterm -> thm
       
    31 end;
       
    32 
       
    33 structure Z3_Replay_Literals: Z3_REPLAY_LITERALS =
       
    34 struct
       
    35 
       
    36 (* literal table *)
       
    37 
       
    38 type littab = thm Termtab.table
       
    39 
       
    40 fun make_littab thms = fold (Termtab.update o `SMT_Util.prop_of) thms Termtab.empty
       
    41 
       
    42 fun insert_lit thm = Termtab.update (`SMT_Util.prop_of thm)
       
    43 fun delete_lit thm = Termtab.delete (SMT_Util.prop_of thm)
       
    44 fun lookup_lit lits = Termtab.lookup lits
       
    45 fun get_first_lit f =
       
    46   Termtab.get_first (fn (t, thm) => if f t then SOME thm else NONE)
       
    47 
       
    48 
       
    49 (* rules *)
       
    50 
       
    51 val true_thm = @{lemma "~False" by simp}
       
    52 val rewrite_true = @{lemma "True == ~ False" by simp}
       
    53 
       
    54 
       
    55 (* properties and term operations *)
       
    56 
       
    57 val is_neg = (fn @{const Not} $ _ => true | _ => false)
       
    58 fun is_neg' f = (fn @{const Not} $ t => f t | _ => false)
       
    59 val is_dneg = is_neg' is_neg
       
    60 val is_conj = (fn @{const HOL.conj} $ _ $ _ => true | _ => false)
       
    61 val is_disj = (fn @{const HOL.disj} $ _ $ _ => true | _ => false)
       
    62 
       
    63 fun dest_disj_term' f = (fn
       
    64     @{const Not} $ (@{const HOL.disj} $ t $ u) => SOME (f t, f u)
       
    65   | _ => NONE)
       
    66 
       
    67 val dest_conj_term = (fn @{const HOL.conj} $ t $ u => SOME (t, u) | _ => NONE)
       
    68 val dest_disj_term =
       
    69   dest_disj_term' (fn @{const Not} $ t => t | t => @{const Not} $ t)
       
    70 
       
    71 fun exists_lit is_conj P =
       
    72   let
       
    73     val dest = if is_conj then dest_conj_term else dest_disj_term
       
    74     fun exists t = P t orelse
       
    75       (case dest t of
       
    76         SOME (t1, t2) => exists t1 orelse exists t2
       
    77       | NONE => false)
       
    78   in exists end
       
    79 
       
    80 val negate = Thm.apply (Thm.cterm_of @{theory} @{const Not})
       
    81 
       
    82 
       
    83 (* proof tools *)
       
    84 
       
    85 (** explosion of conjunctions and disjunctions **)
       
    86 
       
    87 local
       
    88   val precomp = Z3_Replay_Util.precompose2
       
    89 
       
    90   fun destc ct = Thm.dest_binop (Thm.dest_arg ct)
       
    91   val dest_conj1 = precomp destc @{thm conjunct1}
       
    92   val dest_conj2 = precomp destc @{thm conjunct2}
       
    93   fun dest_conj_rules t =
       
    94     dest_conj_term t |> Option.map (K (dest_conj1, dest_conj2))
       
    95 
       
    96   fun destd f ct = f (Thm.dest_binop (Thm.dest_arg (Thm.dest_arg ct)))
       
    97   val dn1 = apfst Thm.dest_arg and dn2 = apsnd Thm.dest_arg
       
    98   val dest_disj1 = precomp (destd I) @{lemma "~(P | Q) ==> ~P" by fast}
       
    99   val dest_disj2 = precomp (destd dn1) @{lemma "~(~P | Q) ==> P" by fast}
       
   100   val dest_disj3 = precomp (destd I) @{lemma "~(P | Q) ==> ~Q" by fast}
       
   101   val dest_disj4 = precomp (destd dn2) @{lemma "~(P | ~Q) ==> Q" by fast}
       
   102 
       
   103   fun dest_disj_rules t =
       
   104     (case dest_disj_term' is_neg t of
       
   105       SOME (true, true) => SOME (dest_disj2, dest_disj4)
       
   106     | SOME (true, false) => SOME (dest_disj2, dest_disj3)
       
   107     | SOME (false, true) => SOME (dest_disj1, dest_disj4)
       
   108     | SOME (false, false) => SOME (dest_disj1, dest_disj3)
       
   109     | NONE => NONE)
       
   110 
       
   111   fun destn ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg ct))]
       
   112   val dneg_rule = Z3_Replay_Util.precompose destn @{thm notnotD}
       
   113 in
       
   114 
       
   115 (*
       
   116   explode a term into literals and collect all rules to be able to deduce
       
   117   particular literals afterwards
       
   118 *)
       
   119 fun explode_term is_conj =
       
   120   let
       
   121     val dest = if is_conj then dest_conj_term else dest_disj_term
       
   122     val dest_rules = if is_conj then dest_conj_rules else dest_disj_rules
       
   123 
       
   124     fun add (t, rs) = Termtab.map_default (t, rs)
       
   125       (fn rs' => if length rs' < length rs then rs' else rs)
       
   126 
       
   127     fun explode1 rules t =
       
   128       (case dest t of
       
   129         SOME (t1, t2) =>
       
   130           let val (rule1, rule2) = the (dest_rules t)
       
   131           in
       
   132             explode1 (rule1 :: rules) t1 #>
       
   133             explode1 (rule2 :: rules) t2 #>
       
   134             add (t, rev rules)
       
   135           end
       
   136       | NONE => add (t, rev rules))
       
   137 
       
   138     fun explode0 (@{const Not} $ (@{const Not} $ t)) =
       
   139           Termtab.make [(t, [dneg_rule])]
       
   140       | explode0 t = explode1 [] t Termtab.empty
       
   141 
       
   142   in explode0 end
       
   143 
       
   144 (*
       
   145   extract a literal by applying previously collected rules
       
   146 *)
       
   147 fun extract_lit thm rules = fold Z3_Replay_Util.compose rules thm
       
   148 
       
   149 
       
   150 (*
       
   151   explode a theorem into its literals
       
   152 *)
       
   153 fun explode is_conj full keep_intermediate stop_lits =
       
   154   let
       
   155     val dest_rules = if is_conj then dest_conj_rules else dest_disj_rules
       
   156     val tab = fold (Termtab.update o rpair ()) stop_lits Termtab.empty
       
   157 
       
   158     fun explode1 thm =
       
   159       if Termtab.defined tab (SMT_Util.prop_of thm) then cons thm
       
   160       else
       
   161         (case dest_rules (SMT_Util.prop_of thm) of
       
   162           SOME (rule1, rule2) =>
       
   163             explode2 rule1 thm #>
       
   164             explode2 rule2 thm #>
       
   165             keep_intermediate ? cons thm
       
   166         | NONE => cons thm)
       
   167 
       
   168     and explode2 dest_rule thm =
       
   169       if full orelse
       
   170         exists_lit is_conj (Termtab.defined tab) (SMT_Util.prop_of thm)
       
   171       then explode1 (Z3_Replay_Util.compose dest_rule thm)
       
   172       else cons (Z3_Replay_Util.compose dest_rule thm)
       
   173 
       
   174     fun explode0 thm =
       
   175       if not is_conj andalso is_dneg (SMT_Util.prop_of thm)
       
   176       then [Z3_Replay_Util.compose dneg_rule thm]
       
   177       else explode1 thm []
       
   178 
       
   179   in explode0 end
       
   180 
       
   181 end
       
   182 
       
   183 
       
   184 (** joining of literals to conjunctions or disjunctions **)
       
   185 
       
   186 local
       
   187   fun on_cprem i f thm = f (Thm.cprem_of thm i)
       
   188   fun on_cprop f thm = f (Thm.cprop_of thm)
       
   189   fun precomp2 f g thm = (on_cprem 1 f thm, on_cprem 2 g thm, f, g, thm)
       
   190   fun comp2 (cv1, cv2, f, g, rule) thm1 thm2 =
       
   191     Thm.instantiate ([], [(cv1, on_cprop f thm1), (cv2, on_cprop g thm2)]) rule
       
   192     |> Z3_Replay_Util.discharge thm1 |> Z3_Replay_Util.discharge thm2
       
   193 
       
   194   fun d1 ct = Thm.dest_arg ct and d2 ct = Thm.dest_arg (Thm.dest_arg ct)
       
   195 
       
   196   val conj_rule = precomp2 d1 d1 @{thm conjI}
       
   197   fun comp_conj ((_, thm1), (_, thm2)) = comp2 conj_rule thm1 thm2
       
   198 
       
   199   val disj1 = precomp2 d2 d2 @{lemma "~P ==> ~Q ==> ~(P | Q)" by fast}
       
   200   val disj2 = precomp2 d2 d1 @{lemma "~P ==> Q ==> ~(P | ~Q)" by fast}
       
   201   val disj3 = precomp2 d1 d2 @{lemma "P ==> ~Q ==> ~(~P | Q)" by fast}
       
   202   val disj4 = precomp2 d1 d1 @{lemma "P ==> Q ==> ~(~P | ~Q)" by fast}
       
   203 
       
   204   fun comp_disj ((false, thm1), (false, thm2)) = comp2 disj1 thm1 thm2
       
   205     | comp_disj ((false, thm1), (true, thm2)) = comp2 disj2 thm1 thm2
       
   206     | comp_disj ((true, thm1), (false, thm2)) = comp2 disj3 thm1 thm2
       
   207     | comp_disj ((true, thm1), (true, thm2)) = comp2 disj4 thm1 thm2
       
   208 
       
   209   fun dest_conj (@{const HOL.conj} $ t $ u) = ((false, t), (false, u))
       
   210     | dest_conj t = raise TERM ("dest_conj", [t])
       
   211 
       
   212   val neg = (fn @{const Not} $ t => (true, t) | t => (false, @{const Not} $ t))
       
   213   fun dest_disj (@{const Not} $ (@{const HOL.disj} $ t $ u)) = (neg t, neg u)
       
   214     | dest_disj t = raise TERM ("dest_disj", [t])
       
   215 
       
   216   val precomp = Z3_Replay_Util.precompose
       
   217   val dnegE = precomp (single o d2 o d1) @{thm notnotD}
       
   218   val dnegI = precomp (single o d1) @{lemma "P ==> ~~P" by fast}
       
   219   fun as_dneg f t = f (@{const Not} $ (@{const Not} $ t))
       
   220 
       
   221   val precomp2 = Z3_Replay_Util.precompose2
       
   222   fun dni f = apsnd f o Thm.dest_binop o f o d1
       
   223   val negIffE = precomp2 (dni d1) @{lemma "~(P = (~Q)) ==> Q = P" by fast}
       
   224   val negIffI = precomp2 (dni I) @{lemma "P = Q ==> ~(Q = (~P))" by fast}
       
   225   val iff_const = @{const HOL.eq (bool)}
       
   226   fun as_negIff f (@{const HOL.eq (bool)} $ t $ u) =
       
   227         f (@{const Not} $ (iff_const $ u $ (@{const Not} $ t)))
       
   228     | as_negIff _ _ = NONE
       
   229 in
       
   230 
       
   231 fun join is_conj littab t =
       
   232   let
       
   233     val comp = if is_conj then comp_conj else comp_disj
       
   234     val dest = if is_conj then dest_conj else dest_disj
       
   235 
       
   236     val lookup = lookup_lit littab
       
   237 
       
   238     fun lookup_rule t =
       
   239       (case t of
       
   240         @{const Not} $ (@{const Not} $ t) => (Z3_Replay_Util.compose dnegI, lookup t)
       
   241       | @{const Not} $ (@{const HOL.eq (bool)} $ t $ (@{const Not} $ u)) =>
       
   242           (Z3_Replay_Util.compose negIffI, lookup (iff_const $ u $ t))
       
   243       | @{const Not} $ ((eq as Const (@{const_name HOL.eq}, _)) $ t $ u) =>
       
   244           let fun rewr lit = lit COMP @{thm not_sym}
       
   245           in (rewr, lookup (@{const Not} $ (eq $ u $ t))) end
       
   246       | _ =>
       
   247           (case as_dneg lookup t of
       
   248             NONE => (Z3_Replay_Util.compose negIffE, as_negIff lookup t)
       
   249           | x => (Z3_Replay_Util.compose dnegE, x)))
       
   250 
       
   251     fun join1 (s, t) =
       
   252       (case lookup t of
       
   253         SOME lit => (s, lit)
       
   254       | NONE =>
       
   255           (case lookup_rule t of
       
   256             (rewrite, SOME lit) => (s, rewrite lit)
       
   257           | (_, NONE) => (s, comp (apply2 join1 (dest t)))))
       
   258 
       
   259   in snd (join1 (if is_conj then (false, t) else (true, t))) end
       
   260 
       
   261 end
       
   262 
       
   263 
       
   264 (** proving equality of conjunctions or disjunctions **)
       
   265 
       
   266 fun iff_intro thm1 thm2 = thm2 COMP (thm1 COMP @{thm iffI})
       
   267 
       
   268 local
       
   269   val cp1 = @{lemma "(~P) = (~Q) ==> P = Q" by simp}
       
   270   val cp2 = @{lemma "(~P) = Q ==> P = (~Q)" by fastforce}
       
   271   val cp3 = @{lemma "P = (~Q) ==> (~P) = Q" by simp}
       
   272 in
       
   273 fun contrapos1 prove (ct, cu) = prove (negate ct, negate cu) COMP cp1
       
   274 fun contrapos2 prove (ct, cu) = prove (negate ct, Thm.dest_arg cu) COMP cp2
       
   275 fun contrapos3 prove (ct, cu) = prove (Thm.dest_arg ct, negate cu) COMP cp3
       
   276 end
       
   277 
       
   278 local
       
   279   val contra_rule = @{lemma "P ==> ~P ==> False" by (rule notE)}
       
   280   fun contra_left conj thm =
       
   281     let
       
   282       val rules = explode_term conj (SMT_Util.prop_of thm)
       
   283       fun contra_lits (t, rs) =
       
   284         (case t of
       
   285           @{const Not} $ u => Termtab.lookup rules u |> Option.map (pair rs)
       
   286         | _ => NONE)
       
   287     in
       
   288       (case Termtab.lookup rules @{const False} of
       
   289         SOME rs => extract_lit thm rs
       
   290       | NONE =>
       
   291           the (Termtab.get_first contra_lits rules)
       
   292           |> apply2 (extract_lit thm)
       
   293           |> (fn (nlit, plit) => nlit COMP (plit COMP contra_rule)))
       
   294     end
       
   295 
       
   296   val falseE_v = Thm.dest_arg (Thm.dest_arg (Thm.cprop_of @{thm FalseE}))
       
   297   fun contra_right ct = Thm.instantiate ([], [(falseE_v, ct)]) @{thm FalseE}
       
   298 in
       
   299 
       
   300 fun contradict conj ct =
       
   301   iff_intro (Z3_Replay_Util.under_assumption (contra_left conj) ct) (contra_right ct)
       
   302 
       
   303 end
       
   304 
       
   305 local
       
   306   fun prove_eq l r (cl, cr) =
       
   307     let
       
   308       fun explode' is_conj = explode is_conj true (l <> r) []
       
   309       fun make_tab is_conj thm = make_littab (true_thm :: explode' is_conj thm)
       
   310       fun prove is_conj ct tab = join is_conj tab (Thm.term_of ct)
       
   311 
       
   312       val thm1 = Z3_Replay_Util.under_assumption (prove r cr o make_tab l) cl
       
   313       val thm2 = Z3_Replay_Util.under_assumption (prove l cl o make_tab r) cr
       
   314     in iff_intro thm1 thm2 end
       
   315 
       
   316   datatype conj_disj = CONJ | DISJ | NCON | NDIS
       
   317   fun kind_of t =
       
   318     if is_conj t then SOME CONJ
       
   319     else if is_disj t then SOME DISJ
       
   320     else if is_neg' is_conj t then SOME NCON
       
   321     else if is_neg' is_disj t then SOME NDIS
       
   322     else NONE
       
   323 in
       
   324 
       
   325 fun prove_conj_disj_eq ct =
       
   326   let val cp as (cl, cr) = Thm.dest_binop (Thm.dest_arg ct)
       
   327   in
       
   328     (case (kind_of (Thm.term_of cl), Thm.term_of cr) of
       
   329       (SOME CONJ, @{const False}) => contradict true cl
       
   330     | (SOME DISJ, @{const Not} $ @{const False}) =>
       
   331         contrapos2 (contradict false o fst) cp
       
   332     | (kl, _) =>
       
   333         (case (kl, kind_of (Thm.term_of cr)) of
       
   334           (SOME CONJ, SOME CONJ) => prove_eq true true cp
       
   335         | (SOME CONJ, SOME NDIS) => prove_eq true false cp
       
   336         | (SOME CONJ, _) => prove_eq true true cp
       
   337         | (SOME DISJ, SOME DISJ) => contrapos1 (prove_eq false false) cp
       
   338         | (SOME DISJ, SOME NCON) => contrapos2 (prove_eq false true) cp
       
   339         | (SOME DISJ, _) => contrapos1 (prove_eq false false) cp
       
   340         | (SOME NCON, SOME NCON) => contrapos1 (prove_eq true true) cp
       
   341         | (SOME NCON, SOME DISJ) => contrapos3 (prove_eq true false) cp
       
   342         | (SOME NCON, NONE) => contrapos3 (prove_eq true false) cp
       
   343         | (SOME NDIS, SOME NDIS) => prove_eq false false cp
       
   344         | (SOME NDIS, SOME CONJ) => prove_eq false true cp
       
   345         | (SOME NDIS, NONE) => prove_eq false true cp
       
   346         | _ => raise CTERM ("prove_conj_disj_eq", [ct])))
       
   347   end
       
   348 
       
   349 end
       
   350 
       
   351 end;