revised interpretation combinator for datatype constructions
authorhaftmann
Wed, 10 Jun 2009 15:04:32 +0200
changeset 31603 fa30cd74d7d6
parent 31602 59df8222c204
child 31604 eb2f9d709296
revised interpretation combinator for datatype constructions
src/HOL/Quickcheck.thy
src/HOL/Tools/quickcheck_generators.ML
--- a/src/HOL/Quickcheck.thy	Wed Jun 10 15:04:31 2009 +0200
+++ b/src/HOL/Quickcheck.thy	Wed Jun 10 15:04:32 2009 +0200
@@ -87,6 +87,28 @@
 
 subsection {* Complex generators *}
 
+text {* Towards @{typ "'a \<Rightarrow> 'b"} *}
+
+axiomatization random_fun_aux :: "typerep \<Rightarrow> typerep \<Rightarrow> ('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> ('a \<Rightarrow> term)
+  \<Rightarrow> (Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed) \<Rightarrow> (Random.seed \<Rightarrow> Random.seed \<times> Random.seed)
+  \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed"
+
+definition random_fun_lift :: "(code_numeral \<Rightarrow> Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed)
+  \<Rightarrow> code_numeral \<Rightarrow> Random.seed \<Rightarrow> (('a\<Colon>term_of \<Rightarrow> 'b\<Colon>typerep) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
+  "random_fun_lift f i = random_fun_aux TYPEREP('a) TYPEREP('b) (op =) Code_Eval.term_of (f i) Random.split_seed"
+
+instantiation "fun" :: ("{eq, term_of}", "{type, random}") random
+begin
+
+definition random_fun :: "code_numeral \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
+  "random = random_fun_lift random"
+
+instance ..
+
+end
+
+text {* Towards type copies and datatypes *}
+
 definition collapse :: "('a \<Rightarrow> ('a \<Rightarrow> 'b \<times> 'a) \<times> 'a) \<Rightarrow> 'a \<Rightarrow> 'b \<times> 'a" where
   "collapse f = (f o\<rightarrow> id)"
 
@@ -109,29 +131,13 @@
 
 code_reserved Quickcheck Quickcheck_Generators
 
-text {* Type @{typ "'a \<Rightarrow> 'b"} *}
-
-axiomatization random_fun_aux :: "typerep \<Rightarrow> typerep \<Rightarrow> ('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> ('a \<Rightarrow> term)
-  \<Rightarrow> (Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed) \<Rightarrow> (Random.seed \<Rightarrow> Random.seed \<times> Random.seed)
-  \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed"
-
 code_const random_fun_aux (Quickcheck "Quickcheck'_Generators.random'_fun")
   -- {* With enough criminal energy this can be abused to derive @{prop False};
   for this reason we use a distinguished target @{text Quickcheck}
   not spoiling the regular trusted code generation *}
 
-instantiation "fun" :: ("{eq, term_of}", "{type, random}") random
-begin
 
-definition random_fun :: "code_numeral \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
-  "random n = random_fun_aux TYPEREP('a) TYPEREP('b) (op =) Code_Eval.term_of (random n) Random.split_seed"
-
-instance ..
-
-end
-
-
-hide (open) const collapse beyond
+hide (open) const collapse beyond random_fun_aux random_fun_lift
 
 no_notation fcomp (infixl "o>" 60)
 no_notation scomp (infixl "o\<rightarrow>" 60)
--- a/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 15:04:31 2009 +0200
+++ b/src/HOL/Tools/quickcheck_generators.ML	Wed Jun 10 15:04:32 2009 +0200
@@ -13,10 +13,10 @@
   val ensure_random_typecopy: string -> theory -> theory
   val random_aux_specification: string -> term list -> local_theory -> local_theory
   val mk_random_aux_eqs: theory -> DatatypeAux.descr -> (string * sort) list
-    -> typ list -> typ list -> string list -> string list
+    -> string list -> string list * string list -> typ list * typ list
     -> string * (term list * (term * term) list)
   val ensure_random_datatype: string list -> theory -> theory
-  val eval_ref: (unit -> int -> int * int -> term list option * (int * int)) option ref
+  val eval_ref: (unit -> int -> seed -> term list option * seed) option ref
   val setup: theory -> theory
 end;
 
@@ -65,10 +65,7 @@
 
 type seed = Random_Engine.seed;
 
-fun random_fun (T1 : typ) (T2 : typ) (eq : 'a -> 'a -> bool) (term_of : 'a -> term)
-    (random : seed -> ('b * (unit -> term)) * seed)
-    (random_split : seed -> seed * seed)
-    (seed : seed) =
+fun random_fun T1 T2 eq term_of random random_split seed =
   let
     val (seed', seed'') = random_split seed;
     val state = ref (seed', [], Const (@{const_name undefined}, T1 --> T2));
@@ -240,7 +237,9 @@
 
 (* constructing random instances on datatypes *)
 
-fun mk_random_aux_eqs thy descr vs Ts rtyps tycos names =
+exception Datatype_Fun; (*FIXME*)
+
+fun mk_random_aux_eqs thy descr vs tycos (names, auxnames) (Ts, Us) =
   let
     val mk_const = curry (Sign.mk_const thy);
     val i = @{term "i\<Colon>code_numeral"};
@@ -248,10 +247,9 @@
     val j = @{term "j\<Colon>code_numeral"};
     val seed = @{term "s\<Colon>Random.seed"};
     val random_auxN = "random_aux";
-    val random_auxsN = map (prefix (random_auxN ^ "_"))
-      (map Long_Name.base_name names @ map DatatypeAux.name_of_typ rtyps);
+    val random_auxsN = map (prefix (random_auxN ^ "_")) (names @ auxnames);
     fun termifyT T = HOLogic.mk_prodT (T, @{typ "unit \<Rightarrow> term"});
-    val rTs = Ts @ rtyps;
+    val rTs = Ts @ Us;
     fun random_resultT T = @{typ Random.seed}
       --> HOLogic.mk_prodT (termifyT T,@{typ Random.seed});
     val pTs = map random_resultT rTs;
@@ -259,19 +257,19 @@
     val random_auxT = sizeT o random_resultT;
     val random_auxs = map2 (fn s => fn rT => Free (s, random_auxT rT))
       random_auxsN rTs;
+
     fun mk_random_call T = (NONE, (HOLogic.mk_random T j, T));
-    fun mk_random_aux_call T =
+    fun mk_random_aux_call fTs (k, _) (tyco, Ts) =
       let
-        val k = find_index (fn T' => T = T') rTs;
+        val _ = if null fTs then () else raise Datatype_Fun; (*FIXME*)
         val random = nth random_auxs k;
         val size = Option.map snd (DatatypeCodegen.find_shortest_path descr k)
           |> the_default 0;
-      in (SOME size, (random $ i1 $ j, T)) end;
-    fun atom T = mk_random_call T;
-    fun dtyp tyco = mk_random_aux_call (Type (tyco, map TFree vs));
-    fun rtyp (tyco, Ts) _ = mk_random_aux_call (Type (tyco, Ts));
-    val (tss1, tss2) = DatatypePackage.construction_interpretation thy
-      {atom = atom, dtyp = dtyp, rtyp = rtyp} vs tycos;
+      in (SOME size, (random $ i1 $ j, Type (tyco, Ts))) end;
+
+    val tss = DatatypeAux.interpret_construction descr vs
+      { atyp = mk_random_call, dtyp = mk_random_aux_call };
+
     fun mk_consexpr simpleT (c, xs) =
       let
         val (ks, simple_tTs) = split_list xs;
@@ -294,8 +292,8 @@
     fun sort_rec xs =
       map_filter (fn (true, t) => SOME t | _ =>  NONE) xs
       @ map_filter (fn (false, t) => SOME t | _ =>  NONE) xs;
-    val gen_exprss = (map o apfst) (fn tyco => Type (tyco, map TFree vs)) tss1
-      @ (map o apfst) Type tss2
+    val gen_exprss = tss
+      |> (map o apfst) Type
       |> map (fn (T, cs) => (T, (sort_rec o map (mk_consexpr T)) cs));
     fun mk_select (rT, xs) =
       mk_const @{const_name Quickcheck.collapse} [@{typ "Random.seed"}, termifyT rT]
@@ -307,18 +305,17 @@
     val prefix = space_implode "_" (random_auxN :: names);
   in (prefix, (random_auxs, auxs_lhss ~~ auxs_rhss)) end;
 
-fun mk_random_datatype descr vs rtyps tycos names thy =
+fun mk_random_datatype descr vs tycos (names, auxnames) (Ts, Us) thy =
   let
     val i = @{term "i\<Colon>code_numeral"};
     val mk_prop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
-    val Ts = map (fn tyco => Type (tyco, map TFree vs)) tycos;
     fun mk_size_arg k = case DatatypeCodegen.find_shortest_path descr k
      of SOME (_, l) => if l = 0 then i
           else @{term "max :: code_numeral \<Rightarrow> code_numeral \<Rightarrow> code_numeral"}
             $ HOLogic.mk_number @{typ code_numeral} l $ i
       | NONE => i;
     val (prefix, (random_auxs, auxs_eqs)) = (apsnd o apsnd o map) mk_prop_eq
-      (mk_random_aux_eqs thy descr vs Ts rtyps tycos names);
+      (mk_random_aux_eqs thy descr vs tycos (names, auxnames) (Ts, Us));
     val random_defs = map_index (fn (k, T) => mk_prop_eq
       (HOLogic.mk_random T i, nth random_auxs k $ mk_size_arg k $ i)) Ts;
   in
@@ -333,38 +330,36 @@
     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   end;
 
-fun ensure_random_datatype (raw_tycos as tyco :: _) thy =
+fun ensure_random_datatype raw_tycos thy =
   let
     val pp = Syntax.pp_global thy;
     val algebra = Sign.classes_of thy;
-    val info = DatatypePackage.the_datatype thy tyco;
-    val descr = #descr info;
-    val tycos = Library.take (length raw_tycos, descr)
-      |> map (fn (_, (tyco, dTs, _)) => tyco);
-    val names = map Long_Name.base_name (the_default tycos (#alt_names info));
+    val (descr, vs, tycos, (names, auxnames), (raw_Ts, raw_Us)) =
+      DatatypePackage.the_datatype_descr thy raw_tycos;
+
+    (*FIXME this is only an approximation*)
     val (raw_vs :: _, raw_coss) = (split_list
       o map (DatatypePackage.the_datatype_spec thy)) tycos;
-    val raw_Ts = maps (maps snd) raw_coss;
-    val vs' = (fold o fold_atyps) (fn TFree (v, _) => insert (op =) v) raw_Ts [];
+    val raw_Ts' = maps (maps snd) raw_coss;
+    val vs' = (fold o fold_atyps) (fn TFree (v, _) => insert (op =) v) raw_Ts' [];
     val vs = map (fn (v, sort) => (v, if member (op =) vs' v
       then Sorts.inter_sort algebra (sort, @{sort random}) else sort)) raw_vs;
-    val rtyps = Library.drop (length tycos, descr)
-      |> map (fn (_, (tyco, dTs, _)) =>
-          Type (tyco, map (DatatypeAux.typ_of_dtyp descr vs) dTs));
-    val sorts = map snd vs;
     val constrain = map_atyps
       (fn TFree (v, _) => TFree (v, (the o AList.lookup (op =) vs) v));
-    val Ts = map constrain raw_Ts;
+
+    val (Ts, Us) = (pairself o map) constrain (raw_Ts, raw_Us);
+    val sorts = map snd vs;
     val algebra' = algebra
       |> fold (fn tyco => Sorts.add_arities pp
            (tyco, map (rpair sorts) @{sort random})) tycos;
-    val can_inst = forall (fn T =>
-      Sorts.of_sort algebra' (T, @{sort random})) Ts;
+    val can_inst = forall (fn T => Sorts.of_sort algebra' (T, @{sort random})) Ts;
     val hast_inst = exists (fn tyco =>
       can (Sorts.mg_domain algebra tyco) @{sort random}) tycos;
-  in if can_inst andalso not hast_inst then (mk_random_datatype descr vs rtyps tycos names thy
+  in if can_inst andalso not hast_inst then
+    (mk_random_datatype descr vs tycos (names, auxnames) (raw_Ts, raw_Us) thy
     (*FIXME ephemeral handles*)
-    handle e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
+    handle Datatype_Fun => thy
+         | e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
          | e as TYPE (msg, _, _) =>  (tracing msg; raise e)
          | e as ERROR msg =>  (tracing msg; raise e))
   else thy end;