re-implemented support for datatypes (including records and typedefs);
added test cases for datatypes, records and typedefs
--- a/src/HOL/IsaMakefile Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/IsaMakefile Mon Jan 03 16:22:08 2011 +0100
@@ -297,7 +297,6 @@
Tools/ATP/atp_systems.ML \
Tools/choice_specification.ML \
Tools/code_evaluation.ML \
- Tools/Datatype/datatype_selectors.ML \
Tools/int_arith.ML \
Tools/groebner.ML \
Tools/list_code.ML \
@@ -356,6 +355,7 @@
Tools/SMT/smtlib_interface.ML \
Tools/SMT/smt_builtin.ML \
Tools/SMT/smt_config.ML \
+ Tools/SMT/smt_datatypes.ML \
Tools/SMT/smt_failure.ML \
Tools/SMT/smt_monomorph.ML \
Tools/SMT/smt_normalize.ML \
--- a/src/HOL/SMT.thy Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/SMT.thy Mon Jan 03 16:22:08 2011 +0100
@@ -5,14 +5,14 @@
header {* Bindings to Satisfiability Modulo Theories (SMT) solvers *}
theory SMT
-imports List
+imports Record
uses
- "Tools/Datatype/datatype_selectors.ML"
"Tools/SMT/smt_utils.ML"
"Tools/SMT/smt_failure.ML"
"Tools/SMT/smt_config.ML"
("Tools/SMT/smt_monomorph.ML")
("Tools/SMT/smt_builtin.ML")
+ ("Tools/SMT/smt_datatypes.ML")
("Tools/SMT/smt_normalize.ML")
("Tools/SMT/smt_translate.ML")
("Tools/SMT/smt_solver.ML")
@@ -123,7 +123,6 @@
-
subsection {* Integer division and modulo for Z3 *}
definition z3div :: "int \<Rightarrow> int \<Rightarrow> int" where
@@ -138,6 +137,7 @@
use "Tools/SMT/smt_monomorph.ML"
use "Tools/SMT/smt_builtin.ML"
+use "Tools/SMT/smt_datatypes.ML"
use "Tools/SMT/smt_normalize.ML"
use "Tools/SMT/smt_translate.ML"
use "Tools/SMT/smt_solver.ML"
@@ -380,13 +380,4 @@
hide_const Pattern fun_app term_true term_false z3div z3mod
hide_const (open) trigger pat nopat weight
-
-
-subsection {* Selectors for datatypes *}
-
-setup {* Datatype_Selectors.setup *}
-
-declare [[ selector Pair 1 = fst, selector Pair 2 = snd ]]
-declare [[ selector Cons 1 = hd, selector Cons 2 = tl ]]
-
end
--- a/src/HOL/SMT_Examples/SMT_Tests.thy Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/SMT_Examples/SMT_Tests.thy Mon Jan 03 16:22:08 2011 +0100
@@ -607,7 +607,11 @@
-section {* Pairs *} (* FIXME: tests for datatypes and records *)
+section {* Datatypes, Records, and Typedefs *}
+
+subsection {* Without support by the SMT solver *}
+
+subsubsection {* Algebraic datatypes *}
lemma
"x = fst (x, y)"
@@ -625,6 +629,252 @@
using fst_conv snd_conv pair_collapse
by smt+
+lemma
+ "[x] \<noteq> Nil"
+ "[x, y] \<noteq> Nil"
+ "x \<noteq> y \<longrightarrow> [x] \<noteq> [y]"
+ "hd (x # xs) = x"
+ "tl (x # xs) = xs"
+ "hd [x, y, z] = x"
+ "tl [x, y, z] = [y, z]"
+ "hd (tl [x, y, z]) = y"
+ "tl (tl [x, y, z]) = [z]"
+ using hd.simps tl.simps(2) list.simps
+ by smt+
+
+lemma
+ "fst (hd [(a, b)]) = a"
+ "snd (hd [(a, b)]) = b"
+ using fst_conv snd_conv pair_collapse hd.simps tl.simps(2) list.simps
+ by smt+
+
+
+subsubsection {* Records *}
+
+record point =
+ x :: int
+ y :: int
+
+record bw_point = point +
+ black :: bool
+
+lemma
+ "p1 = p2 \<longrightarrow> x p1 = x p2"
+ "p1 = p2 \<longrightarrow> y p1 = y p2"
+ "x p1 \<noteq> x p2 \<longrightarrow> p1 \<noteq> p2"
+ "y p1 \<noteq> y p2 \<longrightarrow> p1 \<noteq> p2"
+ using point.simps
+ by smt+
+
+lemma
+ "x \<lparr> x = 3, y = 4 \<rparr> = 3"
+ "y \<lparr> x = 3, y = 4 \<rparr> = 4"
+ "x \<lparr> x = 3, y = 4 \<rparr> \<noteq> y \<lparr> x = 3, y = 4 \<rparr>"
+ "\<lparr> x = 3, y = 4 \<rparr> \<lparr> x := 5 \<rparr> = \<lparr> x = 5, y = 4 \<rparr>"
+ "\<lparr> x = 3, y = 4 \<rparr> \<lparr> y := 6 \<rparr> = \<lparr> x = 3, y = 6 \<rparr>"
+ "p = \<lparr> x = 3, y = 4 \<rparr> \<longrightarrow> p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4 \<rparr> \<longrightarrow> p \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr> = p"
+ using point.simps
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "y (p \<lparr> x := a \<rparr>) = y p"
+ "x (p \<lparr> y := a \<rparr>) = x p"
+ "p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr>"
+ sorry
+
+lemma
+ "p1 = p2 \<longrightarrow> x p1 = x p2"
+ "p1 = p2 \<longrightarrow> y p1 = y p2"
+ "p1 = p2 \<longrightarrow> black p1 = black p2"
+ "x p1 \<noteq> x p2 \<longrightarrow> p1 \<noteq> p2"
+ "y p1 \<noteq> y p2 \<longrightarrow> p1 \<noteq> p2"
+ "black p1 \<noteq> black p2 \<longrightarrow> p1 \<noteq> p2"
+ using point.simps bw_point.simps
+ by smt+
+
+lemma
+ "x \<lparr> x = 3, y = 4, black = b \<rparr> = 3"
+ "y \<lparr> x = 3, y = 4, black = b \<rparr> = 4"
+ "black \<lparr> x = 3, y = 4, black = b \<rparr> = b"
+ "x \<lparr> x = 3, y = 4, black = b \<rparr> \<noteq> y \<lparr> x = 3, y = 4, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> x := 5 \<rparr> = \<lparr> x = 5, y = 4, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> y := 6 \<rparr> = \<lparr> x = 3, y = 6, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> black := w \<rparr> = \<lparr> x = 3, y = 4, black = w \<rparr>"
+ "\<lparr> x = 3, y = 4, black = True \<rparr> \<lparr> black := False \<rparr> =
+ \<lparr> x = 3, y = 4, black = False \<rparr>"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> \<lparr> x := 3 \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> black := True \<rparr> \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p"
+ using point.simps bw_point.simps
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> =
+ p \<lparr> black := True \<rparr> \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr>"
+ sorry
+
+
+subsubsection {* Type definitions *}
+
+typedef three = "{1, 2, 3::int}" by auto
+
+definition n1 where "n1 = Abs_three 1"
+definition n2 where "n2 = Abs_three 2"
+definition n3 where "n3 = Abs_three 3"
+definition nplus where "nplus n m = Abs_three (Rep_three n + Rep_three m)"
+
+lemma three_def': "(x \<in> three) = (x = 1 \<or> x = 2 \<or> x = 3)"
+ by (auto simp add: three_def)
+
+lemma
+ "n1 = n1"
+ "n2 = n2"
+ "n1 \<noteq> n2"
+ "nplus n1 n1 = n2"
+ "nplus n1 n2 = n3"
+ using n1_def n2_def n3_def nplus_def
+ using three_def' Rep_three Abs_three_inverse
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+
+subsection {* With support by the SMT solver (but without proofs) *}
+
+subsubsection {* Algebraic datatypes *}
+
+lemma
+ "x = fst (x, y)"
+ "y = snd (x, y)"
+ "((x, y) = (y, x)) = (x = y)"
+ "((x, y) = (u, v)) = (x = u \<and> y = v)"
+ "(fst (x, y, z) = fst (u, v, w)) = (x = u)"
+ "(snd (x, y, z) = snd (u, v, w)) = (y = v \<and> z = w)"
+ "(fst (snd (x, y, z)) = fst (snd (u, v, w))) = (y = v)"
+ "(snd (snd (x, y, z)) = snd (snd (u, v, w))) = (z = w)"
+ "(fst (x, y) = snd (x, y)) = (x = y)"
+ "p1 = (x, y) \<and> p2 = (y, x) \<longrightarrow> fst p1 = snd p2"
+ "(fst (x, y) = snd (x, y)) = (x = y)"
+ "(fst p = snd p) = (p = (snd p, fst p))"
+ using fst_conv snd_conv pair_collapse
+ using [[smt_datatypes, smt_oracle]]
+ by smt+
+
+lemma
+ "[x] \<noteq> Nil"
+ "[x, y] \<noteq> Nil"
+ "x \<noteq> y \<longrightarrow> [x] \<noteq> [y]"
+ "hd (x # xs) = x"
+ "tl (x # xs) = xs"
+ "hd [x, y, z] = x"
+ "tl [x, y, z] = [y, z]"
+ "hd (tl [x, y, z]) = y"
+ "tl (tl [x, y, z]) = [z]"
+ using hd.simps tl.simps(2)
+ using [[smt_datatypes, smt_oracle]]
+ by smt+
+
+lemma
+ "fst (hd [(a, b)]) = a"
+ "snd (hd [(a, b)]) = b"
+ using fst_conv snd_conv pair_collapse hd.simps tl.simps(2)
+ using [[smt_datatypes, smt_oracle]]
+ by smt+
+
+
+subsubsection {* Records *}
+
+lemma
+ "p1 = p2 \<longrightarrow> x p1 = x p2"
+ "p1 = p2 \<longrightarrow> y p1 = y p2"
+ "x p1 \<noteq> x p2 \<longrightarrow> p1 \<noteq> p2"
+ "y p1 \<noteq> y p2 \<longrightarrow> p1 \<noteq> p2"
+ using point.simps
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "x \<lparr> x = 3, y = 4 \<rparr> = 3"
+ "y \<lparr> x = 3, y = 4 \<rparr> = 4"
+ "x \<lparr> x = 3, y = 4 \<rparr> \<noteq> y \<lparr> x = 3, y = 4 \<rparr>"
+ "\<lparr> x = 3, y = 4 \<rparr> \<lparr> x := 5 \<rparr> = \<lparr> x = 5, y = 4 \<rparr>"
+ "\<lparr> x = 3, y = 4 \<rparr> \<lparr> y := 6 \<rparr> = \<lparr> x = 3, y = 6 \<rparr>"
+ "p = \<lparr> x = 3, y = 4 \<rparr> \<longrightarrow> p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4 \<rparr> \<longrightarrow> p \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr> = p"
+ using point.simps
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "y (p \<lparr> x := a \<rparr>) = y p"
+ "x (p \<lparr> y := a \<rparr>) = x p"
+ "p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr>"
+ using point.simps
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "p1 = p2 \<longrightarrow> x p1 = x p2"
+ "p1 = p2 \<longrightarrow> y p1 = y p2"
+ "p1 = p2 \<longrightarrow> black p1 = black p2"
+ "x p1 \<noteq> x p2 \<longrightarrow> p1 \<noteq> p2"
+ "y p1 \<noteq> y p2 \<longrightarrow> p1 \<noteq> p2"
+ "black p1 \<noteq> black p2 \<longrightarrow> p1 \<noteq> p2"
+ using point.simps bw_point.simps
+ using [[smt_datatypes, smt_oracle]]
+ by smt+
+
+lemma
+ "x \<lparr> x = 3, y = 4, black = b \<rparr> = 3"
+ "y \<lparr> x = 3, y = 4, black = b \<rparr> = 4"
+ "black \<lparr> x = 3, y = 4, black = b \<rparr> = b"
+ "x \<lparr> x = 3, y = 4, black = b \<rparr> \<noteq> y \<lparr> x = 3, y = 4, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> x := 5 \<rparr> = \<lparr> x = 5, y = 4, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> y := 6 \<rparr> = \<lparr> x = 3, y = 6, black = b \<rparr>"
+ "\<lparr> x = 3, y = 4, black = b \<rparr> \<lparr> black := w \<rparr> = \<lparr> x = 3, y = 4, black = w \<rparr>"
+ "\<lparr> x = 3, y = 4, black = True \<rparr> \<lparr> black := False \<rparr> =
+ \<lparr> x = 3, y = 4, black = False \<rparr>"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> \<lparr> x := 3 \<rparr> = p"
+ "p = \<lparr> x = 3, y = 4, black = True \<rparr> \<longrightarrow>
+ p \<lparr> black := True \<rparr> \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> = p"
+ using point.simps bw_point.simps
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
+lemma
+ "p \<lparr> x := 3 \<rparr> \<lparr> y := 4 \<rparr> \<lparr> black := True \<rparr> =
+ p \<lparr> black := True \<rparr> \<lparr> y := 4 \<rparr> \<lparr> x := 3 \<rparr>"
+ using point.simps bw_point.simps
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt
+
+
+subsubsection {* Type definitions *}
+
+lemma
+ "n1 = n1"
+ "n2 = n2"
+ "n1 \<noteq> n2"
+ "nplus n1 n1 = n2"
+ "nplus n1 n2 = n3"
+ using n1_def n2_def n3_def nplus_def
+ using [[smt_datatypes, smt_oracle]]
+ using [[z3_options="AUTO_CONFIG=false"]]
+ by smt+
+
section {* Function updates *}
--- a/src/HOL/Tools/Datatype/datatype_selectors.ML Fri Dec 31 00:11:24 2010 +0100
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,83 +0,0 @@
-(* Title: HOL/Tools/Datatype/datatype_selectors.ML
- Author: Sascha Boehme, TU Muenchen
-
-Selector functions for datatype constructor arguments.
-*)
-
-signature DATATYPE_SELECTORS =
-sig
- val add_selector: ((string * typ) * int) * (string * typ) ->
- Context.generic -> Context.generic
- val lookup_selector: Proof.context -> string * int -> string option
- val setup: theory -> theory
-end
-
-structure Datatype_Selectors: DATATYPE_SELECTORS =
-struct
-
-structure Stringinttab = Table
-(
- type key = string * int
- val ord = prod_ord fast_string_ord int_ord
-)
-
-structure Data = Generic_Data
-(
- type T = string Stringinttab.table
- val empty = Stringinttab.empty
- val extend = I
- fun merge data : T = Stringinttab.merge (K true) data
-)
-
-fun pretty_term context = Syntax.pretty_term (Context.proof_of context)
-
-fun sanity_check context (((con as (n, _), i), sel as (m, _))) =
- let
- val thy = Context.theory_of context
- val varify_const =
- Const #> Type.varify_global [] #> snd #> Term.dest_Const #>
- snd #> Term.strip_type
-
- val (Ts, T) = varify_const con
- val (Us, U) = varify_const sel
- val _ = (0 < i andalso i <= length Ts) orelse
- error (Pretty.string_of (Pretty.block [
- Pretty.str "The constructor ",
- Pretty.quote (pretty_term context (Const con)),
- Pretty.str " has no argument position ",
- Pretty.str (string_of_int i),
- Pretty.str "."]))
- val _ = length Us = 1 orelse
- error (Pretty.string_of (Pretty.block [
- Pretty.str "The term ", Pretty.quote (pretty_term context (Const sel)),
- Pretty.str " might not be a selector ",
- Pretty.str "(it accepts more than one argument)."]))
- val _ =
- (Sign.typ_equiv thy (T, hd Us) andalso
- Sign.typ_equiv thy (nth Ts (i-1), U)) orelse
- error (Pretty.string_of (Pretty.block [
- Pretty.str "The types of the constructor ",
- Pretty.quote (pretty_term context (Const con)),
- Pretty.str " and of the selector ",
- Pretty.quote (pretty_term context (Const sel)),
- Pretty.str " do not fit to each other."]))
- in ((n, i), m) end
-
-fun add_selector (entry as ((con as (n, _), i), (_, T))) context =
- (case Stringinttab.lookup (Data.get context) (n, i) of
- NONE => Data.map (Stringinttab.update (sanity_check context entry)) context
- | SOME c => error (Pretty.string_of (Pretty.block [
- Pretty.str "There is already a selector assigned to constructor ",
- Pretty.quote (pretty_term context (Const con)), Pretty.str ", namely ",
- Pretty.quote (pretty_term context (Const (c, T))), Pretty.str "."])))
-
-fun lookup_selector ctxt = Stringinttab.lookup (Data.get (Context.Proof ctxt))
-
-val setup =
- Attrib.setup @{binding selector}
- ((Args.term >> Term.dest_Const) -- Scan.lift (Parse.nat) --|
- Scan.lift (Parse.$$$ "=") -- (Args.term >> Term.dest_Const) >>
- (Thm.declaration_attribute o K o add_selector))
- "assign a selector function to a datatype constructor argument"
-
-end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/SMT/smt_datatypes.ML Mon Jan 03 16:22:08 2011 +0100
@@ -0,0 +1,126 @@
+(* Title: HOL/Tools/SMT/smt_datatypes.ML
+ Author: Sascha Boehme, TU Muenchen
+
+Collector functions for common type declarations and their representation
+as algebraic datatypes.
+*)
+
+signature SMT_DATATYPES =
+sig
+ val add_decls: typ ->
+ (typ * (term * term list) list) list list * Proof.context ->
+ (typ * (term * term list) list) list list * Proof.context
+end
+
+structure SMT_Datatypes: SMT_DATATYPES =
+struct
+
+val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
+
+fun mk_selectors T Ts ctxt =
+ let
+ val (sels, ctxt') =
+ Variable.variant_fixes (replicate (length Ts) "select") ctxt
+ in (map2 (fn n => fn U => Free (n, T --> U)) sels Ts, ctxt') end
+
+
+(* datatype declarations *)
+
+fun get_datatype_decl ({descr, ...} : Datatype.info) n Ts ctxt =
+ let
+ fun get_vars (_, (m, vs, _)) = if m = n then SOME vs else NONE
+ val vars = the (get_first get_vars descr) ~~ Ts
+ val lookup_var = the o AList.lookup (op =) vars
+
+ val dTs = map (apsnd (fn (m, vs, _) => Type (m, map lookup_var vs))) descr
+ val lookup_typ = the o AList.lookup (op =) dTs
+
+ fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
+ | typ_of (Datatype.DtType (n, dts)) = Type (n, map typ_of dts)
+ | typ_of (Datatype.DtRec i) = lookup_typ i
+
+ fun mk_constr T (m, dts) ctxt =
+ let
+ val Ts = map typ_of dts
+ val constr = Const (m, Ts ---> T)
+ val (selects, ctxt') = mk_selectors T Ts ctxt
+ in ((constr, selects), ctxt') end
+
+ fun mk_decl (i, (_, _, constrs)) ctxt =
+ let
+ val T = lookup_typ i
+ val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
+ in ((T, css), ctxt') end
+
+ in fold_map mk_decl descr ctxt end
+
+
+(* record declarations *)
+
+val record_name_of = Long_Name.implode o fst o split_last o Long_Name.explode
+
+fun get_record_decl ({ext_def, ...} : Record.info) T ctxt =
+ let
+ val (con, _) = Term.dest_Const (lhs_head_of ext_def)
+ val (fields, more) = Record.get_extT_fields (ProofContext.theory_of ctxt) T
+ val fieldTs = map snd fields @ [snd more]
+
+ val constr = Const (con, fieldTs ---> T)
+ val (selects, ctxt') = mk_selectors T fieldTs ctxt
+ in ((T, [(constr, selects)]), ctxt') end
+
+
+(* typedef declarations *)
+
+fun get_typedef_decl (info : Typedef.info) T Ts =
+ let
+ val ({Abs_name, Rep_name, abs_type, rep_type, ...}, _) = info
+
+ val env = snd (Term.dest_Type abs_type) ~~ Ts
+ val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
+
+ val constr = Const (Abs_name, instT (rep_type --> abs_type))
+ val select = Const (Rep_name, instT (abs_type --> rep_type))
+ in (T, [(constr, [select])]) end
+
+
+(* collection of declarations *)
+
+fun declared declss T = exists (exists (equal T o fst)) declss
+
+fun get_decls T n Ts ctxt =
+ let val thy = ProofContext.theory_of ctxt
+ in
+ (case Datatype.get_info thy n of
+ SOME info => get_datatype_decl info n Ts ctxt
+ | NONE =>
+ (case Record.get_info thy (record_name_of n) of
+ SOME info => get_record_decl info T ctxt |>> single
+ | NONE =>
+ (case Typedef.get_info ctxt n of
+ [] => ([], ctxt)
+ | info :: _ => ([get_typedef_decl info T Ts], ctxt))))
+ end
+
+fun add_decls T (declss, ctxt) =
+ let
+ fun add (TFree _) = I
+ | add (TVar _) = I
+ | add (T as Type (@{type_name fun}, _)) =
+ fold add (Term.body_type T :: Term.binder_types T)
+ | add @{typ bool} = I
+ | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
+ if declared declss T orelse declared dss T then (dss, ctxt1)
+ else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
+ else
+ (case get_decls T n Ts ctxt1 of
+ ([], _) => (dss, ctxt1)
+ | (ds, ctxt2) =>
+ let
+ val constrTs =
+ maps (map (snd o Term.dest_Const o fst) o snd) ds
+ val Us = fold (union (op =) o Term.binder_types) constrTs []
+ in fold add Us (ds :: dss, ctxt2) end))
+ in add T ([], ctxt) |>> append declss end
+
+end
--- a/src/HOL/Tools/SMT/smt_translate.ML Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_translate.ML Mon Jan 03 16:22:08 2011 +0100
@@ -134,20 +134,35 @@
(* preprocessing *)
-(** FIXME **)
-
-local
- (*
- force eta-expansion for constructors and selectors,
- add missing datatype selectors via hypothetical definitions,
- also return necessary datatype and record theorems
- *)
-in
+(** datatype declarations **)
fun collect_datatypes_and_records (tr_context, ctxt) ts =
- (([], tr_context, ctxt), ts)
+ let
+ val (declss, ctxt') =
+ fold (Term.fold_types SMT_Datatypes.add_decls) ts ([], ctxt)
+
+ fun is_decl_typ T = exists (exists (equal T o fst)) declss
+
+ fun add_typ' T proper =
+ (case SMT_Builtin.dest_builtin_typ ctxt' T of
+ SOME n => pair n
+ | NONE => add_typ T proper)
-end
+ fun tr_select sel =
+ let val T = Term.range_type (Term.fastype_of sel)
+ in add_fun sel NONE ##>> add_typ' T (not (is_decl_typ T)) end
+ fun tr_constr (constr, selects) =
+ add_fun constr NONE ##>> fold_map tr_select selects
+ fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
+ val (declss', tr_context') = fold_map (fold_map tr_typ) declss tr_context
+
+ fun add (constr, selects) =
+ Termtab.update (constr, length selects) #>
+ fold (Termtab.update o rpair 1) selects
+ val funcs = fold (fold (fold add o snd)) declss Termtab.empty
+
+ in ((funcs, declss', tr_context', ctxt'), ts) end
+ (* FIXME: also return necessary datatype and record theorems *)
(** eta-expand quantifiers, let expressions and built-ins *)
@@ -174,8 +189,15 @@
end
in
-fun eta_expand ctxt =
+fun eta_expand ctxt funcs =
let
+ fun exp_func t T ts =
+ (case Termtab.lookup funcs t of
+ SOME k =>
+ Term.list_comb (t, ts)
+ |> k <> length ts ? expf k (length ts) T
+ | NONE => Term.list_comb (t, ts))
+
fun expand ((q as Const (@{const_name All}, _)) $ Abs a) = q $ abs_expand a
| expand ((q as Const (@{const_name All}, T)) $ t) = q $ exp T t
| expand (q as Const (@{const_name All}, T)) = exp2 T q
@@ -196,7 +218,8 @@
SOME (_, k, us, mk) =>
if k = length us then mk (map expand us)
else expf k (length ts) T (mk (map expand us))
- | NONE => Term.list_comb (u, map expand ts))
+ | NONE => exp_func u T (map expand ts))
+ | (u as Free (_, T), ts) => exp_func u T (map expand ts)
| (Abs a, ts) => Term.list_comb (abs_expand a, map expand ts)
| (u, ts) => Term.list_comb (u, map expand ts))
@@ -530,17 +553,18 @@
val with_datatypes =
has_datatypes andalso Config.get ctxt SMT_Config.datatypes
- fun no_dtyps (tr_context, ctxt) ts = (([], tr_context, ctxt), ts)
+ fun no_dtyps (tr_context, ctxt) ts =
+ ((Termtab.empty, [], tr_context, ctxt), ts)
val ts1 = map (Envir.beta_eta_contract o SMT_Utils.prop_of o snd) ithms
- val ((dtyps, tr_context, ctxt1), ts2) =
+ val ((funcs, dtyps, tr_context, ctxt1), ts2) =
((make_tr_context prefixes, ctxt), ts1)
|-> (if with_datatypes then collect_datatypes_and_records else no_dtyps)
val (ctxt2, ts3) =
ts2
- |> eta_expand ctxt1
+ |> eta_expand ctxt1 funcs
|> lift_lambdas ctxt1
||> intro_explicit_application
--- a/src/HOL/Tools/SMT/smtlib_interface.ML Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smtlib_interface.ML Mon Jan 03 16:22:08 2011 +0100
@@ -130,10 +130,10 @@
fun sdatatypes decls =
let
- fun con (n, []) = add n
+ fun con (n, []) = sep (add n)
| con (n, sels) = par (add n #>
- fold (fn (n, s) => sep (par (add n #> sep (add s)))) sels)
- fun dtyp (n, decl) = add n #> fold (sep o con) decl
+ fold (fn (n, s) => par (add n #> sep (add s))) sels)
+ fun dtyp (n, decl) = add n #> fold con decl
in line (add ":datatypes " #> par (fold (par o dtyp) decls)) end
fun serialize comments {header, sorts, dtyps, funcs} ts =
--- a/src/HOL/Tools/SMT/z3_proof_reconstruction.ML Fri Dec 31 00:11:24 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_proof_reconstruction.ML Mon Jan 03 16:22:08 2011 +0100
@@ -851,7 +851,8 @@
|> tap (check_after idx r ps prop)
in (thm, (ctxt', Inttab.update (idx, thm) ptab')) end
- val disch_rules = [@{thm allI}, @{thm refl}, @{thm reflexive}]
+ val disch_rules = [@{thm allI}, @{thm refl}, @{thm reflexive},
+ Z3_Proof_Literals.true_thm]
fun all_disch_rules rules = map (pair false) (disch_rules @ rules)
fun disch_assm rules thm =