src/HOL/SMT/Tools/z3_proof_terms.ML
changeset 33010 39f73a59e855
child 33029 2fefe039edf1
equal deleted inserted replaced
33008:b0ff69f0a248 33010:39f73a59e855
       
     1 (*  Title:      HOL/SMT/Tools/z3_proof_terms.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Reconstruction functions for terms occurring in Z3 proofs.
       
     5 *)
       
     6 
       
     7 signature Z3_PROOF_TERMS =
       
     8 sig
       
     9   val mk_prop: Thm.cterm -> Thm.cterm
       
    10   val mk_meta_eq: Thm.cterm -> Thm.cterm -> Thm.cterm
       
    11 
       
    12   type preterm
       
    13 
       
    14   val compile: theory -> Name.context -> preterm -> Thm.cterm * Thm.cterm list
       
    15 
       
    16   val mk_bound: theory -> int -> typ -> preterm
       
    17   val mk_fun: Thm.cterm -> preterm list -> preterm
       
    18   val mk_forall: theory -> string * typ -> preterm -> preterm
       
    19   val mk_exists: theory -> string * typ -> preterm -> preterm
       
    20 
       
    21   val mk_true: preterm
       
    22   val mk_false: preterm
       
    23   val mk_not: preterm -> preterm
       
    24   val mk_and: preterm list -> preterm
       
    25   val mk_or: preterm list -> preterm
       
    26   val mk_implies: preterm -> preterm -> preterm
       
    27   val mk_iff: preterm -> preterm -> preterm
       
    28 
       
    29   val mk_eq: preterm -> preterm -> preterm
       
    30   val mk_if: preterm -> preterm -> preterm -> preterm
       
    31   val mk_distinct: preterm list -> preterm
       
    32 
       
    33   val mk_pat: preterm list -> preterm
       
    34   val mk_nopat: preterm list -> preterm
       
    35   val mk_trigger: preterm list -> preterm -> preterm
       
    36 
       
    37   val mk_access: preterm -> preterm -> preterm
       
    38   val mk_update: preterm -> preterm -> preterm -> preterm
       
    39 
       
    40   val mk_int_num: int -> preterm
       
    41   val mk_real_frac_num: int * int option -> preterm
       
    42   val mk_uminus: preterm -> preterm
       
    43   val mk_add: preterm -> preterm -> preterm
       
    44   val mk_sub: preterm -> preterm -> preterm
       
    45   val mk_mul: preterm -> preterm -> preterm
       
    46   val mk_int_div: preterm -> preterm -> preterm
       
    47   val mk_real_div: preterm -> preterm -> preterm
       
    48   val mk_rem: preterm -> preterm -> preterm
       
    49   val mk_mod: preterm -> preterm -> preterm
       
    50   val mk_lt: preterm -> preterm -> preterm
       
    51   val mk_le: preterm -> preterm -> preterm
       
    52 
       
    53   val wordT : int -> typ
       
    54   val mk_bv_num : theory -> int -> int -> preterm
       
    55 
       
    56   val var_prefix: string
       
    57 end
       
    58 
       
    59 structure Z3_Proof_Terms: Z3_PROOF_TERMS =
       
    60 struct
       
    61 
       
    62 fun instTs cUs (cTs, ct) = Thm.instantiate_cterm (cTs ~~ cUs, []) ct
       
    63 fun instT cU (cT, ct) = instTs [cU] ([cT], ct)
       
    64 fun mk_inst_pair destT cpat = (destT (Thm.ctyp_of_term cpat), cpat)
       
    65 val destT1 = hd o Thm.dest_ctyp
       
    66 val destT2 = hd o tl o Thm.dest_ctyp
       
    67 
       
    68 
       
    69 val mk_prop = Thm.capply @{cterm Trueprop}
       
    70 
       
    71 val meta_eq = mk_inst_pair destT1 @{cpat "op =="}
       
    72 fun mk_meta_eq ct = Thm.mk_binop (instT (Thm.ctyp_of_term ct) meta_eq) ct
       
    73 
       
    74 
       
    75 datatype preterm = Preterm of {
       
    76   cterm: Thm.cterm,
       
    77   vars: (int * Thm.cterm) list }
       
    78 
       
    79 fun mk_preterm (ct, vs) = Preterm {cterm=ct, vars=vs}
       
    80 fun dest_preterm (Preterm {cterm, vars}) = (cterm, vars)
       
    81 fun ctyp_of_preterm (Preterm {cterm, ...}) = Thm.ctyp_of_term cterm
       
    82 
       
    83 fun instT' e = instT (ctyp_of_preterm e)
       
    84 
       
    85 val maxidx_of = #maxidx o Thm.rep_cterm
       
    86 
       
    87 val var_prefix = "v"
       
    88 
       
    89 local
       
    90 fun mk_inst nctxt cert vs =
       
    91   let
       
    92     val max = fold (curry Int.max o fst) vs 0
       
    93     val names = fst (Name.variants (replicate (max + 1) var_prefix) nctxt)
       
    94     fun mk (i, v) = (v, cert (Free (nth names i, #T (Thm.rep_cterm v))))
       
    95   in map mk vs end
       
    96 
       
    97 fun fix_vars _ _ ct [] = (ct, [])
       
    98   | fix_vars thy nctxt ct vs =
       
    99       let
       
   100         val cert = Thm.cterm_of thy
       
   101         val inst = mk_inst nctxt cert vs
       
   102       in (Thm.instantiate_cterm ([], inst) ct, map snd inst) end
       
   103 in
       
   104 fun compile thy nctxt (Preterm {cterm, vars}) = fix_vars thy nctxt cterm vars
       
   105 end
       
   106 
       
   107 local
       
   108 fun app e (ct1, vs1) =
       
   109   let
       
   110     fun part (var as (i, v)) (inst, vs) =
       
   111       (case AList.lookup (op =) vs1 i of
       
   112         NONE => (inst, var :: vs)
       
   113       | SOME v' => ((v, v') :: inst, vs))
       
   114 
       
   115     val (ct2, vs2) = dest_preterm e
       
   116     val incr =
       
   117       if maxidx_of ct1 < 0 orelse maxidx_of ct2 < 0 then I
       
   118       else Thm.incr_indexes_cterm (maxidx_of ct1 + 1)
       
   119 
       
   120     val (inst, vs) = fold (part o apsnd incr) vs2 ([], vs1)
       
   121     val ct2' = Thm.instantiate_cterm ([], inst) (incr ct2)
       
   122   in (Thm.capply ct1 ct2', vs) end
       
   123 in
       
   124 fun mk_fun ct es = mk_preterm (fold app es (ct, []))
       
   125 fun mk_binop f t u = mk_fun f [t, u]
       
   126 fun mk_nary _ e [] = e
       
   127   | mk_nary ct _ es = uncurry (fold_rev (mk_binop ct)) (split_last es)
       
   128 end
       
   129 
       
   130 fun mk_bound thy i T =
       
   131   let val ct = Thm.cterm_of thy (Var ((Name.uu, 0), T))
       
   132   in mk_preterm (ct, [(i, ct)]) end
       
   133 
       
   134 local
       
   135 fun mk_quant q thy (n, T) e =
       
   136   let
       
   137     val (ct, vs) = dest_preterm e
       
   138     val cv =
       
   139       (case AList.lookup (op =) vs 0 of
       
   140         SOME cv => cv
       
   141       | _ => Thm.cterm_of thy (Var ((Name.uu, maxidx_of ct + 1), T)))
       
   142     val cq = instT (Thm.ctyp_of_term cv) q
       
   143     fun dec (i, v) = if i = 0 then NONE else SOME (i - 1, v)
       
   144   in mk_preterm (Thm.capply cq (Thm.cabs cv ct), map_filter dec vs) end
       
   145 in
       
   146 val mk_forall = mk_quant (mk_inst_pair (destT1 o destT1) @{cpat All})
       
   147 val mk_exists = mk_quant (mk_inst_pair (destT1 o destT1) @{cpat Ex})
       
   148 end
       
   149 
       
   150 
       
   151 val mk_false = mk_fun @{cterm False} []
       
   152 val mk_not = mk_fun @{cterm Not} o single
       
   153 val mk_true = mk_not mk_false
       
   154 val mk_and = mk_nary @{cterm "op &"} mk_true
       
   155 val mk_or = mk_nary @{cterm "op |"} mk_false
       
   156 val mk_implies = mk_binop @{cterm "op -->"}
       
   157 val mk_iff = mk_binop @{cterm "op = :: bool => _"}
       
   158 
       
   159 val eq = mk_inst_pair destT1 @{cpat "op ="}
       
   160 fun mk_eq t u = mk_binop (instT' t eq) t u
       
   161 
       
   162 val if_term = mk_inst_pair (destT1 o destT2) @{cpat If}
       
   163 fun mk_if c t u = mk_fun (instT' t if_term) [c, t, u]
       
   164 
       
   165 val nil_term = mk_inst_pair destT1 @{cpat Nil}
       
   166 val cons_term = mk_inst_pair destT1 @{cpat Cons}
       
   167 fun mk_list cT es =
       
   168   fold_rev (mk_binop (instT cT cons_term)) es (mk_fun (instT cT nil_term) [])
       
   169 
       
   170 val distinct = mk_inst_pair (destT1 o destT1) @{cpat distinct}
       
   171 fun mk_distinct [] = mk_true
       
   172   | mk_distinct (es as (e :: _)) =
       
   173       mk_fun (instT' e distinct) [mk_list (ctyp_of_preterm e) es]
       
   174 
       
   175 val pat = mk_inst_pair destT1 @{cpat pat}
       
   176 val nopat = mk_inst_pair destT1 @{cpat nopat}
       
   177 val andpat = mk_inst_pair (destT1 o destT2) @{cpat "op andpat"}
       
   178 fun mk_gen_pat _ [] = raise TERM ("mk_gen_pat: empty pattern", [])
       
   179   | mk_gen_pat pat (e :: es) =
       
   180       let fun mk t p = mk_fun (instT' t andpat) [p, t]
       
   181       in fold mk es (mk_fun (instT' e pat) [e]) end
       
   182 val mk_pat = mk_gen_pat pat
       
   183 val mk_nopat = mk_gen_pat nopat
       
   184 
       
   185 fun mk_trigger es e = mk_fun @{cterm trigger} [mk_list @{ctyp pattern} es, e]
       
   186 
       
   187 
       
   188 val access = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat apply}
       
   189 fun mk_access array index =
       
   190   let val cTs = Thm.dest_ctyp (ctyp_of_preterm array)
       
   191   in mk_fun (instTs cTs access) [array, index] end
       
   192 
       
   193 val update = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_upd}
       
   194 fun mk_update array index value =
       
   195   let val cTs = Thm.dest_ctyp (ctyp_of_preterm array)
       
   196   in mk_fun (instTs cTs update) [array, index, value] end
       
   197 
       
   198 
       
   199 fun mk_int_num i = mk_fun (Numeral.mk_cnumber @{ctyp int} i) []
       
   200 fun mk_real_num i = mk_fun (Numeral.mk_cnumber @{ctyp real} i) []
       
   201 
       
   202 fun mk_real_frac_num (e, NONE) = mk_real_num e
       
   203   | mk_real_frac_num (e, SOME d) =
       
   204       mk_binop @{cterm "op / :: real => _"} (mk_real_num e) (mk_real_num d)
       
   205 
       
   206 fun has_int_type e = (Thm.typ_of (ctyp_of_preterm e) = @{typ int})
       
   207 fun choose e i r = if has_int_type e then i else r
       
   208 
       
   209 val uminus_i = @{cterm "uminus :: int => _"}
       
   210 val uminus_r = @{cterm "uminus :: real => _"}
       
   211 fun mk_uminus e = mk_fun (choose e uminus_i uminus_r) [e]
       
   212 
       
   213 fun arith_op int_op real_op t u = mk_binop (choose t int_op real_op) t u
       
   214 
       
   215 val mk_add = arith_op @{cterm "op + :: int => _"} @{cterm "op + :: real => _"}
       
   216 val mk_sub = arith_op @{cterm "op - :: int => _"} @{cterm "op - :: real => _"}
       
   217 val mk_mul = arith_op @{cterm "op * :: int => _"} @{cterm "op * :: real => _"}
       
   218 val mk_int_div = mk_binop @{cterm "op div :: int => _"}
       
   219 val mk_real_div = mk_binop @{cterm "op / :: real => _"}
       
   220 val mk_rem = mk_binop @{cterm "op rem :: int => _"}
       
   221 val mk_mod = mk_binop @{cterm "op mod :: int => _"}
       
   222 val mk_lt = arith_op @{cterm "op < :: int => _"} @{cterm "op < :: real => _"}
       
   223 val mk_le = arith_op @{cterm "op <= :: int => _"} @{cterm "op <= :: real => _"}
       
   224 
       
   225 fun binT size =
       
   226   let
       
   227     fun bitT i T =
       
   228       if i = 0
       
   229       then Type (@{type_name "Numeral_Type.bit0"}, [T])
       
   230       else Type (@{type_name "Numeral_Type.bit1"}, [T])
       
   231 
       
   232     fun binT i =
       
   233       if i = 0 then @{typ "Numeral_Type.num0"}
       
   234       else if i = 1 then @{typ "Numeral_Type.num1"}
       
   235       else let val (q, r) = Integer.div_mod i 2 in bitT r (binT q) end
       
   236   in
       
   237     if size >= 0 then binT size
       
   238     else raise TYPE ("mk_binT: " ^ string_of_int size, [], [])
       
   239   end
       
   240 
       
   241 fun wordT size = Type (@{type_name "word"}, [binT size])
       
   242 
       
   243 fun mk_bv_num thy num size =
       
   244   mk_fun (Numeral.mk_cnumber (Thm.ctyp_of thy (wordT size)) num) []
       
   245 
       
   246 end