--- a/src/HOL/Tools/Datatype/datatype_case.ML Fri Jan 06 20:39:50 2012 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML Fri Jan 06 20:48:52 2012 +0100
@@ -59,17 +59,17 @@
fun add_row_used ((prfx, pats), (tm, tag)) =
fold Term.add_free_names (tm :: pats @ map Free prfx);
-(*try to preserve names given by user*)
-fun default_names names ts =
- map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
+fun default_name name (t, cs) =
+ let
+ val name' = if name = "" then (case t of Free (name', _) => name' | _ => name) else name;
+ val cs' = if is_Free t then cs else filter_out Term_Position.is_position cs;
+ in (name, cs') end;
fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
strip_constraints t ||> cons tT
| strip_constraints t = (t, []);
-val recover_constraints =
- fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT);
-
+fun constrain tT t = Syntax.const @{syntax_const "_constrain"} $ t $ tT;
fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
@@ -90,7 +90,7 @@
fun strip_comb_positions tm =
let
fun result t ts = (Term_Position.strip_positions t, ts);
- fun strip (t as (Const (@{syntax_const "_constrain"}, _) $ _ $ _)) ts = result t ts
+ fun strip (t as Const (@{syntax_const "_constrain"}, _) $ _ $ _) ts = result t ts
| strip (f $ t) ts = strip f (t :: ts)
| strip t ts = result t ts;
in strip tm [] end;
@@ -105,11 +105,15 @@
(Const (name', _), args) =>
if name = name' then
if length args = k then
- let val (args', cnstrts') = split_list (map strip_constraints args) in
+ let
+ val constraints' = map strip_constraints args;
+ val (args', cnstrts') = split_list constraints';
+ val (names', cnstrts'') = split_list (map2 default_name names constraints');
+ in
((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
- (default_names names args', map2 append cnstrts cnstrts'))
+ (names', map2 append cnstrts cnstrts''))
end
- else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ name, i)
+ else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
else ((in_group, row :: not_in_group), (names, cnstrts))
| _ => raise CASE_ERROR ("Not a constructor pattern", i)))
rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
@@ -142,7 +146,7 @@
[((prfx, gvars @ map Free (xs ~~ Ts)),
(Const (@{const_syntax undefined}, res_ty), ~1))]
end
- else in_group
+ else in_group;
in
{constructor = c',
new_formals = gvars,
@@ -162,14 +166,14 @@
let
val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
- fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1)
+ fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
| expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
if is_Free p then
let
val used' = add_row_used row used;
fun expnd c =
let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
- in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end
+ in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
in map expnd constructors end
else [row];
@@ -189,7 +193,7 @@
in mk us rows' end
| SOME (Const (cname, cT), i) =>
(case Option.map ty_info (get_info (cname, cT)) of
- NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
+ NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
| SOME {case_name, constructors} =>
let
val pty = body_type cT;
@@ -312,7 +316,7 @@
in
make_case_untyped ctxt
(if err then Error else Warning) []
- (recover_constraints (filter_out Term_Position.is_position (flat cnstrts)) t)
+ (fold constrain (filter_out Term_Position.is_position (flat cnstrts)) t)
cases
end
| case_tr _ _ _ = case_error "case_tr";