simplified Type_Infer: eliminated separate datatypes pretyp/preterm -- only assign is_paramT TVars;
authorwenzelm
Mon, 13 Sep 2010 11:35:55 +0200
changeset 39294 27fae73fe769
parent 39293 651e5a3e8cfd
child 39295 6e8b0672c6a2
simplified Type_Infer: eliminated separate datatypes pretyp/preterm -- only assign is_paramT TVars;
src/Pure/type_infer.ML
--- a/src/Pure/type_infer.ML	Mon Sep 13 00:10:29 2010 +0200
+++ b/src/Pure/type_infer.ML	Mon Sep 13 11:35:55 2010 +0200
@@ -8,6 +8,7 @@
 sig
   val anyT: sort -> typ
   val is_param: indexname -> bool
+  val is_paramT: typ -> bool
   val param: int -> string * sort -> typ
   val paramify_vars: typ -> typ
   val paramify_dummies: typ -> int -> typ * int
@@ -28,8 +29,14 @@
 (* type inference parameters -- may get instantiated *)
 
 fun is_param (x, _: int) = String.isPrefix "?" x;
+
+fun is_paramT (TVar (xi, _)) = is_param xi
+  | is_paramT _ = false;
+
 fun param i (x, S) = TVar (("?" ^ x, i), S);
 
