typ_match, unify: canonical argument order;
authorwenzelm
Thu, 28 Jul 2005 15:20:03 +0200
changeset 16946 7f9a7fe413f3
parent 16945 5d3ae25673a8
child 16947 c6a90f04924e
typ_match, unify: canonical argument order; added raw_match, raw_instance; proper implementation of raw_unify;
src/Pure/type.ML
--- a/src/Pure/type.ML	Thu Jul 28 15:20:02 2005 +0200
+++ b/src/Pure/type.ML	Thu Jul 28 15:20:03 2005 +0200
@@ -51,11 +51,13 @@
   exception TYPE_MATCH
   type tyenv
   val lookup: tyenv * (indexname * sort) -> typ option
-  val typ_match: tsig -> tyenv * (typ * typ) -> tyenv
+  val typ_match: tsig -> typ * typ -> tyenv -> tyenv
   val typ_instance: tsig -> typ * typ -> bool
+  val raw_match: typ * typ -> tyenv -> tyenv
+  val raw_instance: typ * typ -> bool
   exception TUNIFY
-  val unify: tsig -> tyenv * int -> typ * typ -> tyenv * int
-  val raw_unify: typ * typ -> bool
+  val unify: tsig -> typ * typ -> tyenv * int -> tyenv * int
+  val raw_unify: typ * typ -> tyenv -> tyenv
   val eq_type: tyenv -> typ * typ -> bool
 
   (*extend and merge type signatures*)
@@ -213,10 +215,7 @@
 (* varify, unvarify *)
 
 val varifyT = map_type_tfree (fn (a, S) => TVar ((a, 0), S));
-
-fun unvarifyT (Type (a, Ts)) = Type (a, map unvarifyT Ts)
-  | unvarifyT (TVar ((a, 0), S)) = TFree (a, S)
-  | unvarifyT T = T;
+val unvarifyT = map_type_tvar (fn ((a, 0), S) => TFree (a, S) | v => TVar v);
 
 fun varify (t, fixed) =
   let
@@ -224,9 +223,9 @@
     val ixns = add_term_tvar_ixns (t, []);
     val fmap = fs ~~ map (rpair 0) (variantlist (map fst fs, map #1 ixns))
     fun thaw (f as (a, S)) =
-      (case assoc (fmap, f) of
+      (case gen_assoc (op =) (fmap, f) of
         NONE => TFree f
-      | SOME b => TVar (b, S));
+      | SOME xi => TVar (xi, S));
   in (map_term_types (map_type_tfree thaw) t, fmap) end;
 
 
@@ -294,7 +293,7 @@
 
 exception TYPE_MATCH;
 
-fun typ_match tsig (tyenv, TU) =
+fun typ_match tsig =
   let
     fun match (TVar (v, S), T) subs =
           (case lookup (subs, (v, S)) of
@@ -310,10 +309,27 @@
       | match _ _ = raise TYPE_MATCH
     and matches (T :: Ts, U :: Us) subs = matches (Ts, Us) (match (T, U) subs)
       | matches _ subs = subs;
-  in match TU tyenv end;
+  in match end;
 
 fun typ_instance tsig (T, U) =
-  (typ_match tsig (Vartab.empty, (U, T)); true) handle TYPE_MATCH => false;
+  (typ_match tsig (U, T) Vartab.empty; true) handle TYPE_MATCH => false;
+
+(*purely structural matching*)
+fun raw_match (TVar (v, S), T) subs =
+      (case lookup (subs, (v, S)) of
+        NONE => Vartab.update_new ((v, (S, T)), subs)
+      | SOME U => if U = T then subs else raise TYPE_MATCH)
+  | raw_match (Type (a, Ts), Type (b, Us)) subs =
+      if a <> b then raise TYPE_MATCH
+      else raw_matches (Ts, Us) subs
+  | raw_match (TFree x, TFree y) subs =
+      if x = y then subs else raise TYPE_MATCH
+  | raw_match _ _ = raise TYPE_MATCH
+and raw_matches (T :: Ts, U :: Us) subs = raw_matches (Ts, Us) (raw_match (T, U) subs)
+  | raw_matches _ subs = subs;
+
+fun raw_instance (T, U) =
+  (raw_match (U, T) Vartab.empty; true) handle TYPE_MATCH => false;
 
 
 (* unification *)
@@ -339,7 +355,7 @@
       | NONE => T)
   | devar tye T = T;
 
-fun unify (tsig as TSig {classes = (_, classes), arities, ...}) (tyenv, maxidx) TU =
+fun unify (tsig as TSig {classes = (_, classes), arities, ...}) TU (tyenv, maxidx) =
   let
     val tyvar_count = ref maxidx;
     fun gen_tyvar S = TVar (("'a", inc tyvar_count), S);
@@ -386,10 +402,26 @@
       | unifs _ tye = tye;
   in (unif TU tyenv, ! tyvar_count) end;
 
-(*purely structural unification *)
-fun raw_unify (ty1, ty2) =
-  (unify empty_tsig (Vartab.empty, 0) (strip_sorts ty1, strip_sorts ty2); true)
-    handle TUNIFY => false;
+(*purely structural unification*)
+fun raw_unify (ty1, ty2) tye =
+  (case (devar tye ty1, devar tye ty2) of
+    (T as TVar (v, S1), U as TVar (w, S2)) =>
+      if eq_ix (v, w) then
+        if S1 = S2 then tye else tvar_clash v S1 S2
+      else Vartab.update_new ((w, (S2, T)), tye)
+  | (TVar (v, S), T) =>
+      if occurs v tye T then raise TUNIFY
+      else Vartab.update_new ((v, (S, T)), tye)
+  | (T, TVar (v, S)) =>
+      if occurs v tye T then raise TUNIFY
+      else Vartab.update_new ((v, (S, T)), tye)
+  | (Type (a, Ts), Type (b, Us)) =>
+      if a <> b then raise TUNIFY
+      else raw_unifys (Ts, Us) tye
+  | (T, U) => if T = U then tye else raise TUNIFY)
+and raw_unifys (T :: Ts, U :: Us) tye = raw_unifys (Ts, Us) (raw_unify (T, U) tye)
+  | raw_unifys _ tye = tye;
+
 
 (*check whether two types are equal with respect to a type environment*)
 fun eq_type tye (T, T') =
@@ -527,7 +559,7 @@
       Graph.merge_trans_acyclic (op =) (classes1, classes2)
         handle Graph.DUPS cs => err_dup_classes cs
           | Graph.CYCLES css => err_cyclic_classes pp css;
-  in (space, classes) end;    
+  in (space, classes) end;
 
 end;
 
@@ -578,7 +610,7 @@
   let
     fun err msg =
       error (msg ^ "\nThe error(s) above occurred in type abbreviation: " ^ quote a);
-    val rhs' = compress_type (strip_sorts (no_tvars (cert_typ_syntax tsig rhs)))
+    val rhs' = strip_sorts (no_tvars (cert_typ_syntax tsig rhs))
       handle TYPE (msg, _, _) => err msg;
   in
     (case duplicates vs of