src/HOL/Tools/datatype_aux.ML
changeset 7015 85be09eb136c
parent 6394 3d9fd50fcc43
child 8305 93aa21ec5494
--- a/src/HOL/Tools/datatype_aux.ML	Fri Jul 16 12:09:48 1999 +0200
+++ b/src/HOL/Tools/datatype_aux.ML	Fri Jul 16 12:14:04 1999 +0200
@@ -13,8 +13,6 @@
   
   val foldl1 : ('a * 'a -> 'a) -> 'a list -> 'a
 
-  val get_thy : string -> theory -> theory option
-
   val add_path : bool -> string -> theory -> theory
   val parent_path : bool -> theory -> theory
 
@@ -28,6 +26,10 @@
   val indtac : thm -> int -> tactic
   val exh_tac : (string -> thm) -> int -> tactic
 
+  datatype simproc_dist = QuickAndDirty
+                        | FewConstrs of thm list
+                        | ManyConstrs of thm * simpset;
+
   datatype dtyp =
       DtTFree of string
     | DtType of string * (dtyp list)
@@ -35,6 +37,7 @@
 
   type datatype_info
 
+  exception Datatype
   val dtyp_of_typ : (string * string list) list -> typ -> dtyp
   val mk_Free : string -> typ -> int -> term
   val is_rec_type : dtyp -> bool
@@ -46,14 +49,16 @@
   val dest_conj : term -> term list
   val get_nonrec_types : (int * (string * dtyp list *
     (string * dtyp list) list)) list -> (string * sort) list -> typ list
+  val get_branching_types : (int * (string * dtyp list *
+    (string * dtyp list) list)) list -> (string * sort) list -> typ list
   val get_rec_types : (int * (string * dtyp list *
     (string * dtyp list) list)) list -> (string * sort) list -> typ list
   val check_nonempty : (int * (string * dtyp list *
     (string * dtyp list) list)) list list -> unit
   val unfold_datatypes : 
-    datatype_info Symtab.table ->
-      (int * (string * dtyp list *
-        (string * dtyp list) list)) list -> int ->
+    Sign.sg -> (int * (string * dtyp list * (string * dtyp list) list)) list ->
+      (string * sort) list -> datatype_info Symtab.table ->
+        (int * (string * dtyp list * (string * dtyp list) list)) list -> int ->
           (int * (string * dtyp list *
             (string * dtyp list) list)) list list * int
 end;
@@ -67,9 +72,6 @@
 (* FIXME: move to library ? *)
 fun foldl1 f (x::xs) = foldl f (x, xs);
 
-fun get_thy name thy = find_first
-  (equal name o Sign.name_of o Theory.sign_of) (Theory.ancestors_of thy);
-
 fun add_path flat_names s = if flat_names then I else Theory.add_path s;
 fun parent_path flat_names = if flat_names then I else Theory.parent_path;
 
@@ -92,7 +94,7 @@
 (* split theorem thm_1 & ... & thm_n into n theorems *)
 
 fun split_conj_thm th =
-  ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle _ => [th];
+  ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle THM _ => [th];
 
 val mk_conj = foldr1 (HOLogic.mk_binop "op &");
 val mk_disj = foldr1 (HOLogic.mk_binop "op |");
@@ -138,6 +140,12 @@
   in compose_tac (false, exhaustion', nprems_of exhaustion) i state
   end;
 
+(* handling of distinctness theorems *)
+
+datatype simproc_dist = QuickAndDirty
+                      | FewConstrs of thm list
+                      | ManyConstrs of thm * simpset;
+
 (********************** Internal description of datatypes *********************)
 
 datatype dtyp =
@@ -157,7 +165,7 @@
    case_rewrites : thm list,
    induction : thm,
    exhaustion : thm,
-   distinct : thm list,
+   distinct : simproc_dist,
    inject : thm list,
    nchotomy : thm,
    case_cong : thm};
@@ -172,8 +180,13 @@
       DtType (name, map (subst_DtTFree i substs) ts)
   | subst_DtTFree i _ (DtRec j) = DtRec (i + j);
 
-fun dest_DtTFree (DtTFree a) = a;
-fun dest_DtRec (DtRec i) = i;
+exception Datatype;
+
+fun dest_DtTFree (DtTFree a) = a
+  | dest_DtTFree _ = raise Datatype;
+
+fun dest_DtRec (DtRec i) = i
+  | dest_DtRec _ = raise Datatype;
 
 fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts
   | is_rec_type (DtRec _) = true
@@ -201,6 +214,7 @@
 
 fun get_nonrec_types descr sorts =
   let fun add (Ts, T as DtTFree _) = T ins Ts
+        | add (Ts, T as DtType ("fun", [_, DtRec _])) = Ts
         | add (Ts, T as DtType _) = T ins Ts
         | add (Ts, _) = Ts
   in map (typ_of_dtyp descr sorts) (foldl (fn (Ts, (_, (_, _, constrs))) =>
@@ -213,6 +227,16 @@
 fun get_rec_types descr sorts = map (fn (_ , (s, ds, _)) =>
   Type (s, map (typ_of_dtyp descr sorts) ds)) descr;
 
+(* get all branching types *)
+
+fun get_branching_types descr sorts = 
+  let fun add (Ts, DtType ("fun", [T, DtRec _])) = T ins Ts
+        | add (Ts, _) = Ts
+  in map (typ_of_dtyp descr sorts) (foldl (fn (Ts, (_, (_, _, constrs))) =>
+    foldl (fn (Ts', (_, cargs)) =>
+      foldl add (Ts', cargs)) (Ts, constrs)) ([], descr))
+  end;
+
 (* nonemptiness check for datatypes *)
 
 fun check_nonempty descr =
@@ -223,6 +247,7 @@
         val (_, _, constrs) = the (assoc (descr', i));
         fun arg_nonempty (DtRec i) = if i mem is then false
               else is_nonempty_dt (i::is) i
+          | arg_nonempty (DtType ("fun", [_, T])) = arg_nonempty T
           | arg_nonempty _ = true;
       in exists ((forall arg_nonempty) o snd) constrs
       end
@@ -234,16 +259,19 @@
 (* all types of the form DtType (dt_name, [..., DtRec _, ...]) *)
 (* need to be unfolded                                         *)
 
-fun unfold_datatypes (dt_info : datatype_info Symtab.table) descr i =
+fun unfold_datatypes sign orig_descr sorts (dt_info : datatype_info Symtab.table) descr i =
   let
-    fun get_dt_descr i tname dts =
+    fun typ_error T msg = error ("Non-admissible type expression\n" ^
+      Sign.string_of_typ sign (typ_of_dtyp (orig_descr @ descr) sorts T) ^ "\n" ^ msg);
+
+    fun get_dt_descr T i tname dts =
       (case Symtab.lookup (dt_info, tname) of
-         None => error (tname ^ " is not a datatype - can't use it in\
-           \ indirect recursion")
+         None => typ_error T (tname ^ " is not a datatype - can't use it in\
+           \ nested recursion")
        | (Some {index, descr, ...}) =>
            let val (_, vars, _) = the (assoc (descr, index));
-               val subst = ((map dest_DtTFree vars) ~~ dts) handle _ =>
-                 error ("Type constructor " ^ tname ^ " used with wrong\
+               val subst = ((map dest_DtTFree vars) ~~ dts) handle LIST _ =>
+                 typ_error T ("Type constructor " ^ tname ^ " used with wrong\
                   \ number of arguments")
            in (i + index, map (fn (j, (tn, args, cs)) => (i + j,
              (tn, map (subst_DtTFree i subst) args,
@@ -254,9 +282,18 @@
 
     fun unfold_arg ((i, Ts, descrs), T as (DtType (tname, dts))) =
           if is_rec_type T then
-            let val (index, descr) = get_dt_descr i tname dts;
-                val (descr', i') = unfold_datatypes dt_info descr (i + length descr)
-            in (i', Ts @ [DtRec index], descrs @ descr') end
+            if tname = "fun" then
+              if is_rec_type (hd dts) then
+                typ_error T "Non-strictly positive recursive occurrence of type"
+              else
+                (case hd (tl dts) of
+                   DtType ("fun", _) => typ_error T "Curried function types not allowed"
+                 | T' => let val (i', [T''], descrs') = unfold_arg ((i, [], descrs), T')
+                         in (i', Ts @ [DtType (tname, [hd dts, T''])], descrs') end)
+            else
+              let val (index, descr) = get_dt_descr T i tname dts;
+                  val (descr', i') = unfold_datatypes sign orig_descr sorts dt_info descr (i + length descr)
+              in (i', Ts @ [DtRec index], descrs @ descr') end
           else (i, Ts @ [T], descrs)
       | unfold_arg ((i, Ts, descrs), T) = (i, Ts @ [T], descrs);