allow redundant cases in the list comprehension translation
authortraytel
Fri, 05 Apr 2013 22:08:42 +0200
changeset 51678 1e33b81c328a
parent 51677 d2b3372e6033
child 51679 e7316560928b
allow redundant cases in the list comprehension translation
src/HOL/Inductive.thy
src/HOL/List.thy
src/HOL/Tools/case_translation.ML
--- a/src/HOL/Inductive.thy	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Inductive.thy	Fri Apr 05 22:08:42 2013 +0200
@@ -275,7 +275,7 @@
 ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
 
 consts
-  case_guard :: "'a \<Rightarrow> 'a"
+  case_guard :: "bool \<Rightarrow> 'a \<Rightarrow> 'a"
   case_nil :: "'a \<Rightarrow> 'b"
   case_cons :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
   case_elem :: "'a \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
@@ -299,7 +299,7 @@
   fun fun_tr ctxt [cs] =
     let
       val x = Syntax.free (fst (Name.variant "x" (Term.declare_term_frees cs Name.context)));
-      val ft = Case_Translation.case_tr ctxt [x, cs];
+      val ft = Case_Translation.case_tr true ctxt [x, cs];
     in lambda x ft end
 in [(@{syntax_const "_lam_pats_syntax"}, fun_tr)] end
 *}
--- a/src/HOL/List.thy	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/List.thy	Fri Apr 05 22:08:42 2013 +0200
@@ -407,7 +407,7 @@
           Syntax.const @{syntax_const "_case1"} $
             Syntax.const @{const_syntax dummy_pattern} $ NilC;
         val cs = Syntax.const @{syntax_const "_case2"} $ case1 $ case2;
-      in Syntax_Trans.abs_tr [x, Case_Translation.case_tr ctxt [x, cs]] end;
+      in Syntax_Trans.abs_tr [x, Case_Translation.case_tr false ctxt [x, cs]] end;
 
     fun abs_tr ctxt p e opti =
       (case Term_Position.strip_positions p of
--- a/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
@@ -9,7 +9,7 @@
 signature CASE_TRANSLATION =
 sig
   datatype config = Error | Warning | Quiet
-  val case_tr: Proof.context -> term list -> term
+  val case_tr: bool -> Proof.context -> term list -> term
   val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option
   val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option
   val lookup_by_case: Proof.context -> string -> (term * term list) option
@@ -100,7 +100,7 @@
 
 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
 
-fun case_tr ctxt [t, u] =
+fun case_tr err ctxt [t, u] =
       let
         val thy = Proof_Context.theory_of ctxt;
 
@@ -123,24 +123,26 @@
 
         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
           | dest_case2 t = [t];
+
+        val errt = if err then @{term True} else @{term False};
       in
-        Syntax.const @{const_syntax case_guard} $ (fold_rev
+        Syntax.const @{const_syntax case_guard} $ errt $ (fold_rev
           (fn t => fn u =>
              Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
           (dest_case2 u)
           (Syntax.const @{const_syntax case_nil}) $ t)
       end
-  | case_tr _ _ = case_error "case_tr";
+  | case_tr _ _ _ = case_error "case_tr";
 
 val trfun_setup =
   Sign.add_advanced_trfuns ([],
-    [(@{syntax_const "_case_syntax"}, case_tr)],
+    [(@{syntax_const "_case_syntax"}, case_tr true)],
     [], []);
 
 
 (* print translation *)
 
-fun case_tr' [tx] =
+fun case_tr' [_, tx] =
       let
         val (t, x) = Term.dest_comb tx;
         fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
@@ -410,8 +412,9 @@
 
 fun check_case ctxt =
   let
-    fun decode_case (Const (@{const_name case_guard}, _) $ (t $ u)) =
-          make_case ctxt Error Name.context (decode_case u) (decode_cases t)
+    fun decode_case (Const (@{const_name case_guard}, _) $ b $ (t $ u)) =
+          make_case ctxt (if b = @{term True} then Error else Warning)
+            Name.context (decode_case u) (decode_cases t)
       | decode_case (t $ u) = decode_case t $ decode_case u
       | decode_case (Abs (x, T, u)) =
           let val (x', u') = Term.dest_abs (x, T, u);
@@ -517,7 +520,7 @@
     let
       val T = fastype_of rhs;
     in
-      Const (@{const_name case_guard}, T --> T) $
+      Const (@{const_name case_guard}, @{typ bool} --> T --> T) $ @{term True} $
         (encode_cases recur (fastype_of pat) (fastype_of rhs) ps $ t)
     end
   | encode_case _ _ = case_error "encode_case";