refined case syntax again, improved treatment of constructors without arguments, e.g. "case a of (True, x) => x";
authorwenzelm
Fri, 06 Jan 2012 20:18:49 +0100
changeset 46140 463b594e186a
parent 46139 df2aad3f0ecf
child 46144 cc374091999b
refined case syntax again, improved treatment of constructors without arguments, e.g. "case a of (True, x) => x";
src/HOL/Tools/Datatype/datatype_case.ML
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Fri Jan 06 19:24:23 2012 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Fri Jan 06 20:18:49 2012 +0100
@@ -59,9 +59,11 @@
 fun add_row_used ((prfx, pats), (tm, tag)) =
   fold Term.add_free_names (tm :: pats @ map Free prfx);
 
-(*try to preserve names given by user*)
-fun default_names names ts =
-  map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
+fun default_name name (t, cs) =
+  let
+    val name' = if name = "" then (case t of Free (name', _) => name' | _ => name) else name;
+    val cs' = if is_Free t then cs else filter_out Term_Position.is_position cs;
+  in (name, cs') end;
 
 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
       strip_constraints t ||> cons tT
@@ -103,9 +105,13 @@
           (Const (name', _), args) =>
             if name = name' then
               if length args = k then
-                let val (args', cnstrts') = split_list (map strip_constraints args) in
+                let
+                  val constraints' = map strip_constraints args;
+                  val (args', cnstrts') = split_list constraints';
+                  val (names', cnstrts'') = split_list (map2 default_name names constraints');
+                in
                   ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
-                   (default_names names args', map2 append cnstrts cnstrts'))
+                   (names', map2 append cnstrts cnstrts''))
                 end
               else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
             else ((in_group, row :: not_in_group), (names, cnstrts))
@@ -167,7 +173,7 @@
               val used' = add_row_used row used;
               fun expnd c =
                 let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
-                in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end
+                in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
             in map expnd constructors end
           else [row];