src/Pure/type.ML
author wenzelm
Fri Dec 14 11:54:47 2001 +0100 (2001-12-14 ago)
changeset 12501 36b2ac65e18d
parent 12314 160013745a92
child 12528 b8bc541a4544
permissions -rw-r--r--
varify returns newly introduced variables;
     1 (*  Title:      Pure/type.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow & Lawrence C Paulson
     4 
     5 Type signatures, unification of types, interface to type inference.
     6 *)
     7 
     8 signature TYPE =
     9 sig
    10   (*TFrees and TVars*)
    11   val no_tvars: typ -> typ
    12   val varifyT: typ -> typ
    13   val unvarifyT: typ -> typ
    14   val varify: term * string list -> term * (string * indexname) list
    15   val freeze_thaw_type : typ -> typ * (typ -> typ)
    16   val freeze_thaw : term -> term * (term -> term)
    17 
    18   (*type signatures*)
    19   type type_sig
    20   val rep_tsig: type_sig ->
    21     {classes: class list,
    22      classrel: Sorts.classrel,
    23      default: sort,
    24      tycons: int Symtab.table,
    25      log_types: string list,
    26      univ_witness: (typ * sort) option,
    27      abbrs: (string list * typ) Symtab.table,
    28      arities: Sorts.arities}
    29   val classes: type_sig -> class list
    30   val defaultS: type_sig -> sort
    31   val logical_types: type_sig -> string list
    32   val univ_witness: type_sig -> (typ * sort) option
    33   val subsort: type_sig -> sort * sort -> bool
    34   val eq_sort: type_sig -> sort * sort -> bool
    35   val norm_sort: type_sig -> sort -> sort
    36   val cert_class: type_sig -> class -> class
    37   val cert_sort: type_sig -> sort -> sort
    38   val witness_sorts: type_sig -> sort list -> sort list -> (typ * sort) list
    39   val rem_sorts: typ -> typ
    40   val tsig0: type_sig
    41   val ext_tsig_classes: type_sig -> (class * class list) list -> type_sig
    42   val ext_tsig_classrel: type_sig -> (class * class) list -> type_sig
    43   val ext_tsig_defsort: type_sig -> sort -> type_sig
    44   val ext_tsig_types: type_sig -> (string * int) list -> type_sig
    45   val ext_tsig_abbrs: type_sig -> (string * string list * typ) list -> type_sig
    46   val ext_tsig_arities: type_sig -> (string * sort list * sort)list -> type_sig
    47   val merge_tsigs: type_sig * type_sig -> type_sig
    48   val typ_errors: type_sig -> typ * string list -> string list
    49   val cert_typ: type_sig -> typ -> typ
    50   val cert_typ_no_norm: type_sig -> typ -> typ
    51   val norm_typ: type_sig -> typ -> typ
    52   val norm_term: type_sig -> term -> term
    53   val inst_term_tvars: type_sig * (indexname * typ) list -> term -> term
    54   val inst_typ_tvars: type_sig * (indexname * typ) list -> typ -> typ
    55 
    56   (*type matching*)
    57   exception TYPE_MATCH
    58   val typ_match: type_sig -> typ Vartab.table * (typ * typ)
    59     -> typ Vartab.table
    60   val typ_instance: type_sig * typ * typ -> bool
    61   val of_sort: type_sig -> typ * sort -> bool
    62 
    63   (*type unification*)
    64   exception TUNIFY
    65   val unify: type_sig -> int -> typ Vartab.table -> typ * typ -> typ Vartab.table * int
    66   val raw_unify: typ * typ -> bool
    67 
    68   (*type inference*)
    69   val get_sort: type_sig -> (indexname -> sort option) -> (sort -> sort)
    70     -> (indexname * sort) list -> indexname -> sort
    71   val constrain: term -> typ -> term
    72   val param: string list -> string * sort -> typ
    73   val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
    74     -> type_sig -> (string -> typ option) -> (indexname -> typ option)
    75     -> (indexname -> sort option) -> (string -> string) -> (typ -> typ)
    76     -> (sort -> sort) -> string list -> bool -> typ list -> term list
    77     -> term list * (indexname * typ) list
    78 end;
    79 
    80 structure Type: TYPE =
    81 struct
    82 
    83 
    84 (*** TFrees and TVars ***)
    85 
    86 fun no_tvars T =
    87   (case typ_tvars T of [] => T
    88   | vs => raise TYPE ("Illegal schematic type variable(s): " ^
    89       commas (map (Syntax.string_of_vname o #1) vs), [T], []));
    90 
    91 
    92 (* varify, unvarify *)
    93 
    94 val varifyT = map_type_tfree (fn (a, S) => TVar ((a, 0), S));
    95 
    96 fun unvarifyT (Type (a, Ts)) = Type (a, map unvarifyT Ts)
    97   | unvarifyT (TVar ((a, 0), S)) = TFree (a, S)
    98   | unvarifyT T = T;
    99 
   100 fun varify (t, fixed) =
   101   let
   102     val fs = add_term_tfree_names (t, []) \\ fixed;
   103     val ixns = add_term_tvar_ixns (t, []);
   104     val fmap = fs ~~ map (rpair 0) (variantlist (fs, map #1 ixns))
   105     fun thaw (f as (a, S)) =
   106       (case assoc (fmap, a) of
   107         None => TFree f
   108       | Some b => TVar (b, S));
   109   in (map_term_types (map_type_tfree thaw) t, fmap) end;
   110 
   111 
   112 (* freeze_thaw: freeze TVars in a term; return the "thaw" inverse *)
   113 
   114 local
   115 
   116 fun new_name (ix, (pairs,used)) =
   117       let val v = variant used (string_of_indexname ix)
   118       in  ((ix,v)::pairs, v::used)  end;
   119 
   120 fun freeze_one alist (ix,sort) =
   121   TFree (the (assoc (alist, ix)), sort)
   122     handle OPTION =>
   123       raise TYPE ("Failure during freezing of ?" ^ string_of_indexname ix, [], []);
   124 
   125 fun thaw_one alist (a,sort) = TVar (the (assoc (alist,a)), sort)
   126       handle OPTION => TFree(a,sort);
   127 
   128 in
   129 
   130 (*this sort of code could replace unvarifyT*)
   131 fun freeze_thaw_type T =
   132   let
   133     val used = add_typ_tfree_names (T, [])
   134     and tvars = map #1 (add_typ_tvars (T, []));
   135     val (alist, _) = foldr new_name (tvars, ([], used));
   136   in (map_type_tvar (freeze_one alist) T, map_type_tfree (thaw_one (map swap alist))) end;
   137 
   138 fun freeze_thaw t =
   139   let
   140     val used = it_term_types add_typ_tfree_names (t, [])
   141     and tvars = map #1 (it_term_types add_typ_tvars (t, []));
   142     val (alist, _) = foldr new_name (tvars, ([], used));
   143   in
   144     (case alist of
   145       [] => (t, fn x => x) (*nothing to do!*)
   146     | _ => (map_term_types (map_type_tvar (freeze_one alist)) t,
   147       map_term_types (map_type_tfree (thaw_one (map swap alist)))))
   148   end;
   149 
   150 end;
   151 
   152 
   153 
   154 (*** type signatures ***)
   155 
   156 (* type type_sig *)
   157 
   158 (*
   159   classes: list of all declared classes;
   160   classrel: (see Pure/sorts.ML)
   161   default: default sort attached to all unconstrained type vars;
   162   tycons: table of all declared types with the number of their arguments;
   163   log_types: list of logical type constructors sorted by number of arguments;
   164   univ_witness: type witnessing non-emptiness of least sort
   165   abbrs: table of type abbreviations;
   166   arities: (see Pure/sorts.ML)
   167 *)
   168 
   169 datatype type_sig =
   170   TySg of {
   171     classes: class list,
   172     classrel: Sorts.classrel,
   173     default: sort,
   174     tycons: int Symtab.table,
   175     log_types: string list,
   176     univ_witness: (typ * sort) option,
   177     abbrs: (string list * typ) Symtab.table,
   178     arities: Sorts.arities};
   179 
   180 fun rep_tsig (TySg comps) = comps;
   181 
   182 fun classes (TySg {classes = cs, ...}) = cs;
   183 fun defaultS (TySg {default, ...}) = default;
   184 fun logical_types (TySg {log_types, ...}) = log_types;
   185 fun univ_witness (TySg {univ_witness, ...}) = univ_witness;
   186 
   187 
   188 (* error messages *)
   189 
   190 fun undeclared_class c = "Undeclared class: " ^ quote c;
   191 fun undeclared_classes cs = "Undeclared class(es): " ^ commas_quote cs;
   192 
   193 fun err_undeclared_class s = error (undeclared_class s);
   194 
   195 fun err_dup_classes cs =
   196   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   197 
   198 fun undeclared_type c = "Undeclared type constructor: " ^ quote c;
   199 
   200 fun err_neg_args c =
   201   error ("Negative number of arguments of type constructor: " ^ quote c);
   202 
   203 fun err_dup_tycon c =
   204   error ("Duplicate declaration of type constructor: " ^ quote c);
   205 
   206 fun dup_tyabbrs ts =
   207   "Duplicate declaration of type abbreviation(s): " ^ commas_quote ts;
   208 
   209 fun ty_confl c = "Conflicting type constructor and abbreviation: " ^ quote c;
   210 
   211 
   212 (* sorts *)
   213 
   214 fun subsort (TySg {classrel, ...}) = Sorts.sort_le classrel;
   215 fun eq_sort (TySg {classrel, ...}) = Sorts.sort_eq classrel;
   216 fun norm_sort (TySg {classrel, ...}) = Sorts.norm_sort classrel;
   217 
   218 fun cert_class (TySg {classes, ...}) c =
   219   if c mem_string classes then c else raise TYPE (undeclared_class c, [], []);
   220 
   221 fun cert_sort tsig S = norm_sort tsig (map (cert_class tsig) S);
   222 
   223 fun witness_sorts (tsig as TySg {classrel, arities, log_types, ...}) =
   224   Sorts.witness_sorts (classrel, arities, log_types);
   225 
   226 fun rem_sorts (Type (a, tys)) = Type (a, map rem_sorts tys)
   227   | rem_sorts (TFree (x, _)) = TFree (x, [])
   228   | rem_sorts (TVar (xi, _)) = TVar (xi, []);
   229 
   230 
   231 (* FIXME err_undeclared_class! *)
   232 (* 'leq' checks the partial order on classes according to the
   233    statements in classrel 'a'
   234 *)
   235 
   236 fun less a (C, D) = case Symtab.lookup (a, C) of
   237      Some ss => D mem_string ss
   238    | None => err_undeclared_class C;
   239 
   240 fun leq a (C, D)  =  C = D orelse less a (C, D);
   241 
   242 
   243 
   244 (* FIXME *)
   245 (*Instantiation of type variables in types*)
   246 (*Pre: instantiations obey restrictions! *)
   247 fun inst_typ tye =
   248   let fun inst(var as (v, _)) = case assoc(tye, v) of
   249                                   Some U => inst_typ tye U
   250                                 | None => TVar(var)
   251   in map_type_tvar inst end;
   252 
   253 
   254 
   255 fun of_sort (TySg {classrel, arities, ...}) = Sorts.of_sort (classrel, arities);
   256 
   257 fun check_has_sort (tsig, T, S) =
   258   if of_sort tsig (T, S) then ()
   259   else raise TYPE ("Type not of sort " ^ Sorts.str_of_sort S, [T], []);
   260 
   261 
   262 (*Instantiation of type variables in types *)
   263 fun inst_typ_tvars(tsig, tye) =
   264   let fun inst(var as (v, S)) = case assoc(tye, v) of
   265               Some U => (check_has_sort(tsig, U, S); U)
   266             | None => TVar(var)
   267   in map_type_tvar inst end;
   268 
   269 (*Instantiation of type variables in terms *)
   270 fun inst_term_tvars (_,[]) t = t
   271   | inst_term_tvars arg    t = map_term_types (inst_typ_tvars arg) t;
   272 
   273 
   274 (* norm_typ, norm_term *)
   275 
   276 (*expand abbreviations and normalize sorts*)
   277 fun norm_typ (tsig as TySg {abbrs, ...}) ty =
   278   let
   279     val idx = maxidx_of_typ ty + 1;
   280 
   281     fun norm (Type (a, Ts)) =
   282           (case Symtab.lookup (abbrs, a) of
   283             Some (vs, U) => norm (inst_typ (map (rpair idx) vs ~~ Ts) (incr_tvar idx U))
   284           | None => Type (a, map norm Ts))
   285       | norm (TFree (x, S)) = TFree (x, norm_sort tsig S)
   286       | norm (TVar (xi, S)) = TVar (xi, norm_sort tsig S);
   287 
   288     val ty' = norm ty;
   289   in if ty = ty' then ty else ty' end;  (*dumb tuning to avoid copying*)
   290 
   291 fun norm_term tsig t =
   292   let val t' = map_term_types (norm_typ tsig) t
   293   in if t = t' then t else t' end;  (*dumb tuning to avoid copying*)
   294 
   295 
   296 
   297 (** build type signatures **)
   298 
   299 fun make_tsig (classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities) =
   300   TySg {classes = classes, classrel = classrel, default = default, tycons = tycons,
   301     log_types = log_types, univ_witness = univ_witness, abbrs = abbrs, arities = arities};
   302 
   303 fun rebuild_tsig (TySg {classes, classrel, default, tycons, log_types = _, univ_witness = _, abbrs, arities}) =
   304   let
   305     fun log_class c = Sorts.class_le classrel (c, logicC);
   306     fun log_type (t, _) = exists (log_class o #1) (Symtab.lookup_multi (arities, t));
   307     val ts = filter log_type (Symtab.dest tycons);
   308 
   309     val log_types = map #1 (Library.sort (Library.int_ord o pairself #2) ts);
   310     val univ_witness =
   311       (case Sorts.witness_sorts (classrel, arities, log_types) [] [classes] of
   312         [w] => Some w | _ => None);
   313   in make_tsig (classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities) end;
   314 
   315 val tsig0 =
   316   make_tsig ([], Symtab.empty, [], Symtab.empty, [], None, Symtab.empty, Symtab.empty)
   317   |> rebuild_tsig;
   318 
   319 
   320 (* typ_errors *)
   321 
   322 (*check validity of (not necessarily normal) type; accumulate error messages*)
   323 
   324 fun typ_errors tsig (typ, errors) =
   325   let
   326     val {classes, tycons, abbrs, ...} = rep_tsig tsig;
   327 
   328     fun class_err (errs, c) =
   329       if c mem_string classes then errs
   330       else undeclared_class c ins_string errs;
   331 
   332     val sort_err = foldl class_err;
   333 
   334     fun typ_errs (errs, Type (c, Us)) =
   335           let
   336             val errs' = foldl typ_errs (errs, Us);
   337             fun nargs n =
   338               if n = length Us then errs'
   339               else ("Wrong number of arguments: " ^ quote c) ins_string errs';
   340           in
   341             (case Symtab.lookup (tycons, c) of
   342               Some n => nargs n
   343             | None =>
   344                 (case Symtab.lookup (abbrs, c) of
   345                   Some (vs, _) => nargs (length vs)
   346                 | None => undeclared_type c ins_string errs))
   347           end
   348     | typ_errs (errs, TFree (_, S)) = sort_err (errs, S)
   349     | typ_errs (errs, TVar ((x, i), S)) =
   350         if i < 0 then
   351           ("Negative index for TVar " ^ quote x) ins_string sort_err (errs, S)
   352         else sort_err (errs, S);
   353   in typ_errs (errors, typ) end;
   354 
   355 
   356 (* cert_typ *)           (*exception TYPE*)
   357 
   358 fun cert_typ_no_norm tsig T =
   359   (case typ_errors tsig (T, []) of
   360     [] => T
   361   | errs => raise TYPE (cat_lines errs, [T], []));
   362 
   363 fun cert_typ tsig T = norm_typ tsig (cert_typ_no_norm tsig T);
   364 
   365 
   366 
   367 (** merge type signatures **)
   368 
   369 (* merge classrel *)
   370 
   371 fun assoc_union (as1, []) = as1
   372   | assoc_union (as1, (key, l2) :: as2) =
   373       (case assoc_string (as1, key) of
   374         Some l1 => assoc_union (overwrite (as1, (key, l1 union_string l2)), as2)
   375       | None => assoc_union ((key, l2) :: as1, as2));
   376 
   377 fun merge_classrel (classrel1, classrel2) =
   378   let
   379     val classrel = transitive_closure (assoc_union (Symtab.dest classrel1, Symtab.dest classrel2))
   380   in
   381     if exists (op mem_string) classrel then
   382       error ("Cyclic class structure!")   (* FIXME improve msg, raise TERM *)
   383     else Symtab.make classrel
   384   end;
   385 
   386 
   387 (* coregularity *)
   388 
   389 local
   390 
   391 (* 'is_unique_decl' checks if there exists just one declaration t:(Ss)C *)
   392 
   393 fun is_unique_decl ars (t,(C,w)) = case assoc (ars, C) of
   394       Some(w1) => if w = w1 then () else
   395         error("There are two declarations\n" ^
   396               Sorts.str_of_arity(t, w, [C]) ^ " and\n" ^
   397               Sorts.str_of_arity(t, w1, [C]) ^ "\n" ^
   398               "with the same result class.")
   399     | None => ();
   400 
   401 (* 'coreg' checks if there are two declarations t:(Ss1)C1 and t:(Ss2)C2
   402    such that C1 >= C2 then Ss1 >= Ss2 (elementwise) *)
   403 
   404 fun coreg_err(t, (C1,w1), (C2,w2)) =
   405     error("Declarations " ^ Sorts.str_of_arity(t, w1, [C1]) ^ " and "
   406                           ^ Sorts.str_of_arity(t, w2, [C2]) ^ " are in conflict");
   407 
   408 fun coreg classrel (t, Cw1) =
   409   let
   410     fun check1(Cw1 as (C1,w1), Cw2 as (C2,w2)) =
   411       if leq classrel (C1,C2) then
   412         if Sorts.sorts_le classrel (w1,w2) then ()
   413         else coreg_err(t, Cw1, Cw2)
   414       else ()
   415     fun check(Cw2) = (check1(Cw1,Cw2); check1(Cw2,Cw1))
   416   in seq check end;
   417 
   418 in
   419 
   420 fun add_arity classrel ars (tCw as (_,Cw)) =
   421       (is_unique_decl ars tCw; coreg classrel tCw ars; Cw ins ars);
   422 
   423 end;
   424 
   425 
   426 (* 'merge_arities' builds the union of two 'arities' lists;
   427    it only checks the two restriction conditions and inserts afterwards
   428    all elements of the second list into the first one *)
   429 
   430 local
   431 
   432 fun merge_arities_aux classrel =
   433   let fun test_ar t (ars1, sw) = add_arity classrel ars1 (t,sw);
   434 
   435       fun merge_c (arities1, (c as (t, ars2))) = case assoc (arities1, t) of
   436           Some(ars1) =>
   437             let val ars = foldl (test_ar t) (ars1, ars2)
   438             in overwrite (arities1, (t,ars)) end
   439         | None => c::arities1
   440   in foldl merge_c end;
   441 
   442 in
   443 
   444 fun merge_arities classrel (a1, a2) =
   445   Symtab.make (merge_arities_aux classrel (Symtab.dest a1, Symtab.dest a2));
   446 
   447 end;
   448 
   449 
   450 (* tycons *)
   451 
   452 fun varying_decls t =
   453   error ("Type constructor " ^ quote t ^ " has varying number of arguments");
   454 
   455 fun add_tycons (tycons, tn as (t,n)) =
   456   (case Symtab.lookup (tycons, t) of
   457     Some m => if m = n then tycons else varying_decls t
   458   | None => Symtab.update (tn, tycons));
   459 
   460 
   461 (* merge_abbrs *)
   462 
   463 fun merge_abbrs abbrs =
   464   Symtab.merge (op =) abbrs handle Symtab.DUPS dups => raise TERM (dup_tyabbrs dups, []);
   465 
   466 
   467 (* merge_tsigs *)
   468 
   469 fun merge_tsigs
   470  (TySg {classes = classes1, default = default1, classrel = classrel1, tycons = tycons1,
   471     log_types = _, univ_witness = _, arities = arities1, abbrs = abbrs1},
   472   TySg {classes = classes2, default = default2, classrel = classrel2, tycons = tycons2,
   473     log_types = _, univ_witness = _, arities = arities2, abbrs = abbrs2}) =
   474   let
   475     val classes' = classes1 union_string classes2;
   476     val classrel' = merge_classrel (classrel1, classrel2);
   477     val arities' = merge_arities classrel' (arities1, arities2);
   478     val tycons' = foldl add_tycons (tycons1, Symtab.dest tycons2);
   479     val default' = Sorts.norm_sort classrel' (default1 @ default2);
   480     val abbrs' = merge_abbrs (abbrs1, abbrs2);
   481   in
   482     make_tsig (classes', classrel', default', tycons', [], None, abbrs', arities')
   483     |> rebuild_tsig
   484   end;
   485 
   486 
   487 
   488 (*** extend type signatures ***)
   489 
   490 (** add classes and classrel relations **)
   491 
   492 fun add_classes classes cs =
   493   (case cs inter_string classes of
   494     [] => cs @ classes
   495   | dups => err_dup_classes cs);
   496 
   497 
   498 (*'add_classrel' adds a tuple consisting of a new class (the new class has
   499   already been inserted into the 'classes' list) and its superclasses (they
   500   must be declared in 'classes' too) to the 'classrel' list of the given type
   501   signature; furthermore all inherited superclasses according to the
   502   superclasses brought with are inserted and there is a check that there are
   503   no cycles (i.e. C <= D <= C, with C <> D);*)
   504 
   505 fun add_classrel classes (classrel, (s, ges)) =
   506   let
   507     fun upd (classrel, s') =
   508       if s' mem_string classes then
   509         let val ges' = the (Symtab.lookup (classrel, s))
   510         in case Symtab.lookup (classrel, s') of
   511              Some sups => if s mem_string sups
   512                            then error(" Cycle :" ^ s^" <= "^ s'^" <= "^ s )
   513                            else Symtab.update ((s, sups union_string ges'), classrel)
   514            | None => classrel
   515         end
   516       else err_undeclared_class s'
   517   in foldl upd (Symtab.update ((s, ges), classrel), ges) end;
   518 
   519 
   520 (* 'extend_classes' inserts all new classes into the corresponding
   521    lists ('classes', 'classrel') if possible *)
   522 
   523 fun extend_classes (classes, classrel, new_classes) =
   524   let
   525     val classes' = add_classes classes (map fst new_classes);
   526     val classrel' = foldl (add_classrel classes') (classrel, new_classes);
   527   in (classes', classrel') end;
   528 
   529 
   530 (* ext_tsig_classes *)
   531 
   532 fun ext_tsig_classes tsig new_classes =
   533   let
   534     val TySg {classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities} = tsig;
   535     val (classes', classrel') = extend_classes (classes,classrel, new_classes);
   536   in
   537     make_tsig (classes', classrel', default, tycons, log_types, univ_witness, abbrs, arities)
   538     |> rebuild_tsig
   539   end;
   540 
   541 
   542 (* ext_tsig_classrel *)
   543 
   544 fun ext_tsig_classrel tsig pairs =
   545   let
   546     val TySg {classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities} = tsig;
   547     val cert = cert_class tsig;
   548 
   549     (* FIXME clean! *)
   550     val classrel' =
   551       merge_classrel (classrel, Symtab.make (map (fn (c1, c2) => (cert c1, [cert c2])) pairs));
   552   in
   553     make_tsig (classes, classrel', default, tycons, log_types, univ_witness, abbrs, arities)
   554     |> rebuild_tsig
   555   end;
   556 
   557 
   558 (* ext_tsig_defsort *)
   559 
   560 fun ext_tsig_defsort
   561     (TySg {classes, classrel, default = _, tycons, log_types, univ_witness, abbrs, arities, ...}) default =
   562   make_tsig (classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities);
   563 
   564 
   565 
   566 (** add types **)
   567 
   568 fun ext_tsig_types (TySg {classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities}) ts =
   569   let
   570     fun check_type (c, n) =
   571       if n < 0 then err_neg_args c
   572       else if is_some (Symtab.lookup (tycons, c)) then err_dup_tycon c
   573       else if is_some (Symtab.lookup (abbrs, c)) then error (ty_confl c)
   574       else ();
   575     val _ = seq check_type ts;
   576     val tycons' = Symtab.extend (tycons, ts);
   577     val arities' = Symtab.extend (arities, map (rpair [] o #1) ts);
   578   in make_tsig (classes, classrel, default, tycons', log_types, univ_witness, abbrs, arities') end;
   579 
   580 
   581 
   582 (** add type abbreviations **)
   583 
   584 fun abbr_errors tsig (a, (lhs_vs, rhs)) =
   585   let
   586     val TySg {tycons, abbrs, ...} = tsig;
   587     val rhs_vs = map (#1 o #1) (typ_tvars rhs);
   588 
   589     val dup_lhs_vars =
   590       (case duplicates lhs_vs of
   591         [] => []
   592       | vs => ["Duplicate variables on lhs: " ^ commas_quote vs]);
   593 
   594     val extra_rhs_vars =
   595       (case gen_rems (op =) (rhs_vs, lhs_vs) of
   596         [] => []
   597       | vs => ["Extra variables on rhs: " ^ commas_quote vs]);
   598 
   599     val tycon_confl =
   600       if is_none (Symtab.lookup (tycons, a)) then []
   601       else [ty_confl a];
   602 
   603     val dup_abbr =
   604       if is_none (Symtab.lookup (abbrs, a)) then []
   605       else ["Duplicate declaration of abbreviation"];
   606   in
   607     dup_lhs_vars @ extra_rhs_vars @ tycon_confl @ dup_abbr @
   608       typ_errors tsig (rhs, [])
   609   end;
   610 
   611 fun prep_abbr tsig (a, vs, raw_rhs) =
   612   let
   613     fun err msgs = (seq error_msg msgs;
   614       error ("The error(s) above occurred in type abbreviation " ^ quote a));
   615 
   616     val rhs = rem_sorts (varifyT (no_tvars raw_rhs))
   617       handle TYPE (msg, _, _) => err [msg];
   618     val abbr = (a, (vs, rhs));
   619   in
   620     (case abbr_errors tsig abbr of
   621       [] => abbr
   622     | msgs => err msgs)
   623   end;
   624 
   625 fun add_abbr
   626     (tsig as TySg {classes, classrel, default, tycons, log_types, univ_witness, arities, abbrs}, abbr) =
   627   make_tsig (classes, classrel, default, tycons, log_types, univ_witness,
   628     Symtab.update (prep_abbr tsig abbr, abbrs), arities);
   629 
   630 fun ext_tsig_abbrs tsig raw_abbrs = foldl add_abbr (tsig, raw_abbrs);
   631 
   632 
   633 
   634 (** add arities **)
   635 
   636 (* 'coregular' checks
   637    - the two restrictions 'is_unique_decl' and 'coreg'
   638    - if the classes in the new type declarations are known in the
   639      given type signature
   640    - if one type constructor has always the same number of arguments;
   641    if one type declaration has passed all checks it is inserted into
   642    the 'arities' association list of the given type signatrure  *)
   643 
   644 fun coregular (classes, classrel, tycons) =
   645   let fun ex C = if C mem_string classes then () else err_undeclared_class(C);
   646 
   647       fun addar(arities, (t, (w, C))) = case Symtab.lookup (tycons, t) of
   648             Some(n) => if n <> length w then varying_decls(t) else
   649                      ((seq o seq) ex w; ex C;
   650                       let val ars = the (Symtab.lookup (arities, t))
   651                           val ars' = add_arity classrel ars (t,(C,w))
   652                       in Symtab.update ((t,ars'), arities) end)
   653           | None => error (undeclared_type t);
   654 
   655   in addar end;
   656 
   657 
   658 (* 'close' extends the 'arities' association list after all new type
   659    declarations have been inserted successfully:
   660    for every declaration t:(Ss)C , for all classses D with C <= D:
   661       if there is no declaration t:(Ss')C' with C < C' and C' <= D
   662       then insert the declaration t:(Ss)D into 'arities'
   663    this means, if there exists a declaration t:(Ss)C and there is
   664    no declaration t:(Ss')D with C <=D then the declaration holds
   665    for all range classes more general than C *)
   666 
   667 fun close classrel arities =
   668   let fun check sl (l, (s, dom)) = case Symtab.lookup (classrel, s) of
   669           Some sups =>
   670             let fun close_sup (l, sup) =
   671                   if exists (fn s'' => less classrel (s, s'') andalso
   672                                        leq classrel (s'', sup)) sl
   673                   then l
   674                   else (sup, dom)::l
   675             in foldl close_sup (l, sups) end
   676         | None => l;
   677       fun ext (s, l) = (s, foldl (check (map #1 l)) (l, l));
   678   in map ext arities end;
   679 
   680 
   681 (* ext_tsig_arities *)
   682 
   683 fun norm_domain classrel =
   684   let fun one_min (f, (doms, ran)) = (f, (map (Sorts.norm_sort classrel) doms, ran))
   685   in map one_min end;
   686 
   687 fun ext_tsig_arities tsig sarities =
   688   let
   689     val TySg {classes, classrel, default, tycons, log_types, univ_witness, arities, abbrs} = tsig;
   690     val arities1 =
   691       flat (map (fn (t, ss, cs) => map (fn c => (t, (ss, c))) cs) sarities);
   692     val arities2 =
   693       foldl (coregular (classes, classrel, tycons)) (arities, norm_domain classrel arities1)
   694       |> Symtab.dest |> close classrel |> Symtab.make;
   695   in
   696     make_tsig (classes, classrel, default, tycons, log_types, univ_witness, abbrs, arities2)
   697     |> rebuild_tsig
   698   end;
   699 
   700 
   701 
   702 (*** type unification and friends ***)
   703 
   704 (** matching **)
   705 
   706 exception TYPE_MATCH;
   707 
   708 fun typ_match tsig =
   709   let
   710     fun match (subs, (TVar (v, S), T)) =
   711           (case Vartab.lookup (subs, v) of
   712             None => (Vartab.update_new ((v, (check_has_sort (tsig, T, S); T)), subs)
   713               handle TYPE _ => raise TYPE_MATCH)
   714           | Some U => if U = T then subs else raise TYPE_MATCH)
   715       | match (subs, (Type (a, Ts), Type (b, Us))) =
   716           if a <> b then raise TYPE_MATCH
   717           else foldl match (subs, Ts ~~ Us)
   718       | match (subs, (TFree x, TFree y)) =
   719           if x = y then subs else raise TYPE_MATCH
   720       | match _ = raise TYPE_MATCH;
   721   in match end;
   722 
   723 fun typ_instance (tsig, T, U) =
   724   (typ_match tsig (Vartab.empty, (U, T)); true) handle TYPE_MATCH => false;
   725 
   726 
   727 
   728 (** unification **)
   729 
   730 exception TUNIFY;
   731 
   732 
   733 (* occurs check *)
   734 
   735 fun occurs v tye =
   736   let
   737     fun occ (Type (_, Ts)) = exists occ Ts
   738       | occ (TFree _) = false
   739       | occ (TVar (w, _)) =
   740           eq_ix (v, w) orelse
   741             (case Vartab.lookup (tye, w) of
   742               None => false
   743             | Some U => occ U);
   744   in occ end;
   745 
   746 
   747 (* chase variable assignments *)
   748 
   749 (*if devar returns a type var then it must be unassigned*)
   750 fun devar (T as TVar (v, _), tye) =
   751       (case  Vartab.lookup (tye, v) of
   752         Some U => devar (U, tye)
   753       | None => T)
   754   | devar (T, tye) = T;
   755 
   756 
   757 (* add_env *)
   758 
   759 (*avoids chains 'a |-> 'b |-> 'c ...*)
   760 fun add_env (vT as (v, T), tab) = Vartab.update_new (vT, Vartab.map
   761   (fn (U as (TVar (w, S))) => if eq_ix (v, w) then T else U | U => U) tab);
   762 
   763 (* unify *)
   764 
   765 fun unify (tsig as TySg {classrel, arities, ...}) maxidx tyenv TU =
   766   let
   767     val tyvar_count = ref maxidx;
   768     fun gen_tyvar S = TVar (("'a", inc tyvar_count), S);
   769 
   770     fun mg_domain a S =
   771       Sorts.mg_domain (classrel, arities) a S handle Sorts.DOMAIN _ => raise TUNIFY;
   772 
   773     fun meet ((_, []), tye) = tye
   774       | meet ((TVar (xi, S'), S), tye) =
   775           if Sorts.sort_le classrel (S', S) then tye
   776           else add_env ((xi, gen_tyvar (Sorts.inter_sort classrel (S', S))), tye)
   777       | meet ((TFree (_, S'), S), tye) =
   778           if Sorts.sort_le classrel (S', S) then tye
   779           else raise TUNIFY
   780       | meet ((Type (a, Ts), S), tye) = meets ((Ts, mg_domain a S), tye)
   781     and meets (([], []), tye) = tye
   782       | meets ((T :: Ts, S :: Ss), tye) =
   783           meets ((Ts, Ss), meet ((devar (T, tye), S), tye))
   784       | meets _ = sys_error "meets";
   785 
   786     fun unif ((ty1, ty2), tye) =
   787       (case (devar (ty1, tye), devar (ty2, tye)) of
   788         (T as TVar (v, S1), U as TVar (w, S2)) =>
   789           if eq_ix (v, w) then tye
   790           else if Sorts.sort_le classrel (S1, S2) then add_env ((w, T), tye)
   791           else if Sorts.sort_le classrel (S2, S1) then add_env ((v, U), tye)
   792           else
   793             let val S = gen_tyvar (Sorts.inter_sort classrel (S1, S2)) in
   794               add_env ((v, S), add_env ((w, S), tye))
   795             end
   796       | (TVar (v, S), T) =>
   797           if occurs v tye T then raise TUNIFY
   798           else meet ((T, S), add_env ((v, T), tye))
   799       | (T, TVar (v, S)) =>
   800           if occurs v tye T then raise TUNIFY
   801           else meet ((T, S), add_env ((v, T), tye))
   802       | (Type (a, Ts), Type (b, Us)) =>
   803           if a <> b then raise TUNIFY
   804           else foldr unif (Ts ~~ Us, tye)
   805       | (T, U) => if T = U then tye else raise TUNIFY);
   806   in
   807     (unif (TU, tyenv), ! tyvar_count)
   808   end;
   809 
   810 
   811 (* raw_unify *)
   812 
   813 (*purely structural unification -- ignores sorts*)
   814 fun raw_unify (ty1, ty2) =
   815   (unify tsig0 0 Vartab.empty (rem_sorts ty1, rem_sorts ty2); true)
   816     handle TUNIFY => false;
   817 
   818 
   819 
   820 (** type inference **)
   821 
   822 (* sort constraints *)
   823 
   824 fun get_sort tsig def_sort map_sort raw_env =
   825   let
   826     fun eq ((xi, S), (xi', S')) =
   827       xi = xi' andalso eq_sort tsig (S, S');
   828 
   829     val env = gen_distinct eq (map (apsnd map_sort) raw_env);
   830     val _ =
   831       (case gen_duplicates eq_fst env of
   832         [] => ()
   833       | dups => error ("Inconsistent sort constraints for type variable(s) " ^
   834           commas (map (quote o Syntax.string_of_vname' o fst) dups)));
   835 
   836     fun get xi =
   837       (case (assoc (env, xi), def_sort xi) of
   838         (None, None) => defaultS tsig
   839       | (None, Some S) => S
   840       | (Some S, None) => S
   841       | (Some S, Some S') =>
   842           if eq_sort tsig (S, S') then S'
   843           else error ("Sort constraint inconsistent with default for type variable " ^
   844             quote (Syntax.string_of_vname' xi)));
   845   in get end;
   846 
   847 
   848 (* type constraints *)
   849 
   850 fun constrain t T =
   851   if T = dummyT then t
   852   else Const ("_type_constraint_", T) $ t;
   853 
   854 
   855 (* user parameters *)
   856 
   857 fun is_param (x, _) = size x > 0 andalso ord x = ord "?";
   858 fun param used (x, S) = TVar ((variant used ("?" ^ x), 0), S);
   859 
   860 
   861 (* decode_types *)
   862 
   863 (*transform parse tree into raw term*)
   864 fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
   865   let
   866     fun get_type xi = if_none (def_type xi) dummyT;
   867     fun is_free x = is_some (def_type (x, ~1));
   868     val raw_env = Syntax.raw_term_sorts tm;
   869     val sort_of = get_sort tsig def_sort map_sort raw_env;
   870 
   871     val certT = cert_typ tsig o map_type;
   872     fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
   873 
   874     fun decode (Const ("_constrain", _) $ t $ typ) =
   875           constrain (decode t) (decodeT typ)
   876       | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) =
   877           if T = dummyT then Abs (x, decodeT typ, decode t)
   878           else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT)
   879       | decode (Abs (x, T, t)) = Abs (x, certT T, decode t)
   880       | decode (t $ u) = decode t $ decode u
   881       | decode (Free (x, T)) =
   882           let val c = map_const x in
   883             if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then
   884               Const (c, certT T)
   885             else if T = dummyT then Free (x, get_type (x, ~1))
   886             else constrain (Free (x, certT T)) (get_type (x, ~1))
   887           end
   888       | decode (Var (xi, T)) =
   889           if T = dummyT then Var (xi, get_type xi)
   890           else constrain (Var (xi, certT T)) (get_type xi)
   891       | decode (t as Bound _) = t
   892       | decode (Const (c, T)) = Const (map_const c, certT T);
   893   in decode tm end;
   894 
   895 
   896 (* infer_types *)
   897 
   898 (*
   899   Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti
   900   unifies with Ti (for i=1,...,n).
   901 
   902   tsig: type signature
   903   const_type: name mapping and signature lookup
   904   def_type: partial map from indexnames to types (constrains Frees, Vars)
   905   def_sort: partial map from indexnames to sorts (constrains TFrees, TVars)
   906   used: list of already used type variables
   907   freeze: if true then generated parameters are turned into TFrees, else TVars
   908 *)
   909 
   910 fun infer_types prt prT tsig const_type def_type def_sort
   911     map_const map_type map_sort used freeze pat_Ts raw_ts =
   912   let
   913     val TySg {classrel, arities, ...} = tsig;
   914     val pat_Ts' = map (cert_typ tsig) pat_Ts;
   915     val is_const = is_some o const_type;
   916     val raw_ts' =
   917       map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
   918     val (ts, Ts, unifier) =
   919       TypeInfer.infer_types prt prT const_type classrel arities used freeze
   920         is_param raw_ts' pat_Ts';
   921   in (ts, unifier) end;
   922 
   923 
   924 end;