src/HOL/Tools/Datatype/datatype_case.ML
changeset 42049 e945717b2b15
parent 35845 e5980f0ad025
child 42050 5a505dfec04e
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Tue Mar 22 13:32:20 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Tue Mar 22 13:55:39 2011 +0100
@@ -37,7 +37,7 @@
  *---------------------------------------------------------------------------*)
 
 fun ty_info tab sT =
-  case tab sT of
+  (case tab sT of
     SOME ({descr, case_name, index, sorts, ...} : info) =>
       let
         val (_, (tname, dts, constrs)) = nth descr index;
@@ -48,7 +48,7 @@
           constructors = map (fn (cname, dts') =>
             Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
       end
-  | NONE => NONE;
+  | NONE => NONE);
 
 
 (*---------------------------------------------------------------------------
@@ -93,8 +93,7 @@
       raise CASE_ERROR ("type mismatch", ~1)
     val c' = ty_inst ty_theta c
     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
-  in (c', gvars)
-  end;
+  in (c', gvars) end;
 
 
 (*---------------------------------------------------------------------------
@@ -102,21 +101,22 @@
  * pattern with constructor = name.
  *---------------------------------------------------------------------------*)
 fun mk_group (name, T) rows =
-  let val k = length (binder_types T)
-  in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
-    fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
-        (Const (name', _), args) =>
-          if name = name' then
-            if length args = k then
-              let val (args', cnstrts') = split_list (map strip_constraints args)
-              in
-                ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
-                 (default_names names args', map2 append cnstrts cnstrts'))
-              end
-            else raise CASE_ERROR
-              ("Wrong number of arguments for constructor " ^ name, i)
-          else ((in_group, row :: not_in_group), (names, cnstrts))
-      | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
+  let val k = length (binder_types T) in
+    fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
+      fn ((in_group, not_in_group), (names, cnstrts)) =>
+        (case strip_comb p of
+          (Const (name', _), args) =>
+            if name = name' then
+              if length args = k then
+                let val (args', cnstrts') = split_list (map strip_constraints args)
+                in
+                  ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
+                   (default_names names args', map2 append cnstrts cnstrts'))
+                end
+              else raise CASE_ERROR
+                ("Wrong number of arguments for constructor " ^ name, i)
+            else ((in_group, row :: not_in_group), (names, cnstrts))
+        | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   end;
 
@@ -158,8 +158,7 @@
                     constraints = cnstrts,
                     group = in_group'} :: A}
               end
-      in part {constrs = constructors, rows = rows, A = []}
-      end;
+      in part {constrs = constructors, rows = rows, A = []} end;
 
 (*---------------------------------------------------------------------------
  * Misc. routines used in mk_case
@@ -210,48 +209,46 @@
       | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
           mk {path = path, rows = [row]}
       | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
-          let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
-          in case Option.map (apfst head_of)
-            (find_first (not o is_Free o fst) col0) of
+          let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows in
+            (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
               NONE =>
                 let
                   val rows' = map (fn ((v, _), row) => row ||>
                     pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
                   val (pref_patl, tm) = mk {path = rstp, rows = rows'}
                 in (map v_to_pats pref_patl, tm) end
-            | SOME (Const (cname, cT), i) => (case ty_info tab (cname, cT) of
-                NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
-              | SOME {case_name, constructors} =>
-                let
-                  val pty = body_type cT;
-                  val used' = fold Term.add_free_names rstp used;
-                  val nrows = maps (expand constructors used' pty) rows;
-                  val subproblems = partition ty_match ty_inst type_of used'
-                    constructors pty range_ty nrows;
-                  val constructors' = map #constructor subproblems
-                  val news = map (fn {new_formals, group, ...} =>
-                    {path = new_formals @ rstp, rows = group}) subproblems;
-                  val (pat_rect, dtrees) = split_list (map mk news);
-                  val case_functions = map2
-                    (fn {new_formals, names, constraints, ...} =>
-                       fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
-                         Abs (if s = "" then name else s, T,
-                           abstract_over (x, t)) |>
-                         fold mk_fun_constrain cnstrts)
-                           (new_formals ~~ names ~~ constraints))
-                    subproblems dtrees;
-                  val types = map type_of (case_functions @ [u]);
-                  val case_const = Const (case_name, types ---> range_ty)
-                  val tree = list_comb (case_const, case_functions @ [u])
-                  val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect)
-                in (pat_rect1, tree)
-                end)
+            | SOME (Const (cname, cT), i) =>
+                (case ty_info tab (cname, cT) of
+                  NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
+                | SOME {case_name, constructors} =>
+                  let
+                    val pty = body_type cT;
+                    val used' = fold Term.add_free_names rstp used;
+                    val nrows = maps (expand constructors used' pty) rows;
+                    val subproblems = partition ty_match ty_inst type_of used'
+                      constructors pty range_ty nrows;
+                    val constructors' = map #constructor subproblems
+                    val news = map (fn {new_formals, group, ...} =>
+                      {path = new_formals @ rstp, rows = group}) subproblems;
+                    val (pat_rect, dtrees) = split_list (map mk news);
+                    val case_functions = map2
+                      (fn {new_formals, names, constraints, ...} =>
+                         fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
+                           Abs (if s = "" then name else s, T,
+                             abstract_over (x, t)) |>
+                           fold mk_fun_constrain cnstrts)
+                             (new_formals ~~ names ~~ constraints))
+                      subproblems dtrees;
+                    val types = map type_of (case_functions @ [u]);
+                    val case_const = Const (case_name, types ---> range_ty)
+                    val tree = list_comb (case_const, case_functions @ [u])
+                    val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect)
+                  in (pat_rect1, tree) end)
             | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
-                Syntax.string_of_term ctxt t, i)
+                Syntax.string_of_term ctxt t, i))
           end
       | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
-  in mk
-  end;
+  in mk end;
 
 fun case_error s = error ("Error in case expression:\n" ^ s);
 
@@ -289,12 +286,12 @@
     val finals = map row_of_pat patts2
     val originals = map (row_of_pat o #2) rows
     val _ =
-        case subtract (op =) finals originals of
-          [] => ()
-        | is =>
-            (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
-              ("The following clauses are redundant (covered by preceding clauses):\n" ^
-               cat_lines (map (string_of_clause o nth clauses) is));
+      (case subtract (op =) finals originals of
+        [] => ()
+      | is =>
+          (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
+            ("The following clauses are redundant (covered by preceding clauses):\n" ^
+             cat_lines (map (string_of_clause o nth clauses) is)));
   in
     (case_tm, patts2)
   end;
@@ -308,48 +305,46 @@
 (* parse translation *)
 
 fun case_tr err tab_of ctxt [t, u] =
-    let
-      val thy = ProofContext.theory_of ctxt;
-      val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy);
+      let
+        val thy = ProofContext.theory_of ctxt;
+        val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy);
 
-      (* replace occurrences of dummy_pattern by distinct variables *)
-      (* internalize constant names                                 *)
-      fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
-            let val (t', used') = prep_pat t used
-            in (c $ t' $ tT, used') end
-        | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
-            let val x = Name.variant used "x"
-            in (Free (x, T), x :: used) end
-        | prep_pat (Const (s, T)) used =
-            (Const (intern_const_syntax s, T), used)
-        | prep_pat (v as Free (s, T)) used =
-            let val s' = Sign.intern_const thy s
-            in
-              if Sign.declared_const thy s' then
-                (Const (s', T), used)
-              else (v, used)
-            end
-        | prep_pat (t $ u) used =
-            let
-              val (t', used') = prep_pat t used;
-              val (u', used'') = prep_pat u used'
-            in
-              (t' $ u', used'')
-            end
-        | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
-      fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
-            let val (l', cnstrts) = strip_constraints l
-            in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts)
-            end
-        | dest_case1 t = case_error "dest_case1";
-      fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
-        | dest_case2 t = [t];
-      val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
-      val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
-        (if err then Error else Warning) []
-        (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
-           (flat cnstrts) t) cases;
-    in case_tm end
+        (* replace occurrences of dummy_pattern by distinct variables *)
+        (* internalize constant names                                 *)
+        fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
+              let val (t', used') = prep_pat t used
+              in (c $ t' $ tT, used') end
+          | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
+              let val x = Name.variant used "x"
+              in (Free (x, T), x :: used) end
+          | prep_pat (Const (s, T)) used =
+              (Const (intern_const_syntax s, T), used)
+          | prep_pat (v as Free (s, T)) used =
+              let val s' = Sign.intern_const thy s in
+                if Sign.declared_const thy s' then
+                  (Const (s', T), used)
+                else (v, used)
+              end
+          | prep_pat (t $ u) used =
+              let
+                val (t', used') = prep_pat t used;
+                val (u', used'') = prep_pat u used'
+              in
+                (t' $ u', used'')
+              end
+          | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
+        fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
+              let val (l', cnstrts) = strip_constraints l
+              in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
+          | dest_case1 t = case_error "dest_case1";
+        fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
+          | dest_case2 t = [t];
+        val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
+        val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
+          (if err then Error else Warning) []
+          (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
+             (flat cnstrts) t) cases;
+      in case_tm end
   | case_tr _ _ _ ts = case_error "case_tr";
 
 
@@ -360,7 +355,7 @@
 (* destruct one level of pattern matching *)
 
 fun gen_dest_case name_of type_of tab d used t =
-  case apfst name_of (strip_comb t) of
+  (case apfst name_of (strip_comb t) of
     (SOME cname, ts as _ :: _) =>
       let
         val (fs, x) = split_last ts;
@@ -375,15 +370,14 @@
           in (xs', subst_bounds (rev xs', u)) end;
         fun is_dependent i t =
           let val k = length (strip_abs_vars t) - i
-          in k < 0 orelse exists (fn j => j >= k)
-            (loose_bnos (strip_abs_body t))
-          end;
+          in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
         fun count_cases (_, _, true) = I
           | count_cases (c, (_, body), false) =
               AList.map_default op aconv (body, []) (cons c);
         val is_undefined = name_of #> equal (SOME @{const_name undefined});
         fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
-      in case ty_info tab cname of
+      in
+        (case ty_info tab cname of
           SOME {constructors, case_name} =>
             if length fs = length constructors then
               let
@@ -400,26 +394,28 @@
                 val R = type_of t;
                 val dummy =
                   if d then Const (@{const_name dummy_pattern}, R)
-                  else Free (Name.variant used "x", R)
+                  else Free (Name.variant used "x", R);
               in
-                SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
-                  SOME (_, cs) =>
-                  if length cs = length constructors then [hd cases]
-                  else filter_out (fn (_, (_, body), _) => is_undefined body) cases
-                | NONE => case cases' of
-                  [] => cases
-                | (default, cs) :: _ =>
-                  if length cs = 1 then cases
-                  else if length cs = length constructors then
-                    [hd cases, (dummy, ([], default), false)]
-                  else
-                    filter_out (fn (c, _, _) => member op aconv cs c) cases @
-                    [(dummy, ([], default), false)]))
+                SOME (x,
+                  map mk_case
+                    (case find_first (is_undefined o fst) cases' of
+                      SOME (_, cs) =>
+                      if length cs = length constructors then [hd cases]
+                      else filter_out (fn (_, (_, body), _) => is_undefined body) cases
+                    | NONE => case cases' of
+                      [] => cases
+                    | (default, cs) :: _ =>
+                      if length cs = 1 then cases
+                      else if length cs = length constructors then
+                        [hd cases, (dummy, ([], default), false)]
+                      else
+                        filter_out (fn (c, _, _) => member op aconv cs c) cases @
+                        [(dummy, ([], default), false)]))
               end handle CASE_ERROR _ => NONE
             else NONE
-        | _ => NONE
+        | _ => NONE)
       end
-  | _ => NONE;
+  | _ => NONE);
 
 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
 val dest_case' = gen_dest_case (try (dest_Const #> fst #> Syntax.unmark_const)) (K dummyT);
@@ -428,7 +424,7 @@
 (* destruct nested patterns *)
 
 fun strip_case'' dest (pat, rhs) =
-  case dest (Term.add_free_names pat []) rhs of
+  (case dest (Term.add_free_names pat []) rhs of
     SOME (exp as Free _, clauses) =>
       if member op aconv (OldTerm.term_frees pat) exp andalso
         not (exists (fn (_, rhs') =>
@@ -437,13 +433,13 @@
         maps (strip_case'' dest) (map (fn (pat', rhs') =>
           (subst_free [(exp, pat')] pat, rhs')) clauses)
       else [(pat, rhs)]
-  | _ => [(pat, rhs)];
+  | _ => [(pat, rhs)]);
 
 fun gen_strip_case dest t =
-  case dest [] t of
+  (case dest [] t of
     SOME (x, clauses) =>
       SOME (x, maps (strip_case'' dest) clauses)
-  | NONE => NONE;
+  | NONE => NONE);
 
 val strip_case = gen_strip_case oo dest_case;
 val strip_case' = gen_strip_case oo dest_case';
@@ -455,8 +451,7 @@
   let
     val thy = ProofContext.theory_of ctxt;
     fun mk_clause (pat, rhs) =
-      let val xs = Term.add_frees pat []
-      in
+      let val xs = Term.add_frees pat [] in
         Syntax.const @{syntax_const "_case1"} $
           map_aterms
             (fn Free p => Syntax.mark_boundT p
@@ -466,14 +461,14 @@
             (fn x as Free (s, T) =>
                   if member (op =) xs (s, T) then Syntax.mark_bound s else x
               | t => t) rhs
-      end
+      end;
   in
-    case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
+    (case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
       SOME (x, clauses) =>
         Syntax.const @{syntax_const "_case_syntax"} $ x $
           foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
             (map mk_clause clauses)
-    | NONE => raise Match
+    | NONE => raise Match)
   end;
 
 end;