use proper context operations (for fresh names of type and term variables, and for hypothetical definitions), monomorphize theorems (instead of terms, necessary for hypothetical definitions made during lambda lifting)
authorboehmes
Wed, 12 May 2010 23:54:00 +0200
changeset 36896 c030819254d3
parent 36895 a96f9793d9c5
child 36897 6d1ecdb81ff0
use proper context operations (for fresh names of type and term variables, and for hypothetical definitions), monomorphize theorems (instead of terms, necessary for hypothetical definitions made during lambda lifting)
src/HOL/SMT/SMT_Base.thy
src/HOL/SMT/Tools/cvc3_solver.ML
src/HOL/SMT/Tools/smt_monomorph.ML
src/HOL/SMT/Tools/smt_normalize.ML
src/HOL/SMT/Tools/smt_solver.ML
src/HOL/SMT/Tools/yices_solver.ML
src/HOL/SMT/Tools/z3_proof_reconstruction.ML
src/HOL/SMT/Tools/z3_proof_tools.ML
src/HOL/SMT/Tools/z3_solver.ML
--- a/src/HOL/SMT/SMT_Base.thy	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/SMT_Base.thy	Wed May 12 23:54:00 2010 +0200
@@ -9,8 +9,8 @@
 uses
   "~~/src/Tools/cache_io.ML"
   ("Tools/smt_additional_facts.ML")
+  ("Tools/smt_monomorph.ML")
   ("Tools/smt_normalize.ML")
-  ("Tools/smt_monomorph.ML")
   ("Tools/smt_translate.ML")
   ("Tools/smt_solver.ML")
   ("Tools/smtlib_interface.ML")
@@ -119,8 +119,8 @@
 section {* Setup *}
 
 use "Tools/smt_additional_facts.ML"
+use "Tools/smt_monomorph.ML"
 use "Tools/smt_normalize.ML"
-use "Tools/smt_monomorph.ML"
 use "Tools/smt_translate.ML"
 use "Tools/smt_solver.ML"
 use "Tools/smtlib_interface.ML"
--- a/src/HOL/SMT/Tools/cvc3_solver.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/cvc3_solver.ML	Wed May 12 23:54:00 2010 +0200
@@ -23,7 +23,7 @@
 
 fun raise_cex real = raise SMT_Solver.SMT_COUNTEREXAMPLE (real, [])
 
-fun core_oracle ({output, ...} : SMT_Solver.proof_data) =
+fun core_oracle (output, _) =
   let
     val empty_line = (fn "" => true | _ => false)
     val split_first = (fn [] => ("", []) | l :: ls => (l, ls))
@@ -39,7 +39,7 @@
   command = {env_var=env_var, remote_name=SOME solver_name},
   arguments = options,
   interface = SMTLIB_Interface.interface,
-  reconstruct = oracle }
+  reconstruct = pair o oracle }
 
 val setup =
   Thm.add_oracle (Binding.name solver_name, core_oracle) #-> (fn (_, oracle) =>
--- a/src/HOL/SMT/Tools/smt_monomorph.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/smt_monomorph.ML	Wed May 12 23:54:00 2010 +0200
@@ -1,128 +1,201 @@
 (*  Title:      HOL/SMT/Tools/smt_monomorph.ML
     Author:     Sascha Boehme, TU Muenchen
 
-Monomorphization of terms, i.e., computation of all (necessary) instances.
+Monomorphization of theorems, i.e., computation of all (necessary) instances.
 *)
 
 signature SMT_MONOMORPH =
 sig
-  val monomorph: theory -> term list -> term list
+  val monomorph: thm list -> Proof.context -> thm list * Proof.context
 end
 
 structure SMT_Monomorph: SMT_MONOMORPH =
 struct
 
-fun selection [] = []
-  | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs)
-
-fun permute [] = []
-  | permute [x] = [[x]]
-  | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs)
-
-fun fold_all f = fold (fn x => maps (f x))
-
-
 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
-val term_has_tvars = Term.exists_type typ_has_tvars
 
 val ignored = member (op =) [
   @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
   @{const_name "op ="}, @{const_name zero_class.zero},
   @{const_name one_class.one}, @{const_name number_of}]
-fun consts_of ts = AList.group (op =) (fold Term.add_consts ts [])
-  |> filter_out (ignored o fst)
+
+fun is_const f (n, T) = not (ignored n) andalso f T
+fun add_const_if f g (Const c) = if is_const f c then g c else I
+  | add_const_if _ _ _ = I
+
+fun collect_consts_if f g thm =
+  Term.fold_aterms (add_const_if f g) (Thm.prop_of thm)
+
+fun add_consts f =
+  collect_consts_if f (fn (n, T) => Symtab.map_entry n (insert (op =) T))
+
+val insert_const = OrdList.insert (prod_ord fast_string_ord Term_Ord.typ_ord)
+fun tvar_consts_of thm = collect_consts_if typ_has_tvars insert_const thm []
+
+
+fun incr_indexes thms =
+  let fun inc thm idx = (Thm.incr_indexes idx thm, Thm.maxidx_of thm + idx + 1)
+  in fst (fold_map inc thms 0) end
+
+
+(* Compute all substitutions from the types "Ts" to all relevant
+   types in "grounds", with respect to the given substitution. *)
+fun new_substitutions thy grounds (n, T) subst =
+  if not (typ_has_tvars T) then [subst]
+  else
+    Symtab.lookup_list grounds n
+    |> map_filter (try (fn U => Sign.typ_match thy (T, U) subst))
+    |> cons subst
+
 
