--- a/src/HOL/Tools/datatype_codegen.ML Thu Jan 10 19:18:14 2008 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML Thu Jan 10 19:21:56 2008 +0100
@@ -19,11 +19,6 @@
open Codegen;
-fun mk_tuple [p] = p
- | mk_tuple ps = Pretty.block (Pretty.str "(" ::
- List.concat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
- [Pretty.str ")"]);
-
(**** datatype definition ****)
(* find shortest path to constructor with no recursive arguments *)
@@ -43,7 +38,7 @@
(max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
in case xs of [] => NONE | x :: _ => SOME x end;
-fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) =
+fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) sorts =
let
val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
@@ -57,7 +52,6 @@
| mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
let
val tvs = map DatatypeAux.dest_DtTFree dts;
- val sorts = map (rpair []) tvs;
val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
val (gr', (_, type_id)) = mk_type_id module' tname gr;
val (gr'', ps) =
@@ -81,11 +75,14 @@
(map single ps'))))]) ps))) :: rest)
end;
+ fun mk_constr_term cname Ts T ps =
+ List.concat (separate [Pretty.str " $", Pretty.brk 1]
+ ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
+ mk_type false (Ts ---> T), Pretty.str ")"] :: ps));
+
fun mk_term_of_def gr prfx [] = []
| mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
let
- val tvs = map DatatypeAux.dest_DtTFree dts;
- val sorts = map (rpair []) tvs;
val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
val T = Type (tname, dts');
@@ -100,11 +97,9 @@
[Pretty.str (snd (get_const_id cname gr)),
Pretty.brk 1, mk_tuple args]),
Pretty.str " =", Pretty.brk 1] @
- List.concat (separate [Pretty.str " $", Pretty.brk 1]
- ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
- mk_type false (Ts ---> T), Pretty.str ")"] ::
- map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
- Pretty.brk 1, x]]) (args ~~ Ts)))))
+ mk_constr_term cname Ts T
+ (map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
+ Pretty.brk 1, x]]) (args ~~ Ts))))
end) (prfx, cs')
in eqs @ rest end;
@@ -112,7 +107,8 @@
| mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
let
val tvs = map DatatypeAux.dest_DtTFree dts;
- val sorts = map (rpair []) tvs;
+ val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+ val T = Type (tname, Us);
val (cs1, cs2) =
List.partition (exists DatatypeAux.is_rec_type o snd) cs;
val SOME (cname, _) = find_nonempty descr [i] i;
@@ -120,17 +116,30 @@
fun mk_delay p = Pretty.block
[Pretty.str "fn () =>", Pretty.brk 1, p];
+ fun mk_force p = Pretty.block [p, Pretty.brk 1, Pretty.str "()"];
+
fun mk_constr s b (cname, dts) =
let
val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
(DatatypeAux.typ_of_dtyp descr sorts dt))
[Pretty.str (if b andalso DatatypeAux.is_rec_type dt then "0"
else "j")]) dts;
+ val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+ val xs = map Pretty.str
+ (DatatypeProp.indexify_names (replicate (length dts) "x"));
+ val ts = map Pretty.str
+ (DatatypeProp.indexify_names (replicate (length dts) "t"));
val (_, id) = get_const_id cname gr
- in case gs of
- _ :: _ :: _ => Pretty.block
- [Pretty.str id, Pretty.brk 1, mk_tuple gs]
- | _ => mk_app false (Pretty.str id) (map parens gs)
+ in
+ mk_let
+ (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
+ (mk_tuple
+ [case xs of
+ _ :: _ :: _ => Pretty.block
+ [Pretty.str id, Pretty.brk 1, mk_tuple xs]
+ | _ => mk_app false (Pretty.str id) xs,
+ mk_delay (Pretty.block (mk_constr_term cname Ts T
+ (map (single o mk_force) ts)))])
end;
fun mk_choice [c] = mk_constr "(i-1)" false c
@@ -140,7 +149,9 @@
(map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
[Pretty.str "]"]), Pretty.brk 1, Pretty.str "()"];
- val gs = map (Pretty.str o suffix "G" o strip_tname) tvs;
+ val gs = maps (fn s =>
+ let val s' = strip_tname s
+ in [Pretty.str (s' ^ "G"), Pretty.str (s' ^ "T")] end) tvs;
val gen_name = "gen_" ^ snd (get_type_id tname gr)
in
@@ -284,12 +295,12 @@
fun datatype_tycodegen thy defs gr dep module brack (Type (s, Ts)) =
(case DatatypePackage.get_datatype thy s of
NONE => NONE
- | SOME {descr, ...} =>
+ | SOME {descr, sorts, ...} =>
if is_some (get_assoc_type thy s) then NONE else
let
val (gr', ps) = foldl_map
(invoke_tycodegen thy defs dep module false) (gr, Ts);
- val (gr'', module') = add_dt_defs thy defs dep module gr' descr;
+ val (gr'', module') = add_dt_defs thy defs dep module gr' descr sorts;
val (gr''', tyid) = mk_type_id module' s gr''
in SOME (gr''',
Pretty.block ((if null Ts then [] else