src/HOL/Tools/case_translation.ML
changeset 51677 d2b3372e6033
parent 51676 d602caf11e48
child 51678 1e33b81c328a
--- a/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
@@ -15,7 +15,8 @@
   val lookup_by_case: Proof.context -> string -> (term * term list) option
   val make_case:  Proof.context -> config -> Name.context -> term -> (term * term) list -> term
   val print_case_translations: Proof.context -> unit
-  val strip_case: Proof.context -> bool -> term -> term
+  val strip_case: Proof.context -> bool -> term -> (term * (term * term) list) option
+  val strip_case_full: Proof.context -> bool -> term -> term
   val show_cases: bool Config.T
   val setup: theory -> theory
   val register: term -> term list -> Context.generic -> Context.generic
@@ -501,25 +502,25 @@
 
 (* destruct nested patterns *)
 
-fun encode_clause S T (pat, rhs) =
+fun encode_clause recur S T (pat, rhs) =
   fold (fn x as (_, U) => fn t =>
     Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
       (Term.add_frees pat [])
-      (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
+      (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ recur rhs);
 
-fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
-  | encode_cases S T (p :: ps) =
+fun encode_cases _ S T [] = Const (@{const_name case_nil}, S --> T)
+  | encode_cases recur S T (p :: ps) =
       Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
-        encode_clause S T p $ encode_cases S T ps;
+        encode_clause recur S T p $ encode_cases recur S T ps;
 
-fun encode_case (t, ps as (pat, rhs) :: _) =
+fun encode_case recur (t, ps as (pat, rhs) :: _) =
     let
       val T = fastype_of rhs;
     in
       Const (@{const_name case_guard}, T --> T) $
-        (encode_cases (fastype_of pat) (fastype_of rhs) ps $ t)
+        (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
     end
-  | encode_case _ = case_error "encode_case";
+  | encode_case _ _ = case_error "encode_case";
 
 fun strip_case' ctxt d (pat, rhs) =
   (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
@@ -535,14 +536,21 @@
 
 fun strip_case ctxt d t =
   (case dest_case ctxt d Name.context t of
-    SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
+    SOME (x, clauses) => SOME (x, maps (strip_case' ctxt d) clauses)
+  | NONE => NONE);
+
+fun strip_case_full ctxt d t =
+  (case dest_case ctxt d Name.context t of
+    SOME (x, clauses) =>
+      encode_case (strip_case_full ctxt d)
+        (strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses)
   | NONE =>
-    (case t of
-      (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
-    | (Abs (x, T, u)) =>
-        let val (x', u') = Term.dest_abs (x, T, u);
-        in Term.absfree (x', T) (strip_case ctxt d u') end
-    | _ => t));
+      (case t of
+        (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u
+      | (Abs (x, T, u)) =>
+          let val (x', u') = Term.dest_abs (x, T, u);
+          in Term.absfree (x', T) (strip_case_full ctxt d u') end
+      | _ => t));
 
 
 (* term uncheck *)
@@ -550,7 +558,7 @@
 val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
 
 fun uncheck_case ctxt ts =
-  if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
+  if Config.get ctxt show_cases then map (strip_case_full ctxt true) ts else ts;
 
 val term_uncheck_setup =
   Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);