-fun join_consts cs ds = AList.join (op =) (K (merge (op =))) (cs, ds)
-fun diff_consts cs ds = 
-  let fun diff (n, Ts) =
-    (case AList.lookup (op =) cs n of
-      NONE => SOME (n, Ts)
-    | SOME Us =>
-        let val Ts' = fold (remove (op =)) Us Ts
-        in if null Ts' then NONE else SOME (n, Ts') end)
-  in map_filter diff ds end
+(* Instantiate a set of constants with a substitution.  Also collect
+   all new ground instances for the next round of specialization. *)
+fun apply_subst grounds consts subst =
+  let
+    fun is_new_ground (n, T) = not (typ_has_tvars T) andalso
+      not (member (op =) (Symtab.lookup_list grounds n) T)
+
+    fun apply_const (n, T) new_grounds =
+      let val c = (n, Envir.subst_type subst T)
+      in
+        new_grounds
+        |> is_new_ground c ? Symtab.insert_list (op =) c
+        |> pair c
+      end
+  in fold_map apply_const consts #>> pair subst end
+
 
-fun instances thy is (n, Ts) env =
+(* Compute new substitutions for the theorem "thm", based on
+   previously found substitutions.
+     Also collect new grounds, i.e., instantiated constants
+   (without schematic types) which do not occur in any of the
+   previous rounds. Note that thus no schematic type variables are
+   shared among theorems. *)
+fun specialize thy all_grounds new_grounds (thm, scs) =
   let
-    val Us = these (AList.lookup (op =) is n)
-    val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts)
+    fun spec (subst, consts) next_grounds =
+      [subst]
+      |> fold (maps o new_substitutions thy new_grounds) consts
+      |> rpair next_grounds
+      |-> fold_map (apply_subst all_grounds consts)
   in
-    (case map_product pair Ts' Us of
-      [] => [env]
-    | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs)
+    fold_map spec scs #>> (fn scss =>
+    (thm, fold (fold (insert (eq_snd (op =)))) scss []))
   end
 
-fun proper_match ps env =
-  forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps
 
