recur in the expression to be matched (do not rely on repetitive execution of a check phase);
separate ML-interface function that is not recurring (strip_case)
--- a/src/HOL/Tools/case_translation.ML Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML Fri Apr 05 22:08:42 2013 +0200
@@ -15,7 +15,8 @@
val lookup_by_case: Proof.context -> string -> (term * term list) option
val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term
val print_case_translations: Proof.context -> unit
- val strip_case: Proof.context -> bool -> term -> term
+ val strip_case: Proof.context -> bool -> term -> (term * (term * term) list) option
+ val strip_case_full: Proof.context -> bool -> term -> term
val show_cases: bool Config.T
val setup: theory -> theory
val register: term -> term list -> Context.generic -> Context.generic
@@ -501,25 +502,25 @@
(* destruct nested patterns *)
-fun encode_clause S T (pat, rhs) =
+fun encode_clause recur S T (pat, rhs) =
fold (fn x as (_, U) => fn t =>
Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
(Term.add_frees pat [])
- (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
+ (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ recur rhs);
-fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
- | encode_cases S T (p :: ps) =
+fun encode_cases _ S T [] = Const (@{const_name case_nil}, S --> T)
+ | encode_cases recur S T (p :: ps) =
Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
- encode_clause S T p $ encode_cases S T ps;
+ encode_clause recur S T p $ encode_cases recur S T ps;
-fun encode_case (t, ps as (pat, rhs) :: _) =
+fun encode_case recur (t, ps as (pat, rhs) :: _) =
let
val T = fastype_of rhs;
in
Const (@{const_name case_guard}, T --> T) $
- (encode_cases (fastype_of pat) (fastype_of rhs) ps $ t)
+ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
end
- | encode_case _ = case_error "encode_case";
+ | encode_case _ _ = case_error "encode_case";
fun strip_case' ctxt d (pat, rhs) =
(case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
@@ -535,14 +536,21 @@
fun strip_case ctxt d t =
(case dest_case ctxt d Name.context t of
- SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
+ SOME (x, clauses) => SOME (x, maps (strip_case' ctxt d) clauses)
+ | NONE => NONE);
+
+fun strip_case_full ctxt d t =
+ (case dest_case ctxt d Name.context t of
+ SOME (x, clauses) =>
+ encode_case (strip_case_full ctxt d)
+ (strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses)
| NONE =>
- (case t of
- (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
- | (Abs (x, T, u)) =>
- let val (x', u') = Term.dest_abs (x, T, u);
- in Term.absfree (x', T) (strip_case ctxt d u') end
- | _ => t));
+ (case t of
+ (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u
+ | (Abs (x, T, u)) =>
+ let val (x', u') = Term.dest_abs (x, T, u);
+ in Term.absfree (x', T) (strip_case_full ctxt d u') end
+ | _ => t));
(* term uncheck *)
@@ -550,7 +558,7 @@
val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
fun uncheck_case ctxt ts =
- if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
+ if Config.get ctxt show_cases then map (strip_case_full ctxt true) ts else ts;
val term_uncheck_setup =
Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);