src/HOL/Tools/case_translation.ML
changeset 51678 1e33b81c328a
parent 51677 d2b3372e6033
child 51679 e7316560928b
--- a/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
@@ -9,7 +9,7 @@
 signature CASE_TRANSLATION =
 sig
   datatype config = Error | Warning | Quiet
-  val case_tr: Proof.context -> term list -> term
+  val case_tr: bool -> Proof.context -> term list -> term
   val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option
   val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option
   val lookup_by_case: Proof.context -> string -> (term * term list) option
@@ -100,7 +100,7 @@
 
 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
 
-fun case_tr ctxt [t, u] =
+fun case_tr err ctxt [t, u] =
       let
         val thy = Proof_Context.theory_of ctxt;
 
@@ -123,24 +123,26 @@
 
         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
           | dest_case2 t = [t];
+
+        val errt = if err then @{term True} else @{term False};
       in
-        Syntax.const @{const_syntax case_guard} $ (fold_rev
+        Syntax.const @{const_syntax case_guard} $ errt $ (fold_rev
           (fn t => fn u =>
              Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
           (dest_case2 u)
           (Syntax.const @{const_syntax case_nil}) $ t)
       end
-  | case_tr _ _ = case_error "case_tr";
+  | case_tr _ _ _ = case_error "case_tr";
 
 val trfun_setup =
   Sign.add_advanced_trfuns ([],
-    [(@{syntax_const "_case_syntax"}, case_tr)],
+    [(@{syntax_const "_case_syntax"}, case_tr true)],
     [], []);
 
 
 (* print translation *)
 
-fun case_tr' [tx] =
+fun case_tr' [_, tx] =
       let
         val (t, x) = Term.dest_comb tx;
         fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
@@ -410,8 +412,9 @@
 
 fun check_case ctxt =
   let
-    fun decode_case (Const (@{const_name case_guard}, _) $ (t $ u)) =
-          make_case ctxt Error Name.context (decode_case u) (decode_cases t)
+    fun decode_case (Const (@{const_name case_guard}, _) $ b $ (t $ u)) =
+          make_case ctxt (if b = @{term True} then Error else Warning)
+            Name.context (decode_case u) (decode_cases t)
       | decode_case (t $ u) = decode_case t $ decode_case u
       | decode_case (Abs (x, T, u)) =
           let val (x', u') = Term.dest_abs (x, T, u);
@@ -517,7 +520,7 @@
     let
       val T = fastype_of rhs;
     in
-      Const (@{const_name case_guard}, T --> T) $
+      Const (@{const_name case_guard}, @{typ bool} --> T --> T) $ @{term True} $
         (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
     end
   | encode_case _ _ = case_error "encode_case";