--- a/src/HOL/Tools/datatype_package.ML Fri May 21 21:47:07 2004 +0200
+++ b/src/HOL/Tools/datatype_package.ML Fri May 21 21:48:03 2004 +0200
@@ -393,6 +393,116 @@
fn thy => (simpset_ref_of thy := simpset_of thy addsimprocs [distinct_simproc]; thy)];
+(**** translation rules for case ****)
+
+fun case_tr sg [t, u] =
+ let
+ fun case_error s name ts = raise TERM ("Error in case expression" ^
+ if_none (apsome (curry op ^ " for datatype ") name) "" ^ ":\n" ^ s, ts);
+ fun dest_case1 (Const ("_case1", _) $ t $ u) = (case strip_comb t of
+ (Const (s, _), ts) => (Sign.intern_const sg s, ts)
+ | (Free (s, _), ts) => (Sign.intern_const sg s, ts)
+ | _ => case_error "Head is not a constructor" None [t, u], u)
+ | dest_case1 t = raise TERM ("dest_case1", [t]);
+ fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
+ | dest_case2 t = [t];
+ val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
+ val tab = Symtab.dest (get_datatypes_sg sg);
+ val (cases', default) = (case split_last cases of
+ (cases', (("dummy_pattern", []), t)) => (cases', Some t)
+ | _ => (cases, None))
+ fun abstr (Free (x, T), body) = Term.absfree (x, T, body)
+ | abstr (Const ("_constrain", _) $ Free (x, T) $ tT, body) =
+ Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
+ | abstr (Const ("Pair", _) $ x $ y, body) =
+ Syntax.const "split" $ abstr (x, abstr (y, body))
+ | abstr (t, _) = case_error "Illegal pattern" None [t];
+ in case find_first (fn (_, {descr, index, ...}) =>
+ exists (equal cname o fst) (#3 (snd (nth_elem (index, descr))))) tab of
+ None => case_error ("Not a datatype constructor: " ^ cname) None [u]
+ | Some (tname, {descr, case_name, index, ...}) =>
+ let
+ val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
+ case_error "Illegal occurrence of '_' dummy pattern" (Some tname) [u] else ();
+ val (_, (_, dts, constrs)) = nth_elem (index, descr);
+ val sorts = map (rpair [] o dest_DtTFree) dts;
+ fun find_case (cases, (s, dt)) =
+ (case find_first (equal s o fst o fst) cases' of
+ None => (case default of
+ None => case_error ("No clause for constructor " ^ s) (Some tname) [u]
+ | Some t => (cases, list_abs (map (rpair dummyT) (DatatypeProp.make_tnames
+ (map (typ_of_dtyp descr sorts) dt)), t)))
+ | Some (c as ((_, vs), t)) =>
+ if length dt <> length vs then
+ case_error ("Wrong number of arguments for constructor " ^ s)
+ (Some tname) vs
+ else (cases \ c, foldr abstr (vs, t)))
+ val (cases'', fs) = foldl_map find_case (cases', constrs)
+ in case (cases'', length constrs = length cases', default) of
+ ([], true, Some _) =>
+ case_error "Extra '_' dummy pattern" (Some tname) [u]
+ | (_ :: _, _, _) =>
+ let val extra = distinct (map (fst o fst) cases'')
+ in case extra \\ map fst constrs of
+ [] => case_error ("More than one clause for constructor(s) " ^
+ commas extra) (Some tname) [u]
+ | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
+ (Some tname) [u]
+ end
+ | _ => list_comb (Syntax.const case_name, fs) $ t
+ end
+ end
+ | case_tr sg ts = raise TERM ("case_tr", ts);
+
+fun case_tr' constrs sg ts =
+ if length ts <> length constrs + 1 then raise Match else
+ let
+ val (fs, x) = split_last ts;
+ fun strip_abs 0 t = ([], t)
+ | strip_abs i (Abs p) =
+ let val (x, u) = Syntax.atomic_abs_tr' p
+ in apfst (cons x) (strip_abs (i-1) u) end
+ | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
+ (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
+ fun is_dependent i t =
+ let val k = length (strip_abs_vars t) - i
+ in k < 0 orelse exists (fn j => j >= k)
+ (loose_bnos (strip_abs_body t))
+ end;
+ val cases = map (fn ((cname, dts), t) =>
+ (Sign.cond_extern sg Sign.constK cname,
+ strip_abs (length dts) t, is_dependent (length dts) t))
+ (constrs ~~ fs);
+ fun count_cases (cs, (_, _, true)) = cs
+ | count_cases (cs, (cname, (_, body), false)) = (case assoc (cs, body) of
+ None => (body, [cname]) :: cs
+ | Some cnames => overwrite (cs, (body, cnames @ [cname])));
+ val cases' = sort (int_ord o Library.swap o pairself (length o snd))
+ (foldl count_cases ([], cases));
+ fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
+ list_comb (Syntax.const cname, vs) $ body;
+ in
+ Syntax.const "_case_syntax" $ x $
+ foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
+ (case cases' of
+ [] => cases
+ | (default, cnames) :: _ =>
+ if length cnames = 1 then cases
+ else if length cnames = length constrs then
+ [hd cases, ("dummy_pattern", ([], default), false)]
+ else
+ filter_out (fn (cname, _, _) => cname mem cnames) cases @
+ [("dummy_pattern", ([], default), false)]))
+ end;
+
+fun make_case_tr' case_names descr = flat (map
+ (fn ((_, (_, _, constrs)), case_name) => map (rpair (case_tr' constrs))
+ (NameSpace.accesses' case_name)) (descr ~~ case_names));
+
+val trfun_setup =
+ [Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], [])];
+
+
(* prepare types *)
fun read_typ sign ((Ts, sorts), str) =
@@ -529,8 +639,7 @@
Theory.add_consts_i (map (fn ((name, T), Ts) =>
(name, Ts @ [T] ---> freeT, NoSyn))
- (case_names ~~ newTs ~~ case_fn_Ts)) |>
- Theory.add_trrules_i (DatatypeProp.make_case_trrules new_type_names descr);
+ (case_names ~~ newTs ~~ case_fn_Ts));
val reccomb_names' = map (Sign.intern_const (Theory.sign_of thy2')) reccomb_names;
val case_names' = map (Sign.intern_const (Theory.sign_of thy2')) case_names;
@@ -600,6 +709,7 @@
val split_thms = split ~~ split_asm;
val thy12 = thy11 |>
+ Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), []) |>
Theory.add_path (space_implode "_" new_type_names) |>
add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs Simplifier.cong_add_global |>
@@ -657,6 +767,7 @@
val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
val thy12 = thy11 |>
+ Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), []) |>
Theory.add_path (space_implode "_" new_type_names) |>
add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs (Simplifier.change_global_ss (op addcongs)) |>
@@ -765,6 +876,7 @@
val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
val thy11 = thy10 |>
+ Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, []) |>
add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs (Simplifier.change_global_ss (op addcongs)) |>
put_datatypes (foldr Symtab.update (dt_infos, dt_info)) |>
@@ -870,7 +982,7 @@
(* setup theory *)
-val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup;
+val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup @ trfun_setup;
(* outer syntax *)
--- a/src/HOL/Tools/datatype_prop.ML Fri May 21 21:47:07 2004 +0200
+++ b/src/HOL/Tools/datatype_prop.ML Fri May 21 21:48:03 2004 +0200
@@ -22,8 +22,6 @@
(string * sort) list -> theory -> term list list
val make_splits : string list -> DatatypeAux.descr list ->
(string * sort) list -> theory -> (term * term) list
- val make_case_trrules : string list -> DatatypeAux.descr list ->
- ast Syntax.trrule list
val make_size : DatatypeAux.descr list -> (string * sort) list ->
theory -> term list
val make_weak_case_congs : string list -> DatatypeAux.descr list ->
@@ -344,39 +342,6 @@
(make_case_combs new_type_names descr sorts thy "f"))
end;
-(************************ translation rules for case **************************)
-
-fun make_case_trrules new_type_names descr =
- let
- fun mk_asts i j ((cname, cargs)::constrs) =
- let
- val k = length cargs;
- val xs = map (fn i => Variable ("x" ^ string_of_int i)) (i upto i + k - 1);
- val t = Variable ("t" ^ string_of_int j);
- val ast = Syntax.mk_appl (Constant "_case1")
- [Syntax.mk_appl (Constant (Sign.base_name cname)) xs, t];
- val ast' = foldr (fn (x, y) =>
- Syntax.mk_appl (Constant "_abs") [x, y]) (xs, t)
- in
- (case constrs of
- [] => (ast, [ast'])
- | cs => let val (ast'', asts) = mk_asts (i + k) (j + 1) cs
- in (Syntax.mk_appl (Constant "_case2") [ast, ast''],
- ast'::asts)
- end)
- end;
-
- fun mk_trrule ((_, (_, _, constrs)), tname) =
- let val (ast, asts) = mk_asts 1 1 constrs
- in Syntax.ParsePrintRule
- (Syntax.mk_appl (Constant "_case_syntax") [Variable "t", ast],
- Syntax.mk_appl (Constant (tname ^ "_case"))
- (asts @ [Variable "t"]))
- end
-
- in
- map mk_trrule (hd descr ~~ new_type_names)
- end;
(******************************* size functions *******************************)