local coercion insertion algorithm to support complex coercions
authortraytel
Wed, 17 Aug 2011 19:49:07 +0200
changeset 45060 9c2568c0a504
parent 45059 28d3e387f22e
child 45061 39519609abe0
local coercion insertion algorithm to support complex coercions
src/Tools/subtyping.ML
--- a/src/Tools/subtyping.ML	Wed Aug 17 19:49:07 2011 +0200
+++ b/src/Tools/subtyping.ML	Wed Aug 17 19:49:07 2011 +0200
@@ -22,46 +22,52 @@
 datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
 
 datatype data = Data of
-  {coes: (term * term list) Symreltab.table,  (*coercions table*)
-   coes_graph: unit Graph.T,  (*coercions graph*)
+  {coes: (term * ((typ list * typ list) * term list)) Symreltab.table,  (*coercions table*)
+   (*full coercions graph - only used at coercion declaration/deletion*)
+   full_graph: int Graph.T,
+   (*coercions graph restricted to base types - for efficiency reasons strored in the context*)
+   coes_graph: int Graph.T,
    tmaps: (term * variance list) Symtab.table};  (*map functions*)
 
-fun make_data (coes, coes_graph, tmaps) =
-  Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
+fun make_data (coes, full_graph, coes_graph, tmaps) =
+  Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph, tmaps = tmaps};
 
 structure Data = Generic_Data
 (
   type T = data;
-  val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
+  val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty);
   val extend = I;
   fun merge
-    (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
-      Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
-    make_data (Symreltab.merge (eq_pair (op aconv) (eq_list (op aconv))) (coes1, coes2),
+    (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1, tmaps = tmaps1},
+      Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2, tmaps = tmaps2}) =
+    make_data (Symreltab.merge (eq_pair (op aconv)
+        (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv))))
+        (coes1, coes2),
+      Graph.merge (op =) (full_graph1, full_graph2),
       Graph.merge (op =) (coes_graph1, coes_graph2),
       Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
 );
 
 fun map_data f =
-  Data.map (fn Data {coes, coes_graph, tmaps} =>
-    make_data (f (coes, coes_graph, tmaps)));
+  Data.map (fn Data {coes, full_graph, coes_graph, tmaps} =>
+    make_data (f (coes, full_graph, coes_graph, tmaps)));
 
 fun map_coes f =
