--- a/src/HOL/Tools/Datatype/datatype_case.ML Tue Mar 22 13:32:20 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML Tue Mar 22 13:55:39 2011 +0100
@@ -37,7 +37,7 @@
*---------------------------------------------------------------------------*)
fun ty_info tab sT =
- case tab sT of
+ (case tab sT of
SOME ({descr, case_name, index, sorts, ...} : info) =>
let
val (_, (tname, dts, constrs)) = nth descr index;
@@ -48,7 +48,7 @@
constructors = map (fn (cname, dts') =>
Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
end
- | NONE => NONE;
+ | NONE => NONE);
(*---------------------------------------------------------------------------
@@ -93,8 +93,7 @@
raise CASE_ERROR ("type mismatch", ~1)
val c' = ty_inst ty_theta c
val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
- in (c', gvars)
- end;
+ in (c', gvars) end;
(*---------------------------------------------------------------------------
@@ -102,21 +101,22 @@
* pattern with constructor = name.
*---------------------------------------------------------------------------*)
fun mk_group (name, T) rows =
- let val k = length (binder_types T)
- in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
- fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
- (Const (name', _), args) =>
- if name = name' then
- if length args = k then
- let val (args', cnstrts') = split_list (map strip_constraints args)
- in
- ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
- (default_names names args', map2 append cnstrts cnstrts'))
- end
- else raise CASE_ERROR
- ("Wrong number of arguments for constructor " ^ name, i)
- else ((in_group, row :: not_in_group), (names, cnstrts))
- | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
+ let val k = length (binder_types T) in
+ fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
+ fn ((in_group, not_in_group), (names, cnstrts)) =>
+ (case strip_comb p of
+ (Const (name', _), args) =>
+ if name = name' then
+ if length args = k then
+ let val (args', cnstrts') = split_list (map strip_constraints args)
+ in
+ ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
+ (default_names names args', map2 append cnstrts cnstrts'))
+ end
+ else raise CASE_ERROR
+ ("Wrong number of arguments for constructor " ^ name, i)
+ else ((in_group, row :: not_in_group), (names, cnstrts))
+ | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
end;
@@ -158,8 +158,7 @@
constraints = cnstrts,
group = in_group'} :: A}
end
- in part {constrs = constructors, rows = rows, A = []}
- end;
+ in part {constrs = constructors, rows = rows, A = []} end;
(*---------------------------------------------------------------------------
* Misc. routines used in mk_case
@@ -210,48 +209,46 @@
| mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
mk {path = path, rows = [row]}
| mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
- let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
- in case Option.map (apfst head_of)
- (find_first (not o is_Free o fst) col0) of
+ let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows in
+ (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
NONE =>
let
val rows' = map (fn ((v, _), row) => row ||>
pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
val (pref_patl, tm) = mk {path = rstp, rows = rows'}
in (map v_to_pats pref_patl, tm) end
- | SOME (Const (cname, cT), i) => (case ty_info tab (cname, cT) of
- NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
- | SOME {case_name, constructors} =>
- let
- val pty = body_type cT;
- val used' = fold Term.add_free_names rstp used;
- val nrows = maps (expand constructors used' pty) rows;
- val subproblems = partition ty_match ty_inst type_of used'
- constructors pty range_ty nrows;
- val constructors' = map #constructor subproblems
- val news = map (fn {new_formals, group, ...} =>
- {path = new_formals @ rstp, rows = group}) subproblems;
- val (pat_rect, dtrees) = split_list (map mk news);
- val case_functions = map2
- (fn {new_formals, names, constraints, ...} =>
- fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
- Abs (if s = "" then name else s, T,
- abstract_over (x, t)) |>
- fold mk_fun_constrain cnstrts)
- (new_formals ~~ names ~~ constraints))
- subproblems dtrees;
- val types = map type_of (case_functions @ [u]);
- val case_const = Const (case_name, types ---> range_ty)
- val tree = list_comb (case_const, case_functions @ [u])
- val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect)
- in (pat_rect1, tree)
- end)
+ | SOME (Const (cname, cT), i) =>
+ (case ty_info tab (cname, cT) of
+ NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
+ | SOME {case_name, constructors} =>
+ let
+ val pty = body_type cT;
+ val used' = fold Term.add_free_names rstp used;
+ val nrows = maps (expand constructors used' pty) rows;
+ val subproblems = partition ty_match ty_inst type_of used'
+ constructors pty range_ty nrows;
+ val constructors' = map #constructor subproblems
+ val news = map (fn {new_formals, group, ...} =>
+ {path = new_formals @ rstp, rows = group}) subproblems;
+ val (pat_rect, dtrees) = split_list (map mk news);
+ val case_functions = map2
+ (fn {new_formals, names, constraints, ...} =>
+ fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
+ Abs (if s = "" then name else s, T,
+ abstract_over (x, t)) |>
+ fold mk_fun_constrain cnstrts)
+ (new_formals ~~ names ~~ constraints))
+ subproblems dtrees;
+ val types = map type_of (case_functions @ [u]);
+ val case_const = Const (case_name, types ---> range_ty)
+ val tree = list_comb (case_const, case_functions @ [u])
+ val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect)
+ in (pat_rect1, tree) end)
| SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
- Syntax.string_of_term ctxt t, i)
+ Syntax.string_of_term ctxt t, i))
end
| mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
- in mk
- end;
+ in mk end;
fun case_error s = error ("Error in case expression:\n" ^ s);
@@ -289,12 +286,12 @@
val finals = map row_of_pat patts2
val originals = map (row_of_pat o #2) rows
val _ =
- case subtract (op =) finals originals of
- [] => ()
- | is =>
- (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
- ("The following clauses are redundant (covered by preceding clauses):\n" ^
- cat_lines (map (string_of_clause o nth clauses) is));
+ (case subtract (op =) finals originals of
+ [] => ()
+ | is =>
+ (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
+ ("The following clauses are redundant (covered by preceding clauses):\n" ^
+ cat_lines (map (string_of_clause o nth clauses) is)));
in
(case_tm, patts2)
end;
@@ -308,48 +305,46 @@
(* parse translation *)
fun case_tr err tab_of ctxt [t, u] =
- let
- val thy = ProofContext.theory_of ctxt;
- val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy);
+ let
+ val thy = ProofContext.theory_of ctxt;
+ val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy);
- (* replace occurrences of dummy_pattern by distinct variables *)
- (* internalize constant names *)
- fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
- let val (t', used') = prep_pat t used
- in (c $ t' $ tT, used') end
- | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
- let val x = Name.variant used "x"
- in (Free (x, T), x :: used) end
- | prep_pat (Const (s, T)) used =
- (Const (intern_const_syntax s, T), used)
- | prep_pat (v as Free (s, T)) used =
- let val s' = Sign.intern_const thy s
- in
- if Sign.declared_const thy s' then
- (Const (s', T), used)
- else (v, used)
- end
- | prep_pat (t $ u) used =
- let
- val (t', used') = prep_pat t used;
- val (u', used'') = prep_pat u used'
- in
- (t' $ u', used'')
- end
- | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
- fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
- let val (l', cnstrts) = strip_constraints l
- in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts)
- end
- | dest_case1 t = case_error "dest_case1";
- fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
- | dest_case2 t = [t];
- val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
- val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
- (if err then Error else Warning) []
- (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
- (flat cnstrts) t) cases;
- in case_tm end
+ (* replace occurrences of dummy_pattern by distinct variables *)
+ (* internalize constant names *)
+ fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
+ let val (t', used') = prep_pat t used
+ in (c $ t' $ tT, used') end
+ | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
+ let val x = Name.variant used "x"
+ in (Free (x, T), x :: used) end
+ | prep_pat (Const (s, T)) used =
+ (Const (intern_const_syntax s, T), used)
+ | prep_pat (v as Free (s, T)) used =
+ let val s' = Sign.intern_const thy s in
+ if Sign.declared_const thy s' then
+ (Const (s', T), used)
+ else (v, used)
+ end
+ | prep_pat (t $ u) used =
+ let
+ val (t', used') = prep_pat t used;
+ val (u', used'') = prep_pat u used'
+ in
+ (t' $ u', used'')
+ end
+ | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
+ fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
+ let val (l', cnstrts) = strip_constraints l
+ in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
+ | dest_case1 t = case_error "dest_case1";
+ fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
+ | dest_case2 t = [t];
+ val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
+ val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
+ (if err then Error else Warning) []
+ (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
+ (flat cnstrts) t) cases;
+ in case_tm end
| case_tr _ _ _ ts = case_error "case_tr";
@@ -360,7 +355,7 @@
(* destruct one level of pattern matching *)
fun gen_dest_case name_of type_of tab d used t =
- case apfst name_of (strip_comb t) of
+ (case apfst name_of (strip_comb t) of
(SOME cname, ts as _ :: _) =>
let
val (fs, x) = split_last ts;
@@ -375,15 +370,14 @@
in (xs', subst_bounds (rev xs', u)) end;
fun is_dependent i t =
let val k = length (strip_abs_vars t) - i
- in k < 0 orelse exists (fn j => j >= k)
- (loose_bnos (strip_abs_body t))
- end;
+ in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
fun count_cases (_, _, true) = I
| count_cases (c, (_, body), false) =
AList.map_default op aconv (body, []) (cons c);
val is_undefined = name_of #> equal (SOME @{const_name undefined});
fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
- in case ty_info tab cname of
+ in
+ (case ty_info tab cname of
SOME {constructors, case_name} =>
if length fs = length constructors then
let
@@ -400,26 +394,28 @@
val R = type_of t;
val dummy =
if d then Const (@{const_name dummy_pattern}, R)
- else Free (Name.variant used "x", R)
+ else Free (Name.variant used "x", R);
in
- SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
- SOME (_, cs) =>
- if length cs = length constructors then [hd cases]
- else filter_out (fn (_, (_, body), _) => is_undefined body) cases
- | NONE => case cases' of
- [] => cases
- | (default, cs) :: _ =>
- if length cs = 1 then cases
- else if length cs = length constructors then
- [hd cases, (dummy, ([], default), false)]
- else
- filter_out (fn (c, _, _) => member op aconv cs c) cases @
- [(dummy, ([], default), false)]))
+ SOME (x,
+ map mk_case
+ (case find_first (is_undefined o fst) cases' of
+ SOME (_, cs) =>
+ if length cs = length constructors then [hd cases]
+ else filter_out (fn (_, (_, body), _) => is_undefined body) cases
+ | NONE => case cases' of
+ [] => cases
+ | (default, cs) :: _ =>
+ if length cs = 1 then cases
+ else if length cs = length constructors then
+ [hd cases, (dummy, ([], default), false)]
+ else
+ filter_out (fn (c, _, _) => member op aconv cs c) cases @
+ [(dummy, ([], default), false)]))
end handle CASE_ERROR _ => NONE
else NONE
- | _ => NONE
+ | _ => NONE)
end
- | _ => NONE;
+ | _ => NONE);
val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
val dest_case' = gen_dest_case (try (dest_Const #> fst #> Syntax.unmark_const)) (K dummyT);
@@ -428,7 +424,7 @@
(* destruct nested patterns *)
fun strip_case'' dest (pat, rhs) =
- case dest (Term.add_free_names pat []) rhs of
+ (case dest (Term.add_free_names pat []) rhs of
SOME (exp as Free _, clauses) =>
if member op aconv (OldTerm.term_frees pat) exp andalso
not (exists (fn (_, rhs') =>
@@ -437,13 +433,13 @@
maps (strip_case'' dest) (map (fn (pat', rhs') =>
(subst_free [(exp, pat')] pat, rhs')) clauses)
else [(pat, rhs)]
- | _ => [(pat, rhs)];
+ | _ => [(pat, rhs)]);
fun gen_strip_case dest t =
- case dest [] t of
+ (case dest [] t of
SOME (x, clauses) =>
SOME (x, maps (strip_case'' dest) clauses)
- | NONE => NONE;
+ | NONE => NONE);
val strip_case = gen_strip_case oo dest_case;
val strip_case' = gen_strip_case oo dest_case';
@@ -455,8 +451,7 @@
let
val thy = ProofContext.theory_of ctxt;
fun mk_clause (pat, rhs) =
- let val xs = Term.add_frees pat []
- in
+ let val xs = Term.add_frees pat [] in
Syntax.const @{syntax_const "_case1"} $
map_aterms
(fn Free p => Syntax.mark_boundT p
@@ -466,14 +461,14 @@
(fn x as Free (s, T) =>
if member (op =) xs (s, T) then Syntax.mark_bound s else x
| t => t) rhs
- end
+ end;
in
- case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
+ (case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
SOME (x, clauses) =>
Syntax.const @{syntax_const "_case_syntax"} $ x $
foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
(map mk_clause clauses)
- | NONE => raise Match
+ | NONE => raise Match)
end;
end;