added combinator for interpretation of construction of datatype
authorhaftmann
Fri, 14 Mar 2008 08:52:52 +0100
changeset 26267 ba710daf77a7
parent 26266 35ae83ca190a
child 26268 80aaf4d034be
added combinator for interpretation of construction of datatype
src/HOL/Library/Eval.thy
src/HOL/Tools/datatype_package.ML
src/HOL/ex/Quickcheck.thy
--- a/src/HOL/Library/Eval.thy	Fri Mar 14 08:52:51 2008 +0100
+++ b/src/HOL/Library/Eval.thy	Fri Mar 14 08:52:52 2008 +0100
@@ -46,7 +46,8 @@
       @{term Const} $ Message_String.mk c $ g ty
   | mk_term f g (t1 $ t2) =
       @{term App} $ mk_term f g t1 $ mk_term f g t2
-  | mk_term f g (Free v) = f v;
+  | mk_term f g (Free v) = f v
+  | mk_term f g (Bound i) = Bound i;
 
 fun mk_term_of ty t = Const (@{const_name term_of}, ty --> @{typ term}) $ t;
 
@@ -157,6 +158,7 @@
 ML {*
 signature EVAL =
 sig
+  val mk_term: ((string * typ) -> term) -> (typ -> term) -> term -> term
   val eval_ref: (unit -> term) option ref
   val eval_term: theory -> term -> term
   val evaluate: Proof.context -> term -> unit
@@ -234,6 +236,7 @@
 *}
 
 hide (open) const term_of
-hide const Const App dummy_term
+hide (open) const Const App
+hide const dummy_term
 
 end
--- a/src/HOL/Tools/datatype_package.ML	Fri Mar 14 08:52:51 2008 +0100
+++ b/src/HOL/Tools/datatype_package.ML	Fri Mar 14 08:52:52 2008 +0100
@@ -65,6 +65,10 @@
   val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option
   val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option
   val get_datatype_constrs : theory -> string -> (string * typ) list option
