Datatype.ML
changeset 96 d94d0b324b4b
parent 92 bcd0ee8d71aa
child 101 5f99df1e26c4
--- a/Datatype.ML	Fri Jul 15 13:53:18 1994 +0200
+++ b/Datatype.ML	Fri Jul 15 14:04:28 1994 +0200
@@ -6,28 +6,28 @@
 
 
 (*choice between Ci_neg1 and Ci_neg2 axioms depends on number of constructors*)
-local val dtK = 5
+local
+
+val dtK = 5
+val pars = parents "(" ")";
+val brackets = parents "[" "]";
+
 in
 
 local open ThyParse in
 val datatype_decls =
-  let fun cat s1 s2 = s1 ^ " " ^ s2;
+  let val mk_list = brackets o commas;
 
-      val pars = parents "(" ")";
-      val brackets = parents "[" "]";
-
-      val mk_list = brackets o commas;
-
-      val tvar = type_var >> cat "dtVar";
+      val tvar = type_var >> (fn s => "dtVar" ^ s);
 
       val type_var_list = 
         tvar >> (fn s => [s]) || "(" $$-- list1 tvar --$$ ")";
     
       val typ =
-         ident                  >> (cat "dtId" o quote)
+         ident                  >> (fn s => "dtTyp([]," ^ quote s ^")")
         ||
-         type_var_list -- ident >> (fn (ts, id) => "dtComp (" ^ mk_list ts ^
-  				  ", " ^ quote id ^ ")")
+         type_var_list -- ident >> (fn (ts, id) => "dtTyp(" ^ mk_list ts ^
+  				  "," ^ quote id ^ ")")
         ||
          tvar;
     
@@ -44,25 +44,19 @@
                            pars (commas [s, mk_list ts, syn]));
   
       (*remove all quotes from a string*)
-      fun rem_quotes s = implode (filter (fn c => c <> "\"") (explode s));
+      val rem_quotes = implode o filter (fn c => c <> "\"") o explode;
             