-  map_data (fn (coes, coes_graph, tmaps) =>
-    (f coes, coes_graph, tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+    (f coes, full_graph, coes_graph, tmaps));
 
 fun map_coes_graph f =
-  map_data (fn (coes, coes_graph, tmaps) =>
-    (coes, f coes_graph, tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+    (coes, full_graph, f coes_graph, tmaps));
 
-fun map_coes_and_graph f =
-  map_data (fn (coes, coes_graph, tmaps) =>
-    let val (coes', coes_graph') = f (coes, coes_graph);
-    in (coes', coes_graph', tmaps) end);
+fun map_coes_and_graphs f =
+  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+    let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph);
+    in (coes', full_graph', coes_graph', tmaps) end);
 
 fun map_tmaps f =
-  map_data (fn (coes, coes_graph, tmaps) =>
-    (coes, coes_graph, f tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+    (coes, full_graph, coes_graph, f tmaps));
 
 val rep_data = (fn Data args => args) o Data.get o Context.Proof;
 
@@ -73,6 +79,9 @@
 
 (** utils **)
 
+fun restrict_graph G =
+  Graph.subgraph (fn key => if Graph.get_node G key = 0 then true else false) G;
+
 fun nameT (Type (s, [])) = s;
 fun t_of s = Type (s, []);
 
@@ -88,10 +97,23 @@
 val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
 val is_identity = fn (Abs (_, _, Bound 0)) => true | _ => false;
 
+fun instantiate t Ts = Term.subst_TVars
+  ((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts) t;
+
+exception COERCION_GEN_ERROR of unit -> string;
+
+fun inst_collect tye err T U =
+  (case (T, Type_Infer.deref tye U) of
+    (TVar (xi, S), U) => [(xi, U)]
+  | (Type (a, Ts), Type (b, Us)) =>
+      if a <> b then raise error (err ()) else inst_collects tye err Ts Us
+  | (_, U') => if T <> U' then error (err ()) else [])
+and inst_collects tye err Ts Us =
+  fold2 (fn T => fn U => fn is => inst_collect tye err T U @ is) Ts Us [];
+
 
 (* unification *)
 
-exception TYPE_INFERENCE_ERROR of unit -> string;
 exception NO_UNIFIER of string * typ Vartab.table;
 
 fun unify weak ctxt =
@@ -180,16 +202,20 @@
 
 (* Graph shortcuts *)
 
-fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
+fun maybe_new_node s G = perhaps (try (Graph.new_node s)) G
 fun maybe_new_nodes ss G = fold maybe_new_node ss G
 
 
 
 (** error messages **)
 
+infixr ++> (* lazy error msg composition *)
+
+fun err ++> str = err #> suffix str
+
 fun gen_msg err msg =
-  err () ^ "\nNow trying to infer coercions:\n\nCoercion inference failed" ^
-  (if msg = "" then "" else ": " ^ msg) ^ "\n";
+  err () ^ "\nNow trying to infer coercions globally.\n\nCoercion inference failed" ^
+  (if msg = "" then "" else ":\n" ^ msg) ^ "\n";
 
 fun prep_output ctxt tye bs ts Ts =
   let
@@ -213,7 +239,7 @@
   let
     val (_, Ts') = prep_output ctxt tye [] [] Ts;
     val text =
-      msg ^ "\n" ^ "Cannot unify a list of types that should be the same:" ^ "\n" ^
+      msg ^ "\nCannot unify a list of types that should be the same:\n" ^
         Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts'));
   in
     error text
@@ -226,13 +252,14 @@
         let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
         in (t' :: ts, T' :: Ts) end)
       packs ([], []);
-    val text = cat_lines ([msg, "Cannot fulfil subtype constraints:"] @
-        (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
+    val text = msg ^ "\n" ^ Pretty.string_of (
+        Pretty.big_list "Cannot fulfil subtype constraints:"
+        (map2 (fn [t, u] => fn [T, U] =>
           Pretty.block [
             Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2,
             Syntax.pretty_typ ctxt U, Pretty.brk 3,
             Pretty.str "from function application", Pretty.brk 2,
-            Pretty.block [Syntax.pretty_term ctxt (t $ u)]]))
+            Pretty.block [Syntax.pretty_term ctxt (t $ u)]])
         ts Ts))
   in
     error text
@@ -258,10 +285,11 @@
             val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
             val U = Type_Infer.mk_param idx [];
             val V = Type_Infer.mk_param (idx + 1) [];
-            val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
+            val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2)
               handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
             val error_pack = (bs, t $ u, U, V, U');
-          in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
+          in (V, (tye_idx''),
+            ((U', U), error_pack) :: cs'') end;
   in
     gen [] []
   end;
@@ -316,7 +344,7 @@
                   handle NO_UNIFIER (msg, _) =>
                     err_list ctxt (gen_msg err
                       "failed to unify invariant arguments w.r.t. to the known map function" ^ msg)
-                      (fst tye_idx) Ts)
+                      (fst tye_idx) (T :: Ts))
               | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
                   handle NO_UNIFIER (msg, _) =>
                     error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
@@ -362,10 +390,11 @@
         and simplify done [] tye_idx = (done, tye_idx)
           | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
               (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
-                (Type (a, []), Type (b, [])) =>
+                (T1 as Type (a, []), T2 as Type (b, [])) =>
                   if a = b then simplify done todo tye_idx
                   else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
-                  else error (gen_msg err (a ^ " is not a subtype of " ^ b))
+                  else error (gen_msg err (quote (Syntax.string_of_typ ctxt T1) ^
+                    " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2)))
               | (Type (a, Ts), Type (b, Us)) =>
                   if a <> b then error (gen_msg err "different constructors")
                     (fst tye_idx) error_pack
@@ -538,7 +567,8 @@
                 in
                   if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
                   then apfst (Vartab.update (xi, T)) tye_idx
-                  else err_bound ctxt (gen_msg err ("assigned simple type " ^ s ^
+                  else err_bound ctxt (gen_msg err ("assigned base type " ^
+                    quote (Syntax.string_of_typ ctxt T) ^
                     " clashes with the upper bound of variable " ^
                     Syntax.string_of_typ ctxt (TVar(xi, S)))) tye (find_error_pack (not lower) key)
                 end
@@ -587,77 +617,105 @@
 
 (** coercion insertion **)
 
-fun gen_coercion ctxt tye (T1, T2) =
-  (case pairself (Type_Infer.deref tye) (T1, T2) of
-    ((Type (a, [])), (Type (b, []))) =>
-        if a = b
-        then Abs (Name.uu, Type (a, []), Bound 0)
-        else
-          (case Symreltab.lookup (coes_of ctxt) (a, b) of
-            NONE => raise Fail (a ^ " is not a subtype of " ^ b)
-          | SOME (co, _) => co)
-  | ((Type (a, Ts)), (Type (b, Us))) =>
-        if a <> b
-        then raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
-        else
-          let
-            fun inst t Ts =
-              Term.subst_vars
-                (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
-            fun sub_co (COVARIANT, TU) = SOME (gen_coercion ctxt tye TU)
-              | sub_co (CONTRAVARIANT, TU) = SOME (gen_coercion ctxt tye (swap TU))
-              | sub_co (INVARIANT_TO _, _) = NONE;
-            fun ts_of [] = []
-              | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
-          in
-            (case Symtab.lookup (tmaps_of ctxt) a of
-              NONE => raise Fail ("No map function for " ^ a ^ " known")
-            | SOME tmap =>
-                let
-                  val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
-                in
-                  if null (filter (not o is_identity) used_coes)
-                  then Abs (Name.uu, Type (a, Ts), Bound 0)
-                  else Term.list_comb
-                    (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
-                end)
-          end
-  | (T, U) =>
-        if Type.could_unify (T, U)
-        then Abs (Name.uu, T, Bound 0)
-        else raise Fail ("Cannot generate coercion from "
-          ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U));
+fun gen_coercion ctxt err tye TU =
+  let
+    fun gen (T1, T2) = (case pairself (Type_Infer.deref tye) (T1, T2) of
+        (T1 as (Type (a, [])), T2 as (Type (b, []))) =>
+            if a = b
+            then Abs (Name.uu, Type (a, []), Bound 0)
+            else
+              (case Symreltab.lookup (coes_of ctxt) (a, b) of
+                NONE => raise COERCION_GEN_ERROR (err ++> quote (Syntax.string_of_typ ctxt T1) ^
+                  " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2))
+              | SOME (co, _) => co)
+      | ((Type (a, Ts)), (Type (b, Us))) =>
+            if a <> b
+            then
+              (case Symreltab.lookup (coes_of ctxt) (a, b) of
+                (*immediate error - cannot fix complex coercion with the global algorithm*)
+                NONE => error (err () ^ "No coercion known for type constructors: " ^
+                  quote a ^ " and " ^ quote b)
+              | SOME (co, ((Ts', Us'), _)) =>
+                  let
+                    val co_before = gen (T1, Type (a, Ts'));
+                    val coT = range_type (fastype_of co_before);
+                    val insts = inst_collect tye (err ++> "Could not insert complex coercion")
+                      (domain_type (fastype_of co)) coT;
+                    val co' = Term.subst_TVars insts co;
+                    val co_after = gen (Type (b, (map (typ_subst_TVars insts) Us')), T2);
+                  in
+                    Abs (Name.uu, T1, Library.foldr (op $)
+                      (filter (not o is_identity) [co_after, co', co_before], Bound 0))
+                  end)
+            else
+              let
+                fun sub_co (COVARIANT, TU) = SOME (gen TU)
+                  | sub_co (CONTRAVARIANT, TU) = SOME (gen (swap TU))
+                  | sub_co (INVARIANT_TO _, _) = NONE;
+                fun ts_of [] = []
+                  | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
+              in
+                (case Symtab.lookup (tmaps_of ctxt) a of
+                  NONE => raise COERCION_GEN_ERROR
+                    (err ++> "No map function for " ^ quote a ^ " known")
+                | SOME tmap =>
+                    let
+                      val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
+                    in
+                      if null (filter (not o is_identity) used_coes)
+                      then Abs (Name.uu, Type (a, Ts), Bound 0)
+                      else Term.list_comb
+                        (instantiate (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
+                    end)
+              end
+      | (T, U) =>
+            if Type.could_unify (T, U)
+            then Abs (Name.uu, T, Bound 0)
+            else raise COERCION_GEN_ERROR (err ++> "Cannot generate coercion from " ^
+              quote (Syntax.string_of_typ ctxt T) ^ " to " ^
+              quote (Syntax.string_of_typ ctxt U)));
+  in
+    gen TU
+  end;
 
-fun insert_coercions ctxt tye ts =
+fun function_of ctxt err tye T =
+  (case Type_Infer.deref tye T of
+    Type (C, Ts) =>
+      (case Symreltab.lookup (coes_of ctxt) (C, "fun") of
+        NONE => error (err () ^ "No complex coercion from " ^ quote C ^ " to fun")
+      | SOME (co, ((Ts', _), _)) =>
+        let
+          val co_before = gen_coercion ctxt err tye (Type (C, Ts), Type (C, Ts'));
+          val coT = range_type (fastype_of co_before);
+          val insts = inst_collect tye (err ++> "Could not insert complex coercion")
+            (domain_type (fastype_of co)) coT;
+          val co' = Term.subst_TVars insts co;
+        in
+          Abs (Name.uu, Type (C, Ts), Library.foldr (op $)
+            (filter (not o is_identity) [co', co_before], Bound 0))
+        end)
+  | T' => error (err () ^ "No complex coercion from " ^
+      quote (Syntax.string_of_typ ctxt T') ^ " to fun"));
+
+fun insert_coercions ctxt (tye, idx) ts =
   let
-    fun insert _ (Const (c, T)) =
-          let val T' = T;
-          in (Const (c, T'), T') end
-      | insert _ (Free (x, T)) =
-          let val T' = T;
-          in (Free (x, T'), T') end
-      | insert _ (Var (xi, T)) =
-          let val T' = T;
-          in (Var (xi, T'), T') end
+    fun insert _ (Const (c, T)) = (Const (c, T), T)
+      | insert _ (Free (x, T)) = (Free (x, T), T)
+      | insert _ (Var (xi, T)) = (Var (xi, T), T)
       | insert bs (Bound i) =
           let val T = nth bs i handle General.Subscript => err_loose i;
           in (Bound i, T) end
       | insert bs (Abs (x, T, t)) =
-          let
-            val T' = T;
-            val (t', T'') = insert (T' :: bs) t;
-          in
-            (Abs (x, T', t'), T' --> T'')
-          end
+          let val (t', T') = insert (T :: bs) t;
+          in (Abs (x, T, t'), T --> T') end
       | insert bs (t $ u) =
           let
-            val (t', Type ("fun", [U, T])) =
-              apsnd (Type_Infer.deref tye) (insert bs t);
+            val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
             val (u', U') = insert bs u;
           in
             if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
             then (t' $ u', T)
-            else (t' $ (gen_coercion ctxt tye (U', U) $ u'), T)
+            else (t' $ (gen_coercion ctxt (K "") tye (U', U) $ u'), T)
           end
   in
     map (fst o insert []) ts
@@ -686,24 +744,51 @@
             val V = Type_Infer.mk_param idx [];
             val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
               handle NO_UNIFIER (msg, tye') =>
-                raise TYPE_INFERENCE_ERROR (err_appl_msg ctxt msg tye' bs t T u U);
+                let
+                  val err = err_appl_msg ctxt msg tye' bs t T u U;
+                  val W = Type_Infer.mk_param (idx + 1) [];
+                  val (t'', (tye', idx')) =
+                    (t', strong_unify ctxt (W --> V, T) (tye, idx + 2))
+                      handle NO_UNIFIER _ =>
+                        let
+                          val err' =
+                            err ++> "\nLocal coercion insertion on the operator failed:\n";
+                          val co = function_of ctxt err' tye T;
+                          val (t'', T'', tye_idx'') = inf bs (co $ t') (tye, idx + 2);
+                        in
+                          (t'', strong_unify ctxt (W --> V, T'') tye_idx''
+                             handle NO_UNIFIER (msg, _) => error (err' () ^ msg))
+                        end;
+                  val err' = err ++> (if t' aconv t'' then ""
+                    else "\nSuccessfully coerced the operand to a function of type:\n" ^
+                      Syntax.string_of_typ ctxt
+                        (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^
+                      "\nLocal coercion insertion on the operand failed:\n";
+                  val co = gen_coercion ctxt err' tye' (U, W);
+                  val (u'', U', tye_idx') =
+                    inf bs (if is_identity co then u else co $ u) (tye', idx');
+                in
+                  (t'' $ u'', strong_unify ctxt (U', W) tye_idx'
+                    handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg))
+                end;
           in (tu, V, tye_idx'') end;
 
     fun infer_single t tye_idx =
-      let val (t, _, tye_idx') = inf [] t tye_idx;
+      let val (t, _, tye_idx') = inf [] t tye_idx
       in (t, tye_idx') end;
 
     val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
-      handle TYPE_INFERENCE_ERROR err =>
+      handle COERCION_GEN_ERROR err =>
         let
           fun gen_single t (tye_idx, constraints) =
-            let val (_, tye_idx', constraints') = generate_constraints ctxt err t tye_idx
+            let val (_, tye_idx', constraints') =
+              generate_constraints ctxt (err ++> "\n") t tye_idx
             in (tye_idx', constraints' @ constraints) end;
 
           val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
-          val (tye, idx) = process_constraints ctxt err constraints tye_idx;
+          val (tye, idx) = process_constraints ctxt (err ++> "\n") constraints tye_idx;
         in
-          (insert_coercions ctxt tye ts, (tye, idx))
+          (insert_coercions ctxt (tye, idx) ts, (tye, idx))
         end);
 
     val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
@@ -767,13 +852,22 @@
     map_tmaps (Symtab.update (snd res, (t, res_av))) context
   end;
 
-fun transitive_coercion tab G (a, b) =
+fun transitive_coercion ctxt tab G (a, b) =
   let
+    fun safe_app t (Abs (x, T', u)) =
+      let
+        val t' = map_types Type_Infer.paramify_vars t;
+      in
+        singleton (coercion_infer_types ctxt) (Abs(x, T', (t' $ u)))
+      end;
     val path = hd (Graph.irreducible_paths G (a, b));
     val path' = fst (split_last path) ~~ tl path;
     val coercions = map (fst o the o Symreltab.lookup tab) path';
-  in (Abs (Name.uu, Type (a, []),
-      fold (fn t => fn u => t $ u) coercions (Bound 0)), coercions)
+    val trans_co = singleton (Variable.polymorphic ctxt)
+      (fold safe_app coercions (Abs (Name.uu, dummyT, Bound 0)));
+    val (Ts, Us) = pairself (snd o Term.dest_Type) (Term.dest_funT (type_of trans_co))
+  in
+    (trans_co, ((Ts, Us), coercions))
   end;
 
 fun add_coercion raw_t context =
@@ -788,23 +882,19 @@
     val (T1, T2) = Term.dest_funT (fastype_of t)
       handle TYPE _ => err_coercion ();
 
-    val a =
-      (case T1 of
-        Type (x, []) => x
-      | _ => err_coercion ());
+    val (a, Ts) = Term.dest_Type T1
+      handle TYPE _ => err_coercion ();
 
-    val b =
-      (case T2 of
-        Type (x, []) => x
-      | _ => err_coercion ());
+    val (b, Us) = Term.dest_Type T2
+      handle TYPE _ => err_coercion ();
 
-    fun coercion_data_update (tab, G) =
+    fun coercion_data_update (tab, G, _) =
       let
-        val G' = maybe_new_nodes [a, b] G
+        val G' = maybe_new_nodes [(a, length Ts), (b, length Us)] G
         val G'' = Graph.add_edge_trans_acyclic (a, b) G'
           handle Graph.CYCLES _ => error (
-            Syntax.string_of_typ ctxt T1 ^ " is already a subtype of " ^
-            Syntax.string_of_typ ctxt T2 ^ "!\n\nCannot add coercion of type: " ^
+            Syntax.string_of_typ ctxt T2 ^ " is already a subtype of " ^
+            Syntax.string_of_typ ctxt T1 ^ "!\n\nCannot add coercion of type: " ^
             Syntax.string_of_typ ctxt (T1 --> T2));
         val new_edges =
           flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
@@ -813,14 +903,14 @@
 
         val tab' = fold
           (fn pair => fn tab =>
-            Symreltab.update (pair, transitive_coercion tab G_and_new pair) tab)
+            Symreltab.update (pair, transitive_coercion ctxt tab G_and_new pair) tab)
           (filter (fn pair => pair <> (a, b)) new_edges)
-          (Symreltab.update ((a, b), (t, [])) tab);
+          (Symreltab.update ((a, b), (t, ((Ts, Us), []))) tab);
       in
-        (tab', G'')
+        (tab', G'', restrict_graph G'')
       end;
   in
-    map_coes_and_graph coercion_data_update context
+    map_coes_and_graphs coercion_data_update context
   end;
 
 fun delete_coercion raw_t context =
@@ -836,41 +926,37 @@
     val (T1, T2) = Term.dest_funT (fastype_of t)
       handle TYPE _ => err_coercion false;
 
-    val a =
-      (case T1 of
-        Type (x, []) => x
-      | _ => err_coercion false);
+    val (a, Ts) = dest_Type T1
+      handle TYPE _ => err_coercion false;
 
-    val b =
-      (case T2 of
-        Type (x, []) => x
-      | _ => err_coercion false);
+    val (b, Us) = dest_Type T2
+      handle TYPE _ => err_coercion false;
 
     fun delete_and_insert tab G =
       let
         val pairs =
-          Symreltab.fold (fn ((a, b), (_, ts)) => fn pairs =>
+          Symreltab.fold (fn ((a, b), (_, (_, ts))) => fn pairs =>
             if member (op aconv) ts t then (a, b) :: pairs else pairs) tab [(a, b)];
         fun delete pair (G, tab) = (Graph.del_edge pair G, Symreltab.delete_safe pair tab);
         val (G', tab') = fold delete pairs (G, tab);
         fun reinsert pair (G, xs) = (case (Graph.irreducible_paths G pair) of
               [] => (G, xs)
-            | _ => (Graph.add_edge pair G, (pair, transitive_coercion tab' G' pair) :: xs));
+            | _ => (Graph.add_edge pair G, (pair, transitive_coercion ctxt tab' G' pair) :: xs));
         val (G'', ins) = fold reinsert pairs (G', []);
       in
-        (fold Symreltab.update ins tab', G'')
+        (fold Symreltab.update ins tab', G'', restrict_graph G'')
       end
 
     fun show_term t = Pretty.block [Syntax.pretty_term ctxt t,
       Pretty.str " :: ", Syntax.pretty_typ ctxt (fastype_of t)]
 
-    fun coercion_data_update (tab, G) =
+    fun coercion_data_update (tab, G, _) =
         (case Symreltab.lookup tab (a, b) of
           NONE => err_coercion false
-        | SOME (t', []) => if t aconv t'
+        | SOME (t', (_, [])) => if t aconv t'
             then delete_and_insert tab G
             else err_coercion true
-        | SOME (t', ts) => if t aconv t'
+        | SOME (t', (_, ts)) => if t aconv t'
             then error ("Cannot delete the automatically derived coercion:\n" ^
               Syntax.string_of_term ctxt t ^ " :: " ^
               Syntax.string_of_typ ctxt (fastype_of t) ^
@@ -879,19 +965,26 @@
               "\nwill also remove the transitive coercion.")
             else err_coercion true);
   in
-    map_coes_and_graph coercion_data_update context
+    map_coes_and_graphs coercion_data_update context
   end;
 
 fun print_coercions ctxt =
   let
-    fun show_coercion ((a, b), (t, _)) = Pretty.block [
-      Syntax.pretty_typ ctxt (Type (a, [])),
+    fun separate _ [] = ([], [])
+      | separate P (x::xs) = (if P x then apfst else apsnd) (cons x) (separate P xs);
+    val (simple, complex) =
+      separate (fn (_, (_, ((Ts, Us), _))) => null Ts andalso null Us)
+        (Symreltab.dest (coes_of ctxt));
+    fun show_coercion ((a, b), (t, ((Ts, Us), _))) = Pretty.block [
+      Syntax.pretty_typ ctxt (Type (a, Ts)),
       Pretty.brk 1, Pretty.str "<:", Pretty.brk 1,
-      Syntax.pretty_typ ctxt (Type (b, [])),
+      Syntax.pretty_typ ctxt (Type (b, Us)),
       Pretty.brk 3, Pretty.block [Pretty.str "using", Pretty.brk 1,
       Pretty.quote (Syntax.pretty_term ctxt t)]];
   in
-    Pretty.big_list "Coercions:" (map show_coercion (Symreltab.dest (coes_of ctxt)))
+    Pretty.big_list "Coercions:"
+    [Pretty.big_list "between base types:" (map show_coercion simple),
+     Pretty.big_list "other:" (map show_coercion complex)]
     |> Pretty.writeln
   end;