disallow coercions to interfere with case translations
authortraytel
Sat, 06 Apr 2013 01:42:07 +0200
changeset 51679 e7316560928b
parent 51678 1e33b81c328a
child 51680 8b8cd5a527bc
disallow coercions to interfere with case translations
src/HOL/Inductive.thy
src/HOL/Tools/case_translation.ML
--- a/src/HOL/Inductive.thy	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Inductive.thy	Sat Apr 06 01:42:07 2013 +0200
@@ -275,11 +275,16 @@
 ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
 
 consts
-  case_guard :: "bool \<Rightarrow> 'a \<Rightarrow> 'a"
+  case_guard :: "bool \<Rightarrow> 'a \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'b"
   case_nil :: "'a \<Rightarrow> 'b"
   case_cons :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
   case_elem :: "'a \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
   case_abs :: "('c \<Rightarrow> 'b) \<Rightarrow> 'b"
+declare [[coercion_args case_guard - + -]]
+declare [[coercion_args case_cons - -]]
+declare [[coercion_args case_abs -]]
+declare [[coercion_args case_elem - +]]
+
 ML_file "Tools/case_translation.ML"
 setup Case_Translation.setup
 
--- a/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML	Sat Apr 06 01:42:07 2013 +0200
@@ -126,11 +126,11 @@
 
         val errt = if err then @{term True} else @{term False};
       in
-        Syntax.const @{const_syntax case_guard} $ errt $ (fold_rev
+        Syntax.const @{const_syntax case_guard} $ errt $ t $ (fold_rev
           (fn t => fn u =>
              Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
           (dest_case2 u)
-          (Syntax.const @{const_syntax case_nil}) $ t)
+          (Syntax.const @{const_syntax case_nil}))
       end
   | case_tr _ _ _ = case_error "case_tr";
 
@@ -142,9 +142,8 @@
 
 (* print translation *)
 
-fun case_tr' [_, tx] =
+fun case_tr' [_, x, t] =
       let
-        val (t, x) = Term.dest_comb tx;
         fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
               let val (s', used') = Name.variant s used
               in mk_clause t ((s', T) :: xs) used' end
@@ -412,7 +411,7 @@
 
 fun check_case ctxt =
   let
-    fun decode_case (Const (@{const_name case_guard}, _) $ b $ (t $ u)) =
+    fun decode_case (Const (@{const_name case_guard}, _) $ b $ u $ t) =
           make_case ctxt (if b = @{term True} then Error else Warning)
             Name.context (decode_case u) (decode_cases t)
       | decode_case (t $ u) = decode_case t $ decode_case u
@@ -518,10 +517,11 @@
 
 fun encode_case recur (t, ps as (pat, rhs) :: _) =
     let
+      val tT = fastype_of t;
       val T = fastype_of rhs;
     in
-      Const (@{const_name case_guard}, @{typ bool} --> T --> T) $ @{term True} $
-        (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
+      Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
+        @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
     end
   | encode_case _ _ = case_error "encode_case";