-      (*generate names of ineq axioms*)
-      fun rules_ineq cs tname = 
-        let (*combine all constructor names with all others w/o duplicates*)
-            fun negOne _ [] = [] 
-              | negOne (c : (string * 'a) * 'b) ((c2 : (string * 'a) * 'b) 
-                                                 :: cs) = 
-                  quote ("ineq_" ^ rem_quotes (#1 (#1 c)) ^ "_" ^ 
-                  rem_quotes (#1 (#1 c2))) :: negOne c cs;
-  
+      (*generate names of distinct axioms*)
+      fun rules_distinct cs tname = 
+        let val uqcs = map (fn ((s,_),_) => rem_quotes s) cs;
+            (*combine all constructor names with all others w/o duplicates*)
+            fun negOne c = map (fn c2 => quote (c ^ "_not_" ^ c2));
             fun neg1 [] = []
               | neg1 (c1 :: cs) = (negOne c1 cs) @ (neg1 cs)
-        in if length cs < dtK then neg1 cs
-           else map (fn n => quote (tname ^ "_ord" ^ string_of_int n)) 
-                    (0 upto (length cs))
+        in if length uqcs < dtK then neg1 uqcs
+           else quote (tname ^ "_ord_distinct") ::
+                map (fn c => quote (tname ^ "_ord_" ^ c)) uqcs
         end;
-
-      fun arg1 ((_, ts), _) = not (null ts);
           
       (*generate string for calling 'add_datatype'*)
       fun mk_params ((ts, tname), cons) =
@@ -72,40 +66,32 @@
        \struct\n\
        \  val inject = map (get_axiom thy) " ^
          mk_list (map (fn ((s,_), _) => quote ("inject_" ^ rem_quotes s)) 
-                      (filter arg1 cons)) ^ ";\n\
-       \  val ineq = " ^ (if length cons < dtK then "let val ineq' = " else "")
-         ^ "map (get_axiom thy) " ^ mk_list (rules_ineq cons tname) ^ 
+                      (filter_out (null o snd o fst) cons)) ^ ";\n\
+       \  val distinct = " ^ (if length cons < dtK then "let val distinct' = " else "")
+         ^ "map (get_axiom thy) " ^ mk_list (rules_distinct cons tname) ^ 
          (if length cons < dtK then 
-           "  in ineq' @ (map (fn t => sym COMP (t RS contrapos)) ineq') end"
+           "  in distinct' @ (map (fn t => sym COMP (t RS contrapos)) distinct') end"
           else "") ^ ";\n\
        \  val induct = get_axiom thy \"" ^ tname ^ "_induct\";\n\
        \  val cases = map (get_axiom thy) " ^
          mk_list (map (fn ((s,_),_) => 
                          quote(tname ^ "_case_" ^ rem_quotes s)) cons) ^ ";\n\
-       \  val simps = inject @ ineq @ cases;\n\
-       \  fun induct_tac a = res_inst_tac [(" ^ quote tname ^ ", a)] induct;\n\
+       \  val simps = inject @ distinct @ cases;\n\
+       \  fun induct_tac a = res_inst_tac[(" ^ quote tname ^ ", a)]induct;\n\
        \end;\n");
   in (type_var_list || empty) -- ident --$$ "=" -- constructs >> mk_params end
 end;
 
 (*used for constructor parameters*)
 datatype dt_type = dtVar of string |
-                   dtId  of string |
-                   dtComp of dt_type list * string |
+                   dtTyp of dt_type list * string |
                    dtRek of dt_type list * string;
 
 local open Syntax.Mixfix
       exception Impossible
 in
 fun add_datatype (typevars, tname, cons_list') thy = 
-  let fun cat s1 s2 = s1 ^ " " ^ s2;
-
-      val pars = parents "(" ")";
-      val brackets = parents "[" "]";
-
-      val mk_list = brackets o commas;
-
-      (*check if constructor names are unique*)
+  let (*check if constructor names are unique*)
       fun check_cons (cs : (string * 'b * 'c) list) =
         (case findrep (map #1 cs) of
            [] => true
@@ -113,25 +99,16 @@
 
       (*search for free type variables and convert recursive *)
       fun analyse_types (cons, typlist, syn) =
-            let fun analyse ((dtVar v) :: typlist) =
-                     if ((dtVar v) mem typevars) then
-                       (dtVar v) :: analyse typlist
+            let fun analyse(t as dtVar v) =
+                     if t mem typevars then t
                      else error ("Variable " ^ v ^ " is free.")
-                  | analyse ((dtId s) :: typlist) =
-                     if tname<>s then (dtId s) :: analyse typlist
-                     else if null typevars then 
-                       dtRek ([], tname) :: analyse typlist
+                  | analyse(dtTyp(typl,s)) =
+                     if tname <> s then dtTyp(analyses typl, s)
+                     else if typevars = typl then dtRek(typl, s)
                      else error (s ^ " used in different ways")
-                  | analyse (dtComp (typl,s) :: typlist) =
-                     if tname <> s then dtComp (analyse typl, s)
-                                     :: analyse typlist
-                     else if typevars = typl then
-                       dtRek (typl, s) :: analyse typlist
-                     else 
-                       error (s ^ " used in different ways")
-                  | analyse [] = []
-                  | analyse ((dtRek _) :: _) = raise Impossible;
-            in (cons, analyse typlist, syn) end;
+                  | analyse(dtRek _) = raise Impossible
+                 and analyses ts = map analyse ts;
+            in (cons, analyses typlist, syn) end;
 
       (*test if there are elements that are not recursive, i.e. if the type is
         not empty*)
@@ -147,8 +124,8 @@
       (*Pretty printers for type lists;
         pp_typlist1: parentheses, pp_typlist2: brackets*)
       fun pp_typ (dtVar s) = s
-        | pp_typ (dtId s) = s
-        | pp_typ (dtComp (typvars, id)) = (pp_typlist1 typvars) ^ id
+        | pp_typ (dtTyp (typvars, id)) =
+            if null typvars then id else (pp_typlist1 typvars) ^ id
         | pp_typ (dtRek (typvars, id)) = (pp_typlist1 typvars) ^ id
       and
           pp_typlist' ts = commas (map pp_typ ts)
@@ -185,12 +162,10 @@
          end;
 
       (*type declarations for constructors*)
-      fun const_types ((id, typlist, syn) :: cs) =
+      fun const_type (id, typlist, syn) =
            (id,  
             (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
-             pp_typlist1 typevars ^ tname, syn)
-           :: const_types cs
-        | const_types [] = [];
+             pp_typlist1 typevars ^ tname, syn);
 
       fun create_typevar (dtVar s) typlist =
             if (dtVar s) mem typlist then 
@@ -199,29 +174,24 @@
         | create_typevar _ _ = raise Impossible;
 
       fun assumpt (dtRek _ :: ts, v :: vs ,found) =
-            let val h = if found then ";P(" ^ v ^ ")"
-                                 else "[| P(" ^ v ^ ")"
+            let val h = if found then ";P(" ^ v ^ ")" else "[| P(" ^ v ^ ")"
             in h ^ (assumpt (ts, vs, true)) end
         | assumpt (t :: ts, v :: vs, found) = assumpt (ts, vs, found)
         | assumpt ([], [], found) = if found then "|] ==>" else ""
         | assumpt _ = raise Impossible;
 
       (*insert type with suggested name 'varname' into table*)
-      fun insert typ varname ((t, s, n) :: xs) = 
+      fun insert typ varname ((tri as (t, s, n)) :: xs) = 
             if typ = t then (t, s, n+1) :: xs
-            else if varname = s then (t,s,n) :: (insert typ (varname ^ "'") xs)
-                                else (t,s,n) :: (insert typ varname xs)
+            else tri :: (if varname = s then insert typ (varname ^ "'") xs
+                         else insert typ varname xs)
         | insert typ varname [] = [(typ, varname, 1)];
 
-      fun insert_types (dtRek (l,id) :: ts) tab =
-            insert_types ts (insert (dtRek(l,id)) id tab)
-        | insert_types ((dtVar s) :: ts) tab =
-            insert_types ts (insert (dtVar s) (implode (tl (explode s))) tab)
-        | insert_types ((dtId s) :: ts) tab =
-            insert_types ts (insert (dtId s) s tab)
-        | insert_types (dtComp (l,id) :: ts) tab =
-            insert_types ts (insert (dtComp(l,id)) id tab)
-        | insert_types [] tab = tab;
+      fun typid(dtRek(_,id)) = id
+        | typid(dtVar s) = implode (tl (explode s))
+        | typid(dtTyp(_,id)) = id;
+
+      val insert_types = foldl (fn (tab,typ) => insert typ (typid typ) tab);
 
       fun update(dtRek _, s, v :: vs, (dtRek _) :: ts) = s :: vs
         | update(t, s, v :: vs, t1 :: ts) = 
@@ -241,19 +211,19 @@
         | update_n _ = raise Impossible;
 
       (*insert type variables into table*)
-      fun convert ((t, s, n) :: ts) var_list typ_list =
-            let val h = if n=1 then update (t, s, var_list, typ_list)
-                               else update_n (t, s, var_list, typ_list, 1)
-            in convert ts h typ_list end
-        | convert [] var_list _ = var_list;
+      fun convert typs =
+        let fun conv(vars, (t, s, n)) =
+              if n=1 then update (t, s, vars, typs)
+                     else update_n (t, s, vars, typs, 1)
+        in foldl conv end;
 
       fun empty_list n = replicate n "";
 
       fun t_inducting ((id, typl, syn) :: cs) =
             let val name = const_name id syn;
-                val tab = insert_types typl [];
+                val tab = insert_types([],typl);
                 val arity = length typl;
-                val var_list = convert tab (empty_list arity) typl; 
+                val var_list = convert typl (empty_list arity,tab); 
                 val h = if arity = 0 then " P(" ^ name ^ ")"
                         else " !!" ^ (space_implode " " var_list) ^ "." ^
                              (assumpt (typl, var_list, false)) ^ "P(" ^ 
@@ -275,7 +245,7 @@
       fun case_rules t_case arity n ((id, typlist, syn) :: cs) =
             let val name = const_name id syn;
                 val args = if null typlist then ""
-  			   else "(" ^ Args ("x", ",", 1, length typlist) ^ ")"
+  			   else pars(Args("x", ",", 1, length typlist))
             in (t_case ^ "_" ^ id,
                 t_case ^ "(" ^ name ^ args ^ "," ^ Args ("f", ",", 1, arity) 
                 ^ ") = f" ^ string_of_int(n) ^ args)
@@ -304,7 +274,7 @@
          in (dekl, rules) end;
 
       val consts = 
-        const_types cons_list
+        map const_type cons_list
 	@ (if length cons_list < dtK then []
 	   else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
 	@ case_const;
@@ -315,8 +285,7 @@
 
       (*generate 'name_1', ..., 'name_n'*)
       fun C_exp(name, n, var) =
-        if n > 0 then name ^ "(" ^ Args (var, ",", 1, n) ^ ")"
-                 else name;
+        if n > 0 then name ^ pars(Args(var, ",", 1, n)) else name;
 
       (*generate 'x_n = y_n, ..., x_m = y_m'*)
       fun Arg_eql(n,m) = 
@@ -327,57 +296,46 @@
       fun Ci_ing ((id, typlist, syn) :: cs) =
             let val name = const_name id syn;
                 val arity = length typlist;
-            in if arity > 0 
-               then ("inject_" ^ id,
+            in if arity = 0 then Ci_ing cs
+               else ("inject_" ^ id,
                      "(" ^ C_exp(name,arity,"x") ^ "=" ^ C_exp(name,arity,"y") 
                      ^ ") = (" ^ Arg_eql (1, arity) ^ ")") :: (Ci_ing cs)
-               else (Ci_ing cs)      
             end
         | Ci_ing [] = [];
 
-      fun Ci_negOne _ [] = []
-        | Ci_negOne c (c1::cs) =
-           let val (id1, tl1, syn1) = c
-               val (id2, tl2, syn2) = c1
-               val name1 = const_name id1 syn1;
+      fun Ci_negOne (id1, tl1, syn1) (id2, tl2, syn2) =
+           let val name1 = const_name id1 syn1;
                val name2 = const_name id2 syn2;
-               val arit1 = length tl1
-               val arit2 = length tl2
-               val h = "(" ^ C_exp(name1, arit1, "x") ^ "~=" ^
-                             C_exp(name2, arit2, "y") ^ ")"
-           in ("ineq_" ^ id1 ^ "_" ^ id2, h):: (Ci_negOne c cs) 
-	   end;
+               val ax = C_exp(name1, length tl1, "x") ^ "~=" ^
+                        C_exp(name2, length tl2, "y")
+           in (id1 ^ "_not_" ^ id2, ax) end;
 
       fun Ci_neg1 [] = []
-        | Ci_neg1 (c1::cs) = Ci_negOne c1 cs @ Ci_neg1 cs;
+        | Ci_neg1 (c1::cs) = (map (Ci_negOne c1) cs) @ Ci_neg1 cs;
 
       fun suc_expr n = 
         if n=0 then "0" else "Suc(" ^ suc_expr(n-1) ^ ")";
 
-      fun Ci_neg2equals (ord_t, ((id, typlist, syn) :: cs), n) =
-          let val name = const_name id syn;
-              val h = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
-                      ^ ") = " ^ (suc_expr n)
-          in (ord_t ^ (string_of_int (n+1)), h) 
-             :: (Ci_neg2equals (ord_t, cs , n+1))
-          end
-        | Ci_neg2equals (_, [], _) = [];
-
-      val Ci_neg2 =
+      fun Ci_neg2() =
         let val ord_t = tname ^ "_ord";
-        in (Ci_neg2equals (ord_t, cons_list, 0)) @
-           [(ord_t ^ "0",
-            "(" ^ ord_t ^ "(x) ~= " ^ ord_t ^ "(y)) ==> (x ~= y)")]
+            val cis = cons_list ~~ (0 upto (length cons_list - 1))
+            fun Ci_neg2equals ((id, typlist, syn), n) =
+              let val name = const_name id syn;
+                  val ax = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
+                                 ^ ") = " ^ (suc_expr n)
+              in (ord_t ^ "_" ^ id, ax) end
+        in (ord_t ^ "_distinct", ord_t^"(x) ~= "^ord_t^"(y) ==> x ~= y") ::
+           (map Ci_neg2equals cis)
         end;
 
-      val rules_ineq = if length cons_list < dtK then Ci_neg1 cons_list
-                                                 else Ci_neg2;
+      val rules_distinct = if length cons_list < dtK then Ci_neg1 cons_list
+                           else Ci_neg2();
 
       val rules_inject = Ci_ing cons_list;
 
       val rule_induct = (tname ^ "_induct", t_induct cons_list tname);
 
-      val rules = rule_induct :: (rules_inject @ rules_ineq @ rules_case);
+      val rules = rule_induct :: (rules_inject @ rules_distinct @ rules_case);
   in thy
      |> add_types types
      |> add_arities arities