Test data generation and conversion to terms is now more closely
authorberghofe
Thu, 10 Jan 2008 19:21:56 +0100
changeset 25889 c93803252748
parent 25888 48cc198b9ac5
child 25890 0ba401ddbaed
Test data generation and conversion to terms is now more closely intertwined, to allow displaying of functions in test data.
src/HOL/Tools/datatype_codegen.ML
--- a/src/HOL/Tools/datatype_codegen.ML	Thu Jan 10 19:18:14 2008 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML	Thu Jan 10 19:21:56 2008 +0100
@@ -19,11 +19,6 @@
 
 open Codegen;
 
-fun mk_tuple [p] = p
-  | mk_tuple ps = Pretty.block (Pretty.str "(" ::
-      List.concat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
-        [Pretty.str ")"]);
-
 (**** datatype definition ****)
 
 (* find shortest path to constructor with no recursive arguments *)
@@ -43,7 +38,7 @@
         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
   in case xs of [] => NONE | x :: _ => SOME x end;
 
-fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) =
+fun add_dt_defs thy defs dep module gr (descr: DatatypeAux.descr) sorts =
   let
     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
@@ -57,7 +52,6 @@
       | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
           let
             val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
             val (gr', (_, type_id)) = mk_type_id module' tname gr;
             val (gr'', ps) =
@@ -81,11 +75,14 @@
                          (map single ps'))))]) ps))) :: rest)
           end;
 
+    fun mk_constr_term cname Ts T ps =
+      List.concat (separate [Pretty.str " $", Pretty.brk 1]
+        ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
+          mk_type false (Ts ---> T), Pretty.str ")"] :: ps));
+
     fun mk_term_of_def gr prfx [] = []
       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
           let
-            val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
             val T = Type (tname, dts');
@@ -100,11 +97,9 @@
                    [Pretty.str (snd (get_const_id cname gr)),
                     Pretty.brk 1, mk_tuple args]),
                  Pretty.str " =", Pretty.brk 1] @
-                 List.concat (separate [Pretty.str " $", Pretty.brk 1]
-                   ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
-                     mk_type false (Ts ---> T), Pretty.str ")"] ::
-                    map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
-                      Pretty.brk 1, x]]) (args ~~ Ts)))))
+                 mk_constr_term cname Ts T
+                   (map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
+                      Pretty.brk 1, x]]) (args ~~ Ts))))
               end) (prfx, cs')
           in eqs @ rest end;
 
@@ -112,7 +107,8 @@
       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
           let
             val tvs = map DatatypeAux.dest_DtTFree dts;
-            val sorts = map (rpair []) tvs;
+            val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+            val T = Type (tname, Us);
             val (cs1, cs2) =
               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
             val SOME (cname, _) = find_nonempty descr [i] i;
@@ -120,17 +116,30 @@
             fun mk_delay p = Pretty.block
               [Pretty.str "fn () =>", Pretty.brk 1, p];
 
+            fun mk_force p = Pretty.block [p, Pretty.brk 1, Pretty.str "()"];
+
             fun mk_constr s b (cname, dts) =
               let
                 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
                     (DatatypeAux.typ_of_dtyp descr sorts dt))
                   [Pretty.str (if b andalso DatatypeAux.is_rec_type dt then "0"
                      else "j")]) dts;
+                val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
+                val xs = map Pretty.str
+                  (DatatypeProp.indexify_names (replicate (length dts) "x"));
+                val ts = map Pretty.str
+                  (DatatypeProp.indexify_names (replicate (length dts) "t"));
                 val (_, id) = get_const_id cname gr
-              in case gs of
-                  _ :: _ :: _ => Pretty.block
-                    [Pretty.str id, Pretty.brk 1, mk_tuple gs]
-                | _ => mk_app false (Pretty.str id) (map parens gs)
+              in
+                mk_let
+                  (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
+                  (mk_tuple
+                    [case xs of
+                       _ :: _ :: _ => Pretty.block
+                         [Pretty.str id, Pretty.brk 1, mk_tuple xs]
+                     | _ => mk_app false (Pretty.str id) xs,
+                     mk_delay (Pretty.block (mk_constr_term cname Ts T
+                       (map (single o mk_force) ts)))])
               end;
 
             fun mk_choice [c] = mk_constr "(i-1)" false c
@@ -140,7 +149,9 @@
                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
                   [Pretty.str "]"]), Pretty.brk 1, Pretty.str "()"];
 
-            val gs = map (Pretty.str o suffix "G" o strip_tname) tvs;
+            val gs = maps (fn s =>
+              let val s' = strip_tname s
+              in [Pretty.str (s' ^ "G"), Pretty.str (s' ^ "T")] end) tvs;
             val gen_name = "gen_" ^ snd (get_type_id tname gr)
 
           in
@@ -284,12 +295,12 @@
 fun datatype_tycodegen thy defs gr dep module brack (Type (s, Ts)) =
       (case DatatypePackage.get_datatype thy s of
          NONE => NONE
-       | SOME {descr, ...} =>
+       | SOME {descr, sorts, ...} =>
            if is_some (get_assoc_type thy s) then NONE else
            let
              val (gr', ps) = foldl_map
                (invoke_tycodegen thy defs dep module false) (gr, Ts);
-             val (gr'', module') = add_dt_defs thy defs dep module gr' descr;
+             val (gr'', module') = add_dt_defs thy defs dep module gr' descr sorts;
              val (gr''', tyid) = mk_type_id module' s gr''
            in SOME (gr''',
              Pretty.block ((if null Ts then [] else