src/HOL/Tools/datatype_codegen.ML
changeset 25889 c93803252748
parent 25864 11f531354852
child 26513 6f306c8c2c54
--- a/src/HOL/Tools/datatype_codegen.ML	Thu Jan 10 19:18:14 2008 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML	Thu Jan 10 19:21:56 2008 +0100
@@ -19,11 +19,6 @@
 
 open Codegen;
 
-fun mk_tuple [p] = p
-  | mk_tuple ps = Pretty.block (Pretty.str "(" ::
-      List.concat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
-        [Pretty.str ")"]);
-
 (**** datatype definition ****)
 
 (* find shortest path to constructor with no recursive arguments *)
@@ -43,7 +38,7 @@
         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
   in case xs of [] => NONE | x :: _ => SOME x end;
 
-fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) =
+fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) sorts =
   let
     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
@@ -57,7 +52,6 @@
       | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
           let
             val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
             val (gr', (_, type_id)) = mk_type_id module' tname gr;
             val (gr'', ps) =
@@ -81,11 +75,14 @@
                          (map single ps'))))]) ps))) :: rest)
           end;
 
+    fun mk_constr_term cname Ts T ps =
+      List.concat (separate [Pretty.str " $", Pretty.brk 1]
+        ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
+          mk_type false (Ts ---> T), Pretty.str ")"] :: ps));
+
     fun mk_term_of_def gr prfx [] = []
       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
           let
-            val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
             val T = Type (tname, dts');
@@ -100,11 +97,9 @@
                    [Pretty.str (snd (get_const_id cname gr)),
                     Pretty.brk 1, mk_tuple args]),
                  Pretty.str " =", Pretty.brk 1] @
-                 List.concat (separate [Pretty.str " $", Pretty.brk 1]
-                   ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
-                     mk_type false (Ts ---> T), Pretty.str ")"] ::
-                    map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
-                      Pretty.brk 1, x]]) (args ~~ Ts)))))
+                 mk_constr_term cname Ts T
+                   (map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
+                      Pretty.brk 1, x]]) (args ~~ Ts))))
               end) (prfx, cs')
           in eqs @ rest end;
 
@@ -112,7 +107,8 @@
       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
           let
             val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
+            val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+            val T = Type (tname, Us);
             val (cs1, cs2) =
               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
             val SOME (cname, _) = find_nonempty descr [i] i;
@@ -120,17 +116,30 @@
             fun mk_delay p = Pretty.block
               [Pretty.str "fn () =>", Pretty.brk 1, p];
 
+            fun mk_force p = Pretty.block [p, Pretty.brk 1, Pretty.str "()"];
+
             fun mk_constr s b (cname, dts) =
               let
                 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
                     (DatatypeAux.typ_of_dtyp descr sorts dt))
                   [Pretty.str (if b andalso DatatypeAux.is_rec_type dt then "0"
                      else "j")]) dts;
+                val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+                val xs = map Pretty.str
+                  (DatatypeProp.indexify_names (replicate (length dts) "x"));
+                val ts = map Pretty.str
+                  (DatatypeProp.indexify_names (replicate (length dts) "t"));
                 val (_, id) = get_const_id cname gr
-              in case gs of
-                  _ :: _ :: _ => Pretty.block
-                    [Pretty.str id, Pretty.brk 1, mk_tuple gs]
-                | _ => mk_app false (Pretty.str id) (map parens gs)
+              in
+                mk_let
+                  (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
+                  (mk_tuple
+                    [case xs of
+                       _ :: _ :: _ => Pretty.block
+                         [Pretty.str id, Pretty.brk 1, mk_tuple xs]
+                     | _ => mk_app false (Pretty.str id) xs,
+                     mk_delay (Pretty.block (mk_constr_term cname Ts T
+                       (map (single o mk_force) ts)))])
               end;
 
             fun mk_choice [c] = mk_constr "(i-1)" false c
@@ -140,7 +149,9 @@
                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
                   [Pretty.str "]"]), Pretty.brk 1, Pretty.str "()"];
 
-            val gs = map (Pretty.str o suffix "G" o strip_tname) tvs;
+            val gs = maps (fn s =>
+              let val s' = strip_tname s
+              in [Pretty.str (s' ^ "G"), Pretty.str (s' ^ "T")] end) tvs;
             val gen_name = "gen_" ^ snd (get_type_id tname gr)
 
           in
@@ -284,12 +295,12 @@
 fun datatype_tycodegen thy defs gr dep module brack (Type (s, Ts)) =
       (case DatatypePackage.get_datatype thy s of
          NONE => NONE
-       | SOME {descr, ...} =>
+       | SOME {descr, sorts, ...} =>
            if is_some (get_assoc_type thy s) then NONE else
            let
              val (gr', ps) = foldl_map
                (invoke_tycodegen thy defs dep module false) (gr, Ts);
-             val (gr'', module') = add_dt_defs thy defs dep module gr' descr;
+             val (gr'', module') = add_dt_defs thy defs dep module gr' descr sorts;
              val (gr''', tyid) = mk_type_id module' s gr''
            in SOME (gr''',
              Pretty.block ((if null Ts then [] else