+  val construction_interpretation: theory
+    -> { atom: typ -> 'a, dtyp: string -> 'a, rtyp: string -> 'a list -> 'a }
+    -> (string * Term.sort) list -> string list
+    -> (string * (string * 'a list) list) list
   val interpretation: (string list -> theory -> theory) -> theory -> theory
   val print_datatypes : theory -> unit
   val make_case :  Proof.context -> bool -> string list -> term ->
@@ -159,6 +163,22 @@
         in SOME (map mk_co cos) end
     | NONE => NONE;
 
+fun construction_interpretation thy { atom, dtyp, rtyp } sorts tycos =
+  let
+    val descr = (#descr o the_datatype thy o hd) tycos;
+    val k = length tycos;
+    val descr_of = the o AList.lookup (op =) descr;
+    fun interpT (T as DtTFree _) = atom (typ_of_dtyp descr sorts T)
+      | interpT (T as DtType (tyco, Ts)) = if is_rec_type T
+          then rtyp tyco (map interpT Ts)
+          else atom (typ_of_dtyp descr sorts T)
+      | interpT (DtRec l) = if l < k then (dtyp o #1 o descr_of) l
+          else let val (tyco, Ts, _) = descr_of l
+          in rtyp tyco (map interpT Ts) end;
+    fun interpC (c, Ts) = (c, map interpT Ts);
+    fun interpK (_, (tyco, _, cs)) = (tyco, map interpC cs);
+  in map interpK (Library.take (k, descr)) end;
+
 fun find_tname var Bi =
   let val frees = map dest_Free (term_frees Bi)
       val params = rename_wrt_term Bi (Logic.strip_params Bi);
--- a/src/HOL/ex/Quickcheck.thy	Fri Mar 14 08:52:51 2008 +0100
+++ b/src/HOL/ex/Quickcheck.thy	Fri Mar 14 08:52:52 2008 +0100
@@ -13,7 +13,7 @@
 class random = type +
   fixes random :: "index \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
 
-print_classes
+text {* Type @{typ "'a itself"} *}
 
 instantiation itself :: (type) random
 begin
@@ -25,7 +25,9 @@
 
 end
 
-lemma random_aux_if:
+text {* Datatypes *}
+
+lemma random'_if:
   fixes random' :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
   assumes "random' 0 j = undefined"
     and "\<And>i. random' (Suc_index i) j = rhs2 i"
@@ -34,108 +36,107 @@
 
 setup {*
 let
-  exception REC;
-  fun random_inst tyco thy =
+  exception REC of string;
+  fun mk_collapse thy ty = Sign.mk_const thy
+    (@{const_name collapse}, [@{typ seed}, ty]);
+  fun mk_cons this_ty (c, args) =
     let
-      val { descr, index, ... } = DatatypePackage.the_datatype thy tyco;
-      val _ = if length descr > 1 then raise REC else ();
-      val (raw_vs, _) = DatatypePackage.the_datatype_spec thy tyco;
+      val tys = map (fst o fst) args;
+      val return = StateMonad.return this_ty @{typ seed}
+        (list_comb (Const (c, tys ---> this_ty),
+           map Bound (length tys - 1 downto 0)));
+      val t = fold_rev (fn ((ty, _), random) => fn t =>
+        StateMonad.mbind ty this_ty @{typ seed} random (Abs ("", ty, t)))
+          args return;
+      val is_rec = exists (snd o fst) args;
+    in (is_rec, StateMonad.run this_ty @{typ seed} t) end;
+  fun mk_conss thy ty [] = NONE
+    | mk_conss thy ty [(_, t)] = SOME t
+    | mk_conss thy ty ts = SOME (mk_collapse thy ty $
+          (Sign.mk_const thy (@{const_name select}, [StateMonad.liftT ty @{typ seed}]) $
+            HOLogic.mk_list (StateMonad.liftT ty @{typ seed}) (map snd ts)));
+  fun mk_clauses thy ty (tyco, (ts_rec, ts_atom)) = 
+    let
+      val SOME t_atom = mk_conss thy ty ts_atom;
+    in case mk_conss thy ty ts_rec
+     of SOME t_rec => mk_collapse thy ty $
+          (Sign.mk_const thy (@{const_name select_default}, [StateMonad.liftT ty @{typ seed}]) $
+             @{term "i\<Colon>index"} $ t_rec $ t_atom)
+      | NONE => t_atom
+    end;
+  fun mk_random_eqs thy tycos =
+    let
+      val (raw_vs, _) = DatatypePackage.the_datatype_spec thy (hd tycos);
       val vs = (map o apsnd)
         (curry (Sorts.inter_sort (Sign.classes_of thy)) @{sort random}) raw_vs;
-      val ty = Type (tyco, map TFree vs);
-      val typ_of = DatatypeAux.typ_of_dtyp descr vs;
-      val SOME (_, _, constrs) = AList.lookup (op =) descr index;
-      val randomN = NameSpace.base @{const_name random};
-      val random_aux_name = randomN ^ "_" ^ Class.type_name tyco ^ "'";
-      fun lift_ty ty = StateMonad.liftT ty @{typ seed};
-      val ty_aux = @{typ index} --> @{typ index} --> lift_ty ty;
-      fun random ty =
-        Const (@{const_name random}, @{typ index} --> lift_ty ty);
-      val random_aux = Free (random_aux_name, ty_aux);
-      fun add_cons_arg dty (is_rec, t) =
-        let
-          val ty' = typ_of dty;
-          val rec_call = case try DatatypeAux.dest_DtRec dty
-           of SOME index' => index = index'
-            | NONE => false
-          val random' = if rec_call
-            then random_aux $ @{term "i\<Colon>index"} $ @{term "j\<Colon>index"}
-            else random ty' $ @{term "j\<Colon>index"}
-          val is_rec' = is_rec orelse DatatypeAux.is_rec_type dty;
-          val t' = StateMonad.mbind ty' ty @{typ seed} random' (Abs ("", ty', t))
-        in (is_rec', t') end;
-      fun mk_cons_t (c, dtys) =
-        let
-          val ty' = map typ_of dtys ---> ty;
-          val t = StateMonad.return ty @{typ seed} (list_comb (Const (c, ty'),
-            map Bound (length dtys - 1 downto 0)));
-          val (is_rec, t') = fold_rev add_cons_arg dtys (false, t);
-        in (is_rec, StateMonad.run ty @{typ seed} t') end;
-      fun check_empty [] = NONE
-        | check_empty xs = SOME xs;
-      fun bundle_cons_ts cons_ts =
-        let
-          val ts = map snd cons_ts;
-          val t = HOLogic.mk_list (lift_ty ty) ts;
-          val t' = Const (@{const_name select}, HOLogic.listT (lift_ty ty) --> lift_ty (lift_ty ty)) $ t;
-          val t'' = Const (@{const_name collapse}, lift_ty (lift_ty ty) --> lift_ty ty) $ t';
-        in t'' end;
-      fun bundle_conss (some_rec_t, nonrec_t) =
+      val this_ty = Type (hd tycos, map TFree vs);
+      val this_ty' = StateMonad.liftT this_ty @{typ seed};
+      val random_name = NameSpace.base @{const_name random};
+      val random'_name = random_name ^ "_" ^ Class.type_name (hd tycos) ^ "'";
+      fun random ty = Sign.mk_const thy (@{const_name random}, [ty]);
+      val random' = Free (random'_name,
+        @{typ index} --> @{typ index} --> this_ty');
+      fun atom ty = ((ty, false), random ty $ @{term "j\<Colon>index"});
+      fun dtyp tyco = ((this_ty, true), random' $ @{term "i\<Colon>index"} $ @{term "j\<Colon>index"});
+      fun rtyp tyco tys = raise REC
+        ("Will not generate random elements for mutual recursive type " ^ quote (hd tycos));
+      val rhss = DatatypePackage.construction_interpretation thy
+            { atom = atom, dtyp = dtyp, rtyp = rtyp } vs tycos
+        |> (map o apsnd o map) (mk_cons this_ty) 
+        |> (map o apsnd) (List.partition fst)
+        |> map (mk_clauses thy this_ty)
+      val eqss = map ((apsnd o map) (HOLogic.mk_Trueprop o HOLogic.mk_eq) o (fn rhs => ((this_ty, random'), [
+          (random' $ @{term "0\<Colon>index"} $ @{term "j\<Colon>index"}, Const (@{const_name undefined}, this_ty')),
+          (random' $ @{term "Suc_index i"} $ @{term "j\<Colon>index"}, rhs)
+        ]))) rhss;
+    in eqss end;
+  fun random_inst [tyco] thy =
         let
-          val t = case some_rec_t
-           of SOME rec_t => Const (@{const_name collapse}, lift_ty (lift_ty ty) --> lift_ty ty)
-               $ (Const (@{const_name select_default},
-                   @{typ index} --> lift_ty ty --> lift_ty ty --> lift_ty (lift_ty ty))
-                  $ @{term "i\<Colon>index"} $ rec_t $ nonrec_t)
-            | NONE => nonrec_t
-        in t end;
-      val random_rhs = constrs
-        |> map mk_cons_t 
-        |> List.partition fst
-        |> apfst (Option.map bundle_cons_ts o check_empty)
-        |> apsnd bundle_cons_ts
-        |> bundle_conss;
-      val random_aux_undef = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
-        (random_aux $ @{term "0\<Colon>index"} $ @{term "j\<Colon>index"}, Const (@{const_name undefined}, lift_ty ty))
-      val random_aux_eq = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
-        (random_aux $ @{term "Suc_index i"} $ @{term "j\<Colon>index"}, random_rhs);
-      val random_eq = (HOLogic.mk_Trueprop o HOLogic.mk_eq) (Const (@{const_name random},
-        @{typ index} --> lift_ty ty) $ @{term "i\<Colon>index"},
-          random_aux $ @{term "i\<Colon>index"} $ @{term "i\<Colon>index"});
-      val del_func = Attrib.internal (fn _ => Thm.declaration_attribute
-        (fn thm => Context.mapping (Code.del_func thm) I));
-      fun add_code simps lthy =
-        let
-          val thy = ProofContext.theory_of lthy;
-          val thm = @{thm random_aux_if}
-            |> Drule.instantiate' [SOME (Thm.ctyp_of thy ty)] [SOME (Thm.cterm_of thy random_aux)]
-            |> (fn thm => thm OF simps)
-            |> singleton (ProofContext.export lthy (ProofContext.init thy))
+          val (raw_vs, _) = DatatypePackage.the_datatype_spec thy tyco;
+          val vs = (map o apsnd)
+            (curry (Sorts.inter_sort (Sign.classes_of thy)) @{sort random}) raw_vs;
+          val { descr, index, ... } = DatatypePackage.the_datatype thy tyco;
+          val ((this_ty, random'), eqs') = singleton (mk_random_eqs thy) tyco;
+          val eq = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
+            (Sign.mk_const thy (@{const_name random}, [this_ty]) $ @{term "i\<Colon>index"},
+               random' $ @{term "i\<Colon>index"} $ @{term "i\<Colon>index"})
+          val del_func = Attrib.internal (fn _ => Thm.declaration_attribute
+            (fn thm => Context.mapping (Code.del_func thm) I));
+          fun add_code simps lthy =
+            let
+              val thy = ProofContext.theory_of lthy;
+              val thm = @{thm random'_if}
+                |> Drule.instantiate' [SOME (Thm.ctyp_of thy this_ty)] [SOME (Thm.cterm_of thy random')]
+                |> (fn thm => thm OF simps)
+                |> singleton (ProofContext.export lthy (ProofContext.init thy))
+            in
+              lthy
+              |> LocalTheory.theory (PureThy.note Thm.internalK (fst (dest_Free random') ^ "_code", thm)
+                   #-> Code.add_func)
+            end;
         in
-          lthy
-          |> LocalTheory.theory (PureThy.note Thm.internalK (random_aux_name ^ "_code", thm)
-               #-> Code.add_func)
-        end;
-    in
-      thy
-      |> TheoryTarget.instantiation ([tyco], vs, @{sort random})
-      |> PrimrecPackage.add_primrec [(random_aux_name, SOME ty_aux, NoSyn)]
-           [(("", [del_func]), random_aux_undef), (("", [del_func]), random_aux_eq)]
-      |-> add_code
-      |> `(fn lthy => Syntax.check_term lthy random_eq)
-      |-> (fn eq => Specification.definition (NONE, (("", []), eq)))
-      |> snd
-      |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
-      |> LocalTheory.exit
-      |> ProofContext.theory_of
-    end;
-  fun add_random_inst [tyco] = (fn thy => random_inst tyco thy handle REC =>
-        (warning ("Will not generated random elements for mutual recursive type " ^ quote tyco); thy))
-    | add_random_inst tycos = tap (fn _ => warning
-        ("Will not generated random elements for mutual recursive type(s) " ^ commas (map quote tycos)));
+          thy
+          |> TheoryTarget.instantiation ([tyco], vs, @{sort random})
+          |> PrimrecPackage.add_primrec
+               [(fst (dest_Free random'), SOME (snd (dest_Free random')), NoSyn)]
+                 (map (fn eq => (("", [del_func]), eq)) eqs')
+          |-> add_code
+          |> `(fn lthy => Syntax.check_term lthy eq)
+          |-> (fn eq => Specification.definition (NONE, (("", []), eq)))
+          |> snd
+          |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
+          |> LocalTheory.exit
+          |> ProofContext.theory_of
+        end
+    | random_inst tycos thy = raise REC
+        ("Will not generate random elements for mutual recursive type(s) " ^ commas (map quote tycos));
+  fun add_random_inst tycos thy = random_inst tycos thy
+     handle REC msg => (warning msg; thy);
 in DatatypePackage.interpretation add_random_inst end
 *}
 
+text {* Type @{typ int} *}
+
 instantiation int :: random
 begin
 
@@ -149,6 +150,8 @@
 
 end
 
+text {* Type @{typ "'a set"} *}
+
 instantiation set :: (random) random
 begin
 
@@ -162,7 +165,7 @@
   "random_set' i j s = (if i = 0 then undefined else collapse (select_default (i - 1)
        (do x \<leftarrow> random (i - 1); xs \<leftarrow> random_set' (i - 1) j; return (insert x xs) done)
        (return {})) s)"
-  by (rule random_aux_if random_set'.simps)+
+  by (rule random'_if random_set'.simps)+
 
 definition
   "random i = random_set' i i"
@@ -171,7 +174,63 @@
 
 end
 
-code_reserved SML Quickcheck
+text {* Type @{typ "'a \<Rightarrow> 'b"} *}
+
+ML {*
+structure Random_Engine =
+struct
+
+open Random_Engine;
+
+fun random_fun (eq : 'a -> 'a -> bool)
+    (random : Random_Engine.seed -> 'b * Random_Engine.seed)
+    (random_split : Random_Engine.seed -> Random_Engine.seed * Random_Engine.seed)
+    (seed : Random_Engine.seed) =
+  let
+    val (seed', seed'') = random_split seed;
+    val state = ref (seed', []);
+    fun random_fun' x =
+      let
+        val (seed, fun_map) = ! state;
+      in case AList.lookup (uncurry eq) fun_map x
+       of SOME y => y
+        | NONE => let
+              val (y, seed') = random seed;
+              val _ = state := (seed', (x, y) :: fun_map);
+            in y end
+      end;
+  in (random_fun', seed'') end;
+
+end
+*}
+
+axiomatization
+  random_fun_aux :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> (seed \<Rightarrow> 'b \<times> seed)
+    \<Rightarrow> (seed \<Rightarrow> seed \<times> seed) \<Rightarrow> seed \<Rightarrow> ('a \<Rightarrow> 'b) \<times> seed"
+
+code_const random_fun_aux (SML "Random'_Engine.random'_fun")
+
+instantiation "fun" :: (term_of, term_of) term_of
+begin
+
+instance ..
+
+end
+
+code_const "Eval.term_of :: ('a\<Colon>term_of \<Rightarrow> 'b\<Colon>term_of) \<Rightarrow> _"
+  (SML "(fn '_ => Const (\"arbitrary\", dummyT))")
+
+instantiation "fun" :: (eq, "{type, random}") random
+begin
+
+definition
+  "random n = random_fun_aux (op =) (random n) split_seed"
+
+instance ..
+
+end
+
+code_reserved SML Random_Engine
 
 
 subsection {* Quickcheck generator *}
@@ -220,6 +279,9 @@
 end
 *}
 
+
+subsection {* Examples *}
+
 (*setup {* Quickcheck.VALUE @{term "\<lambda>(n::nat) (m::nat) (q::nat). n = m + q + 1"} [@{typ nat}, @{typ nat}, @{typ nat}] *}
 export_code VALUE in SML module_name QuickcheckExample file "~~/../../gen_code/quickcheck.ML"
 use "~~/../../gen_code/quickcheck.ML"
@@ -287,61 +349,6 @@
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 88 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 
-subsection {* Incremental function generator *}
-
-ML {*
-structure Quickcheck =
-struct
-
-open Quickcheck;
-
-fun random_fun (eq : 'a -> 'a -> bool)
-    (random : Random_Engine.seed -> 'b * Random_Engine.seed)
-    (random_split : Random_Engine.seed -> Random_Engine.seed * Random_Engine.seed)
-    (seed : Random_Engine.seed) =
-  let
-    val (seed', seed'') = random_split seed;
-    val state = ref (seed', []);
-    fun random_fun' x =
-      let
-        val (seed, fun_map) = ! state;
-      in case AList.lookup (uncurry eq) fun_map x
-       of SOME y => y
-        | NONE => let
-              val (y, seed') = random seed;
-              val _ = state := (seed', (x, y) :: fun_map);
-            in y end
-      end;
-  in (random_fun', seed'') end;
-
-end
-*}
-
-axiomatization
-  random_fun_aux :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> (seed \<Rightarrow> 'b \<times> seed)
-    \<Rightarrow> (seed \<Rightarrow> seed \<times> seed) \<Rightarrow> seed \<Rightarrow> ('a \<Rightarrow> 'b) \<times> seed"
-
-code_const random_fun_aux (SML "Quickcheck.random'_fun")
-
-instantiation "fun" :: (term_of, term_of) term_of
-begin
-
-instance ..
-
-end
-
-code_const "Eval.term_of :: ('a\<Colon>term_of \<Rightarrow> 'b\<Colon>term_of) \<Rightarrow> _"
-  (SML "(fn '_ => Const (\"arbitrary\", dummyT))")
-
-instantiation "fun" :: (eq, "{type, random}") random
-begin
-
-definition
-  "random n = random_fun_aux (op =) (random n) split_seed"
-
-instance ..
-
-end
 
 ML {* val f = Quickcheck.compile_generator_expr @{theory}
   @{term "\<lambda>f k. int (f k) = k"} [@{typ "int \<Rightarrow> nat"}, @{typ int}] *}
@@ -380,4 +387,11 @@
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 
+definition "map2 f xs ys = map (split f) (zip xs ys)"
+
+lemma
+  assumes "\<And>x. f x x = x"
+  shows "map2 f xs xs = xs"
+  by (induct xs) (simp_all add: map2_def assms)
+
 end