src/HOL/Tools/case_translation.ML
changeset 52154 b707a26d8fe0
parent 52143 36ffe23b25f8
child 52155 761c325a65d4
--- a/src/HOL/Tools/case_translation.ML	Sun May 26 14:02:03 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML	Sun May 26 18:37:43 2013 +0200
@@ -90,12 +90,14 @@
 val lookup_by_case = Symtab.lookup o cases_of;
 
 
+
 (** installation **)
 
 fun case_error s = error ("Error in case expression:\n" ^ s);
 
 val name_of = try (dest_Const #> fst);
 
+
 (* parse translation *)
 
 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
@@ -129,9 +131,9 @@
           | replace_dummies t used = (t, used);
 
         fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
-              let val (l', _) = replace_dummies l (Term.declare_term_frees t Name.context)
-              in abs_pat l' []
-                (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l' $ r)
+              let val (l', _) = replace_dummies l (Term.declare_term_frees t Name.context) in
+                abs_pat l' []
+                  (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l' $ r)
               end
           | dest_case1 _ = case_error "dest_case1";
 
@@ -140,11 +142,11 @@
 
         val errt = if err then @{term True} else @{term False};
       in
-        Syntax.const @{const_syntax case_guard} $ errt $ t $ (fold_rev
-          (fn t => fn u =>
-             Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
-          (dest_case2 u)
-          (Syntax.const @{const_syntax case_nil}))
+        Syntax.const @{const_syntax case_guard} $ errt $ t $
+          (fold_rev
+            (fn t => fn u => Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
+            (dest_case2 u)
+            (Syntax.const @{const_syntax case_nil}))
       end
   | case_tr _ _ _ = case_error "case_tr";
 
@@ -154,24 +156,24 @@
 (* print translation *)
 
 fun case_tr' (_ :: x :: t :: ts) =
-      let
-        fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
-              let val (s', used') = Name.variant s used
-              in mk_clause t ((s', T) :: xs) used' end
-          | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
-              Syntax.const @{syntax_const "_case1"} $
-                subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
-                subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
+  let
+    fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
+          let val (s', used') = Name.variant s used
+          in mk_clause t ((s', T) :: xs) used' end
+      | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
+          Syntax.const @{syntax_const "_case1"} $
+            subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
+            subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
 
-        fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
-          | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
-              mk_clause t [] (Term.declare_term_frees t Name.context) ::
-              mk_clauses u
-      in
-        list_comb (Syntax.const @{syntax_const "_case_syntax"} $ x $
-          foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
-            (mk_clauses t), ts)
-      end;
+    fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
+      | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
+          mk_clause t [] (Term.declare_term_frees t Name.context) ::
+          mk_clauses u
+  in
+    list_comb (Syntax.const @{syntax_const "_case_syntax"} $ x $
+      foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
+        (mk_clauses t), ts)
+  end;
 
 val trfun_setup' = Sign.print_translation [(@{const_syntax "case_guard"}, K case_tr')];
 
@@ -204,13 +206,13 @@
 
 
 (*Each pattern carries with it a tag i, which denotes the clause it
-came from. i = ~1 indicates that the clause was added by pattern
-completion.*)
+  came from. i = ~1 indicates that the clause was added by pattern
+  completion.*)
 
 fun add_row_used ((prfx, pats), (tm, tag)) =
   fold Term.declare_term_frees (tm :: pats @ map Free prfx);
 
-(* try to preserve names given by user *)
+(*try to preserve names given by user*)
 fun default_name "" (Free (name', _)) = name'
   | default_name name _ = name;
 
@@ -220,8 +222,9 @@
   let
     val (_, T) = dest_Const c;
     val Ts = binder_types T;
-    val (names, _) = fold_map Name.variant
-      (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
+    val (names, _) =
+      fold_map Name.variant
+        (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
     val ty = body_type T;
     val ty_theta = Type.raw_match (ty, colty) Vartab.empty
       handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
@@ -251,8 +254,7 @@
 (* Partitioning *)
 
 fun partition _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
-  | partition used constructors colty res_ty
-        (rows as (((prfx, _ :: ps), _) :: _)) =
+  | partition used constructors colty res_ty (rows as (((prfx, _ :: ps), _) :: _)) =
       let
         fun part [] [] = []
           | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
@@ -311,8 +313,7 @@
       | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
       | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
-            (case Option.map (apfst head_of)
-                (find_first (not o is_Free o fst) col0) of
+            (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
               NONE =>
                 let
                   val rows' = map (fn ((v, _), row) => row ||>
@@ -326,8 +327,7 @@
                       val pty = body_type cT;
                       val used' = fold Term.declare_term_frees us used;
                       val nrows = maps (expand constructors used' pty) rows;
-                      val subproblems =
-                        partition used' constructors pty range_ty nrows;
+                      val subproblems = partition used' constructors pty range_ty nrows;
                       val (pat_rect, dtrees) =
                         split_list (map (fn {new_formals, group, ...} =>
                           mk (new_formals @ us) group) subproblems);
@@ -509,13 +509,13 @@
         encode_clause recur S T p $ encode_cases recur S T ps;
 
 fun encode_case recur (t, ps as (pat, rhs) :: _) =
-    let
-      val tT = fastype_of t;
-      val T = fastype_of rhs;
-    in
-      Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
-        @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
-    end
+      let
+        val tT = fastype_of t;
+        val T = fastype_of rhs;
+      in
+        Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
+          @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
+      end
   | encode_case _ _ = case_error "encode_case";
 
 fun strip_case' ctxt d (pat, rhs) =
@@ -542,8 +542,8 @@
         (strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses)
   | NONE =>
       (case t of
-        (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u
-      | (Abs (x, T, u)) =>
+        t $ u => strip_case_full ctxt d t $ strip_case_full ctxt d u
+      | Abs (x, T, u) =>
           let val (x', u') = Term.dest_abs (x, T, u);
           in Term.absfree (x', T) (strip_case_full ctxt d u') end
       | _ => t));