src/HOL/Tools/quickcheck_generators.ML
changeset 31603 fa30cd74d7d6
parent 31595 bd2f7211a420
child 31608 a98a47ffdd8d
--- a/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 15:04:31 2009 +0200
+++ b/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 15:04:32 2009 +0200
@@ -13,10 +13,10 @@
   val ensure_random_typecopy: string -> theory -> theory
   val random_aux_specification: string -> term list -> local_theory -> local_theory
   val mk_random_aux_eqs: theory -> DatatypeAux.descr -> (string * sort) list
-    -> typ list -> typ list -> string list -> string list
+    -> string list -> string list * string list -> typ list * typ list
     -> string * (term list * (term * term) list)
   val ensure_random_datatype: string list -> theory -> theory
-  val eval_ref: (unit -> int -> int * int -> term list option * (int * int)) option ref
+  val eval_ref: (unit -> int -> seed -> term list option * seed) option ref
   val setup: theory -> theory
 end;
 
@@ -65,10 +65,7 @@
 
 type seed = Random_Engine.seed;
 
-fun random_fun (T1 : typ) (T2 : typ) (eq : 'a -> 'a -> bool) (term_of : 'a -> term)
-    (random : seed -> ('b * (unit -> term)) * seed)
-    (random_split : seed -> seed * seed)
-    (seed : seed) =
+fun random_fun T1 T2 eq term_of random random_split seed =
   let
     val (seed', seed'') = random_split seed;
     val state = ref (seed', [], Const (@{const_name undefined}, T1 --> T2));
@@ -240,7 +237,9 @@
 
 (* constructing random instances on datatypes *)
 
-fun mk_random_aux_eqs thy descr vs Ts rtyps tycos names =
+exception Datatype_Fun; (*FIXME*)
+
+fun mk_random_aux_eqs thy descr vs tycos (names, auxnames) (Ts, Us) =
   let
     val mk_const = curry (Sign.mk_const thy);
     val i = @{term "i\<Colon>code_numeral"};
@@ -248,10 +247,9 @@
     val j = @{term "j\<Colon>code_numeral"};
     val seed = @{term "s\<Colon>Random.seed"};
     val random_auxN = "random_aux";
-    val random_auxsN = map (prefix (random_auxN ^ "_"))
-      (map Long_Name.base_name names @ map DatatypeAux.name_of_typ rtyps);
+    val random_auxsN = map (prefix (random_auxN ^ "_")) (names @ auxnames);
     fun termifyT T = HOLogic.mk_prodT (T, @{typ "unit \<Rightarrow> term"});
-    val rTs = Ts @ rtyps;
+    val rTs = Ts @ Us;
     fun random_resultT T = @{typ Random.seed}
       --> HOLogic.mk_prodT (termifyT T,@{typ Random.seed});
     val pTs = map random_resultT rTs;
@@ -259,19 +257,19 @@
     val random_auxT = sizeT o random_resultT;
     val random_auxs = map2 (fn s => fn rT => Free (s, random_auxT rT))
       random_auxsN rTs;
+
     fun mk_random_call T = (NONE, (HOLogic.mk_random T j, T));
-    fun mk_random_aux_call T =
+    fun mk_random_aux_call fTs (k, _) (tyco, Ts) =
       let
-        val k = find_index (fn T' => T = T') rTs;
+        val _ = if null fTs then () else raise Datatype_Fun; (*FIXME*)
         val random = nth random_auxs k;
         val size = Option.map snd (DatatypeCodegen.find_shortest_path descr k)
           |> the_default 0;
-      in (SOME size, (random $ i1 $ j, T)) end;
-    fun atom T = mk_random_call T;
-    fun dtyp tyco = mk_random_aux_call (Type (tyco, map TFree vs));
-    fun rtyp (tyco, Ts) _ = mk_random_aux_call (Type (tyco, Ts));
-    val (tss1, tss2) = DatatypePackage.construction_interpretation thy
-      {atom = atom, dtyp = dtyp, rtyp = rtyp} vs tycos;
+      in (SOME size, (random $ i1 $ j, Type (tyco, Ts))) end;
+
+    val tss = DatatypeAux.interpret_construction descr vs
+      { atyp = mk_random_call, dtyp = mk_random_aux_call };
+
     fun mk_consexpr simpleT (c, xs) =
       let
         val (ks, simple_tTs) = split_list xs;
@@ -294,8 +292,8 @@
     fun sort_rec xs =
       map_filter (fn (true, t) => SOME t | _ =>  NONE) xs
       @ map_filter (fn (false, t) => SOME t | _ =>  NONE) xs;
-    val gen_exprss = (map o apfst) (fn tyco => Type (tyco, map TFree vs)) tss1
-      @ (map o apfst) Type tss2
+    val gen_exprss = tss
+      |> (map o apfst) Type
       |> map (fn (T, cs) => (T, (sort_rec o map (mk_consexpr T)) cs));
     fun mk_select (rT, xs) =
       mk_const @{const_name Quickcheck.collapse} [@{typ "Random.seed"}, termifyT rT]
@@ -307,18 +305,17 @@
     val prefix = space_implode "_" (random_auxN :: names);
   in (prefix, (random_auxs, auxs_lhss ~~ auxs_rhss)) end;
 
-fun mk_random_datatype descr vs rtyps tycos names thy =
+fun mk_random_datatype descr vs tycos (names, auxnames) (Ts, Us) thy =
   let
     val i = @{term "i\<Colon>code_numeral"};
     val mk_prop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
-    val Ts = map (fn tyco => Type (tyco, map TFree vs)) tycos;
     fun mk_size_arg k = case DatatypeCodegen.find_shortest_path descr k
      of SOME (_, l) => if l = 0 then i
           else @{term "max :: code_numeral \<Rightarrow> code_numeral \<Rightarrow> code_numeral"}
             $ HOLogic.mk_number @{typ code_numeral} l $ i
       | NONE => i;
     val (prefix, (random_auxs, auxs_eqs)) = (apsnd o apsnd o map) mk_prop_eq
-      (mk_random_aux_eqs thy descr vs Ts rtyps tycos names);
+      (mk_random_aux_eqs thy descr vs tycos (names, auxnames) (Ts, Us));
     val random_defs = map_index (fn (k, T) => mk_prop_eq
       (HOLogic.mk_random T i, nth random_auxs k $ mk_size_arg k $ i)) Ts;
   in
@@ -333,38 +330,36 @@
     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   end;
 
-fun ensure_random_datatype (raw_tycos as tyco :: _) thy =
+fun ensure_random_datatype raw_tycos thy =
   let
     val pp = Syntax.pp_global thy;
     val algebra = Sign.classes_of thy;
-    val info = DatatypePackage.the_datatype thy tyco;
-    val descr = #descr info;
-    val tycos = Library.take (length raw_tycos, descr)
-      |> map (fn (_, (tyco, dTs, _)) => tyco);
-    val names = map Long_Name.base_name (the_default tycos (#alt_names info));
+    val (descr, vs, tycos, (names, auxnames), (raw_Ts, raw_Us)) =
+      DatatypePackage.the_datatype_descr thy raw_tycos;
+
+    (*FIXME this is only an approximation*)
     val (raw_vs :: _, raw_coss) = (split_list
       o map (DatatypePackage.the_datatype_spec thy)) tycos;
-    val raw_Ts = maps (maps snd) raw_coss;
-    val vs' = (fold o fold_atyps) (fn TFree (v, _) => insert (op =) v) raw_Ts [];
+    val raw_Ts' = maps (maps snd) raw_coss;
+    val vs' = (fold o fold_atyps) (fn TFree (v, _) => insert (op =) v) raw_Ts' [];
     val vs = map (fn (v, sort) => (v, if member (op =) vs' v
       then Sorts.inter_sort algebra (sort, @{sort random}) else sort)) raw_vs;
-    val rtyps = Library.drop (length tycos, descr)
-      |> map (fn (_, (tyco, dTs, _)) =>
-          Type (tyco, map (DatatypeAux.typ_of_dtyp descr vs) dTs));
-    val sorts = map snd vs;
     val constrain = map_atyps
       (fn TFree (v, _) => TFree (v, (the o AList.lookup (op =) vs) v));
-    val Ts = map constrain raw_Ts;
+
+    val (Ts, Us) = (pairself o map) constrain (raw_Ts, raw_Us);
+    val sorts = map snd vs;
     val algebra' = algebra
       |> fold (fn tyco => Sorts.add_arities pp
            (tyco, map (rpair sorts) @{sort random})) tycos;
-    val can_inst = forall (fn T =>
-      Sorts.of_sort algebra' (T, @{sort random})) Ts;
+    val can_inst = forall (fn T => Sorts.of_sort algebra' (T, @{sort random})) Ts;
     val hast_inst = exists (fn tyco =>
       can (Sorts.mg_domain algebra tyco) @{sort random}) tycos;
-  in if can_inst andalso not hast_inst then (mk_random_datatype descr vs rtyps tycos names thy
+  in if can_inst andalso not hast_inst then
+    (mk_random_datatype descr vs tycos (names, auxnames) (raw_Ts, raw_Us) thy
     (*FIXME ephemeral handles*)
-    handle e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
+    handle Datatype_Fun => thy
+         | e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
          | e as TYPE (msg, _, _) =>  (tracing msg; raise e)
          | e as ERROR msg =>  (tracing msg; raise e))
   else thy end;