Cleaned up code.
authornipkow
Mon, 15 Aug 1994 15:20:34 +0200
changeset 103 c57ab3ce997e
parent 102 18d44ab74672
child 104 a0e6613dfbee
Cleaned up code.
Datatype.ML
--- a/Datatype.ML	Sat Aug 13 16:34:30 1994 +0200
+++ b/Datatype.ML	Mon Aug 15 15:20:34 1994 +0200
@@ -1,7 +1,7 @@
 (*  Title:       HOL/Datatype
     ID:          $Id$
-    Author:      Max Breitling / Carsten Clasohm /
-                 Norbert Voelker / Tobias Nipkow
+    Author:      Max Breitling, Carsten Clasohm,
+                 Tobias Nipkow, Norbert Voelker
     Copyright    1994 TU Muenchen
 *)
 
@@ -10,9 +10,6 @@
 local
 
 val dtK = 5
-val pars = enclose "(" ")";
-val brackets = enclose "[" "]";
-val mk_list = brackets o commas;
 
 in
 
@@ -40,9 +37,11 @@
          ||
           cons                        >> (fn c => [c])) ts;  
   
-      val mk_cons = map (fn ((s, ts), syn) => 
-                           pars (commas [s, mk_list ts, syn]));
-
+      fun mk_cons cs =
+        case findrep (map (fst o fst) cs) of
+           [] => map (fn ((s,ts),syn) => parens (commas [s,mk_list ts,syn])) cs
+         | c::_ => error("Constructor \"" ^ c ^ "\" occurs twice");
+      
       (*remove all quotes from a string*)
       val rem_quotes = implode o filter (fn c => c <> "\"") o explode;
 
@@ -57,17 +56,15 @@
            else quote (tname ^ "_ord_distinct") ::
                 map (fn c => quote (tname ^ "_ord_" ^ c)) uqcs
         end;
-
-       fun rule_names tname cons pre =
+         
+       fun rules tname cons pre =
+         " map (get_axiom thy) " ^
          mk_list (map (fn ((s,_),_) => quote(tname ^ pre ^ rem_quotes s)) cons)
 
-       fun rules tname cons pre =
-         " map (get_axiom thy) " ^ rule_names tname cons pre       
-
       (*generate string for calling 'add_datatype'*)
       fun mk_params ((ts, tname), cons) =
        ("val (thy," ^ tname ^ "_add_primrec) =  add_datatype\n" ^
-       pars (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]) ^
+       parens (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]) ^
        " thy\n\
        \val thy=thy",
        "structure " ^ tname ^ " =\n\
@@ -107,9 +104,10 @@
                    dtRek of dt_type list * string;
 
 local open Syntax.Mixfix
+           ThyParse
       exception Impossible
 
-val is_rek = (fn dtRek _ => true  |  _  => false);
+val is_Rek = (fn dtRek _ => true  |  _  => false);
 
 (* ------------------------------------------------------------------------- *)
 (* Die Funktionen fuer das Umsetzen von Gleichungen in eine Definition mit   *)
@@ -121,7 +119,7 @@
 
 fun rek_args (args, targs) = 
 let fun h (x :: xs, tx :: txs, res) 
-           = h(xs,txs,if is_rek tx then x :: res else res )
+           = h(xs,txs,if is_Rek tx then x :: res else res )
      |  h ([],[],res) = res
 in h (args,targs,[])
 end;
@@ -197,9 +195,9 @@
      |  h reqs (eq::eqs) (c::cs)  res =
 	let
           val (f,(Const(cname_eq,_),args),rhs) = dest_eq eq;
-          val (cname,targs,syn) = c;
+          val (_,cname,targs,_) = c;
         in
-	  if (cname_eq <> const_name cname syn) then h reqs eqs (c::cs) res
+	  if cname_eq <> cname then h reqs eqs (c::cs) res
           else
           if fst(dest_Const(f)) = fname
              andalso (duplicates (map (fst o dest_Free) args) = [])
