--- a/src/HOL/Inductive.thy Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Inductive.thy Sat Apr 06 01:42:07 2013 +0200
@@ -275,11 +275,16 @@
ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
consts
- case_guard :: "bool \<Rightarrow> 'a \<Rightarrow> 'a"
+ case_guard :: "bool \<Rightarrow> 'a \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'b"
case_nil :: "'a \<Rightarrow> 'b"
case_cons :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
case_elem :: "'a \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
case_abs :: "('c \<Rightarrow> 'b) \<Rightarrow> 'b"
+declare [[coercion_args case_guard - + -]]
+declare [[coercion_args case_cons - -]]
+declare [[coercion_args case_abs -]]
+declare [[coercion_args case_elem - +]]
+
ML_file "Tools/case_translation.ML"
setup Case_Translation.setup
--- a/src/HOL/Tools/case_translation.ML Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML Sat Apr 06 01:42:07 2013 +0200
@@ -126,11 +126,11 @@
val errt = if err then @{term True} else @{term False};
in
- Syntax.const @{const_syntax case_guard} $ errt $ (fold_rev
+ Syntax.const @{const_syntax case_guard} $ errt $ t $ (fold_rev
(fn t => fn u =>
Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
(dest_case2 u)
- (Syntax.const @{const_syntax case_nil}) $ t)
+ (Syntax.const @{const_syntax case_nil}))
end
| case_tr _ _ _ = case_error "case_tr";
@@ -142,9 +142,8 @@
(* print translation *)
-fun case_tr' [_, tx] =
+fun case_tr' [_, x, t] =
let
- val (t, x) = Term.dest_comb tx;
fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
let val (s', used') = Name.variant s used
in mk_clause t ((s', T) :: xs) used' end
@@ -412,7 +411,7 @@
fun check_case ctxt =
let
- fun decode_case (Const (@{const_name case_guard}, _) $ b $ (t $ u)) =
+ fun decode_case (Const (@{const_name case_guard}, _) $ b $ u $ t) =
make_case ctxt (if b = @{term True} then Error else Warning)
Name.context (decode_case u) (decode_cases t)
| decode_case (t $ u) = decode_case t $ decode_case u
@@ -518,10 +517,11 @@
fun encode_case recur (t, ps as (pat, rhs) :: _) =
let
+ val tT = fastype_of t;
val T = fastype_of rhs;
in
- Const (@{const_name case_guard}, @{typ bool} --> T --> T) $ @{term True} $
- (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
+ Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
+ @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
end
| encode_case _ _ = case_error "encode_case";