case translations performed in a separate check phase (with adjustments by traytel)
authorberghofe
Tue, 22 Jan 2013 13:32:41 +0100
changeset 51672 d5c5e088ebdf
parent 51671 0d142a78fb7c
child 51673 4dfa00e264d8
case translations performed in a separate check phase (with adjustments by traytel)
src/HOL/Inductive.thy
src/HOL/List.thy
src/HOL/Tools/Datatype/datatype_case.ML
src/HOL/Tools/Datatype/rep_datatype.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
src/HOL/Tools/Quickcheck/exhaustive_generators.ML
src/HOL/Tools/Quickcheck/narrowing_generators.ML
--- a/src/HOL/Inductive.thy	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Inductive.thy	Tue Jan 22 13:32:41 2013 +0100
@@ -273,7 +273,14 @@
 ML_file "Tools/Datatype/datatype_aux.ML"
 ML_file "Tools/Datatype/datatype_prop.ML"
 ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
+
+consts
+  case_nil :: "'a \<Rightarrow> 'b"
+  case_cons :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
+  case_elem :: "'a \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
+  case_abs :: "('c \<Rightarrow> 'b) \<Rightarrow> 'b"
 ML_file "Tools/Datatype/datatype_case.ML" setup Datatype_Case.setup
+
 ML_file "Tools/Datatype/rep_datatype.ML"
 ML_file "Tools/Datatype/datatype_codegen.ML" setup Datatype_Codegen.setup
 ML_file "Tools/Datatype/primrec.ML"
@@ -290,7 +297,7 @@
   fun fun_tr ctxt [cs] =
     let
       val x = Syntax.free (fst (Name.variant "x" (Term.declare_term_frees cs Name.context)));
-      val ft = Datatype_Case.case_tr true ctxt [x, cs];
+      val ft = Datatype_Case.case_tr ctxt [x, cs];
     in lambda x ft end
 in [(@{syntax_const "_lam_pats_syntax"}, fun_tr)] end
 *}
--- a/src/HOL/List.thy	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/List.thy	Tue Jan 22 13:32:41 2013 +0100
@@ -407,7 +407,7 @@
           Syntax.const @{syntax_const "_case1"} $
             Syntax.const @{const_syntax dummy_pattern} $ NilC;
         val cs = Syntax.const @{syntax_const "_case2"} $ case1 $ case2;
-      in Syntax_Trans.abs_tr [x, Datatype_Case.case_tr false ctxt [x, cs]] end;
+      in Syntax_Trans.abs_tr [x, Datatype_Case.case_tr ctxt [x, cs]] end;
 
     fun abs_tr ctxt p e opti =
       (case Term_Position.strip_positions p of
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Tue Jan 22 13:32:41 2013 +0100
@@ -5,12 +5,6 @@
 Datatype package: nested case expressions on datatypes.
 
 TODO:
-  * Avoid fragile operations on syntax trees (with type constraints
-    getting in the way).  Instead work with auxiliary "destructor"
-    constants in translations and introduce the actual case
-    combinators in a separate term check phase (similar to term
-    abbreviations).
-
   * Avoid hard-wiring with datatype package.  Instead provide generic
     generic declarations of case splits based on an internal data slot.
 *)
@@ -19,12 +13,10 @@
 sig
   datatype config = Error | Warning | Quiet
   type info = Datatype_Aux.info
-  val make_case :  Proof.context -> config -> string list -> term -> (term * term) list -> term
-  val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
-  val case_tr: bool -> Proof.context -> term list -> term
+  val case_tr: Proof.context -> term list -> term
+  val make_case:  Proof.context -> config -> Name.context -> term -> (term * term) list -> term
+  val strip_case: Proof.context -> bool -> term -> term
   val show_cases: bool Config.T
-  val case_tr': string -> Proof.context -> term list -> term
-  val add_case_tr' : string list -> theory -> theory
   val setup: theory -> theory
 end;
 
@@ -36,7 +28,8 @@
 
 exception CASE_ERROR of string * int;
 
-fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
+fun match_type ctxt pat ob =
+  Sign.typ_match (Proof_Context.theory_of ctxt) (pat, ob) Vartab.empty;
 
 (* Get information about datatypes *)
 
@@ -57,101 +50,77 @@
 completion.*)
 
 fun add_row_used ((prfx, pats), (tm, tag)) =