@@ -224,35 +222,27 @@
 in
 
 fun add_datatype (typevars, tname, cons_list') thy = 
-  let (*check if constructor names are unique*)
-      fun check_cons (cs : (string * 'b * 'c) list) =
-        (case findrep (map #1 cs) of
-           [] => true
-         | c::_ => error("Constructor \"" ^ c ^ "\" occurs twice"));
-
-      (*search for free type variables and convert recursive *)
+  let (*search for free type variables and convert recursive *)
       fun analyse_types (cons, typlist, syn) =
             let fun analyse(t as dtVar v) =
                      if t mem typevars then t
-                     else error ("Variable " ^ v ^ " is free.")
+                     else error ("Free type variable " ^ v ^ " on rhs.")
                   | analyse(dtTyp(typl,s)) =
                      if tname <> s then dtTyp(analyses typl, s)
                      else if typevars = typl then dtRek(typl, s)
                      else error (s ^ " used in different ways")
                   | analyse(dtRek _) = raise Impossible
                  and analyses ts = map analyse ts;
-            in (cons, analyses typlist, syn) end;
+            in (cons, const_name cons syn, analyses typlist, syn) end;
 
-      (*test if there are elements that are not recursive, i.e. if the type is
-        not empty*)
-      fun one_not_rek (cs : ('a * dt_type list * 'c) list) = 
-        let val contains_no_rek = forall (fn dtRek _ => false | _ => true);
-        in exists (contains_no_rek o #2) cs orelse
-           error("Empty type not allowed!") end;
+      (*test if all elements are recursive, i.e. if the type is empty*)
+      fun non_empty (cs : ('a * 'b * dt_type list * 'c) list) = 
+        not(forall (exists is_Rek o #3) cs) orelse
+        error("Empty datatype not allowed!");
 
-      val dummy = check_cons cons_list';
       val cons_list = map analyse_types cons_list';
-      val dummy = one_not_rek cons_list;
+      val dummy = non_empty cons_list;
+      val num_of_cons = length cons_list;
 
       (*Pretty printers for type lists;
         pp_typlist1: parentheses, pp_typlist2: brackets*)
@@ -263,7 +253,7 @@
       and
           pp_typlist' ts = commas (map pp_typ ts)
       and
-          pp_typlist1 ts = if null ts then "" else pars (pp_typlist' ts);
+          pp_typlist1 ts = if null ts then "" else parens (pp_typlist' ts);
 
       fun pp_typlist2 ts = if null ts then "" else brackets (pp_typlist' ts);
 
@@ -272,12 +262,11 @@
 			    	        Args(var, delim, n+1, m);
 
       (* Generate syntax translation for case rules *)
-      fun calc_xrules c_nr y_nr ((id, typlist, syn) :: cs) = 
-            let val name = const_name id syn;
-                val arity = length typlist;
+      fun calc_xrules c_nr y_nr ((_, name, typlist, _) :: cs) = 
+            let val arity = length typlist;
                 val body  = "z" ^ string_of_int(c_nr);
                 val args1 = if arity=0 then ""
-                            else pars (Args ("y", ",", y_nr, y_nr+arity-1));
+                            else parens (Args ("y", ",", y_nr, y_nr+arity-1));
                 val args2 = if arity=0 then ""
                             else "% " ^ Args ("y", " ", y_nr, y_nr+arity-1) 
                             ^ ". ";
@@ -295,16 +284,11 @@
          end;
 
       (*type declarations for constructors*)
-      fun const_type (id, typlist, syn) =
+      fun const_type (id, _, typlist, syn) =
            (id,  
             (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
              pp_typlist1 typevars ^ tname, syn);
 
-      fun create_typevar (dtVar s) typlist =
-            if (dtVar s) mem typlist then 
-	      create_typevar (dtVar (s ^ "'")) typlist 
-            else s
-        | create_typevar _ _ = raise Impossible;
 
       fun assumpt (dtRek _ :: ts, v :: vs ,found) =
             let val h = if found then ";P(" ^ v ^ ")" else "[| P(" ^ v ^ ")"
@@ -352,9 +336,8 @@
 
       fun empty_list n = replicate n "";
 
-      fun t_inducting ((id, typl, syn) :: cs) =
-            let val name = const_name id syn;
-                val tab = insert_types([],typl);
+      fun t_inducting ((_, name, typl, _) :: cs) =
+            let val tab = insert_types([],typl);
                 val arity = length typl;
                 val var_list = convert typl (empty_list arity,tab); 
                 val h = if arity = 0 then " P(" ^ name ^ ")"
@@ -365,28 +348,28 @@
             in if rest = "" then h else h ^ "; " ^ rest end
         | t_inducting [] = "";
 
-      fun t_induct cl typ_name=
+      fun t_induct cl typ_name =
         "[|" ^ t_inducting cl ^ "|] ==> P(" ^ typ_name ^ ")";
 
-      fun case_typlist typevar ((_, typlist, _) :: cs) =
-           let val h = if (length typlist) > 0 then 
-		         (pp_typlist2 typlist) ^ "=>"
+      fun gen_typlist typevar f ((_, _, ts, _) :: cs) =
+           let val h = if (length ts) > 0
+                       then pp_typlist2(f ts) ^ "=>"
                        else ""
-           in "," ^ h ^ typevar ^ (case_typlist typevar cs) end
-        | case_typlist _ [] = "";
+           in "," ^ h ^ typevar ^ (gen_typlist typevar f cs) end
+        | gen_typlist _ _ [] = "";
 
       val t_case = tname ^ "_case";
 
-      fun case_rules arity n ((id, typlist, syn) :: cs) =
-            let val name = const_name id syn;
-                val args = if null typlist then ""
-  			   else pars(Args("x", ",", 1, length typlist))
+      fun case_rules n ((id, name, typlist, _) :: cs) =
+            let val args = if null typlist then ""
+  			   else parens(Args("x", ",", 1, length typlist))
             in (t_case ^ "_" ^ id,
-                t_case ^ "(" ^ name ^ args ^ "," ^ Args ("f", ",", 1, arity) 
-                ^ ") = f" ^ string_of_int(n) ^ args)
-               :: (case_rules arity (n+1) cs)
+                t_case ^ "(" ^ name ^ args ^ "," ^
+                  Args("f", ",", 1, num_of_cons)
+                  ^ ") = f" ^ string_of_int(n) ^ args)
+               :: (case_rules (n+1) cs)
             end
-        | case_rules _ _ [] = [];
+        | case_rules _ [] = [];
 
       val datatype_arity = length typevars;
 
@@ -398,14 +381,15 @@
 
       val datatype_name = pp_typlist1 typevars ^ tname;
 
-      val (case_const, rules_case) =
-         let val typevar = create_typevar (dtVar "'beta") typevars;
-             val arity = length cons_list;
-             val dekl = (t_case, "[" ^ pp_typlist1 typevars ^ tname ^
-                       case_typlist typevar cons_list ^ "]=>" ^ typevar, NoSyn)
-             val rules = case_rules arity 1 cons_list;
-         in (dekl, rules) end;
+      val new_tvar_name = variant (map (fn dtVar s => s) typevars) "'z";
 
+      val case_const =
+         (t_case,
+          "[" ^ pp_typlist1 typevars ^ tname ^
+                gen_typlist new_tvar_name I cons_list ^ "] =>" ^ new_tvar_name,
+          NoSyn);
+
+      val rules_case = case_rules 1 cons_list;
 
 
 (* -------------------------------------------------------------------- *)
@@ -414,62 +398,47 @@
 
       val t_rec = tname ^ "_rec"
 
-fun add_reks typevar ts = 
-  let val tv = dtVar typevar; 
-      fun h (t::ts) res = h ts (if is_rek(t) then tv::res else res)
-	| h [] res  = res        
-  in  h ts ts
-  end;
+      fun add_reks ts = 
+        let val tv = dtVar new_tvar_name; 
+            fun h (t::ts) res = h ts (if is_Rek(t) then tv::res else res)
+	      | h [] res  = res        
+        in  h ts ts  end;
 
-fun rec_typlist typevar ((c,ts,_)::cs) = 
-    let val h = if (length ts) > 0 
-	        then (pp_typlist2 (add_reks typevar ts)) ^ "=>"
-                else ""
-    in "," ^ h ^ typevar ^ (rec_typlist typevar cs)
-    end
-  | rec_typlist _ [] = "";
-
-fun arg_reks arity ts = 
+fun arg_reks ts = 
   let fun arg_rek (t::ts) n res  = 
         let val h = t_rec ^"(" ^ "x" ^string_of_int(n) 
-			       ^"," ^Args("f",",",1,arity) ^")," 
-        in arg_rek ts (n+1) (if is_rek(t) then res ^ h else res)
+			       ^"," ^Args("f",",",1,num_of_cons) ^")," 
+        in arg_rek ts (n+1) (if is_Rek(t) then res ^ h else res)
         end 
       | arg_rek [] _ res = res        
-  in  arg_rek ts 1 ""
-  end;
+  in  arg_rek ts 1 ""  end;
 
-fun rec_rules arity n ((id,ts,syn)::cs) =
-  let val name = const_name id syn;
-      val lts = length ts 
-      val args = if (lts = 0) then ""
-	         else "(" ^ Args("x",",",1,lts) ^ ")" 
+fun rec_rules n ((id,name,ts,_)::cs) =
+  let val lts = length ts 
+      val args = if lts = 0 then ""
+	         else parens(Args("x",",",1,lts)) 
       val rargs = if (lts = 0) then ""
-	          else "("^ arg_reks arity ts ^ Args("x",",",1,lts) ^")"
+	          else "("^ arg_reks ts ^ Args("x",",",1,lts) ^")"
   in     
     ( t_rec ^ "_" ^ id
-    , t_rec ^ "(" ^ name ^ args ^ "," ^ Args("f",",",1,arity) ^ ") = f"
+    , t_rec ^ "(" ^ name ^ args ^ "," ^ Args("f",",",1,num_of_cons) ^ ") = f"
       ^ string_of_int(n) ^ rargs) 
-     :: (rec_rules arity (n+1) cs)
+     :: (rec_rules (n+1) cs)
   end
-  | rec_rules _ _ [] = [];
+  | rec_rules _ [] = [];
 
-val (rec_const,rules_rec) =
-   let val typevar = create_typevar (dtVar "'beta") typevars
-       val arity = length cons_list
-       val dekl = (t_rec,
-                    "[" ^ (pp_typlist1 typevars) ^ tname ^
-                      (rec_typlist typevar cons_list) ^ "]=>" ^ typevar,
-                    NoSyn)
-       val rules = rec_rules arity 1 cons_list
-       in (dekl,rules)
-   end;
+      val rec_const =
+        (t_rec,
+         "[" ^ (pp_typlist1 typevars) ^ tname ^
+               (gen_typlist new_tvar_name add_reks cons_list) ^
+               "] =>" ^ new_tvar_name,
+         NoSyn);
 
-
+      val rules_rec = rec_rules 1 cons_list
 
       val consts = 
         map const_type cons_list
-	@ (if length cons_list < dtK then []
+	@ (if num_of_cons < dtK then []
 	   else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
 	@ [case_const,rec_const];
 
@@ -479,7 +448,7 @@
 
       (*generate 'name_1', ..., 'name_n'*)
       fun C_exp(name, n, var) =
-        if n > 0 then name ^ pars(Args(var, ",", 1, n)) else name;
+        if n > 0 then name ^ parens(Args(var, ",", 1, n)) else name;
 
       (*generate 'x_n = y_n, ..., x_m = y_m'*)
       fun Arg_eql(n,m) = 
@@ -487,9 +456,8 @@
         else "x" ^ string_of_int(n) ^ "=y" ^ string_of_int(n) ^ " & " ^ 
              Arg_eql(n+1, m);
 
-      fun Ci_ing ((id, typlist, syn) :: cs) =
-            let val name = const_name id syn;
-                val arity = length typlist;
+      fun Ci_ing ((id, name, typlist, _) :: cs) =
+            let val arity = length typlist;
             in if arity = 0 then Ci_ing cs
                else ("inject_" ^ id,
                      "(" ^ C_exp(name,arity,"x") ^ "=" ^ C_exp(name,arity,"y") 
@@ -497,10 +465,8 @@
             end
         | Ci_ing [] = [];
 
-      fun Ci_negOne (id1, tl1, syn1) (id2, tl2, syn2) =
-           let val name1 = const_name id1 syn1;
-               val name2 = const_name id2 syn2;
-               val ax = C_exp(name1, length tl1, "x") ^ "~=" ^
+      fun Ci_negOne (id1, name1, tl1, _) (id2, name2, tl2, _) =
+           let val ax = C_exp(name1, length tl1, "x") ^ "~=" ^
                         C_exp(name2, length tl2, "y")
            in (id1 ^ "_not_" ^ id2, ax) end;
 
@@ -512,17 +478,16 @@
 
       fun Ci_neg2() =
         let val ord_t = tname ^ "_ord";
-            val cis = cons_list ~~ (0 upto (length cons_list - 1))
-            fun Ci_neg2equals ((id, typlist, syn), n) =
-              let val name = const_name id syn;
-                  val ax = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
+            val cis = cons_list ~~ (0 upto (num_of_cons - 1))
+            fun Ci_neg2equals ((id, name, typlist, _), n) =
+              let val ax = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
                                  ^ ") = " ^ (suc_expr n)
               in (ord_t ^ "_" ^ id, ax) end
         in (ord_t ^ "_distinct", ord_t^"(x) ~= "^ord_t^"(y) ==> x ~= y") ::
            (map Ci_neg2equals cis)
         end;
 
-      val rules_distinct = if length cons_list < dtK then Ci_neg1 cons_list
+      val rules_distinct = if num_of_cons < dtK then Ci_neg1 cons_list
                            else Ci_neg2();
 
       val rules_inject = Ci_ing cons_list;