-fun eq_tab (tab1, tab2) = eq_set (op =) (Vartab.dest tab1, Vartab.dest tab2)
+(* Compute all necessary substitutions.
+     Instead of operating on the propositions of the theorems, the
+   computation uses only the constants occurring with schematic type
+   variables in the propositions. To ease comparisons, such sets of
+   costants are always kept in their initial order. *)
+fun incremental_monomorph thy limit all_grounds new_grounds ths =
+  let
+    val all_grounds' = Symtab.merge_list (op =) (all_grounds, new_grounds)
+    val spec = specialize thy all_grounds' new_grounds
+    val (ths', new_grounds') = fold_map spec ths Symtab.empty
+  in
+    if Symtab.is_empty new_grounds' then ths'
+    else if limit > 0
+    then incremental_monomorph thy (limit-1) all_grounds' new_grounds' ths'
+    else (warning "SMT: monomorphization limit reached"; ths')
+  end
 
-fun specialize thy cs is ((r, ps), ces) (ts, ns) =
-  let
-    val ps' = filter (AList.defined (op =) is o fst) ps
 
-    val envs = permute ps'
-      |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty])
-      |> filter (proper_match ps')
-      |> filter_out (member eq_tab ces)
-      |> distinct eq_tab
+fun filter_most_specific thy =
+  let
+    fun typ_match (_, T) (_, U) = Sign.typ_match thy (T, U)
+
+    fun is_trivial subst = Vartab.is_empty subst orelse
+      forall (fn (v, (S, T)) => TVar (v, S) = T) (Vartab.dest subst)
 
-    val us = map (fn env => Envir.subst_term_types env r) envs
-    val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns
-  in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
+    fun match general specific =
+      (case try (fold2 typ_match general specific) Vartab.empty of
+        NONE => false
+      | SOME subst => not (is_trivial subst))
+
+    fun most_specific _ [] = []
+      | most_specific css ((ss, cs) :: scs) =
+          let val substs = most_specific (cs :: css) scs
+          in
+            if exists (match cs) css orelse exists (match cs o snd) scs
+            then substs else ss :: substs
+          end
+
+  in most_specific [] end
 
 
-fun incr_tvar_indices i t =
+fun instantiate thy Tenv =
   let
-    val incrT = Logic.incr_tvar_same i
+    fun replace (v, (_, T)) (U as TVar (u, _)) = if u = v then T else U
+      | replace _ T = T
+
+    fun complete (vT as (v, _)) subst =
+      subst
+      |> not (Vartab.defined subst v) ? Vartab.update vT
+      |> Vartab.map (apsnd (Term.map_atyps (replace vT)))
+
+    fun cert (ix, (S, T)) = pairself (Thm.ctyp_of thy) (TVar (ix, S), T)
+
+    fun inst thm subst =
+      let val cTs = Vartab.fold (cons o cert) (fold complete Tenv subst) []
+      in Thm.instantiate (cTs, []) thm end
+
+  in uncurry (map o inst) end
+
 
-    fun incr t =
-      (case t of
-        Const (n, T) => Const (n, incrT T)
-      | Free (n, T) => Free (n, incrT T)
-      | Abs (n, T, t1) => (Abs (n, incrT T, incr t1 handle Same.SAME => t1)
-          handle Same.SAME => Abs (n, T, incr t1))
-      | t1 $ t2 => (incr t1 $ (incr t2 handle Same.SAME => t2)
-          handle Same.SAME => t1 $ incr t2)
-      | _ => Same.same t)
-  in incr t handle Same.SAME => t end
+fun mono_all ctxt _ [] monos = (monos, ctxt)
+  | mono_all ctxt limit polys monos =
+      let
+        fun invent_types thm ctxt =
+          let val (vs, Ss) = split_list (Term.add_tvars (Thm.prop_of thm) [])
+          in
+            ctxt
+            |> Variable.invent_types Ss
+            |>> map2 (fn v => fn (n, S) => (v, (S, TFree (n, S)))) vs
+          end
+        val (Tenvs, ctxt') = fold_map invent_types polys ctxt
+
+        val thy = ProofContext.theory_of ctxt'
+
+        val ths = polys
+          |> map (fn thm => (thm, [(Vartab.empty, tvar_consts_of thm)]))
+
+        (* all constant names occurring with schematic types *)
+        val ns = fold (fold (fold (insert (op =) o fst) o snd) o snd) ths []
+
+        (* all known instances with non-schematic types *)
+        val grounds =
+          Symtab.make (map (rpair []) ns)
+          |> fold (add_consts (K true)) monos
+          |> fold (add_consts (not o typ_has_tvars)) polys
+      in
+        polys
+        |> map (fn thm => (thm, [(Vartab.empty, tvar_consts_of thm)]))
+        |> incremental_monomorph thy limit Symtab.empty grounds
+        |> map (apsnd (filter_most_specific thy))
+        |> flat o map2 (instantiate thy) Tenvs
+        |> append monos
+        |> rpair ctxt'
+      end
 
 
 val monomorph_limit = 10
 
-(* Instantiate all polymorphic constants (i.e., constants occurring both with
-   ground types and type variables) with all (necessary) ground types; thereby
-   create copies of terms containing those constants.
-   To prevent non-termination, there is an upper limit for the number of
-   recursions involved in the fixpoint construction. *)
-fun monomorph thy =
-  let
-    fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
-    fun incr_indices ts = fst (fold_map incr ts 0)
 
-    fun with_tvar (n, Ts) =
-      let val Ts' = filter typ_has_tvars Ts
-      in if null Ts' then NONE else SOME (n, Ts') end
-    fun extract_consts_with_tvar t = (t, map_filter with_tvar (consts_of [t]))
-
-    fun mono rps count is ces cs ts =
-      let
-        val spec = specialize thy cs is
-        val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
-        val cs' = join_consts is cs
-      in
-        if null is' then ts'
-        else if count > monomorph_limit then
-          (warning "monomorphization limit reached"; ts')
-        else mono rps (count + 1) is' ces' cs' ts'
-      end
-    fun mono_all rps ms = if null rps then ms
-      else mono rps 0 (consts_of ms) (map (K []) rps) [] ms
-  in
-    List.partition term_has_tvars
-    #>> incr_indices
-    #>> map extract_consts_with_tvar
-    #-> mono_all
-  end
+(* Instantiate all polymorphic constants (i.e., constants occurring
+   both with ground types and type variables) with all (necessary)
+   ground types; thereby create copies of theorems containing those
+   constants.
+     To prevent non-termination, there is an upper limit for the
+   number of recursions involved in the fixpoint construction.
+     The initial set of theorems must not contain any schematic term
+   variables, and the final list of theorems does not contain any
+   schematic type variables anymore. *)
+fun monomorph thms ctxt =
+  thms
+  |> List.partition (Term.exists_type typ_has_tvars o Thm.prop_of)
+  |>> incr_indexes
+  |-> mono_all ctxt monomorph_limit
 
 end
--- a/src/HOL/SMT/Tools/smt_normalize.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/smt_normalize.ML	Wed May 12 23:54:00 2010 +0200
@@ -16,10 +16,7 @@
 
 signature SMT_NORMALIZE =
 sig
-  val instantiate_free: cterm * cterm -> thm -> thm
-  val discharge_definition: cterm -> thm -> thm
-
-  val normalize: Proof.context -> thm list -> cterm list * thm list
+  val normalize: thm list -> Proof.context -> thm list * Proof.context
 end
 
 structure SMT_Normalize: SMT_NORMALIZE =
@@ -31,18 +28,6 @@
 fun if_conv c cv1 cv2 ct = (if c (Thm.term_of ct) then cv1 else cv2) ct
 fun if_true_conv c cv = if_conv c cv Conv.all_conv
 
-fun instantiate_free (cv, ct) =
-  (Term.exists_subterm (equal (Thm.term_of cv)) o Thm.prop_of) ??
-  (Thm.forall_elim ct o Thm.forall_intr cv)
-
-fun discharge_definition ct thm =
-  let val (cv, cu) = Thm.dest_equals ct
-  in
-    Thm.implies_intr ct thm
-    |> instantiate_free (cv, cu)
-    |> (fn thm => Thm.implies_elim thm (Thm.reflexive cu))
-  end
-
 
 
 (* simplification of trivial distincts (distinct should have at least
@@ -332,35 +317,34 @@
   fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
   fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
 
-  fun norm_meta_def cv thm = 
-    let val thm' = Thm.combination thm (Thm.reflexive cv)
-    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
-
   fun cert ctxt = Thm.cterm_of (ProofContext.theory_of ctxt)
 
-  val fresh_name = yield_singleton Name.variants
-
   fun used_vars cvs ct =
     let
       val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
-      val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
+      val add = (fn SOME ct => insert (op aconvc) ct | _ => I)
     in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
-  fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
-  fun add_def ct thm = Termtab.update (Thm.term_of ct, (serial (), thm))
 
-  fun replace ctxt cvs ct (cx as (nctxt, defs)) =
+  fun apply cv thm = 
+    let val thm' = Thm.combination thm (Thm.reflexive cv)
+    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
+  fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq)
+
+  fun replace_lambda cvs ct (cx as (ctxt, defs)) =
     let
       val cvs' = used_vars cvs ct
       val ct' = fold_rev Thm.cabs cvs' ct
     in
       (case Termtab.lookup defs (Thm.term_of ct') of
-        SOME (_, eq) => (make_def cvs' eq, cx)
+        SOME eq => (apply_def cvs' eq, cx)
       | NONE =>
           let
-            val {T, ...} = Thm.rep_cterm ct'
-            val (n, nctxt') = fresh_name "" nctxt
-            val eq = Thm.assume (mk_meta_eq (cert ctxt (Free (n, T))) ct')
-          in (make_def cvs' eq, (nctxt', add_def ct' eq defs)) end)
+            val {T, ...} = Thm.rep_cterm ct' and n = Name.uu
+            val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
+            val cu = mk_meta_eq (cert ctxt (Free (n', T))) ct'
+            val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt'
+            val defs' = Termtab.update (Thm.term_of ct', eq) defs
+          in (apply_def cvs' eq, (ctxt'', defs')) end)
     end
 
   fun none ct cx = (Thm.reflexive ct, cx)
@@ -368,28 +352,25 @@
     let val (cu1, cu2) = Thm.dest_comb ct
     in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
   fun in_arg f = in_comb none f
-  fun in_abs f cvs ct (nctxt, defs) =
-    let
-      val (n, nctxt') = fresh_name Name.uu nctxt
-      val (cv, cu) = Thm.dest_abs (SOME n) ct
-    in f (cv :: cvs) cu (nctxt', defs) |>> Thm.abstract_rule n cv end
-
-  fun replace_lambdas ctxt =
+  fun in_abs f cvs ct (ctxt, defs) =
     let
-      fun repl cvs ct =
-        (case Thm.term_of ct of
-          Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs repl cvs)
-        | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs repl cvs)
-        | Const _ $ Abs _ => in_arg (at_lambda cvs)
-        | Const (@{const_name Let}, _) $ _ $ Abs _ =>
-            in_comb (in_arg (repl cvs)) (in_abs repl cvs)
-        | Abs _ => at_lambda cvs
-        | _ $ _ => in_comb (repl cvs) (repl cvs)
-        | _ => none) ct
-      and at_lambda cvs ct =
-        in_abs repl cvs ct #-> (fn thm =>
-        replace ctxt cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
-    in repl [] end
+      val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
+      val (cv, cu) = Thm.dest_abs (SOME n) ct
+    in  (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end
+
+  fun traverse cvs ct =
+    (case Thm.term_of ct of
+      Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs)
+    | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs)
+    | Const (@{const_name Let}, _) $ _ $ Abs _ =>
+        in_comb (in_arg (traverse cvs)) (in_abs traverse cvs)
+    | Abs _ => at_lambda cvs
+    | _ $ _ => in_comb (traverse cvs) (traverse cvs)
+    | _ => none) ct
+
+  and at_lambda cvs ct =
+    in_abs traverse cvs ct #-> (fn thm =>
+    replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
 
   fun has_free_lambdas t =
     (case t of
@@ -400,26 +381,17 @@
     | Abs _ => true
     | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2
     | _ => false)
+
+  fun lift_lm f thm cx =
+    if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
+    else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm)
 in
-fun lift_lambdas ctxt thms =
+fun lift_lambdas thms ctxt =
   let
-    val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
-    fun rewrite f thm cx =
-      if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
-      else f (Thm.cprop_of thm) cx |>> (fn thm' => Thm.equal_elim thm' thm)
-
-    val rev_int_fst_ord = rev_order o int_ord o pairself fst
-    fun ordered_values tab =
-      Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
-      |> map snd
-
-    val (thms', (_, defs)) =
-      (declare_frees thms (Name.make_context []), Termtab.empty)
-      |> fold_map (rewrite (replace_lambdas ctxt)) thms
-    val eqs = ordered_values defs
-  in
-    (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
-  end
+    val cx = (ctxt, Termtab.empty)
+    val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx
+    val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs []
+  in (eqs @ thms', ctxt') end
 end
 
 
@@ -483,14 +455,16 @@
 
 (* combined normalization *)
 
-fun normalize ctxt thms =
+fun normalize thms ctxt =
   thms
   |> trivial_distinct ctxt
   |> rewrite_bool_cases ctxt
   |> normalize_numerals ctxt
   |> nat_as_int ctxt
   |> map (unfold_defs ctxt #> normalize_rule ctxt)
-  |> lift_lambdas ctxt
-  |> apsnd (explicit_application ctxt)
+  |> rpair ctxt
+  |-> SMT_Monomorph.monomorph
+  |-> lift_lambdas
+  |-> (fn thms' => `(fn ctxt' => explicit_application ctxt' thms'))
 
 end
--- a/src/HOL/SMT/Tools/smt_solver.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/smt_solver.ML	Wed May 12 23:54:00 2010 +0200
@@ -9,15 +9,12 @@
   exception SMT of string
   exception SMT_COUNTEREXAMPLE of bool * term list
 
-  type proof_data = {
-    context: Proof.context,
-    output: string list,
-    recon: SMT_Translate.recon }
   type solver_config = {
     command: {env_var: string, remote_name: string option},
     arguments: string list,
     interface: string list -> SMT_Translate.config,
-    reconstruct: proof_data -> thm }
+    reconstruct: (string list * SMT_Translate.recon) -> Proof.context ->
+      thm * Proof.context }
 
   (*options*)
   val timeout: int Config.T
@@ -56,16 +53,12 @@
 exception SMT_COUNTEREXAMPLE of bool * term list
 
 
-type proof_data = {
-  context: Proof.context,
-  output: string list,
-  recon: SMT_Translate.recon }
-
 type solver_config = {
   command: {env_var: string, remote_name: string option},
   arguments: string list,
   interface: string list -> SMT_Translate.config,
-  reconstruct: proof_data -> thm }
+  reconstruct: (string list * SMT_Translate.recon) -> Proof.context ->
+    thm * Proof.context }
 
 
 
@@ -173,12 +166,16 @@
       Pretty.big_list "functions:" (map pretty_term (Symtab.dest terms))])) ()
   end
 
-fun invoke translate_config command arguments ctxt thms =
+fun invoke translate_config command arguments thms ctxt =
   thms
   |> SMT_Translate.translate translate_config ctxt
   ||> tap (trace_recon_data ctxt)
   |>> run_solver ctxt command arguments
-  |> (fn (ls, recon) => {context=ctxt, output=ls, recon=recon})
+  |> rpair ctxt
+
+fun discharge_definitions thm =
+  if Thm.nprems_of thm = 0 then thm
+  else discharge_definitions (@{thm reflexive} RS thm)
 
 fun gen_solver name solver ctxt prems =
   let
@@ -188,10 +185,13 @@
       "arguments:" :: arguments
   in
     SMT_Additional_Facts.add_facts prems
-    |> SMT_Normalize.normalize ctxt 
-    ||> invoke (interface comments) command arguments ctxt
-    ||> reconstruct
-    |-> fold SMT_Normalize.discharge_definition
+    |> rpair ctxt
+    |-> SMT_Normalize.normalize
+    |-> invoke (interface comments) command arguments
+    |-> reconstruct
+    |-> (fn thm => fn ctxt' => thm
+    |> singleton (ProofContext.export ctxt' ctxt)
+    |> discharge_definitions)
   end
 
 
--- a/src/HOL/SMT/Tools/yices_solver.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/yices_solver.ML	Wed May 12 23:54:00 2010 +0200
@@ -19,7 +19,7 @@
 
 fun raise_cex real = raise SMT_Solver.SMT_COUNTEREXAMPLE (real, [])
 
-fun core_oracle ({output, ...} : SMT_Solver.proof_data) =
+fun core_oracle (output, _) =
   let
     val empty_line = (fn "" => true | _ => false)
     val split_first = (fn [] => ("", []) | l :: ls => (l, ls))
@@ -35,7 +35,7 @@
   command = {env_var=env_var, remote_name=NONE},
   arguments = options,
   interface = SMTLIB_Interface.interface,
-  reconstruct = oracle }
+  reconstruct = pair o oracle }
 
 val setup =
   Thm.add_oracle (Binding.name solver_name, core_oracle) #-> (fn (_, oracle) =>
--- a/src/HOL/SMT/Tools/z3_proof_reconstruction.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/z3_proof_reconstruction.ML	Wed May 12 23:54:00 2010 +0200
@@ -7,7 +7,8 @@
 signature Z3_PROOF_RECONSTRUCTION =
 sig
   val trace_assms: bool Config.T
-  val reconstruct: Proof.context -> SMT_Translate.recon -> string list -> thm
+  val reconstruct: string list * SMT_Translate.recon -> Proof.context ->
+    thm * Proof.context
   val setup: theory -> theory
 end
 
@@ -118,9 +119,7 @@
 
 (* proof representation *)
 
-datatype proof =
-  Unproved of P.proof_step |
-  Sequent of { hyps: cterm list, thm: theorem }
+datatype proof = Unproved of P.proof_step | Proved of theorem
 
 
 
@@ -156,7 +155,7 @@
 fun prepare_assms unfolds assms =
   let
     val unfolds' = rewrite_rules [L.rewrite_true] unfolds
-    val assms' = rewrite_rules (unfolds' @ prep_rules) assms
+    val assms' = rewrite_rules (union Thm.eq_thm unfolds' prep_rules) assms
   in (unfolds', T.thm_net_of assms') end
 
 fun asserted _ NONE ct = Thm (Thm.assume ct)
@@ -196,8 +195,7 @@
       val ls = L.explode conj false false [t] lit
       val lits' = fold L.insert_lit ls (L.delete_lit lit lits)
 
-      fun upd (Sequent {hyps, thm}) =
-            Sequent {hyps = hyps, thm = Literals (thm_of thm, lits')}
+      fun upd (Proved thm) = Proved (Literals (thm_of thm, lits'))
         | upd p = p
     in (the (L.lookup_lit lits' t), Inttab.map_entry idx upd ptab) end
 
@@ -349,7 +347,7 @@
       SOME thm => thm
     | NONE => raise CTERM ("intro_def", [ct]))
 in
-fun intro_def ct = apsnd Thm (T.make_hyp_def (apply_rule ct))
+fun intro_def ct = T.make_hyp_def (apply_rule ct) #>> Thm
 
 fun apply_def thm =
   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
@@ -590,7 +588,7 @@
   fun kind (Const (@{const_name Ex}, _) $ _) = (sk_ex_rule, I, I)
     | kind (@{term Not} $ (Const (@{const_name All}, _) $ _)) =
         (sk_all_rule, Thm.dest_arg, Thm.capply @{cterm Not})
-    | kind _ = z3_exn "skolemize: no quantifier"
+    | kind t = raise TERM ("skolemize", [t])
 
   fun dest_abs_type (Abs (_, T, _)) = T
     | dest_abs_type t = raise TERM ("dest_abs_type", [t])
@@ -614,22 +612,23 @@
 
   fun transitive f thm = Thm.transitive thm (f (Thm.rhs_of thm))
 
-  fun sk_step (rule, elim) (cv, mct, cb) (is, thm) =
+  fun sk_step (rule, elim) (cv, mct, cb) ((is, thm), ctxt) =
     (case mct of
       SOME ct =>
-        T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
-        |> apsnd (pair ((cv, ct) :: is) o Thm.transitive thm)
-    | NONE => ([], (is, transitive (Conv.rewr_conv elim) thm)))
+        ctxt
+        |> T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
+        |>> pair ((cv, ct) :: is) o Thm.transitive thm
+    | NONE => ((is, transitive (Conv.rewr_conv elim) thm), ctxt))
 in
-fun skolemize ctxt ct =
+fun skolemize ct ctxt =
   let
     val (lhs, rhs) = Thm.dest_binop (Thm.dest_arg ct)
     val (rule, (ctab, cbs)) = bodies_of (ProofContext.theory_of ctxt) lhs rhs
     fun lookup_var (cv, cb) = (cv, AList.lookup (op aconvc) ctab cv, cb)
   in
-    ([], Thm.reflexive lhs)
-    |> fold_map (sk_step rule) (map lookup_var cbs)
-    |> apfst (rev o flat) o apsnd (MetaEq o snd)
+    (([], Thm.reflexive lhs), ctxt)
+    |> fold (sk_step rule) (map lookup_var cbs)
+    |>> MetaEq o snd
   end
 end
 
@@ -702,14 +701,14 @@
   fun count_rules ptab =
     let
       fun count (_, Unproved _) (solved, total) = (solved, total + 1)
-        | count (_, Sequent _) (solved, total) = (solved + 1, total + 1)
+        | count (_, Proved _) (solved, total) = (solved + 1, total + 1)
     in Inttab.fold count ptab (0, 0) end
 
   fun header idx r (solved, total) = 
     "Z3: #" ^ string_of_int idx ^ ": " ^ P.string_of_rule r ^ " (goal " ^
     string_of_int (solved + 1) ^ " of " ^ string_of_int total ^ ")"
 
-  fun check ctxt idx r ps ct ((_, p), _) =
+  fun check ctxt idx r ps ct p =
     let val thm = thm_of p |> tap (Thm.join_proofs o single)
     in
       if (Thm.cprop_of thm) aconvc ct then ()
@@ -720,12 +719,12 @@
             Syntax.pretty_term ctxt (Thm.term_of ct)]])))
     end
 in
-fun trace_rule ctxt idx prove r ps ct ptab =
+fun trace_rule idx prove r ps ct (cxp as (ctxt, ptab)) =
   let
     val _ = SMT_Solver.trace_msg ctxt (header idx r o count_rules) ptab
-    val result = prove r ps ct ptab
-    val _ = if not (Config.get ctxt SMT_Solver.trace) then ()
-      else check ctxt idx r ps ct result
+    val result as (p, cxp' as (ctxt', _)) = prove r ps ct cxp
+    val _ = if not (Config.get ctxt' SMT_Solver.trace) then ()
+      else check ctxt' idx r ps ct p
   in result end
 end
 
@@ -733,96 +732,87 @@
 (* overall reconstruction procedure *)
 
 fun not_supported r =
-  z3_exn ("proof rule not implemented: " ^ quote (P.string_of_rule r))
+  raise Fail ("Z3: proof rule not implemented: " ^ quote (P.string_of_rule r))
 
 fun prove ctxt unfolds assms vars =
   let
     val assms' = Option.map (prepare_assms unfolds) assms
     val simpset = T.make_simpset ctxt (Z3_Simps.get ctxt)
 
-    fun step r ps ct ptab =
+    fun step r ps ct (cxp as (cx, ptab)) =
       (case (r, ps) of
         (* core rules *)
-        (P.TrueAxiom, _) => (([], Thm L.true_thm), ptab)
-      | (P.Asserted, _) => (([], asserted ctxt assms' ct), ptab)
-      | (P.Goal, _) => (([], asserted ctxt assms' ct), ptab)
-      | (P.ModusPonens, [(p, _), (q, _)]) => (([], mp q (thm_of p)), ptab)
-      | (P.ModusPonensOeq, [(p, _), (q, _)]) => (([], mp q (thm_of p)), ptab)
-      | (P.AndElim, [(p, i)]) => apfst (pair []) (and_elim (p, i) ct ptab)
-      | (P.NotOrElim, [(p, i)]) => apfst (pair []) (not_or_elim (p, i) ct ptab)
-      | (P.Hypothesis, _) => (([], Thm (Thm.assume ct)), ptab)
-      | (P.Lemma, [(p, _)]) => (([], lemma (thm_of p) ct), ptab)
+        (P.TrueAxiom, _) => (Thm L.true_thm, cxp)
+      | (P.Asserted, _) => (asserted cx assms' ct, cxp)
+      | (P.Goal, _) => (asserted cx assms' ct, cxp)
+      | (P.ModusPonens, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
+      | (P.ModusPonensOeq, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
+      | (P.AndElim, [(p, i)]) => and_elim (p, i) ct ptab ||> pair cx
+      | (P.NotOrElim, [(p, i)]) => not_or_elim (p, i) ct ptab ||> pair cx
+      | (P.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
+      | (P.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
       | (P.UnitResolution, (p, _) :: ps) =>
-          (([], unit_resolution (thm_of p) (map (thm_of o fst) ps) ct), ptab)
-      | (P.IffTrue, [(p, _)]) => (([], iff_true (thm_of p)), ptab)
-      | (P.IffFalse, [(p, _)]) => (([], iff_false (thm_of p)), ptab)
-      | (P.Distributivity, _) => (([], distributivity ctxt ct), ptab)
-      | (P.DefAxiom, _) => (([], def_axiom ctxt ct), ptab)
-      | (P.IntroDef, _) => (intro_def ct, ptab)
-      | (P.ApplyDef, [(p, _)]) => (([], apply_def (thm_of p)), ptab)
-      | (P.IffOeq, [(p, _)]) => (([], p), ptab)
-      | (P.NnfPos, _) => (([], nnf ctxt vars (map fst ps) ct), ptab)
-      | (P.NnfNeg, _) => (([], nnf ctxt vars (map fst ps) ct), ptab)
+          (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
+      | (P.IffTrue, [(p, _)]) => (iff_true (thm_of p), cxp)
+      | (P.IffFalse, [(p, _)]) => (iff_false (thm_of p), cxp)
+      | (P.Distributivity, _) => (distributivity cx ct, cxp)
+      | (P.DefAxiom, _) => (def_axiom cx ct, cxp)
+      | (P.IntroDef, _) => intro_def ct cx ||> rpair ptab
+      | (P.ApplyDef, [(p, _)]) => (apply_def (thm_of p), cxp)
+      | (P.IffOeq, [(p, _)]) => (p, cxp)
+      | (P.NnfPos, _) => (nnf cx vars (map fst ps) ct, cxp)
+      | (P.NnfNeg, _) => (nnf cx vars (map fst ps) ct, cxp)
 
         (* equality rules *)
-      | (P.Reflexivity, _) => (([], refl ct), ptab)
-      | (P.Symmetry, [(p, _)]) => (([], symm p), ptab)
-      | (P.Transitivity, [(p, _), (q, _)]) => (([], trans p q), ptab)
-      | (P.Monotonicity, _) => (([], monotonicity (map fst ps) ct), ptab)
-      | (P.Commutativity, _) => (([], commutativity ct), ptab)
+      | (P.Reflexivity, _) => (refl ct, cxp)
+      | (P.Symmetry, [(p, _)]) => (symm p, cxp)
+      | (P.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
+      | (P.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
+      | (P.Commutativity, _) => (commutativity ct, cxp)
 
         (* quantifier rules *)
-      | (P.QuantIntro, [(p, _)]) => (([], quant_intro vars p ct), ptab)
-      | (P.PullQuant, _) => (([], pull_quant ctxt ct), ptab)
-      | (P.PushQuant, _) => (([], push_quant ctxt ct), ptab)
-      | (P.ElimUnusedVars, _) => (([], elim_unused_vars ctxt ct), ptab)
-      | (P.DestEqRes, _) => (([], dest_eq_res ctxt ct), ptab)
-      | (P.QuantInst, _) => (([], quant_inst ct), ptab)
-      | (P.Skolemize, _) => (skolemize ctxt ct, ptab)
+      | (P.QuantIntro, [(p, _)]) => (quant_intro vars p ct, cxp)
+      | (P.PullQuant, _) => (pull_quant cx ct, cxp)
+      | (P.PushQuant, _) => (push_quant cx ct, cxp)
+      | (P.ElimUnusedVars, _) => (elim_unused_vars cx ct, cxp)
+      | (P.DestEqRes, _) => (dest_eq_res cx ct, cxp)
+      | (P.QuantInst, _) => (quant_inst ct, cxp)
+      | (P.Skolemize, _) => skolemize ct cx ||> rpair ptab
 
         (* theory rules *)
       | (P.ThLemma, _) =>
-          (([], th_lemma ctxt simpset (map (thm_of o fst) ps) ct), ptab)
-      | (P.Rewrite, _) => (([], rewrite ctxt simpset [] ct), ptab)
+          (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
+      | (P.Rewrite, _) => (rewrite cx simpset [] ct, cxp)
       | (P.RewriteStar, ps) =>
-          (([], rewrite ctxt simpset (map fst ps) ct), ptab)
+          (rewrite cx simpset (map fst ps) ct, cxp)
 
       | (P.NnfStar, _) => not_supported r
       | (P.CnfStar, _) => not_supported r
       | (P.TransitivityStar, _) => not_supported r
       | (P.PullQuantStar, _) => not_supported r
 
-      | _ => z3_exn ("Proof rule " ^ quote (P.string_of_rule r) ^
+      | _ => raise Fail ("Z3: proof rule " ^ quote (P.string_of_rule r) ^
          " has an unexpected number of arguments."))
 
-    fun eq_hyp_def (ct, cu) = Thm.dest_arg1 ct aconvc Thm.dest_arg1 cu
-      (* compare only the defined Frees, not the whole definitions *)
+    fun conclude idx rule prop (ps, cxp) =
+      trace_rule idx step rule ps prop cxp
+      |-> (fn p => apsnd (Inttab.update (idx, Proved p)) #> pair p)
 
-    fun conclude idx rule prop ((hypss, ps), ptab) =
-      trace_rule ctxt idx step rule ps prop ptab
-      |>> apfst (distinct eq_hyp_def o fold append hypss)
-
-    fun add_sequent idx (hyps, thm) ptab =
-      ((hyps, thm), Inttab.update (idx, Sequent {hyps=hyps, thm=thm}) ptab)
-
-    fun lookup idx ptab =
+    fun lookup idx (cxp as (cx, ptab)) =
       (case Inttab.lookup ptab idx of
         SOME (Unproved (P.Proof_Step {rule, prems, prop})) =>
-          fold_map lookup prems ptab
-          |>> split_list
-          |>> apsnd (fn ps => ps ~~ prems)
+          fold_map lookup prems cxp
+          |>> map2 rpair prems
           |> conclude idx rule prop
-          |-> add_sequent idx
-      | SOME (Sequent {hyps, thm}) => ((hyps, thm), ptab)
+      | SOME (Proved p) => (p, cxp)
       | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
 
-    fun result (hyps, thm) =
-      fold SMT_Normalize.discharge_definition hyps (thm_of thm)
+    fun result (p, (cx, _)) = (thm_of p, cx)
   in
-    (fn (idx, ptab) => result (fst (lookup idx (Inttab.map Unproved ptab))))
+    (fn (idx, ptab) => result (lookup idx (ctxt, Inttab.map Unproved ptab)))
   end
 
-fun reconstruct ctxt {typs, terms, unfolds, assms} output =
+fun reconstruct (output, {typs, terms, unfolds, assms}) ctxt =
   P.parse ctxt typs terms output
   |> (fn (idx, (ptab, vars, cx)) => prove cx unfolds assms vars (idx, ptab))
 
--- a/src/HOL/SMT/Tools/z3_proof_tools.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/z3_proof_tools.ML	Wed May 12 23:54:00 2010 +0200
@@ -24,7 +24,7 @@
   val unfold_eqs: Proof.context -> thm list -> conv
   val match_instantiate: (cterm -> cterm) -> cterm -> thm -> thm
   val by_tac: (int -> tactic) -> cterm -> thm
-  val make_hyp_def: thm -> cterm list * thm
+  val make_hyp_def: thm -> Proof.context -> thm * Proof.context
   val by_abstraction: Proof.context -> thm list -> (Proof.context -> cterm ->
     thm) -> cterm -> thm
 
@@ -103,7 +103,7 @@
 fun by_tac tac ct = Goal.norm_result (Goal.prove_internal [] ct (K (tac 1)))
 
 (* |- c x == t x ==> P (c x)  ~~>  c == t |- P (c x) *) 
-fun make_hyp_def thm =
+fun make_hyp_def thm ctxt =
   let
     val (lhs, rhs) = Thm.dest_binop (Thm.cprem_of thm 1)
     val (cf, cvs) = Drule.strip_comb lhs
@@ -111,7 +111,10 @@
     fun apply cv th =
       Thm.combination th (Thm.reflexive cv)
       |> Conv.fconv_rule (Conv.arg_conv (Thm.beta_conversion false))
-  in ([eq], Thm.implies_elim thm (fold apply cvs (Thm.assume eq))) end
+  in
+    yield_singleton Assumption.add_assumes eq ctxt
+    |>> Thm.implies_elim thm o fold apply cvs
+  end
 
 
 
@@ -336,7 +339,7 @@
 in
 
 fun make_simpset ctxt rules = Simplifier.context ctxt (HOL_ss
-  addsimps @{thms ring_distribs} addsimps @{thms field_simps}
+  addsimps @{thms field_simps}
   addsimps [@{thm times_divide_eq_right}, @{thm times_divide_eq_left}]
   addsimps @{thms arith_special} addsimps @{thms less_bin_simps}
   addsimps @{thms le_bin_simps} addsimps @{thms eq_bin_simps}
--- a/src/HOL/SMT/Tools/z3_solver.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/z3_solver.ML	Wed May 12 23:54:00 2010 +0200
@@ -43,7 +43,7 @@
   let val cex = Z3_Model.parse_counterex recon ls
   in raise SMT_Solver.SMT_COUNTEREXAMPLE (real, cex) end
 
-fun check_unsat recon output =
+fun if_unsat f (output, recon) =
   let
     fun jnk l =
       String.isPrefix "WARNING" l orelse
@@ -51,27 +51,23 @@
       forall Symbol.is_ascii_blank (Symbol.explode l)
     val (ls, l) = the_default ([], "") (try split_last (filter_out jnk output))
   in
-    if String.isPrefix "unsat" l then ls
+    if String.isPrefix "unsat" l then f (ls, recon)
     else if String.isPrefix "sat" l then raise_cex true recon ls
     else if String.isPrefix "unknown" l then raise_cex false recon ls
     else raise SMT_Solver.SMT (solver_name ^ " failed")
   end
 
-fun core_oracle ({output, recon, ...} : SMT_Solver.proof_data) =
-  check_unsat recon output
-  |> K @{cprop False}
+val core_oracle = if_unsat (K @{cprop False})
 
-fun prover ({context, output, recon} : SMT_Solver.proof_data) =
-  check_unsat recon output
-  |> Z3_Proof_Reconstruction.reconstruct context recon
+val prover = if_unsat Z3_Proof_Reconstruction.reconstruct
 
 fun solver oracle ctxt =
   let val with_proof = Config.get ctxt proofs
   in
-    {command = {env_var=env_var, remote_name=SOME solver_name},
+   {command = {env_var=env_var, remote_name=SOME solver_name},
     arguments = cmdline_options ctxt,
     interface = Z3_Interface.interface,
-    reconstruct = if with_proof then prover else oracle}
+    reconstruct = if with_proof then prover else pair o oracle}
   end
 
 val setup =