-  fold Term.add_free_names (tm :: pats @ map Free prfx);
+  fold Term.declare_term_frees (tm :: pats @ map Free prfx);
 
-fun default_name name (t, cs) =
-  let
-    val name' = if name = "" then (case t of Free (name', _) => name' | _ => name) else name;
-    val cs' = if is_Free t then cs else filter_out Term_Position.is_position cs;
-  in (name', cs') end;
-
-fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
-      strip_constraints t ||> cons tT
-  | strip_constraints t = (t, []);
-
-fun constrain tT t = Syntax.const @{syntax_const "_constrain"} $ t $ tT;
-fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
+(* try to preserve names given by user *)
+fun default_name "" (Free (name', _)) = name'
+  | default_name name _ = name;
 
 
 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
-fun fresh_constr ty_match ty_inst colty used c =
+fun fresh_constr ctxt colty used c =
   let
     val (_, T) = dest_Const c;
     val Ts = binder_types T;
-    val names =
-      Name.variant_list used (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts));
+    val (names, _) = fold_map Name.variant
+      (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
     val ty = body_type T;
-    val ty_theta = ty_match ty colty
+    val ty_theta = match_type ctxt ty colty
       handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
-    val c' = ty_inst ty_theta c;
-    val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts);
+    val c' = Envir.subst_term_types ty_theta c;
+    val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts);
   in (c', gvars) end;
 
-fun strip_comb_positions tm =
-  let
-    fun result t ts = (Term_Position.strip_positions t, ts);
-    fun strip (t as Const (@{syntax_const "_constrain"}, _) $ _ $ _) ts = result t ts
-      | strip (f $ t) ts = strip f (t :: ts)
-      | strip t ts = result t ts;
-  in strip tm [] end;
-
 (*Go through a list of rows and pick out the ones beginning with a
   pattern with constructor = name.*)
 fun mk_group (name, T) rows =
   let val k = length (binder_types T) in
     fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
-      fn ((in_group, not_in_group), (names, cnstrts)) =>
-        (case strip_comb_positions p of
+      fn ((in_group, not_in_group), names) =>
+        (case strip_comb p of
           (Const (name', _), args) =>
             if name = name' then
               if length args = k then
-                let
-                  val constraints' = map strip_constraints args;
-                  val (args', cnstrts') = split_list constraints';
-                  val (names', cnstrts'') = split_list (map2 default_name names constraints');
-                in
-                  ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
-                   (names', map2 append cnstrts cnstrts''))
-                end
+                ((((prfx, args @ ps), rhs) :: in_group, not_in_group),
+                 map2 default_name names args)
               else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
-            else ((in_group, row :: not_in_group), (names, cnstrts))
+            else ((in_group, row :: not_in_group), names)
         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
-    rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
+    rows (([], []), replicate k "") |>> pairself rev
   end;
 
 
 (* Partitioning *)
 
-fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
-  | partition ty_match ty_inst type_of used constructors colty res_ty
+fun partition _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
+  | partition ctxt used constructors colty res_ty
         (rows as (((prfx, _ :: ps), _) :: _)) =
       let
         fun part [] [] = []
           | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
           | part (c :: cs) rows =
               let
-                val ((in_group, not_in_group), (names, cnstrts)) = mk_group (dest_Const c) rows;
+                val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows;
                 val used' = fold add_row_used in_group used;
-                val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
+                val (c', gvars) = fresh_constr ctxt colty used' c;
                 val in_group' =
                   if null in_group  (* Constructor not given *)
                   then
                     let
-                      val Ts = map type_of ps;
-                      val xs =
-                        Name.variant_list
-                          (fold Term.add_free_names gvars used')
-                          (replicate (length ps) "x");
+                      val Ts = map fastype_of ps;
+                      val (xs, _) =
+                        fold_map Name.variant
+                          (replicate (length ps) "x")
+                          (fold Term.declare_term_frees gvars used');
                     in
                       [((prfx, gvars @ map Free (xs ~~ Ts)),
-                        (Const (@{const_syntax undefined}, res_ty), ~1))]
+                        (Const (@{const_name undefined}, res_ty), ~1))]
                     end
                   else in_group;
               in
                 {constructor = c',
                  new_formals = gvars,
                  names = names,
-                 constraints = cnstrts,
                  group = in_group'} :: part cs not_in_group
               end;
       in part constructors rows end;
@@ -162,7 +131,7 @@
 
 (* Translation of pattern terms into nested case expressions. *)
 
-fun mk_case ctxt ty_match ty_inst type_of used range_ty =
+fun mk_case ctxt used range_ty =
   let
     val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
 
@@ -172,19 +141,19 @@
             let
               val used' = add_row_used row used;
               fun expnd c =
-                let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
+                let val capp = list_comb (fresh_constr ctxt ty used' c)
                 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
             in map expnd constructors end
           else [row];
 
-    val name = singleton (Name.variant_list used) "a";
+    val (name, _) = Name.variant "a" used;
 
     fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
       | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
       | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
       | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
-            (case Option.map (apfst (fst o strip_comb_positions))
+            (case Option.map (apfst head_of)
                 (find_first (not o is_Free o fst) col0) of
               NONE =>
                 let
@@ -197,21 +166,20 @@
                 | SOME {case_name, constructors} =>
                     let
                       val pty = body_type cT;
-                      val used' = fold Term.add_free_names us used;
+                      val used' = fold Term.declare_term_frees us used;
                       val nrows = maps (expand constructors used' pty) rows;
                       val subproblems =
-                        partition ty_match ty_inst type_of used'
-                          constructors pty range_ty nrows;
+                        partition ctxt used' constructors pty range_ty nrows;
                       val (pat_rect, dtrees) =
                         split_list (map (fn {new_formals, group, ...} =>
                           mk (new_formals @ us) group) subproblems);
                       val case_functions =
-                        map2 (fn {new_formals, names, constraints, ...} =>
-                          fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
-                            Abs (if s = "" then name else s, T, abstract_over (x, t))
-                            |> fold constrain_Abs cnstrts) (new_formals ~~ names ~~ constraints))
+                        map2 (fn {new_formals, names, ...} =>
+                          fold_rev (fn (x as Free (_, T), s) => fn t =>
+                            Abs (if s = "" then name else s, T, abstract_over (x, t)))
+                              (new_formals ~~ names))
                         subproblems dtrees;
-                      val types = map type_of (case_functions @ [u]);
+                      val types = map fastype_of (case_functions @ [u]);
                       val case_const = Const (case_name, types ---> range_ty);
                       val tree = list_comb (case_const, case_functions @ [u]);
                     in (flat pat_rect, tree) end)
@@ -221,9 +189,19 @@
       | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
   in mk end;
 
-fun case_error s = error ("Error in case expression:\n" ^ s);
 
-local
+(* replace occurrences of dummy_pattern by distinct variables *)
+fun replace_dummies (Const (@{const_name dummy_pattern}, T)) used =
+      let val (x, used') = Name.variant "x" used
+      in (Free (x, T), used') end
+  | replace_dummies (t $ u) used =
+      let
+        val (t', used') = replace_dummies t used;
+        val (u', used'') = replace_dummies u used';
+      in (t' $ u', used'') end
+  | replace_dummies t used = (t, used);
+
+fun case_error s = error ("Error in case expression:\n" ^ s);
 
 (*Repeated variable occurrences in a pattern are not allowed.*)
 fun no_repeat_vars ctxt pat = fold_aterms
@@ -233,22 +211,28 @@
           case_error (quote s ^ " occurs repeatedly in the pattern " ^
             quote (Syntax.string_of_term ctxt pat))
         else x :: xs)
-    | _ => I) (Term_Position.strip_positions pat) [];
+    | _ => I) pat [];
 
-fun gen_make_case ty_match ty_inst type_of ctxt config used x clauses =
+fun make_case ctxt config used x clauses =
   let
     fun string_of_clause (pat, rhs) =
       Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
     val _ = map (no_repeat_vars ctxt o fst) clauses;
-    val rows = map_index (fn (i, (pat, rhs)) => (([], [pat]), (rhs, i))) clauses;
+    val (rows, used') = used |>
+      fold (fn (pat, rhs) =>
+        Term.declare_term_frees pat #> Term.declare_term_frees rhs) clauses |>
+      fold_map (fn (i, (pat, rhs)) => fn used =>
+        let val (pat', used') = replace_dummies pat used
+        in ((([], [pat']), (rhs, i)), used') end)
+          (map_index I clauses);
     val rangeT =
-      (case distinct (op =) (map (type_of o snd) clauses) of
+      (case distinct (op =) (map (fastype_of o snd) clauses) of
         [] => case_error "no clauses given"
       | [T] => T
       | _ => case_error "all cases must have the same result type");
     val used' = fold add_row_used rows used;
     val (tags, case_tm) =
-      mk_case ctxt ty_match ty_inst type_of used' rangeT [x] rows
+      mk_case ctxt used' rangeT [x] rows
         handle CASE_ERROR (msg, i) =>
           case_error
             (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
@@ -263,77 +247,88 @@
     case_tm
   end;
 
-in
+
+(* term check *)
+
+fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used =
+      let val (s', used') = Name.variant s used
+      in decode_clause t (Free (s', T) :: xs) used' end
+  | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ =
+      (subst_bounds (xs, t), subst_bounds (xs, u))
+  | decode_clause _ _ _ = case_error "decode_clause";
+
+fun decode_cases (Const (@{const_name case_nil}, _)) = []
+  | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) =
+      decode_clause t [] (Term.declare_term_frees t Name.context) ::
+      decode_cases u
+  | decode_cases _ = case_error "decode_cases";
 
-fun make_case ctxt =
-  gen_make_case (match_type (Proof_Context.theory_of ctxt))
-    Envir.subst_term_types fastype_of ctxt;
+fun check_case ctxt =
+  let
+    fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) =
+        make_case ctxt Error Name.context (decode_case u) (decode_cases t)
+    | decode_case (t $ u) = decode_case t $ decode_case u
+    | decode_case (Abs (x, T, u)) =
+        let val (x', u') = Term.dest_abs (x, T, u);
+        in Term.absfree (x', T) (decode_case u') end
+    | decode_case t = t;
+  in
+    map decode_case
+  end;
 
-val make_case_untyped =
-  gen_make_case (K (K Vartab.empty)) (K (Term.map_types (K dummyT))) (K dummyT);
-
-end;
+val term_check_setup =
+  Context.theory_map (Syntax_Phases.term_check 1 "case" check_case);
 
 
 (* parse translation *)
 
-fun case_tr err ctxt [t, u] =
+fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
+
+fun case_tr ctxt [t, u] =
       let
         val thy = Proof_Context.theory_of ctxt;
-        val intern_const_syntax = Consts.intern_syntax (Proof_Context.consts_of ctxt);
+
+        fun is_const s =
+          Sign.declared_const thy (Proof_Context.intern_const ctxt s);
+
+        fun abs p tTs t = Syntax.const @{const_syntax case_abs} $
+          fold constrain_Abs tTs (absfree p t);
 
-        (* replace occurrences of dummy_pattern by distinct variables *)
-        (* internalize constant names                                 *)
-        (* FIXME proper name context!? *)
-        fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
-              let val (t', used') = prep_pat t used
-              in (c $ t' $ tT, used') end
-          | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
-              let val x = singleton (Name.variant_list used) "x"
-              in (Free (x, T), x :: used) end
-          | prep_pat (Const (s, T)) used = (Const (intern_const_syntax s, T), used)
-          | prep_pat (v as Free (s, T)) used =
-              let val s' = Proof_Context.intern_const ctxt s in
-                if Sign.declared_const thy s' then (Const (s', T), used)
-                else (v, used)
-              end
-          | prep_pat (t $ u) used =
-              let
-                val (t', used') = prep_pat t used;
-                val (u', used'') = prep_pat u used';
-              in (t' $ u', used'') end
-          | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
+        fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs)
+          | abs_pat (Free (p as (x, _))) tTs =
+              if is_const x then I else abs p tTs
+          | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t []
+          | abs_pat _ _ = I;
 
-        fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
-              let val (l', cnstrts) = strip_constraints l
-              in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
-          | dest_case1 t = case_error "dest_case1";
+        fun dest_case1 (Const (@{syntax_const "_case1"}, _) $ l $ r) =
+              abs_pat l []
+                (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l $ r)
+          | dest_case1 _ = case_error "dest_case1";
 
         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
           | dest_case2 t = [t];
-
-        val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
       in
-        make_case_untyped ctxt
-          (if err then Error else Warning) []
-          (fold constrain (filter_out Term_Position.is_position (flat cnstrts)) t)
-          cases
+        fold_rev
+          (fn t => fn u =>
+             Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
+          (dest_case2 u)
+          (Syntax.const @{const_syntax case_nil}) $ t
       end
-  | case_tr _ _ _ = case_error "case_tr";
+  | case_tr _ _ = case_error "case_tr";
 
 val trfun_setup =
   Sign.add_advanced_trfuns ([],
-    [(@{syntax_const "_case_syntax"}, case_tr true)],
+    [(@{syntax_const "_case_syntax"}, case_tr)],
     [], []);
 
 
 (* Pretty printing of nested case expressions *)
 
+val name_of = try (dest_Const #> fst);
+
 (* destruct one level of pattern matching *)
 
-local
-
-fun gen_dest_case name_of type_of ctxt d used t =
+fun dest_case ctxt d used t =
   (case apfst name_of (strip_comb t) of
     (SOME cname, ts as _ :: _) =>
       let
@@ -371,12 +366,12 @@
                     val k = length Us;
                     val p as (xs, _) = strip_abs k Us t;
                   in
-                    (Const (s, map type_of xs ---> type_of x), p, is_dependent k t)
+                    (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t)
                   end) (constructors ~~ fs);
                 val cases' =
                   sort (int_ord o swap o pairself (length o snd))
                     (fold_rev count_cases cases []);
-                val R = type_of t;
+                val R = fastype_of t;
                 val dummy =
                   if d then Term.dummy_pattern R
                   else Free (Name.variant "x" used |> fst, R);
@@ -403,81 +398,93 @@
       end
   | _ => NONE);
 
-in
-
-val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
-val dest_case' = gen_dest_case (try (dest_Const #> fst #> Lexicon.unmark_const)) (K dummyT);
-
-end;
-
 
 (* destruct nested patterns *)
 
-local
+fun encode_clause S T (pat, rhs) =
+  fold (fn x as (_, U) => fn t =>
+    Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
+      (Term.add_frees pat [])
+      (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
 
-fun strip_case'' dest (pat, rhs) =
-  (case dest (Term.declare_term_frees pat Name.context) rhs of
+fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
+  | encode_cases S T (p :: ps) =
+      Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
+        encode_clause S T p $ encode_cases S T ps;
+
+fun encode_case (t, ps as (pat, rhs) :: _) =
+      encode_cases (fastype_of pat) (fastype_of rhs) ps $ t
+  | encode_case _ = case_error "encode_case";
+
+fun strip_case' ctxt d (pat, rhs) =
+  (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
     SOME (exp as Free _, clauses) =>
       if Term.exists_subterm (curry (op aconv) exp) pat andalso
         not (exists (fn (_, rhs') =>
           Term.exists_subterm (curry (op aconv) exp) rhs') clauses)
       then
-        maps (strip_case'' dest) (map (fn (pat', rhs') =>
+        maps (strip_case' ctxt d) (map (fn (pat', rhs') =>
           (subst_free [(exp, pat')] pat, rhs')) clauses)
       else [(pat, rhs)]
   | _ => [(pat, rhs)]);
 
-fun gen_strip_case dest t =
-  (case dest Name.context t of
-    SOME (x, clauses) => SOME (x, maps (strip_case'' dest) clauses)
-  | NONE => NONE);
+fun strip_case ctxt d t =
+  (case dest_case ctxt d Name.context t of
+    SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
+  | NONE =>
+    (case t of
+      (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
+    | (Abs (x, T, u)) =>
+        let val (x', u') = Term.dest_abs (x, T, u);
+        in Term.absfree (x', T) (strip_case ctxt d u') end
+    | _ => t));
 
-in
+
+(* term uncheck *)
+
+val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
 
-val strip_case = gen_strip_case oo dest_case;
-val strip_case' = gen_strip_case oo dest_case';
+fun uncheck_case ctxt ts =
+  if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
 
-end;
+val term_uncheck_setup =
+  Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);
 
 
 (* print translation *)
 
-val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
+fun case_tr' [t, u, x] =
+      let
+        fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
+              let val (s', used') = Name.variant s used
+              in mk_clause t ((s', T) :: xs) used' end
+          | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
+              Syntax.const @{syntax_const "_case1"} $
+                subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
+                subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
 
-fun case_tr' cname ctxt ts =
-  if Config.get ctxt show_cases then
-    let
-      fun mk_clause (pat, rhs) =
-        let val xs = Term.add_frees pat [] in
-          Syntax.const @{syntax_const "_case1"} $
-            map_aterms
-              (fn Free p => Syntax_Trans.mark_bound_abs p
-                | Const (s, _) => Syntax.const (Lexicon.mark_const s)
-                | t => t) pat $
-            map_aterms
-              (fn x as Free v =>
-                  if member (op =) xs v then Syntax_Trans.mark_bound_body v else x
-                | t => t) rhs
-        end;
-    in
-      (case strip_case' ctxt true (list_comb (Syntax.const cname, ts)) of
-        SOME (x, clauses) =>
-          Syntax.const @{syntax_const "_case_syntax"} $ x $
-            foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
-              (map mk_clause clauses)
-      | NONE => raise Match)
-    end
-  else raise Match;
+        fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
+          | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
+              mk_clauses' t u
+        and mk_clauses' t u =
+              mk_clause t [] (Term.declare_term_frees t Name.context) ::
+              mk_clauses u
+      in
+        Syntax.const @{syntax_const "_case_syntax"} $ x $
+          foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
+            (mk_clauses' t u)
+      end;
 
-fun add_case_tr' case_names thy =
-  Sign.add_advanced_trfuns ([], [],
-    map (fn case_name =>
-      let val case_name' = Lexicon.mark_const case_name
-      in (case_name', case_tr' case_name') end) case_names, []) thy;
+val trfun_setup' = Sign.add_trfuns
+  ([], [], [(@{const_syntax "case_cons"}, case_tr')], []);
 
 
 (* theory setup *)
 
-val setup = trfun_setup;
+val setup =
+  trfun_setup #>
+  trfun_setup' #>
+  term_check_setup #>
+  term_uncheck_setup;
 
 end;
--- a/src/HOL/Tools/Datatype/rep_datatype.ML	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Tools/Datatype/rep_datatype.ML	Tue Jan 22 13:32:41 2013 +0100
@@ -536,7 +536,6 @@
     |> snd
     |> Datatype_Data.register dt_infos
     |> Datatype_Data.interpretation_data (config, dt_names)
-    |> Datatype_Case.add_case_tr' case_names
     |> pair dt_names
   end;
 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Tue Jan 22 13:32:41 2013 +0100
@@ -640,7 +640,7 @@
     val v = Free (name, T);
     val v' = Free (name', T);
   in
-    lambda v (Datatype_Case.make_case ctxt Datatype_Case.Quiet [] v
+    lambda v (Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context v
       [(HOLogic.mk_tuple out_ts,
         if null eqs'' then success_t
         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
@@ -916,7 +916,7 @@
         in
           (pattern, compilation)
         end
-        val switch = Datatype_Case.make_case ctxt Datatype_Case.Quiet [] inp_var
+        val switch = Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context inp_var
           ((map compile_single_case switched_clauses) @
             [(xt, mk_empty compfuns (HOLogic.mk_tupleT outTs))])
       in
--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Tue Jan 22 13:32:41 2013 +0100
@@ -292,8 +292,9 @@
                 val bound_vars' = union (op =) (vars_of lhs) (union (op =) varnames bound_vars)
                 val cont_t = mk_smart_test_term' concl bound_vars' (new_assms @ assms) genuine
               in
-                mk_test (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
-                  [(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
+                mk_test (vars_of lhs,
+                  Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context lhs
+                    [(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
               end
             else c (assm, assms)
         fun default (assm, assms) =
--- a/src/HOL/Tools/Quickcheck/narrowing_generators.ML	Wed Apr 10 13:10:38 2013 +0200
+++ b/src/HOL/Tools/Quickcheck/narrowing_generators.ML	Tue Jan 22 13:32:41 2013 +0100
@@ -427,7 +427,7 @@
   end
 
 fun mk_case_term ctxt p ((@{const_name Ex}, (x, T)) :: qs') (Existential_Counterexample cs) =
-    Datatype_Case.make_case ctxt Datatype_Case.Quiet [] (Free (x, T)) (map (fn (t, c) =>
+    Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context (Free (x, T)) (map (fn (t, c) =>
       (t, mk_case_term ctxt (p - 1) qs' c)) cs)
   | mk_case_term ctxt p ((@{const_name All}, _) :: qs') (Universal_Counterexample (t, c)) =
     if p = 0 then t else mk_case_term ctxt (p - 1) qs' c