diff -r d602caf11e48 -r d2b3372e6033 src/HOL/Tools/case_translation.ML --- a/src/HOL/Tools/case_translation.ML Fri Apr 05 22:08:42 2013 +0200 +++ b/src/HOL/Tools/case_translation.ML Fri Apr 05 22:08:42 2013 +0200 @@ -15,7 +15,8 @@ val lookup_by_case: Proof.context -> string -> (term * term list) option val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term val print_case_translations: Proof.context -> unit - val strip_case: Proof.context -> bool -> term -> term + val strip_case: Proof.context -> bool -> term -> (term * (term * term) list) option + val strip_case_full: Proof.context -> bool -> term -> term val show_cases: bool Config.T val setup: theory -> theory val register: term -> term list -> Context.generic -> Context.generic @@ -501,25 +502,25 @@ (* destruct nested patterns *) -fun encode_clause S T (pat, rhs) = +fun encode_clause recur S T (pat, rhs) = fold (fn x as (_, U) => fn t => Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t) (Term.add_frees pat []) - (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs); + (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ recur rhs); -fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T) - | encode_cases S T (p :: ps) = +fun encode_cases _ S T [] = Const (@{const_name case_nil}, S --> T) + | encode_cases recur S T (p :: ps) = Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $ - encode_clause S T p $ encode_cases S T ps; + encode_clause recur S T p $ encode_cases recur S T ps; -fun encode_case (t, ps as (pat, rhs) :: _) = +fun encode_case recur (t, ps as (pat, rhs) :: _) = let val T = fastype_of rhs; in Const (@{const_name case_guard}, T --> T) $ - (encode_cases (fastype_of pat) (fastype_of rhs) ps $ t) + (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t) end - | encode_case _ = case_error "encode_case"; + | encode_case _ _ = case_error "encode_case"; fun strip_case' ctxt d (pat, rhs) = (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of @@ -535,14 +536,21 @@ fun strip_case ctxt d t = (case dest_case ctxt d Name.context t of - SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses) + SOME (x, clauses) => SOME (x, maps (strip_case' ctxt d) clauses) + | NONE => NONE); + +fun strip_case_full ctxt d t = + (case dest_case ctxt d Name.context t of + SOME (x, clauses) => + encode_case (strip_case_full ctxt d) + (strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses) | NONE => - (case t of - (t $ u) => strip_case ctxt d t $ strip_case ctxt d u - | (Abs (x, T, u)) => - let val (x', u') = Term.dest_abs (x, T, u); - in Term.absfree (x', T) (strip_case ctxt d u') end - | _ => t)); + (case t of + (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u + | (Abs (x, T, u)) => + let val (x', u') = Term.dest_abs (x, T, u); + in Term.absfree (x', T) (strip_case_full ctxt d u') end + | _ => t)); (* term uncheck *) @@ -550,7 +558,7 @@ val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true); fun uncheck_case ctxt ts = - if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts; + if Config.get ctxt show_cases then map (strip_case_full ctxt true) ts else ts; val term_uncheck_setup = Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);