src/HOL/Tools/case_translation.ML
changeset 51674 2b1498a2ce85
parent 51673 4dfa00e264d8
child 51675 18bbc78888aa
--- a/src/HOL/Tools/case_translation.ML	Tue Jan 22 14:33:45 2013 +0100
+++ b/src/HOL/Tools/case_translation.ML	Fri Apr 05 22:08:42 2013 +0200
@@ -123,11 +123,11 @@
         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
           | dest_case2 t = [t];
       in
-        fold_rev
+        Syntax.const @{const_syntax case_guard} $ (fold_rev
           (fn t => fn u =>
              Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
           (dest_case2 u)
-          (Syntax.const @{const_syntax case_nil}) $ t
+          (Syntax.const @{const_syntax case_nil}) $ t)
       end
   | case_tr _ _ = case_error "case_tr";
 
@@ -139,8 +139,9 @@
 
 (* print translation *)
 
-fun case_tr' [t, u, x] =
+fun case_tr' [tx] =
       let
+        val (t, x) = Term.dest_comb tx;
         fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
               let val (s', used') = Name.variant s used
               in mk_clause t ((s', T) :: xs) used' end
@@ -151,18 +152,16 @@
 
         fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
           | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
-              mk_clauses' t u
-        and mk_clauses' t u =
               mk_clause t [] (Term.declare_term_frees t Name.context) ::
               mk_clauses u
       in
         Syntax.const @{syntax_const "_case_syntax"} $ x $
           foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
-            (mk_clauses' t u)
+            (mk_clauses t)
       end;
 
 val trfun_setup' = Sign.add_trfuns
-  ([], [], [(@{const_syntax "case_cons"}, case_tr')], []);
+  ([], [], [(@{const_syntax "case_guard"}, case_tr')], []);
 
 
 (* declarations *)
@@ -413,7 +412,7 @@
 
 fun check_case ctxt =
   let
-    fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) =
+    fun decode_case (Const (@{const_name case_guard}, _) $ (t $ u)) =
         make_case ctxt Error Name.context (decode_case u) (decode_cases t)
     | decode_case (t $ u) = decode_case t $ decode_case u
     | decode_case (Abs (x, T, u)) =
@@ -517,7 +516,12 @@
         encode_clause S T p $ encode_cases S T ps;
 
 fun encode_case (t, ps as (pat, rhs) :: _) =
-      encode_cases (fastype_of pat) (fastype_of rhs) ps $ t
+    let
+      val T = fastype_of rhs;
+    in
+      Const (@{const_name case_guard}, T --> T) $
+        (encode_cases (fastype_of pat) (fastype_of rhs) ps $ t)
+    end
   | encode_case _ = case_error "encode_case";
 
 fun strip_case' ctxt d (pat, rhs) =