src/HOL/Tools/quickcheck_generators.ML
changeset 31608 a98a47ffdd8d
parent 31603 fa30cd74d7d6
child 31609 8d353e3214d0
--- a/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 16:10:30 2009 +0200
+++ b/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 16:10:31 2009 +0200
@@ -330,33 +330,31 @@
     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   end;
 
+
+
 fun ensure_random_datatype raw_tycos thy =
   let
     val pp = Syntax.pp_global thy;
     val algebra = Sign.classes_of thy;
-    val (descr, vs, tycos, (names, auxnames), (raw_Ts, raw_Us)) =
+    val (descr, raw_vs, tycos, (names, auxnames), raw_TUs) =
       DatatypePackage.the_datatype_descr thy raw_tycos;
-
-    (*FIXME this is only an approximation*)
-    val (raw_vs :: _, raw_coss) = (split_list
-      o map (DatatypePackage.the_datatype_spec thy)) tycos;
-    val raw_Ts' = maps (maps snd) raw_coss;
-    val vs' = (fold o fold_atyps) (fn TFree (v, _) => insert (op =) v) raw_Ts' [];
-    val vs = map (fn (v, sort) => (v, if member (op =) vs' v
-      then Sorts.inter_sort algebra (sort, @{sort random}) else sort)) raw_vs;
-    val constrain = map_atyps
-      (fn TFree (v, _) => TFree (v, (the o AList.lookup (op =) vs) v));
-
-    val (Ts, Us) = (pairself o map) constrain (raw_Ts, raw_Us);
-    val sorts = map snd vs;
-    val algebra' = algebra
-      |> fold (fn tyco => Sorts.add_arities pp
-           (tyco, map (rpair sorts) @{sort random})) tycos;
-    val can_inst = forall (fn T => Sorts.of_sort algebra' (T, @{sort random})) Ts;
-    val hast_inst = exists (fn tyco =>
+    val aTs = (flat o maps snd o maps snd) (DatatypeAux.interpret_construction descr raw_vs
+      { atyp = single, dtyp = K o K });
+    fun meet_random T = Sorts.meet_sort (Sign.classes_of thy) (Logic.varifyT T, @{sort random});
+    val vtab = (Vartab.empty
+      |> fold (fn (v, sort) => Vartab.update ((v, 0), sort)) raw_vs
+      |> fold meet_random aTs
+      |> SOME) handle CLASS_ERROR => NONE;
+    val vconstrain = case vtab of SOME vtab => (fn (v, _) =>
+          (v, (the o Vartab.lookup vtab) (v, 0)))
+      | NONE => I;
+    val vs = map vconstrain raw_vs;
+    val constrain = map_atyps (fn TFree v => TFree (vconstrain v));
+    val has_inst = exists (fn tyco =>
       can (Sorts.mg_domain algebra tyco) @{sort random}) tycos;
-  in if can_inst andalso not hast_inst then
-    (mk_random_datatype descr vs tycos (names, auxnames) (raw_Ts, raw_Us) thy
+  in if is_some vtab andalso not has_inst then
+    (mk_random_datatype descr vs tycos (names, auxnames)
+      ((pairself o map) constrain raw_TUs) thy
     (*FIXME ephemeral handles*)
     handle Datatype_Fun => thy
          | e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)