+fun mk_param i S = TVar (("?'a", i), S);
+
 val paramify_vars =
   Same.commit
     (Term_Subst.map_atypsT_same
@@ -62,76 +69,42 @@
 
 
 
-(** pretyps and preterms **)
-
-datatype pretyp =
-  PType of string * pretyp list |
-  PTFree of string * sort |
-  PTVar of indexname * sort |
-  Param of int * sort;
-
-datatype preterm =
-  PConst of string * pretyp |
-  PFree of string * pretyp |
-  PVar of indexname * pretyp |
-  PBound of int |
-  PAbs of string * pretyp * preterm |
-  PAppl of preterm * preterm;
-
-
-(* utils *)
+(** prepare types/terms: create inference parameters **)
 
-fun deref tye (T as Param (i, S)) =
-      (case Inttab.lookup tye i of
-        NONE => T
-      | SOME U => deref tye U)
-  | deref tye T = T;
+(* prepare_typ *)
 
-fun fold_pretyps f (PConst (_, T)) x = f T x
-  | fold_pretyps f (PFree (_, T)) x = f T x
-  | fold_pretyps f (PVar (_, T)) x = f T x
-  | fold_pretyps _ (PBound _) x = x
-  | fold_pretyps f (PAbs (_, T, t)) x = fold_pretyps f t (f T x)
-  | fold_pretyps f (PAppl (t, u)) x = fold_pretyps f u (fold_pretyps f t x);
-
-
-
-(** raw typs/terms to pretyps/preterms **)
-
-(* pretyp_of *)
-
-fun pretyp_of typ params_idx =
+fun prepare_typ typ params_idx =
   let
     val (params', idx) = fold_atyps
       (fn TVar (xi as (x, _), S) =>
           (fn ps_idx as (ps, idx) =>
             if is_param xi andalso not (Vartab.defined ps xi)
-            then (Vartab.update (xi, Param (idx, S)) ps, idx + 1) else ps_idx)
+            then (Vartab.update (xi, mk_param idx S) ps, idx + 1) else ps_idx)
         | _ => I) typ params_idx;
 
-    fun pre_of (TVar (v as (xi, _))) idx =
+    fun prepare (T as Type (a, Ts)) idx =
+          if T = dummyT then (mk_param idx [], idx + 1)
+          else
+            let val (Ts', idx') = fold_map prepare Ts idx
+            in (Type (a, Ts'), idx') end
+      | prepare (T as TVar (xi, _)) idx =
           (case Vartab.lookup params' xi of
-            NONE => PTVar v
+            NONE => T
           | SOME p => p, idx)
-      | pre_of (TFree ("'_dummy_", S)) idx = (Param (idx, S), idx + 1)
-      | pre_of (TFree v) idx = (PTFree v, idx)
-      | pre_of (T as Type (a, Ts)) idx =
-          if T = dummyT then (Param (idx, []), idx + 1)
-          else
-            let val (Ts', idx') = fold_map pre_of Ts idx
-            in (PType (a, Ts'), idx') end;
+      | prepare (TFree ("'_dummy_", S)) idx = (mk_param idx S, idx + 1)
+      | prepare (T as TFree _) idx = (T, idx);
 
-    val (ptyp, idx') = pre_of typ idx;
-  in (ptyp, (params', idx')) end;
+    val (typ', idx') = prepare typ idx;
+  in (typ', (params', idx')) end;
 
 
-(* preterm_of *)
+(* prepare_term *)
 
-fun preterm_of const_type tm (vparams, params, idx) =
+fun prepare_term const_type tm (vparams, params, idx) =
   let
     fun add_vparm xi (ps_idx as (ps, idx)) =
       if not (Vartab.defined ps xi) then
-        (Vartab.update (xi, Param (idx, [])) ps, idx + 1)
+        (Vartab.update (xi, mk_param idx []) ps, idx + 1)
       else ps_idx;
 
     val (vparams', idx') = fold_aterms
@@ -142,109 +115,96 @@
       tm (vparams, idx);
     fun var_param xi = the (Vartab.lookup vparams' xi);
 
-    fun polyT_of T idx = apsnd snd (pretyp_of (paramify_vars T) (Vartab.empty, idx));
+    fun polyT_of T idx = apsnd snd (prepare_typ (paramify_vars T) (Vartab.empty, idx));
 
     fun constraint T t ps =
       if T = dummyT then (t, ps)
       else
-        let val (T', ps') = pretyp_of T ps
-        in (PAppl (PConst ("_type_constraint_", PType ("fun", [T', T'])), t), ps') end;
+        let val (T', ps') = prepare_typ T ps
+        in (Type.constraint T' t, ps') end;
 
-    fun pre_of (Const (c, T)) (ps, idx) =
+    fun prepare (Const ("_type_constraint_", T) $ t) ps_idx =
+          let
+            val (T', ps_idx') = prepare_typ T ps_idx;
+            val (t', ps_idx'') = prepare t ps_idx';
+          in (Const ("_type_constraint_", T') $ t', ps_idx'') end
+      | prepare (Const (c, T)) (ps, idx) =
           (case const_type c of
             SOME U =>
-              let val (pU, idx') = polyT_of U idx
-              in constraint T (PConst (c, pU)) (ps, idx') end
+              let val (U', idx') = polyT_of U idx
+              in constraint T (Const (c, U')) (ps, idx') end
           | NONE => error ("Undeclared constant: " ^ quote c))
-      | pre_of (Const ("_type_constraint_", T) $ t) ps_idx =
+      | prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
+          let val (T', idx') = polyT_of T idx
+          in (Var (xi, T'), (ps, idx')) end
+      | prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx
+      | prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx
+      | prepare (Bound i) ps_idx = (Bound i, ps_idx)
+      | prepare (Abs (x, T, t)) ps_idx =
           let
-            val (T', ps_idx') = pretyp_of T ps_idx;
-            val (t', ps_idx'') = pre_of t ps_idx';
-          in (PAppl (PConst ("_type_constraint_", T'), t'), ps_idx'') end
-      | pre_of (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
-          let val (pT, idx') = polyT_of T idx
-          in (PVar (xi, pT), (ps, idx')) end
-      | pre_of (Var (xi, T)) ps_idx = constraint T (PVar (xi, var_param xi)) ps_idx
-      | pre_of (Free (x, T)) ps_idx = constraint T (PFree (x, var_param (x, ~1))) ps_idx
-      | pre_of (Bound i) ps_idx = (PBound i, ps_idx)
-      | pre_of (Abs (x, T, t)) ps_idx =
+            val (T', ps_idx') = prepare_typ T ps_idx;
+            val (t', ps_idx'') = prepare t ps_idx';
+          in (Abs (x, T', t'), ps_idx'') end
+      | prepare (t $ u) ps_idx =
           let
-            val (T', ps_idx') = pretyp_of T ps_idx;
-            val (t', ps_idx'') = pre_of t ps_idx';
-          in (PAbs (x, T', t'), ps_idx'') end
-      | pre_of (t $ u) ps_idx =
-          let
-            val (t', ps_idx') = pre_of t ps_idx;
-            val (u', ps_idx'') = pre_of u ps_idx';
-          in (PAppl (t', u'), ps_idx'') end;
+            val (t', ps_idx') = prepare t ps_idx;
+            val (u', ps_idx'') = prepare u ps_idx';
+          in (t' $ u', ps_idx'') end;
 
-    val (tm', (params', idx'')) = pre_of tm (params, idx');
+    val (tm', (params', idx'')) = prepare tm (params, idx');
   in (tm', (vparams', params', idx'')) end;
 
 
 
-(** pretyps/terms to typs/terms **)
+(** finish types/terms: standardize remaining parameters **)
 
-(* add_parms *)
+(* dereferenced views *)
 
-fun add_parmsT tye T =
+fun deref tye (T as TVar (xi, _)) =
+      (case Vartab.lookup tye xi of
+        NONE => T
+      | SOME U => deref tye U)
+  | deref tye T = T;
+
+fun add_parms tye T =
   (case deref tye T of
-    PType (_, Ts) => fold (add_parmsT tye) Ts
-  | Param (i, _) => insert (op =) i
+    Type (_, Ts) => fold (add_parms tye) Ts
+  | TVar (xi, _) => if is_param xi then insert (op =) xi else I
   | _ => I);
 
-fun add_parms tye = fold_pretyps (add_parmsT tye);
-
-
-(* add_names *)
-
-fun add_namesT tye T =
+fun add_names tye T =
   (case deref tye T of
-    PType (_, Ts) => fold (add_namesT tye) Ts
-  | PTFree (x, _) => Name.declare x
-  | PTVar ((x, _), _) => Name.declare x
-  | Param _ => I);
-
-fun add_names tye = fold_pretyps (add_namesT tye);
+    Type (_, Ts) => fold (add_names tye) Ts
+  | TFree (x, _) => Name.declare x
+  | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x);
 
 
-(* simple_typ/term_of *)
-
-fun simple_typ_of tye f T =
-  (case deref tye T of
-    PType (a, Ts) => Type (a, map (simple_typ_of tye f) Ts)
-  | PTFree v => TFree v
-  | PTVar v => TVar v
-  | Param (i, S) => TVar (f i, S));
+(* finish *)
 
-fun simple_term_of tye f (PConst (c, T)) = Const (c, simple_typ_of tye f T)
-  | simple_term_of tye f (PFree (x, T)) = Free (x, simple_typ_of tye f T)
-  | simple_term_of tye f (PVar (xi, T)) = Var (xi, simple_typ_of tye f T)
-  | simple_term_of tye f (PBound i) = Bound i
-  | simple_term_of tye f (PAbs (x, T, t)) =
-      Abs (x, simple_typ_of tye f T, simple_term_of tye f t)
-  | simple_term_of tye f (PAppl (t, u)) =
-      simple_term_of tye f t $ simple_term_of tye f u;
-
+fun finish ctxt tye (Ts, ts) =
+  let
+    val used =
+      (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt));
+    val parms = rev ((fold o fold_types) (add_parms tye) ts (fold (add_parms tye) Ts []));
+    val names = Name.invents used ("?" ^ Name.aT) (length parms);
+    val tab = Vartab.make (parms ~~ names);
+    val idx = Variable.maxidx_of ctxt + 1;
 
-(* typs_terms_of *)
-
-fun typs_terms_of ctxt tye (Ts, ts) =
-  let
-    val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
-    val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
-    val names = Name.invents used ("?" ^ Name.aT) (length parms);
-    val tab = Inttab.make (parms ~~ names);
-
-    val maxidx = Variable.maxidx_of ctxt;
-    fun f i = (the (Inttab.lookup tab i), maxidx + 1);
-  in (map (simple_typ_of tye f) Ts, map (Type.strip_constraints o simple_term_of tye f) ts) end;
+    fun finish_typ T =
+      (case deref tye T of
+        Type (a, Ts) => Type (a, map finish_typ Ts)
+      | U as TFree _ => U
+      | U as TVar (xi, S) =>
+          (case Vartab.lookup tab xi of
+            NONE => U
+          | SOME a => TVar ((a, idx), S)));
+  in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end;
 
 
 
 (** order-sorted unification of types **)
 
-exception NO_UNIFIER of string * pretyp Inttab.table;
+exception NO_UNIFIER of string * typ Vartab.table;
 
 fun unify ctxt pp =
   let
@@ -259,17 +219,15 @@
         Syntax.string_of_sort ctxt S;
 
     fun meet (_, []) tye_idx = tye_idx
-      | meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
-          if Sign.subsort thy (S', S) then tye_idx
-          else (Inttab.update_new (i,
-            Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
-      | meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
+      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
           meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
-      | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
+      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
           if Sign.subsort thy (S', S) then tye_idx
           else raise NO_UNIFIER (not_of_sort x S' S, tye)
-      | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
+      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
           if Sign.subsort thy (S', S) then tye_idx
+          else if is_param xi then
+            (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
@@ -278,20 +236,20 @@
 
     (* occurs check and assignment *)
 
-    fun occurs_check tye i (Param (i', S)) =
-          if i = i' then raise NO_UNIFIER ("Occurs check!", tye)
+    fun occurs_check tye xi (TVar (xi', S)) =
+          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
           else
-            (case Inttab.lookup tye i' of
+            (case Vartab.lookup tye xi' of
               NONE => ()
-            | SOME T => occurs_check tye i T)
-      | occurs_check tye i (PType (_, Ts)) = List.app (occurs_check tye i) Ts
+            | SOME T => occurs_check tye xi T)
+      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
       | occurs_check _ _ _ = ();
 
-    fun assign i (T as Param (i', _)) S tye_idx =
-          if i = i' then tye_idx
-          else tye_idx |> meet (T, S) |>> Inttab.update_new (i, T)
-      | assign i T S (tye_idx as (tye, _)) =
-          (occurs_check tye i T; tye_idx |> meet (T, S) |>> Inttab.update_new (i, T));
+    fun assign xi (T as TVar (xi', _)) S env =
+          if xi = xi' then env
+          else env |> meet (T, S) |>> Vartab.update_new (xi, T)
+      | assign xi T S (env as (tye, _)) =
+          (occurs_check tye xi T; env |> meet (T, S) |>> Vartab.update_new (xi, T));
 
 
     (* unification *)
@@ -299,16 +257,16 @@
     fun show_tycon (a, Ts) =
       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
 
-    fun unif (T1, T2) (tye_idx as (tye, idx)) =
-      (case (deref tye T1, deref tye T2) of
-        (Param (i, S), T) => assign i T S tye_idx
-      | (T, Param (i, S)) => assign i T S tye_idx
-      | (PType (a, Ts), PType (b, Us)) =>
+    fun unif (T1, T2) (env as (tye, _)) =
+      (case pairself (`is_paramT o deref tye) (T1, T2) of
+        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
+      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
+      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
           if a <> b then
             raise NO_UNIFIER
               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
-          else fold unif (Ts ~~ Us) tye_idx
-      | (T, U) => if T = U then tye_idx else raise NO_UNIFIER ("", tye));
+          else fold unif (Ts ~~ Us) env
+      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
 
   in unif end;
 
@@ -327,7 +285,7 @@
 
     fun prep_output tye bs ts Ts =
       let
-        val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
+        val (Ts_bTs', ts') = finish ctxt tye (Ts @ map snd bs, ts);
         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
         fun prep t =
           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
@@ -346,21 +304,20 @@
 
     (* main *)
 
-    fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
-      | inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
-      | inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
-      | inf bs (PBound i) tye_idx =
+    fun inf _ (Const (_, T)) tye_idx = (T, tye_idx)
+      | inf _ (Free (_, T)) tye_idx = (T, tye_idx)
+      | inf _ (Var (_, T)) tye_idx = (T, tye_idx)
+      | inf bs (Bound i) tye_idx =
           (snd (nth bs i handle Subscript => err_loose i), tye_idx)
-      | inf bs (PAbs (x, T, t)) tye_idx =
+      | inf bs (Abs (x, T, t)) tye_idx =
           let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
-          in (PType ("fun", [T, U]), tye_idx') end
-      | inf bs (PAppl (t, u)) tye_idx =
+          in (T --> U, tye_idx') end
+      | inf bs (t $ u) tye_idx =
           let
             val (T, tye_idx') = inf bs t tye_idx;
             val (U, (tye, idx)) = inf bs u tye_idx';
-            val V = Param (idx, []);
-            val U_to_V = PType ("fun", [U, V]);
-            val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
+            val V = mk_param idx [];
+            val tye_idx'' = unify ctxt pp (U --> V, T) (tye, idx + 1)
               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
           in (V, tye_idx'') end;
 
@@ -381,11 +338,11 @@
     (*convert to preterms*)
     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
     val (ts', (_, _, idx)) =
-      fold_map (preterm_of const_type o constrain_vars) ts
+      fold_map (prepare_term const_type o constrain_vars) ts
       (Vartab.empty, Vartab.empty, 0);
 
     (*do type inference*)
-    val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
-  in #2 (typs_terms_of ctxt tye ([], ts')) end;
+    val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx);
+  in #2 (finish ctxt tye ([], ts')) end;
 
 end;