Made gen_dest_case more robust against eta contraction
authorberghofe
Fri, 30 Dec 2011 18:12:00 +0100
changeset 46059 f805747f8571
parent 46057 8664713db181
child 46060 f94b7179a75d
Made gen_dest_case more robust against eta contraction
src/HOL/Tools/Datatype/datatype_case.ML
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 30 11:11:57 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 30 18:12:00 2011 +0100
@@ -310,22 +310,25 @@
 
 local
 
-(* FIXME proper name context!? *)
 fun gen_dest_case name_of type_of ctxt d used t =
   (case apfst name_of (strip_comb t) of
     (SOME cname, ts as _ :: _) =>
       let
         val (fs, x) = split_last ts;
-        fun strip_abs i t =
+        fun strip_abs i Us t =
           let
             val zs = strip_abs_vars t;
-            val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
-            val (xs, ys) = chop i zs;
+            val j = length zs;
+            val (xs, ys) =
+              if j < i then (zs @ map (pair "x") (drop j Us), [])
+              else chop i zs;
             val u = list_abs (ys, strip_abs_body t);
-            val xs' =
-              map Free
-                (Name.variant_list (Misc_Legacy.add_term_names (u, used)) (map #1 xs) ~~ map #2 xs);
-          in (xs', subst_bounds (rev xs', u)) end;
+            val xs' = map Free
+              ((fold_map Name.variant (map fst xs)
+                  (Term.declare_term_names u used) |> fst) ~~
+               map snd xs);
+            val (xs1, xs2) = chop j xs'
+          in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end;
         fun is_dependent i t =
           let val k = length (strip_abs_vars t) - i
           in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
@@ -341,8 +344,9 @@
               let
                 val cases = map (fn (Const (s, U), t) =>
                   let
-                    val k = length (binder_types U);
-                    val p as (xs, _) = strip_abs k t;
+                    val Us = binder_types U;
+                    val k = length Us;
+                    val p as (xs, _) = strip_abs k Us t;
                   in
                     (Const (s, map type_of xs ---> type_of x), p, is_dependent k t)
                   end) (constructors ~~ fs);
@@ -352,7 +356,7 @@
                 val R = type_of t;
                 val dummy =
                   if d then Term.dummy_pattern R
-                  else Free (singleton (Name.variant_list used) "x", R);
+                  else Free (Name.variant "x" used |> fst, R);
               in
                 SOME (x,
                   map mk_case
@@ -370,7 +374,7 @@
                             else
                               filter_out (fn (c, _, _) => member op aconv cs c) cases @
                                 [(dummy, ([], default), false)])))
-              end handle CASE_ERROR _ => NONE
+              end
             else NONE
         | _ => NONE)
       end
@@ -389,7 +393,7 @@
 local
 
 fun strip_case'' dest (pat, rhs) =
-  (case dest (Term.add_free_names pat []) rhs of
+  (case dest (Term.declare_term_frees pat Name.context) rhs of
     SOME (exp as Free _, clauses) =>
       if Term.exists_subterm (curry (op aconv) exp) pat andalso
         not (exists (fn (_, rhs') =>
@@ -401,7 +405,7 @@
   | _ => [(pat, rhs)]);
 
 fun gen_strip_case dest t =
-  (case dest [] t of
+  (case dest Name.context t of
     SOME (x, clauses) => SOME (x, maps (strip_case'' dest) clauses)
   | NONE => NONE);