added prove reconstruction for injective functions;
added SMT_Utils to collect frequently used functions
--- a/src/HOL/IsaMakefile Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/IsaMakefile Mon Nov 22 15:45:42 2010 +0100
@@ -349,9 +349,11 @@
Tools/SMT/smt_setup_solvers.ML \
Tools/SMT/smt_solver.ML \
Tools/SMT/smt_translate.ML \
+ Tools/SMT/smt_utils.ML \
Tools/SMT/z3_interface.ML \
Tools/SMT/z3_model.ML \
Tools/SMT/z3_proof_literals.ML \
+ Tools/SMT/z3_proof_methods.ML \
Tools/SMT/z3_proof_parser.ML \
Tools/SMT/z3_proof_reconstruction.ML \
Tools/SMT/z3_proof_tools.ML \
--- a/src/HOL/SMT.thy Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/SMT.thy Mon Nov 22 15:45:42 2010 +0100
@@ -10,6 +10,7 @@
"Tools/Datatype/datatype_selectors.ML"
"Tools/SMT/smt_failure.ML"
"Tools/SMT/smt_config.ML"
+ "Tools/SMT/smt_utils.ML"
"Tools/SMT/smt_monomorph.ML"
("Tools/SMT/smt_builtin.ML")
("Tools/SMT/smt_normalize.ML")
@@ -19,6 +20,7 @@
("Tools/SMT/z3_proof_parser.ML")
("Tools/SMT/z3_proof_tools.ML")
("Tools/SMT/z3_proof_literals.ML")
+ ("Tools/SMT/z3_proof_methods.ML")
("Tools/SMT/z3_proof_reconstruction.ML")
("Tools/SMT/z3_model.ML")
("Tools/SMT/z3_interface.ML")
@@ -137,6 +139,7 @@
use "Tools/SMT/z3_proof_parser.ML"
use "Tools/SMT/z3_proof_tools.ML"
use "Tools/SMT/z3_proof_literals.ML"
+use "Tools/SMT/z3_proof_methods.ML"
use "Tools/SMT/z3_proof_reconstruction.ML"
use "Tools/SMT/z3_model.ML"
use "Tools/SMT/smt_setup_solvers.ML"
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/SMT/smt_utils.ML Mon Nov 22 15:45:42 2010 +0100
@@ -0,0 +1,121 @@
+(* Title: HOL/Tools/SMT/smt_utils.ML
+ Author: Sascha Boehme, TU Muenchen
+
+General utility functions.
+*)
+
+signature SMT_UTILS =
+sig
+ val repeat: ('a -> 'a option) -> 'a -> 'a
+ val repeat_yield: ('a -> 'b -> ('a * 'b) option) -> 'a -> 'b -> 'a * 'b
+
+ (* terms *)
+ val dest_conj: term -> term * term
+ val dest_disj: term -> term * term
+
+ (* patterns and instantiations *)
+ val mk_const_pat: theory -> string -> (ctyp -> 'a) -> 'a * cterm
+ val destT1: ctyp -> ctyp
+ val destT2: ctyp -> ctyp
+ val instTs: ctyp list -> ctyp list * cterm -> cterm
+ val instT: ctyp -> ctyp * cterm -> cterm
+ val instT': cterm -> ctyp * cterm -> cterm
+
+ (* certified terms *)
+ val certify: Proof.context -> term -> cterm
+ val dest_cabs: cterm -> Proof.context -> cterm * Proof.context
+ val dest_all_cabs: cterm -> Proof.context -> cterm * Proof.context
+ val dest_cbinder: cterm -> Proof.context -> cterm * Proof.context
+ val dest_all_cbinders: cterm -> Proof.context -> cterm * Proof.context
+ val mk_cprop: cterm -> cterm
+ val dest_cprop: cterm -> cterm
+ val mk_cequals: cterm -> cterm -> cterm
+
+ (* conversions *)
+ val if_conv: (term -> bool) -> conv -> conv -> conv
+ val if_true_conv: (term -> bool) -> conv -> conv
+ val binders_conv: (Proof.context -> conv) -> Proof.context -> conv
+ val prop_conv: conv -> conv
+end
+
+structure SMT_Utils: SMT_UTILS =
+struct
+
+fun repeat f =
+ let fun rep x = (case f x of SOME y => rep y | NONE => x)
+ in rep end
+
+fun repeat_yield f =
+ let fun rep x y = (case f x y of SOME (x', y') => rep x' y' | NONE => (x, y))
+ in rep end
+
+
+(* terms *)
+
+fun dest_conj (@{const HOL.conj} $ t $ u) = (t, u)
+ | dest_conj t = raise TERM ("not a conjunction", [t])
+
+fun dest_disj (@{const HOL.disj} $ t $ u) = (t, u)
+ | dest_disj t = raise TERM ("not a disjunction", [t])
+
+
+(* patterns and instantiations *)
+
+fun mk_const_pat thy name destT =
+ let val cpat = Thm.cterm_of thy (Const (name, Sign.the_const_type thy name))
+ in (destT (Thm.ctyp_of_term cpat), cpat) end
+
+val destT1 = hd o Thm.dest_ctyp
+val destT2 = hd o tl o Thm.dest_ctyp
+
+fun instTs cUs (cTs, ct) = Thm.instantiate_cterm (cTs ~~ cUs, []) ct
+fun instT cU (cT, ct) = instTs [cU] ([cT], ct)
+fun instT' ct = instT (Thm.ctyp_of_term ct)
+
+
+(* certified terms *)
+
+fun certify ctxt = Thm.cterm_of (ProofContext.theory_of ctxt)
+
+fun dest_cabs ct ctxt =
+ (case Thm.term_of ct of
+ Abs _ =>
+ let val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
+ in (snd (Thm.dest_abs (SOME n) ct), ctxt') end
+ | _ => raise CTERM ("no abstraction", [ct]))
+
+val dest_all_cabs = repeat_yield (try o dest_cabs)
+
+fun dest_cbinder ct ctxt =
+ (case Thm.term_of ct of
+ Const _ $ Abs _ => dest_cabs (Thm.dest_arg ct) ctxt
+ | _ => raise CTERM ("not a binder", [ct]))
+
+val dest_all_cbinders = repeat_yield (try o dest_cbinder)
+
+val mk_cprop = Thm.capply @{cterm Trueprop}
+
+fun dest_cprop ct =
+ (case Thm.term_of ct of
+ @{const Trueprop} $ _ => Thm.dest_arg ct
+ | _ => raise CTERM ("not a property", [ct]))
+
+val equals = mk_const_pat @{theory} @{const_name "=="} destT1
+fun mk_cequals ct cu = Thm.mk_binop (instT' ct equals) ct cu
+
+
+(* conversions *)
+
+fun if_conv f cv1 cv2 ct = if f (Thm.term_of ct) then cv1 ct else cv2 ct
+
+fun if_true_conv f cv = if_conv f cv Conv.all_conv
+
+fun binders_conv cv ctxt =
+ Conv.binder_conv (binders_conv cv o snd) ctxt else_conv cv ctxt
+
+fun prop_conv cv ct =
+ (case Thm.term_of ct of
+ @{const Trueprop} $ _ => Conv.arg_conv cv ct
+ | _ => raise CTERM ("not a property", [ct]))
+
+end
--- a/src/HOL/Tools/SMT/z3_interface.ML Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_interface.ML Mon Nov 22 15:45:42 2010 +0100
@@ -21,16 +21,13 @@
val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
val is_builtin_theory_term: Proof.context -> term -> bool
-
- val mk_inst_pair: (ctyp -> 'a) -> cterm -> 'a * cterm
- val destT1: ctyp -> ctyp
- val destT2: ctyp -> ctyp
- val instT': cterm -> ctyp * cterm -> cterm
end
structure Z3_Interface: Z3_INTERFACE =
struct
+structure U = SMT_Utils
+
(** Z3-specific builtins **)
@@ -163,13 +160,6 @@
| mk_builtin_num ctxt i T =
chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
-fun instTs cUs (cTs, ct) = Thm.instantiate_cterm (cTs ~~ cUs, []) ct
-fun instT cU (cT, ct) = instTs [cU] ([cT], ct)
-fun instT' ct = instT (Thm.ctyp_of_term ct)
-fun mk_inst_pair destT cpat = (destT (Thm.ctyp_of_term cpat), cpat)
-val destT1 = hd o Thm.dest_ctyp
-val destT2 = hd o tl o Thm.dest_ctyp
-
val mk_true = Thm.cterm_of @{theory} (@{const Not} $ @{const False})
val mk_false = Thm.cterm_of @{theory} @{const False}
val mk_not = Thm.capply (Thm.cterm_of @{theory} @{const Not})
@@ -181,31 +171,34 @@
fun mk_nary _ cu [] = cu
| mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
-val eq = mk_inst_pair destT1 @{cpat HOL.eq}
-fun mk_eq ct cu = Thm.mk_binop (instT' ct eq) ct cu
+val eq = U.mk_const_pat @{theory} @{const_name HOL.eq} U.destT1
+fun mk_eq ct cu = Thm.mk_binop (U.instT' ct eq) ct cu
-val if_term = mk_inst_pair (destT1 o destT2) @{cpat If}
-fun mk_if cc ct cu = Thm.mk_binop (Thm.capply (instT' ct if_term) cc) ct cu
+val if_term = U.mk_const_pat @{theory} @{const_name If} (U.destT1 o U.destT2)
+fun mk_if cc ct cu = Thm.mk_binop (Thm.capply (U.instT' ct if_term) cc) ct cu
-val nil_term = mk_inst_pair destT1 @{cpat Nil}
-val cons_term = mk_inst_pair destT1 @{cpat Cons}
+val nil_term = U.mk_const_pat @{theory} @{const_name Nil} U.destT1
+val cons_term = U.mk_const_pat @{theory} @{const_name Cons} U.destT1
fun mk_list cT cts =
- fold_rev (Thm.mk_binop (instT cT cons_term)) cts (instT cT nil_term)
+ fold_rev (Thm.mk_binop (U.instT cT cons_term)) cts (U.instT cT nil_term)
-val distinct = mk_inst_pair (destT1 o destT1) @{cpat SMT.distinct}
+val distinct = U.mk_const_pat @{theory} @{const_name SMT.distinct}
+ (U.destT1 o U.destT1)
fun mk_distinct [] = mk_true
| mk_distinct (cts as (ct :: _)) =
- Thm.capply (instT' ct distinct) (mk_list (Thm.ctyp_of_term ct) cts)
+ Thm.capply (U.instT' ct distinct) (mk_list (Thm.ctyp_of_term ct) cts)
-val access = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_app}
+val access = U.mk_const_pat @{theory} @{const_name fun_app}
+ (Thm.dest_ctyp o U.destT1)
fun mk_access array index =
let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
- in Thm.mk_binop (instTs cTs access) array index end
+ in Thm.mk_binop (U.instTs cTs access) array index end
-val update = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_upd}
+val update = U.mk_const_pat @{theory} @{const_name fun_upd}
+ (Thm.dest_ctyp o U.destT1)
fun mk_update array index value =
let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
- in Thm.capply (Thm.mk_binop (instTs cTs update) array index) value end
+ in Thm.capply (Thm.mk_binop (U.instTs cTs update) array index) value end
val mk_uminus = Thm.capply (Thm.cterm_of @{theory} @{const uminus (int)})
val mk_add = Thm.mk_binop (Thm.cterm_of @{theory} @{const plus (int)})
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/SMT/z3_proof_methods.ML Mon Nov 22 15:45:42 2010 +0100
@@ -0,0 +1,137 @@
+(* Title: HOL/Tools/SMT/z3_proof_methods.ML
+ Author: Sascha Boehme, TU Muenchen
+
+Proof methods for Z3 proof reconstruction.
+*)
+
+signature Z3_PROOF_METHODS =
+sig
+ val prove_injectivity: Proof.context -> cterm -> thm
+end
+
+structure Z3_Proof_Methods: Z3_PROOF_METHODS =
+struct
+
+structure U = SMT_Utils
+
+
+fun apply tac st =
+ (case Seq.pull (tac 1 st) of
+ NONE => raise THM ("tactic failed", 1, [st])
+ | SOME (st', _) => st')
+
+
+
+(* injectivity *)
+
+local
+
+val B = @{typ bool}
+fun mk_univ T = Const (@{const_name top}, T --> B)
+fun mk_inj_on T U =
+ Const (@{const_name inj_on}, (T --> U) --> (T --> B) --> B)
+fun mk_inv_into T U =
+ Const (@{const_name inv_into}, [T --> B, T --> U, U] ---> T)
+
+fun mk_inv_of ctxt ct =
+ let
+ val T = #T (Thm.rep_cterm ct)
+ val dT = Term.domain_type T
+ val inv = U.certify ctxt (mk_inv_into dT (Term.range_type T))
+ val univ = U.certify ctxt (mk_univ dT)
+ in Thm.mk_binop inv univ ct end
+
+fun mk_inj_prop ctxt ct =
+ let
+ val T = #T (Thm.rep_cterm ct)
+ val dT = Term.domain_type T
+ val inj = U.certify ctxt (mk_inj_on dT (Term.range_type T))
+ val univ = U.certify ctxt (mk_univ dT)
+ in U.mk_cprop (Thm.mk_binop inj ct univ) end
+
+
+val disjE = @{lemma "~P | Q ==> P ==> Q" by fast}
+
+fun prove_inj_prop ctxt hdef lhs =
+ let
+ val (ct, ctxt') = U.dest_all_cabs (Thm.rhs_of hdef) ctxt
+ val rule = disjE OF [Object_Logic.rulify (Thm.assume lhs)]
+ in
+ Goal.init (mk_inj_prop ctxt' (Thm.dest_arg ct))
+ |> apply (Tactic.rtac @{thm injI})
+ |> apply (Tactic.solve_tac [rule, rule RS @{thm sym}])
+ |> Goal.norm_result o Goal.finish ctxt'
+ |> singleton (Variable.export ctxt' ctxt)
+ end
+
+fun prove_rhs ctxt hdef lhs rhs =
+ Goal.init rhs
+ |> apply (CONVERSION (Conv.top_sweep_conv (K (Conv.rewr_conv hdef)) ctxt))
+ |> apply (REPEAT_ALL_NEW (Tactic.match_tac @{thms allI}))
+ |> apply (Tactic.rtac (@{thm inv_f_f} OF [prove_inj_prop ctxt hdef lhs]))
+ |> Goal.norm_result o Goal.finish ctxt
+
+
+fun expand thm ct =
+ let
+ val cpat = Thm.dest_arg (Thm.rhs_of thm)
+ val (cl, cr) = Thm.dest_binop (Thm.dest_arg (Thm.dest_arg1 ct))
+ val thm1 = Thm.instantiate (Thm.match (cpat, cl)) thm
+ val thm2 = Thm.instantiate (Thm.match (cpat, cr)) thm
+ in Conv.arg_conv (Conv.binop_conv (Conv.rewrs_conv [thm1, thm2])) ct end
+
+fun prove_lhs ctxt rhs lhs =
+ let
+ val eq = Thm.symmetric (mk_meta_eq (Object_Logic.rulify (Thm.assume rhs)))
+ in
+ Goal.init lhs
+ |> apply (CONVERSION (U.prop_conv (U.binders_conv (K (expand eq)) ctxt)))
+ |> apply (Simplifier.simp_tac HOL_ss)
+ |> Goal.finish ctxt
+ end
+
+
+fun mk_hdef ctxt rhs =
+ let
+ val (ct, ctxt') = U.dest_all_cbinders (U.dest_cprop rhs) ctxt
+ val (cl, cv) = Thm.dest_binop ct
+ val (cg, (cargs, cf)) = Drule.strip_comb cl ||> split_last
+ val cu = fold_rev Thm.cabs cargs (mk_inv_of ctxt' (Thm.cabs cv cf))
+ in Thm.assume (U.mk_cequals cg cu) end
+
+fun prove_inj_eq ctxt ct =
+ let
+ val (lhs, rhs) = pairself U.mk_cprop (Thm.dest_binop (U.dest_cprop ct))
+ val hdef = mk_hdef ctxt rhs
+ val lhs_thm = Thm.implies_intr rhs (prove_lhs ctxt rhs lhs)
+ val rhs_thm = Thm.implies_intr lhs (prove_rhs ctxt hdef lhs rhs)
+ in lhs_thm COMP (rhs_thm COMP @{thm iffI}) end
+
+
+val swap_eq_thm = mk_meta_eq @{thm eq_commute}
+val swap_disj_thm = mk_meta_eq @{thm disj_commute}
+
+fun swap_conv dest eq =
+ U.if_true_conv ((op <) o pairself Term.size_of_term o dest)
+ (Conv.rewr_conv eq)
+
+val swap_eq_conv = swap_conv HOLogic.dest_eq swap_eq_thm
+val swap_disj_conv = swap_conv U.dest_disj swap_disj_thm
+
+fun norm_conv ctxt =
+ swap_eq_conv then_conv
+ Conv.arg1_conv (U.binders_conv (K swap_disj_conv) ctxt) then_conv
+ Conv.arg_conv (U.binders_conv (K swap_eq_conv) ctxt)
+
+in
+
+fun prove_injectivity ctxt ct =
+ ct
+ |> Goal.init
+ |> apply (CONVERSION (U.prop_conv (norm_conv ctxt)))
+ |> apply (CSUBGOAL (uncurry (Tactic.rtac o prove_inj_eq ctxt)))
+ |> Goal.norm_result o Goal.finish ctxt
+
+end
+
+end
--- a/src/HOL/Tools/SMT/z3_proof_parser.ML Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_proof_parser.ML Mon Nov 22 15:45:42 2010 +0100
@@ -29,6 +29,7 @@
structure Z3_Proof_Parser: Z3_PROOF_PARSER =
struct
+structure U = SMT_Utils
structure I = Z3_Interface
@@ -134,10 +135,11 @@
SOME cv => cv
| _ => Thm.cterm_of thy (Var ((Name.uu, maxidx_of ct + 1), T)))
fun dec (i, v) = if i = 0 then NONE else SOME (i-1, v)
- in (Thm.capply (I.instT' cv q) (Thm.cabs cv ct), map_filter dec vars) end
+ in (Thm.capply (U.instT' cv q) (Thm.cabs cv ct), map_filter dec vars) end
- val forall = I.mk_inst_pair (I.destT1 o I.destT1) @{cpat All}
- val exists = I.mk_inst_pair (I.destT1 o I.destT1) @{cpat Ex}
+ fun quant name = U.mk_const_pat @{theory} name (U.destT1 o U.destT1)
+ val forall = quant @{const_name All}
+ val exists = quant @{const_name Ex}
in
fun mk_forall thy = fold_rev (mk_quant thy forall)
fun mk_exists thy = fold_rev (mk_quant thy exists)
--- a/src/HOL/Tools/SMT/z3_proof_reconstruction.ML Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_proof_reconstruction.ML Mon Nov 22 15:45:42 2010 +0100
@@ -18,6 +18,7 @@
structure P = Z3_Proof_Parser
structure T = Z3_Proof_Tools
structure L = Z3_Proof_Literals
+structure M = Z3_Proof_Methods
fun z3_exn msg = raise SMT_Failure.SMT (SMT_Failure.Other_Failure
("Z3 proof reconstruction: " ^ msg))
@@ -684,7 +685,7 @@
val prove_conj_disj_eq = T.with_conv unfold_conv L.prove_conj_disj_eq
in
-fun rewrite ctxt simpset ths = Thm o with_conv ctxt ths (try_apply ctxt [] [
+fun rewrite' ctxt simpset ths = Thm o with_conv ctxt ths (try_apply ctxt [] [
named ctxt "conj/disj/distinct" prove_conj_disj_eq,
T.by_abstraction (true, false) ctxt [] (fn ctxt' => T.by_tac (
NAMED ctxt' "simp (logic)" (Simplifier.simp_tac simpset)
@@ -698,7 +699,17 @@
NAMED ctxt' "simp (full)" (Simplifier.simp_tac simpset)
THEN_ALL_NEW (
NAMED ctxt' "fast (full)" (Classical.fast_tac HOL_cs)
- ORELSE' NAMED ctxt' "arith (full)" (Arith_Data.arith_tac ctxt'))))])
+ ORELSE' NAMED ctxt' "arith (full)" (Arith_Data.arith_tac ctxt')))),
+ named ctxt "injectivity" (M.prove_injectivity ctxt)])
+
+fun rewrite simpset thms ct ctxt = (* FIXME: join with rewrite' *)
+ let
+ val thm = rewrite' ctxt simpset thms ct
+ val ord = Term_Ord.fast_term_ord o pairself Thm.term_of
+ val chyps = fold (Ord_List.union ord o #hyps o Thm.crep_thm o thm_of) thms []
+ val new_chyps = Ord_List.subtract ord chyps (#hyps (Thm.crep_thm (thm_of thm)))
+ val (_, ctxt') = Assumption.add_assumes new_chyps ctxt
+ in (thm, ctxt') end
end
@@ -789,9 +800,8 @@
(* theory rules *)
| (P.ThLemma _, _) => (* FIXME: use arguments *)
(th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
- | (P.Rewrite, _) => (rewrite cx simpset [] ct, cxp)
- | (P.RewriteStar, ps) =>
- (rewrite cx simpset (map fst ps) ct, cxp)
+ | (P.Rewrite, _) => rewrite simpset [] ct cx ||> rpair ptab
+ | (P.RewriteStar, ps) => rewrite simpset (map fst ps) ct cx ||> rpair ptab
| (P.NnfStar, _) => not_supported r
| (P.CnfStar, _) => not_supported r
--- a/src/HOL/Tools/SMT/z3_proof_tools.ML Mon Nov 22 14:27:42 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_proof_tools.ML Mon Nov 22 15:45:42 2010 +0100
@@ -47,6 +47,7 @@
structure Z3_Proof_Tools: Z3_PROOF_TOOLS =
struct
+structure U = SMT_Utils
structure I = Z3_Interface
@@ -60,10 +61,7 @@
val mk_prop = Thm.capply (Thm.cterm_of @{theory} @{const Trueprop})
-val eq = I.mk_inst_pair I.destT1 @{cpat "op =="}
-fun mk_meta_eq_cterm ct cu = Thm.mk_binop (I.instT' ct eq) ct cu
-
-fun as_meta_eq ct = uncurry mk_meta_eq_cterm (Thm.dest_binop (Thm.dest_arg ct))
+fun as_meta_eq ct = uncurry U.mk_cequals (Thm.dest_binop (Thm.dest_arg ct))
@@ -112,7 +110,7 @@
let
val (lhs, rhs) = Thm.dest_binop (Thm.cprem_of thm 1)
val (cf, cvs) = Drule.strip_comb lhs
- val eq = mk_meta_eq_cterm cf (fold_rev Thm.cabs cvs rhs)
+ val eq = U.mk_cequals cf (fold_rev Thm.cabs cvs rhs)
fun apply cv th =
Thm.combination th (Thm.reflexive cv)
|> Conv.fconv_rule (Conv.arg_conv (Thm.beta_conversion false))