src/Tools/subtyping.ML
changeset 40281 3c6198fd0937
child 40282 329cd9dd5949
equal deleted inserted replaced
40280:0dd2827e8596 40281:3c6198fd0937
       
     1 (*  Title:      Tools/subtyping.ML
       
     2     Author:     Dmitriy Traytel, TU Muenchen
       
     3 
       
     4 Coercive subtyping via subtype constraints.
       
     5 *)
       
     6 
       
     7 signature SUBTYPING =
       
     8 sig
       
     9   datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
       
    10   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
       
    11     term list -> term list
       
    12 end;
       
    13 
       
    14 structure Subtyping =
       
    15 struct
       
    16 
       
    17 
       
    18 
       
    19 (** coercions data **)
       
    20 
       
    21 datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
       
    22 
       
    23 datatype data = Data of
       
    24   {coes: term Symreltab.table, (* coercions table *)
       
    25    coes_graph: unit Graph.T, (* coercions graph *)
       
    26    tmaps: (term * variance list) Symtab.table}; (* map functions *)
       
    27 
       
    28 fun make_data (coes, coes_graph, tmaps) =
       
    29   Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
       
    30 
       
    31 structure Data = Generic_Data
       
    32 (
       
    33   type T = data;
       
    34   val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
       
    35   val extend = I;
       
    36   fun merge
       
    37     (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
       
    38       Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
       
    39     make_data (Symreltab.merge (op aconv) (coes1, coes2),
       
    40       Graph.merge (op =) (coes_graph1, coes_graph2),
       
    41       Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
       
    42 );
       
    43 
       
    44 fun map_data f =
       
    45   Data.map (fn Data {coes, coes_graph, tmaps} =>
       
    46     make_data (f (coes, coes_graph, tmaps)));
       
    47 
       
    48 fun map_coes f =
       
    49   map_data (fn (coes, coes_graph, tmaps) =>
       
    50     (f coes, coes_graph, tmaps));
       
    51 
       
    52 fun map_coes_graph f =
       
    53   map_data (fn (coes, coes_graph, tmaps) =>
       
    54     (coes, f coes_graph, tmaps));
       
    55 
       
    56 fun map_coes_and_graph f =
       
    57   map_data (fn (coes, coes_graph, tmaps) =>
       
    58     let val (coes', coes_graph') = f (coes, coes_graph);
       
    59     in (coes', coes_graph', tmaps) end);
       
    60 
       
    61 fun map_tmaps f =
       
    62   map_data (fn (coes, coes_graph, tmaps) =>
       
    63     (coes, coes_graph, f tmaps));
       
    64 
       
    65 fun rep_data context = Data.get context |> (fn Data args => args);
       
    66 
       
    67 val coes_of = #coes o rep_data;
       
    68 val coes_graph_of = #coes_graph o rep_data;
       
    69 val tmaps_of = #tmaps o rep_data;
       
    70 
       
    71 
       
    72 
       
    73 (** utils **)
       
    74 
       
    75 val is_param = Type_Infer.is_param
       
    76 val is_paramT = Type_Infer.is_paramT
       
    77 val deref = Type_Infer.deref
       
    78 fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *)
       
    79 
       
    80 fun nameT (Type (s, [])) = s;
       
    81 fun t_of s = Type (s, []);
       
    82 fun sort_of (TFree (_, S)) = SOME S
       
    83   | sort_of (TVar (_, S)) = SOME S
       
    84   | sort_of _ = NONE;
       
    85 
       
    86 val is_typeT = fn (Type _) => true | _ => false;
       
    87 val is_compT = fn (Type (_, _::_)) => true | _ => false;
       
    88 val is_freeT = fn (TFree _) => true | _ => false;
       
    89 val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false;
       
    90 
       
    91 
       
    92 (* unification TODO dup? needed for weak unification *)
       
    93 
       
    94 exception NO_UNIFIER of string * typ Vartab.table;
       
    95 
       
    96 fun unify weak ctxt =
       
    97   let
       
    98     val thy = ProofContext.theory_of ctxt;
       
    99     val pp = Syntax.pp ctxt;
       
   100     val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
       
   101 
       
   102     
       
   103     (* adjust sorts of parameters *)
       
   104 
       
   105     fun not_of_sort x S' S =
       
   106       "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
       
   107         Syntax.string_of_sort ctxt S;
       
   108 
       
   109     fun meet (_, []) tye_idx = tye_idx
       
   110       | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
       
   111           meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
       
   112       | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
       
   113           if Sign.subsort thy (S', S) then tye_idx
       
   114           else raise NO_UNIFIER (not_of_sort x S' S, tye)
       
   115       | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
       
   116           if Sign.subsort thy (S', S) then tye_idx
       
   117           else if Type_Infer.is_param xi then
       
   118             (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
       
   119           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
       
   120     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
       
   121           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
       
   122       | meets _ tye_idx = tye_idx;
       
   123 
       
   124     val weak_meet = if weak then fn _ => I else meet
       
   125 
       
   126 
       
   127     (* occurs check and assignment *)
       
   128 
       
   129     fun occurs_check tye xi (TVar (xi', _)) =
       
   130           if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
       
   131           else
       
   132             (case Vartab.lookup tye xi' of
       
   133               NONE => ()
       
   134             | SOME T => occurs_check tye xi T)
       
   135       | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
       
   136       | occurs_check _ _ _ = ();
       
   137 
       
   138     fun assign xi (T as TVar (xi', _)) S env =
       
   139           if xi = xi' then env
       
   140           else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
       
   141       | assign xi T S (env as (tye, _)) =
       
   142           (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
       
   143 
       
   144 
       
   145     (* unification *)
       
   146 
       
   147     fun show_tycon (a, Ts) =
       
   148       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
       
   149 
       
   150     fun unif (T1, T2) (env as (tye, _)) =
       
   151       (case pairself (`is_paramT o deref tye) (T1, T2) of
       
   152         ((true, TVar (xi, S)), (_, T)) => assign xi T S env
       
   153       | ((_, T), (true, TVar (xi, S))) => assign xi T S env
       
   154       | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
       
   155           if weak andalso null Ts andalso null Us then env
       
   156           else if a <> b then
       
   157             raise NO_UNIFIER
       
   158               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
       
   159           else fold unif (Ts ~~ Us) env
       
   160       | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
       
   161 
       
   162   in unif end;
       
   163 
       
   164 val weak_unify = unify true;
       
   165 val strong_unify = unify false;
       
   166 
       
   167 
       
   168 (* Typ_Graph shortcuts *)
       
   169 
       
   170 val add_edge = Typ_Graph.add_edge_acyclic;
       
   171 fun get_preds G T = Typ_Graph.all_preds G [T];
       
   172 fun get_succs G T = Typ_Graph.all_succs G [T];
       
   173 fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
       
   174 fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
       
   175 fun new_imm_preds G Ts = 
       
   176   subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts));
       
   177 fun new_imm_succs G Ts = 
       
   178   subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts));
       
   179 
       
   180 
       
   181 (* Graph shortcuts *)
       
   182 
       
   183 fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
       
   184 fun maybe_new_nodes ss G = fold maybe_new_node ss G
       
   185 
       
   186 
       
   187 
       
   188 (** error messages **)
       
   189 
       
   190 fun prep_output ctxt tye bs ts Ts =
       
   191   let
       
   192     val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
       
   193     val (Ts', Ts'') = chop (length Ts) Ts_bTs';
       
   194     fun prep t =
       
   195       let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
       
   196       in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
       
   197   in (map prep ts', Ts') end;
       
   198 
       
   199 fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
       
   200 
       
   201 fun inf_failed msg =
       
   202   "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
       
   203 
       
   204 fun err_appl ctxt msg tye bs t T u U =
       
   205   let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
       
   206   in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end;
       
   207 
       
   208 fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') =
       
   209   err_appl ctxt msg tye bs t (U --> V) u U';
       
   210 
       
   211 fun err_list ctxt msg tye Ts =
       
   212   let
       
   213     val (_, Ts') = prep_output ctxt tye [] [] Ts;
       
   214     val text = cat_lines ([inf_failed msg,
       
   215       "Cannot unify a list of types that should be the same,",
       
   216       "according to suptype dependencies:",
       
   217       (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]);
       
   218   in
       
   219     error text
       
   220   end;
       
   221 
       
   222 fun err_bound ctxt msg tye packs =
       
   223   let
       
   224     val pp = Syntax.pp ctxt;
       
   225     val (ts, Ts) = fold
       
   226       (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
       
   227         let val (t', T') = prep_output ctxt tye bs [t, u] [U, U']
       
   228         in (t'::ts, T'::Ts) end)
       
   229       packs ([], []);
       
   230     val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @
       
   231         (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
       
   232           Pretty.block [
       
   233             Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U,
       
   234             Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2,
       
   235             Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]]))
       
   236         ts Ts))
       
   237   in
       
   238     error text
       
   239   end;
       
   240 
       
   241 
       
   242 
       
   243 (** constraint generation **)
       
   244 
       
   245 fun generate_constraints ctxt =
       
   246   let
       
   247     fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
       
   248       | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
       
   249       | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
       
   250       | gen cs bs (Bound i) tye_idx =
       
   251           (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs)
       
   252       | gen cs bs (Abs (x, T, t)) tye_idx =
       
   253           let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
       
   254           in (T --> U, tye_idx', cs') end
       
   255       | gen cs bs (t $ u) tye_idx =
       
   256           let
       
   257             val (T, tye_idx', cs') = gen cs bs t tye_idx;
       
   258             val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
       
   259             val U = mk_param idx [];
       
   260             val V = mk_param (idx + 1) [];
       
   261             val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
       
   262               handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
       
   263             val error_pack = (bs, t $ u, U, V, U');
       
   264           in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
       
   265   in
       
   266     gen [] []
       
   267   end;
       
   268 
       
   269 
       
   270 
       
   271 (** constraint resolution **)
       
   272 
       
   273 exception BOUND_ERROR of string;
       
   274 
       
   275 fun process_constraints ctxt cs tye_idx =
       
   276   let
       
   277     val coes_graph = coes_graph_of (Context.Proof ctxt);
       
   278     val tmaps = tmaps_of (Context.Proof ctxt);
       
   279     val tsig = Sign.tsig_of (ProofContext.theory_of ctxt);
       
   280     val pp = Syntax.pp ctxt;
       
   281     val arity_sorts = Type.arity_sorts pp tsig;
       
   282     val subsort = Type.subsort tsig;
       
   283 
       
   284     fun split_cs _ [] = ([], [])
       
   285       | split_cs f (c::cs) =
       
   286           (case pairself f (fst c) of
       
   287             (false, false) => apsnd (cons c) (split_cs f cs)
       
   288           | _ => apfst (cons c) (split_cs f cs));
       
   289 
       
   290           
       
   291     (* check whether constraint simplification will terminate using weak unification *)
       
   292     
       
   293     val _ = fold (fn (TU, error_pack) => fn tye_idx =>
       
   294       (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
       
   295         err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg)
       
   296           tye error_pack)) cs tye_idx;
       
   297 
       
   298 
       
   299     (* simplify constraints *)
       
   300     
       
   301     fun simplify_constraints cs tye_idx =
       
   302       let
       
   303         fun contract a Ts Us error_pack done todo tye idx =
       
   304           let
       
   305             val arg_var =
       
   306               (case Symtab.lookup tmaps a of
       
   307                 (*everything is invariant for unknown constructors*)
       
   308                 NONE => replicate (length Ts) INVARIANT
       
   309               | SOME av => snd av);
       
   310             fun new_constraints (variance, constraint) (cs, tye_idx) =
       
   311               (case variance of
       
   312                 COVARIANT => (constraint :: cs, tye_idx)
       
   313               | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
       
   314               | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
       
   315                   handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
       
   316             val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
       
   317               (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
       
   318             val test_update = is_compT orf is_freeT orf is_fixedvarT;
       
   319             val (ch, done') =
       
   320               if not (null new) then ([],   done)
       
   321               else split_cs (test_update o deref tye') done;
       
   322             val todo' = ch @ todo;
       
   323           in
       
   324             simplify done' (new @ todo') (tye', idx')
       
   325           end
       
   326         (*xi is definitely a parameter*)
       
   327         and expand varleq xi S a Ts error_pack done todo tye idx =
       
   328           let
       
   329             val n = length Ts;
       
   330             val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S);
       
   331             val tye' = Vartab.update_new (xi, Type(a, args)) tye;
       
   332             val (ch, done') = split_cs (is_compT o deref tye') done;
       
   333             val todo' = ch @ todo;
       
   334             val new =
       
   335               if varleq then (Type(a, args), Type (a, Ts))
       
   336               else (Type (a, Ts), Type(a, args));
       
   337           in
       
   338             simplify done' ((new, error_pack) :: todo') (tye', idx + n)
       
   339           end
       
   340         (*TU is a pair of a parameter and a free/fixed variable*)
       
   341         and eliminate TU error_pack done todo tye idx =
       
   342           let
       
   343             val [TVar (xi, S)] = filter is_paramT TU;
       
   344             val [T] = filter_out is_paramT TU;
       
   345             val SOME S' = sort_of T;
       
   346             val test_update = if is_freeT T then is_freeT else is_fixedvarT;
       
   347             val tye' = Vartab.update_new (xi, T) tye;
       
   348             val (ch, done') = split_cs (test_update o deref tye') done;
       
   349             val todo' = ch @ todo;
       
   350           in
       
   351             if subsort (S', S) (*TODO check this*)
       
   352             then simplify done' todo' (tye', idx)
       
   353             else err_subtype ctxt "Sort mismatch" tye error_pack
       
   354           end
       
   355         and simplify done [] tye_idx = (done, tye_idx)
       
   356           | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
       
   357               (case (deref tye T, deref tye U) of
       
   358                 (Type (a, []), Type (b, [])) =>
       
   359                   if a = b then simplify done todo tye_idx
       
   360                   else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
       
   361                   else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack
       
   362               | (Type (a, Ts), Type (b, Us)) =>
       
   363                   if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
       
   364                   else contract a Ts Us error_pack done todo tye idx
       
   365               | (TVar (xi, S), Type (a, Ts as (_::_))) =>
       
   366                   expand true xi S a Ts error_pack done todo tye idx
       
   367               | (Type (a, Ts as (_::_)), TVar (xi, S)) =>
       
   368                   expand false xi S a Ts error_pack done todo tye idx
       
   369               | (T, U) =>
       
   370                   if T = U then simplify done todo tye_idx
       
   371                   else if exists (is_freeT orf is_fixedvarT) [T, U] andalso 
       
   372                     exists is_paramT [T, U]
       
   373                   then eliminate [T, U] error_pack done todo tye idx
       
   374                   else if exists (is_freeT orf is_fixedvarT) [T, U]
       
   375                   then err_subtype ctxt "Not eliminated free/fixed variables"
       
   376                         (fst tye_idx) error_pack
       
   377                   else simplify (((T, U), error_pack)::done) todo tye_idx);
       
   378       in
       
   379         simplify [] cs tye_idx
       
   380       end;
       
   381 
       
   382 
       
   383     (* do simplification *)
       
   384     
       
   385     val (cs', tye_idx') = simplify_constraints cs tye_idx;
       
   386 
       
   387     fun find_error_pack lower T' =
       
   388       map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs');
       
   389 
       
   390     fun unify_list (T::Ts) tye_idx =
       
   391       fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx
       
   392         handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T::Ts))
       
   393       Ts tye_idx;
       
   394 
       
   395     (*styps stands either for supertypes or for subtypes of a type T
       
   396       in terms of the subtype-relation (excluding T itself)*)
       
   397     fun styps super T = 
       
   398       (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T
       
   399         handle Graph.UNDEF _ => [];
       
   400 
       
   401     fun minmax sup (T::Ts) =
       
   402       let
       
   403         fun adjust T U = if sup then (T, U) else (U, T);
       
   404         fun extract T [] = T
       
   405           | extract T (U::Us) = 
       
   406               if Graph.is_edge coes_graph (adjust T U) then extract T Us
       
   407               else if Graph.is_edge coes_graph (adjust U T) then extract U Us
       
   408               else raise BOUND_ERROR "Uncomparable types in type list";
       
   409       in
       
   410         t_of (extract T Ts)
       
   411       end;
       
   412 
       
   413     fun ex_styp_of_sort super T styps_and_sorts = 
       
   414       let
       
   415         fun adjust T U = if super then (T, U) else (U, T);
       
   416         fun styp_test U Ts = forall 
       
   417           (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
       
   418         fun fitting Ts S U = Type.of_sort tsig (t_of U, S) andalso styp_test U Ts
       
   419       in
       
   420         forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
       
   421       end;
       
   422 
       
   423     (* computes the tightest possible, correct assignment for 'a::S
       
   424        e.g. in the supremum case (sup = true):
       
   425                ------- 'a::S---
       
   426               /        /    \  \
       
   427              /        /      \  \
       
   428         'b::C1   'c::C2 ...  T1 T2 ...
       
   429 
       
   430        sorts - list of sorts [C1, C2, ...]
       
   431        T::Ts - non-empty list of base types [T1, T2, ...]
       
   432     *)
       
   433     fun tightest sup S styps_and_sorts (T::Ts) =
       
   434       let
       
   435         fun restriction T = Type.of_sort tsig (t_of T, S)
       
   436           andalso ex_styp_of_sort (not sup) T styps_and_sorts;
       
   437         fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
       
   438       in
       
   439         (case fold candidates Ts (filter restriction (T :: styps sup T)) of
       
   440           [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum"))
       
   441         | [T] => t_of T
       
   442         | Ts => minmax sup Ts)
       
   443       end;
       
   444 
       
   445     fun build_graph G [] tye_idx = (G, tye_idx)
       
   446       | build_graph G ((T, U)::cs) tye_idx =
       
   447         if T = U then build_graph G cs tye_idx
       
   448         else
       
   449           let
       
   450             val G' = maybe_new_typnodes [T, U] G;
       
   451             val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
       
   452               handle Typ_Graph.CYCLES cycles =>
       
   453                 let
       
   454                   val (tye, idx) = fold unify_list cycles tye_idx
       
   455                 in
       
   456                   (*all cycles collapse to one node,
       
   457                     because all of them share at least the nodes x and y*)
       
   458                   collapse (tye, idx) (distinct (op =) (flat cycles)) G
       
   459                 end;
       
   460           in
       
   461             build_graph G'' cs tye_idx'
       
   462           end
       
   463     and collapse (tye, idx) nodes G = (*nodes non-empty list*)
       
   464       let
       
   465         val T = hd nodes;
       
   466         val P = new_imm_preds G nodes;
       
   467         val S = new_imm_succs G nodes;
       
   468         val G' = Typ_Graph.del_nodes (tl nodes) G;
       
   469       in
       
   470         build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
       
   471       end;
       
   472 
       
   473     fun assign_bound lower G key (tye_idx as (tye, _)) =
       
   474       if is_paramT (deref tye key) then
       
   475         let
       
   476           val TVar (xi, S) = deref tye key;
       
   477           val get_bound = if lower then get_preds else get_succs;
       
   478           val raw_bound = get_bound G key;
       
   479           val bound = map (deref tye) raw_bound;
       
   480           val not_params = filter_out is_paramT bound;
       
   481           fun to_fulfil T = 
       
   482             (case sort_of T of
       
   483               NONE => NONE
       
   484             | SOME S => 
       
   485                 SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S));
       
   486           val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
       
   487           val assignment =
       
   488             if null bound orelse null not_params then NONE
       
   489             else SOME (tightest lower S styps_and_sorts (map nameT not_params)
       
   490                 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
       
   491         in
       
   492           (case assignment of
       
   493             NONE => tye_idx
       
   494           | SOME T =>
       
   495               if is_paramT T then tye_idx
       
   496               else if lower then (*upper bound check*)
       
   497                 let
       
   498                   val other_bound = map (deref tye) (get_succs G key);
       
   499                   val s = nameT T;
       
   500                 in
       
   501                   if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
       
   502                   then apfst (Vartab.update (xi, T)) tye_idx
       
   503                   else err_bound ctxt ("Assigned simple type " ^ s ^
       
   504                     " clashes with the upper bound of variable " ^
       
   505                     Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key)
       
   506                 end
       
   507               else apfst (Vartab.update (xi, T)) tye_idx)
       
   508         end
       
   509       else tye_idx;
       
   510 
       
   511     val assign_lb = assign_bound true;
       
   512     val assign_ub = assign_bound false;
       
   513 
       
   514     fun assign_alternating ts' ts G tye_idx =
       
   515       if ts' = ts then tye_idx
       
   516       else
       
   517         let
       
   518           val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
       
   519             |> fold (assign_ub G) ts;
       
   520         in
       
   521           assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx'
       
   522         end;
       
   523 
       
   524     (*Unify all weakly connected components of the constraint forest,
       
   525       that contain only params. These are the only WCCs that contain 
       
   526       params anyway.*)
       
   527     fun unify_params G (tye_idx as (tye, _)) =
       
   528       let
       
   529         val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G);
       
   530         val to_unify = map (fn T => T :: get_preds G T) max_params;
       
   531       in
       
   532         fold unify_list to_unify tye_idx
       
   533       end;
       
   534 
       
   535     fun solve_constraints G tye_idx = tye_idx
       
   536       |> assign_alternating [] (Typ_Graph.keys G) G
       
   537       |> unify_params G;
       
   538   in
       
   539     build_graph Typ_Graph.empty (map fst cs') tye_idx'
       
   540       |-> solve_constraints
       
   541   end;
       
   542 
       
   543 
       
   544 
       
   545 (** coercion insertion **)
       
   546 
       
   547 fun insert_coercions ctxt tye ts =
       
   548   let
       
   549     fun deep_deref T =
       
   550       (case deref tye T of
       
   551         Type (a, Ts) => Type (a, map deep_deref Ts)
       
   552       | U => U);
       
   553 
       
   554     fun gen_coercion ((Type (a, [])), (Type (b, []))) =
       
   555           if a = b
       
   556           then Abs (Name.uu, Type (a, []), Bound 0)
       
   557           else
       
   558             (case Symreltab.lookup (coes_of (Context.Proof ctxt)) (a, b) of
       
   559               NONE => raise Fail (a ^ " is not a subtype of " ^ b)
       
   560             | SOME co => co)
       
   561       | gen_coercion ((Type (a, Ts)), (Type (b, Us))) =
       
   562           if a <> b
       
   563           then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
       
   564           else
       
   565             let
       
   566               fun inst t Ts = 
       
   567                 Term.subst_vars 
       
   568                   (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
       
   569               fun sub_co (COVARIANT, TU) = gen_coercion TU
       
   570                 | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU);
       
   571               fun ts_of [] = []
       
   572                 | ts_of (Type ("fun", [x1, x2])::xs) = x1::x2::(ts_of xs);
       
   573             in
       
   574               (case Symtab.lookup (tmaps_of (Context.Proof ctxt)) a of
       
   575                 NONE => raise Fail ("No map function for " ^ a ^ " known")
       
   576               | SOME tmap =>
       
   577                   let
       
   578                     val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
       
   579                   in
       
   580                     Term.list_comb
       
   581                       (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
       
   582                   end)
       
   583             end
       
   584       | gen_coercion (T, U) =
       
   585           if Type.could_unify (T, U)
       
   586           then Abs (Name.uu, T, Bound 0)
       
   587           else raise Fail ("Cannot generate coercion from "
       
   588             ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U);
       
   589 
       
   590     fun insert _ (Const (c, T)) =
       
   591           let val T' = deep_deref T;
       
   592           in (Const (c, T'), T') end
       
   593       | insert _ (Free (x, T)) =
       
   594           let val T' = deep_deref T;
       
   595           in (Free (x, T'), T') end
       
   596       | insert _ (Var (xi, T)) =
       
   597           let val T' = deep_deref T;
       
   598           in (Var (xi, T'), T') end
       
   599       | insert bs (Bound i) =
       
   600           let val T = nth bs i handle Subscript =>
       
   601             raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
       
   602           in (Bound i, T) end
       
   603       | insert bs (Abs (x, T, t)) =
       
   604           let
       
   605             val T' = deep_deref T;
       
   606             val (t', T'') = insert (T'::bs) t;
       
   607           in
       
   608             (Abs (x, T', t'), T' --> T'')
       
   609           end
       
   610       | insert bs (t $ u) =
       
   611           let
       
   612             val (t', Type ("fun", [U, T])) = insert bs t;
       
   613             val (u', U') = insert bs u;
       
   614           in
       
   615             if U <> U'
       
   616             then (t' $ (gen_coercion (U', U) $ u'), T)
       
   617             else (t' $ u', T)
       
   618           end
       
   619   in
       
   620     map (fst o insert []) ts
       
   621   end;
       
   622 
       
   623 
       
   624 
       
   625 (** assembling the pipeline **)
       
   626 
       
   627 fun infer_types ctxt const_type var_type raw_ts =
       
   628   let
       
   629     val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts;
       
   630 
       
   631     fun gen_all t (tye_idx, constraints) =
       
   632       let
       
   633         val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx
       
   634       in (tye_idx', constraints' @ constraints) end;
       
   635 
       
   636     val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []);
       
   637     val (tye, _) = process_constraints ctxt constraints tye_idx;
       
   638     val ts' = insert_coercions ctxt tye ts;
       
   639 
       
   640     val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
       
   641   in ts'' end;
       
   642 
       
   643 
       
   644 
       
   645 (** installation **)
       
   646 
       
   647 fun coercion_infer_types ctxt =
       
   648   infer_types ctxt
       
   649     (try (Consts.the_constraint (ProofContext.consts_of ctxt)))
       
   650     (ProofContext.def_type ctxt);
       
   651 
       
   652 local
       
   653 
       
   654 fun add eq what f = Context.>> (what (fn xs => fn ctxt =>
       
   655   let val xs' = f ctxt xs in if eq_list eq (xs, xs') then NONE else SOME (xs', ctxt) end));
       
   656 
       
   657 in
       
   658 
       
   659 val _ = add (op aconv) (Syntax.add_term_check ~100 "coercions") coercion_infer_types;
       
   660 
       
   661 end;
       
   662 
       
   663 
       
   664 (* interface *)
       
   665 
       
   666 fun add_type_map map_fun context =
       
   667   let
       
   668     val ctxt = Context.proof_of context;
       
   669     val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt map_fun);
       
   670 
       
   671     fun err_str () = "\n\nthe general type signature for a map function is" ^
       
   672       "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^
       
   673       "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
       
   674 
       
   675     fun gen_arg_var ([], []) = []
       
   676       | gen_arg_var ((T, T')::Ts, (U, U')::Us) =
       
   677           if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
       
   678           else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
       
   679           else error ("Functions do not apply to arguments correctly:" ^ err_str ())
       
   680       | gen_arg_var (_, _) =
       
   681           error ("Different numbers of functions and arguments\n" ^ err_str ());
       
   682 
       
   683     (* TODO: This function is only needed to introde the fun type map
       
   684       function: "% f g h . g o h o f". There must be a better solution. *)
       
   685     fun balanced (Type (_, [])) (Type (_, [])) = true
       
   686       | balanced (Type (a, Ts)) (Type (b, Us)) =
       
   687           a = b andalso forall I (map2 balanced Ts Us)
       
   688       | balanced (TFree _) (TFree _) = true
       
   689       | balanced (TVar _) (TVar _) = true
       
   690       | balanced _ _ = false;
       
   691 
       
   692     fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) =
       
   693           if balanced T U
       
   694           then ((pairs, Ts~~Us), C)
       
   695           else if C = "fun"
       
   696             then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U
       
   697             else error ("Not a proper map function:" ^ err_str ())
       
   698       | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ());
       
   699 
       
   700     val res = check_map_fun ([], []) (fastype_of t);
       
   701     val res_av = gen_arg_var (fst res);
       
   702   in
       
   703     map_tmaps (Symtab.update (snd res, (t, res_av))) context
       
   704   end;
       
   705 
       
   706 fun add_coercion coercion context =
       
   707   let
       
   708     val ctxt = Context.proof_of context;
       
   709     val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt coercion);
       
   710 
       
   711     fun err_coercion () = error ("Bad type for coercion " ^
       
   712         Syntax.string_of_term ctxt t ^ ":\n" ^
       
   713         Syntax.string_of_typ ctxt (fastype_of t));
       
   714 
       
   715     val (Type ("fun", [T1, T2])) = fastype_of t
       
   716       handle Bind => err_coercion ();
       
   717 
       
   718     val a =
       
   719       (case T1 of
       
   720         Type (x, []) => x
       
   721       | _ => err_coercion ());
       
   722 
       
   723     val b =
       
   724       (case T2 of
       
   725         Type (x, []) => x
       
   726       | _ => err_coercion ());
       
   727 
       
   728     fun coercion_data_update (tab, G) =
       
   729       let
       
   730         val G' = maybe_new_nodes [a, b] G
       
   731         val G'' = Graph.add_edge_trans_acyclic (a, b) G'
       
   732           handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^
       
   733             "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b);
       
   734         val new_edges =
       
   735           flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
       
   736             if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
       
   737         val G_and_new = Graph.add_edge (a, b) G';
       
   738 
       
   739         fun complex_coercion tab G (a, b) =
       
   740           let
       
   741             val path = hd (Graph.irreducible_paths G (a, b))
       
   742             val path' = (fst (split_last path)) ~~ tl path
       
   743           in Abs (Name.uu, Type (a, []),
       
   744               fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
       
   745           end;
       
   746 
       
   747         val tab' = fold
       
   748           (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab)
       
   749           (filter (fn pair => pair <> (a, b)) new_edges)
       
   750           (Symreltab.update ((a, b), t) tab);
       
   751       in
       
   752         (tab', G'')
       
   753       end;
       
   754   in
       
   755     map_coes_and_graph coercion_data_update context
       
   756   end;
       
   757 
       
   758 val _ = Context.>> (Context.map_theory
       
   759   (Attrib.setup (Binding.name "coercion") (Scan.lift Parse.term >>
       
   760     (fn t => fn (context, thm) => (add_coercion t context, thm)))
       
   761     "declaration of new coercions" #>
       
   762   Attrib.setup (Binding.name "map_function") (Scan.lift Parse.term >>
       
   763     (fn t => fn (context, thm) => (add_type_map t context, thm)))
       
   764     "declaration of new map functions"));
       
   765 
       
   766 end;