src/Tools/subtyping.ML
author wenzelm
Fri Oct 29 21:49:33 2010 +0200 (2010-10-29)
changeset 40282 329cd9dd5949
parent 40281 3c6198fd0937
child 40283 805ce1d64af0
permissions -rw-r--r--
proper header;
tuned whitespace;
wenzelm@40281
     1
(*  Title:      Tools/subtyping.ML
wenzelm@40281
     2
    Author:     Dmitriy Traytel, TU Muenchen
wenzelm@40281
     3
wenzelm@40281
     4
Coercive subtyping via subtype constraints.
wenzelm@40281
     5
*)
wenzelm@40281
     6
wenzelm@40281
     7
signature SUBTYPING =
wenzelm@40281
     8
sig
wenzelm@40281
     9
  datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
wenzelm@40281
    10
  val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
wenzelm@40281
    11
    term list -> term list
wenzelm@40281
    12
end;
wenzelm@40281
    13
wenzelm@40281
    14
structure Subtyping =
wenzelm@40281
    15
struct
wenzelm@40281
    16
wenzelm@40281
    17
(** coercions data **)
wenzelm@40281
    18
wenzelm@40281
    19
datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
wenzelm@40281
    20
wenzelm@40281
    21
datatype data = Data of
wenzelm@40282
    22
  {coes: term Symreltab.table,  (*coercions table*)
wenzelm@40282
    23
   coes_graph: unit Graph.T,  (*coercions graph*)
wenzelm@40282
    24
   tmaps: (term * variance list) Symtab.table};  (*map functions*)
wenzelm@40281
    25
wenzelm@40281
    26
fun make_data (coes, coes_graph, tmaps) =
wenzelm@40281
    27
  Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
wenzelm@40281
    28
wenzelm@40281
    29
structure Data = Generic_Data
wenzelm@40281
    30
(
wenzelm@40281
    31
  type T = data;
wenzelm@40281
    32
  val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
wenzelm@40281
    33
  val extend = I;
wenzelm@40281
    34
  fun merge
wenzelm@40281
    35
    (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
wenzelm@40281
    36
      Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
wenzelm@40281
    37
    make_data (Symreltab.merge (op aconv) (coes1, coes2),
wenzelm@40281
    38
      Graph.merge (op =) (coes_graph1, coes_graph2),
wenzelm@40281
    39
      Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
wenzelm@40281
    40
);
wenzelm@40281
    41
wenzelm@40281
    42
fun map_data f =
wenzelm@40281
    43
  Data.map (fn Data {coes, coes_graph, tmaps} =>
wenzelm@40281
    44
    make_data (f (coes, coes_graph, tmaps)));
wenzelm@40281
    45
wenzelm@40281
    46
fun map_coes f =
wenzelm@40281
    47
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    48
    (f coes, coes_graph, tmaps));
wenzelm@40281
    49
wenzelm@40281
    50
fun map_coes_graph f =
wenzelm@40281
    51
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    52
    (coes, f coes_graph, tmaps));
wenzelm@40281
    53
wenzelm@40281
    54
fun map_coes_and_graph f =
wenzelm@40281
    55
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    56
    let val (coes', coes_graph') = f (coes, coes_graph);
wenzelm@40281
    57
    in (coes', coes_graph', tmaps) end);
wenzelm@40281
    58
wenzelm@40281
    59
fun map_tmaps f =
wenzelm@40281
    60
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    61
    (coes, coes_graph, f tmaps));
wenzelm@40281
    62
wenzelm@40281
    63
fun rep_data context = Data.get context |> (fn Data args => args);
wenzelm@40281
    64
wenzelm@40281
    65
val coes_of = #coes o rep_data;
wenzelm@40281
    66
val coes_graph_of = #coes_graph o rep_data;
wenzelm@40281
    67
val tmaps_of = #tmaps o rep_data;
wenzelm@40281
    68
wenzelm@40281
    69
wenzelm@40281
    70
wenzelm@40281
    71
(** utils **)
wenzelm@40281
    72
wenzelm@40281
    73
val is_param = Type_Infer.is_param
wenzelm@40281
    74
val is_paramT = Type_Infer.is_paramT
wenzelm@40281
    75
val deref = Type_Infer.deref
wenzelm@40281
    76
fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *)
wenzelm@40281
    77
wenzelm@40281
    78
fun nameT (Type (s, [])) = s;
wenzelm@40281
    79
fun t_of s = Type (s, []);
wenzelm@40281
    80
fun sort_of (TFree (_, S)) = SOME S
wenzelm@40281
    81
  | sort_of (TVar (_, S)) = SOME S
wenzelm@40281
    82
  | sort_of _ = NONE;
wenzelm@40281
    83
wenzelm@40281
    84
val is_typeT = fn (Type _) => true | _ => false;
wenzelm@40282
    85
val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
wenzelm@40281
    86
val is_freeT = fn (TFree _) => true | _ => false;
wenzelm@40281
    87
val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false;
wenzelm@40281
    88
wenzelm@40281
    89
wenzelm@40282
    90
(* unification *)  (* TODO dup? needed for weak unification *)
wenzelm@40281
    91
wenzelm@40281
    92
exception NO_UNIFIER of string * typ Vartab.table;
wenzelm@40281
    93
wenzelm@40281
    94
fun unify weak ctxt =
wenzelm@40281
    95
  let
wenzelm@40281
    96
    val thy = ProofContext.theory_of ctxt;
wenzelm@40281
    97
    val pp = Syntax.pp ctxt;
wenzelm@40281
    98
    val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
wenzelm@40281
    99
wenzelm@40282
   100
wenzelm@40281
   101
    (* adjust sorts of parameters *)
wenzelm@40281
   102
wenzelm@40281
   103
    fun not_of_sort x S' S =
wenzelm@40281
   104
      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
wenzelm@40281
   105
        Syntax.string_of_sort ctxt S;
wenzelm@40281
   106
wenzelm@40281
   107
    fun meet (_, []) tye_idx = tye_idx
wenzelm@40281
   108
      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
wenzelm@40281
   109
          meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
wenzelm@40281
   110
      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
wenzelm@40281
   111
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   112
          else raise NO_UNIFIER (not_of_sort x S' S, tye)
wenzelm@40281
   113
      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
wenzelm@40281
   114
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   115
          else if Type_Infer.is_param xi then
wenzelm@40281
   116
            (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
wenzelm@40281
   117
          else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
wenzelm@40281
   118
    and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
wenzelm@40281
   119
          meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
wenzelm@40281
   120
      | meets _ tye_idx = tye_idx;
wenzelm@40281
   121
wenzelm@40281
   122
    val weak_meet = if weak then fn _ => I else meet
wenzelm@40281
   123
wenzelm@40281
   124
wenzelm@40281
   125
    (* occurs check and assignment *)
wenzelm@40281
   126
wenzelm@40281
   127
    fun occurs_check tye xi (TVar (xi', _)) =
wenzelm@40281
   128
          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
wenzelm@40281
   129
          else
wenzelm@40281
   130
            (case Vartab.lookup tye xi' of
wenzelm@40281
   131
              NONE => ()
wenzelm@40281
   132
            | SOME T => occurs_check tye xi T)
wenzelm@40281
   133
      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
wenzelm@40281
   134
      | occurs_check _ _ _ = ();
wenzelm@40281
   135
wenzelm@40281
   136
    fun assign xi (T as TVar (xi', _)) S env =
wenzelm@40281
   137
          if xi = xi' then env
wenzelm@40281
   138
          else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
wenzelm@40281
   139
      | assign xi T S (env as (tye, _)) =
wenzelm@40281
   140
          (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
wenzelm@40281
   141
wenzelm@40281
   142
wenzelm@40281
   143
    (* unification *)
wenzelm@40281
   144
wenzelm@40281
   145
    fun show_tycon (a, Ts) =
wenzelm@40281
   146
      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
wenzelm@40281
   147
wenzelm@40281
   148
    fun unif (T1, T2) (env as (tye, _)) =
wenzelm@40281
   149
      (case pairself (`is_paramT o deref tye) (T1, T2) of
wenzelm@40281
   150
        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
wenzelm@40281
   151
      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
wenzelm@40281
   152
      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
wenzelm@40281
   153
          if weak andalso null Ts andalso null Us then env
wenzelm@40281
   154
          else if a <> b then
wenzelm@40281
   155
            raise NO_UNIFIER
wenzelm@40281
   156
              ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
wenzelm@40281
   157
          else fold unif (Ts ~~ Us) env
wenzelm@40281
   158
      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
wenzelm@40281
   159
wenzelm@40281
   160
  in unif end;
wenzelm@40281
   161
wenzelm@40281
   162
val weak_unify = unify true;
wenzelm@40281
   163
val strong_unify = unify false;
wenzelm@40281
   164
wenzelm@40281
   165
wenzelm@40281
   166
(* Typ_Graph shortcuts *)
wenzelm@40281
   167
wenzelm@40281
   168
val add_edge = Typ_Graph.add_edge_acyclic;
wenzelm@40281
   169
fun get_preds G T = Typ_Graph.all_preds G [T];
wenzelm@40281
   170
fun get_succs G T = Typ_Graph.all_succs G [T];
wenzelm@40281
   171
fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
wenzelm@40281
   172
fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
wenzelm@40282
   173
fun new_imm_preds G Ts =
wenzelm@40281
   174
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts));
wenzelm@40282
   175
fun new_imm_succs G Ts =
wenzelm@40281
   176
  subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts));
wenzelm@40281
   177
wenzelm@40281
   178
wenzelm@40281
   179
(* Graph shortcuts *)
wenzelm@40281
   180
wenzelm@40281
   181
fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
wenzelm@40281
   182
fun maybe_new_nodes ss G = fold maybe_new_node ss G
wenzelm@40281
   183
wenzelm@40281
   184
wenzelm@40281
   185
wenzelm@40281
   186
(** error messages **)
wenzelm@40281
   187
wenzelm@40281
   188
fun prep_output ctxt tye bs ts Ts =
wenzelm@40281
   189
  let
wenzelm@40281
   190
    val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
wenzelm@40281
   191
    val (Ts', Ts'') = chop (length Ts) Ts_bTs';
wenzelm@40281
   192
    fun prep t =
wenzelm@40281
   193
      let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
wenzelm@40281
   194
      in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
wenzelm@40281
   195
  in (map prep ts', Ts') end;
wenzelm@40281
   196
wenzelm@40281
   197
fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
wenzelm@40281
   198
wenzelm@40281
   199
fun inf_failed msg =
wenzelm@40281
   200
  "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
wenzelm@40281
   201
wenzelm@40281
   202
fun err_appl ctxt msg tye bs t T u U =
wenzelm@40281
   203
  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
wenzelm@40281
   204
  in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end;
wenzelm@40281
   205
wenzelm@40281
   206
fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') =
wenzelm@40281
   207
  err_appl ctxt msg tye bs t (U --> V) u U';
wenzelm@40281
   208
wenzelm@40281
   209
fun err_list ctxt msg tye Ts =
wenzelm@40281
   210
  let
wenzelm@40281
   211
    val (_, Ts') = prep_output ctxt tye [] [] Ts;
wenzelm@40281
   212
    val text = cat_lines ([inf_failed msg,
wenzelm@40281
   213
      "Cannot unify a list of types that should be the same,",
wenzelm@40281
   214
      "according to suptype dependencies:",
wenzelm@40281
   215
      (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]);
wenzelm@40281
   216
  in
wenzelm@40281
   217
    error text
wenzelm@40281
   218
  end;
wenzelm@40281
   219
wenzelm@40281
   220
fun err_bound ctxt msg tye packs =
wenzelm@40281
   221
  let
wenzelm@40281
   222
    val pp = Syntax.pp ctxt;
wenzelm@40281
   223
    val (ts, Ts) = fold
wenzelm@40281
   224
      (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
wenzelm@40281
   225
        let val (t', T') = prep_output ctxt tye bs [t, u] [U, U']
wenzelm@40282
   226
        in (t' :: ts, T' :: Ts) end)
wenzelm@40281
   227
      packs ([], []);
wenzelm@40281
   228
    val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @
wenzelm@40281
   229
        (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
wenzelm@40281
   230
          Pretty.block [
wenzelm@40281
   231
            Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U,
wenzelm@40281
   232
            Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2,
wenzelm@40281
   233
            Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]]))
wenzelm@40281
   234
        ts Ts))
wenzelm@40281
   235
  in
wenzelm@40281
   236
    error text
wenzelm@40281
   237
  end;
wenzelm@40281
   238
wenzelm@40281
   239
wenzelm@40281
   240
wenzelm@40281
   241
(** constraint generation **)
wenzelm@40281
   242
wenzelm@40281
   243
fun generate_constraints ctxt =
wenzelm@40281
   244
  let
wenzelm@40281
   245
    fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   246
      | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   247
      | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   248
      | gen cs bs (Bound i) tye_idx =
wenzelm@40281
   249
          (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs)
wenzelm@40281
   250
      | gen cs bs (Abs (x, T, t)) tye_idx =
wenzelm@40281
   251
          let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
wenzelm@40281
   252
          in (T --> U, tye_idx', cs') end
wenzelm@40281
   253
      | gen cs bs (t $ u) tye_idx =
wenzelm@40281
   254
          let
wenzelm@40281
   255
            val (T, tye_idx', cs') = gen cs bs t tye_idx;
wenzelm@40281
   256
            val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
wenzelm@40281
   257
            val U = mk_param idx [];
wenzelm@40281
   258
            val V = mk_param (idx + 1) [];
wenzelm@40281
   259
            val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
wenzelm@40281
   260
              handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
wenzelm@40281
   261
            val error_pack = (bs, t $ u, U, V, U');
wenzelm@40281
   262
          in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
wenzelm@40281
   263
  in
wenzelm@40281
   264
    gen [] []
wenzelm@40281
   265
  end;
wenzelm@40281
   266
wenzelm@40281
   267
wenzelm@40281
   268
wenzelm@40281
   269
(** constraint resolution **)
wenzelm@40281
   270
wenzelm@40281
   271
exception BOUND_ERROR of string;
wenzelm@40281
   272
wenzelm@40281
   273
fun process_constraints ctxt cs tye_idx =
wenzelm@40281
   274
  let
wenzelm@40281
   275
    val coes_graph = coes_graph_of (Context.Proof ctxt);
wenzelm@40281
   276
    val tmaps = tmaps_of (Context.Proof ctxt);
wenzelm@40281
   277
    val tsig = Sign.tsig_of (ProofContext.theory_of ctxt);
wenzelm@40281
   278
    val pp = Syntax.pp ctxt;
wenzelm@40281
   279
    val arity_sorts = Type.arity_sorts pp tsig;
wenzelm@40281
   280
    val subsort = Type.subsort tsig;
wenzelm@40281
   281
wenzelm@40281
   282
    fun split_cs _ [] = ([], [])
wenzelm@40282
   283
      | split_cs f (c :: cs) =
wenzelm@40281
   284
          (case pairself f (fst c) of
wenzelm@40281
   285
            (false, false) => apsnd (cons c) (split_cs f cs)
wenzelm@40281
   286
          | _ => apfst (cons c) (split_cs f cs));
wenzelm@40281
   287
wenzelm@40282
   288
wenzelm@40281
   289
    (* check whether constraint simplification will terminate using weak unification *)
wenzelm@40282
   290
wenzelm@40281
   291
    val _ = fold (fn (TU, error_pack) => fn tye_idx =>
wenzelm@40281
   292
      (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
wenzelm@40281
   293
        err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg)
wenzelm@40281
   294
          tye error_pack)) cs tye_idx;
wenzelm@40281
   295
wenzelm@40281
   296
wenzelm@40281
   297
    (* simplify constraints *)
wenzelm@40282
   298
wenzelm@40281
   299
    fun simplify_constraints cs tye_idx =
wenzelm@40281
   300
      let
wenzelm@40281
   301
        fun contract a Ts Us error_pack done todo tye idx =
wenzelm@40281
   302
          let
wenzelm@40281
   303
            val arg_var =
wenzelm@40281
   304
              (case Symtab.lookup tmaps a of
wenzelm@40281
   305
                (*everything is invariant for unknown constructors*)
wenzelm@40281
   306
                NONE => replicate (length Ts) INVARIANT
wenzelm@40281
   307
              | SOME av => snd av);
wenzelm@40281
   308
            fun new_constraints (variance, constraint) (cs, tye_idx) =
wenzelm@40281
   309
              (case variance of
wenzelm@40281
   310
                COVARIANT => (constraint :: cs, tye_idx)
wenzelm@40281
   311
              | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
wenzelm@40281
   312
              | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
wenzelm@40281
   313
                  handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
wenzelm@40281
   314
            val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
wenzelm@40281
   315
              (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
wenzelm@40281
   316
            val test_update = is_compT orf is_freeT orf is_fixedvarT;
wenzelm@40281
   317
            val (ch, done') =
wenzelm@40281
   318
              if not (null new) then ([],   done)
wenzelm@40281
   319
              else split_cs (test_update o deref tye') done;
wenzelm@40281
   320
            val todo' = ch @ todo;
wenzelm@40281
   321
          in
wenzelm@40281
   322
            simplify done' (new @ todo') (tye', idx')
wenzelm@40281
   323
          end
wenzelm@40281
   324
        (*xi is definitely a parameter*)
wenzelm@40281
   325
        and expand varleq xi S a Ts error_pack done todo tye idx =
wenzelm@40281
   326
          let
wenzelm@40281
   327
            val n = length Ts;
wenzelm@40281
   328
            val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S);
wenzelm@40281
   329
            val tye' = Vartab.update_new (xi, Type(a, args)) tye;
wenzelm@40281
   330
            val (ch, done') = split_cs (is_compT o deref tye') done;
wenzelm@40281
   331
            val todo' = ch @ todo;
wenzelm@40281
   332
            val new =
wenzelm@40281
   333
              if varleq then (Type(a, args), Type (a, Ts))
wenzelm@40281
   334
              else (Type (a, Ts), Type(a, args));
wenzelm@40281
   335
          in
wenzelm@40281
   336
            simplify done' ((new, error_pack) :: todo') (tye', idx + n)
wenzelm@40281
   337
          end
wenzelm@40281
   338
        (*TU is a pair of a parameter and a free/fixed variable*)
wenzelm@40281
   339
        and eliminate TU error_pack done todo tye idx =
wenzelm@40281
   340
          let
wenzelm@40281
   341
            val [TVar (xi, S)] = filter is_paramT TU;
wenzelm@40281
   342
            val [T] = filter_out is_paramT TU;
wenzelm@40281
   343
            val SOME S' = sort_of T;
wenzelm@40281
   344
            val test_update = if is_freeT T then is_freeT else is_fixedvarT;
wenzelm@40281
   345
            val tye' = Vartab.update_new (xi, T) tye;
wenzelm@40281
   346
            val (ch, done') = split_cs (test_update o deref tye') done;
wenzelm@40281
   347
            val todo' = ch @ todo;
wenzelm@40281
   348
          in
wenzelm@40281
   349
            if subsort (S', S) (*TODO check this*)
wenzelm@40281
   350
            then simplify done' todo' (tye', idx)
wenzelm@40281
   351
            else err_subtype ctxt "Sort mismatch" tye error_pack
wenzelm@40281
   352
          end
wenzelm@40281
   353
        and simplify done [] tye_idx = (done, tye_idx)
wenzelm@40281
   354
          | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
wenzelm@40281
   355
              (case (deref tye T, deref tye U) of
wenzelm@40281
   356
                (Type (a, []), Type (b, [])) =>
wenzelm@40281
   357
                  if a = b then simplify done todo tye_idx
wenzelm@40281
   358
                  else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
wenzelm@40281
   359
                  else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack
wenzelm@40281
   360
              | (Type (a, Ts), Type (b, Us)) =>
wenzelm@40281
   361
                  if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
wenzelm@40281
   362
                  else contract a Ts Us error_pack done todo tye idx
wenzelm@40282
   363
              | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
wenzelm@40281
   364
                  expand true xi S a Ts error_pack done todo tye idx
wenzelm@40282
   365
              | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
wenzelm@40281
   366
                  expand false xi S a Ts error_pack done todo tye idx
wenzelm@40281
   367
              | (T, U) =>
wenzelm@40281
   368
                  if T = U then simplify done todo tye_idx
wenzelm@40282
   369
                  else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
wenzelm@40281
   370
                    exists is_paramT [T, U]
wenzelm@40281
   371
                  then eliminate [T, U] error_pack done todo tye idx
wenzelm@40281
   372
                  else if exists (is_freeT orf is_fixedvarT) [T, U]
wenzelm@40281
   373
                  then err_subtype ctxt "Not eliminated free/fixed variables"
wenzelm@40281
   374
                        (fst tye_idx) error_pack
wenzelm@40282
   375
                  else simplify (((T, U), error_pack) :: done) todo tye_idx);
wenzelm@40281
   376
      in
wenzelm@40281
   377
        simplify [] cs tye_idx
wenzelm@40281
   378
      end;
wenzelm@40281
   379
wenzelm@40281
   380
wenzelm@40281
   381
    (* do simplification *)
wenzelm@40282
   382
wenzelm@40281
   383
    val (cs', tye_idx') = simplify_constraints cs tye_idx;
wenzelm@40281
   384
wenzelm@40281
   385
    fun find_error_pack lower T' =
wenzelm@40281
   386
      map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs');
wenzelm@40281
   387
wenzelm@40282
   388
    fun unify_list (T :: Ts) tye_idx =
wenzelm@40281
   389
      fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx
wenzelm@40282
   390
        handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T :: Ts))
wenzelm@40281
   391
      Ts tye_idx;
wenzelm@40281
   392
wenzelm@40281
   393
    (*styps stands either for supertypes or for subtypes of a type T
wenzelm@40281
   394
      in terms of the subtype-relation (excluding T itself)*)
wenzelm@40282
   395
    fun styps super T =
wenzelm@40281
   396
      (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T
wenzelm@40281
   397
        handle Graph.UNDEF _ => [];
wenzelm@40281
   398
wenzelm@40282
   399
    fun minmax sup (T :: Ts) =
wenzelm@40281
   400
      let
wenzelm@40281
   401
        fun adjust T U = if sup then (T, U) else (U, T);
wenzelm@40281
   402
        fun extract T [] = T
wenzelm@40282
   403
          | extract T (U :: Us) =
wenzelm@40281
   404
              if Graph.is_edge coes_graph (adjust T U) then extract T Us
wenzelm@40281
   405
              else if Graph.is_edge coes_graph (adjust U T) then extract U Us
wenzelm@40281
   406
              else raise BOUND_ERROR "Uncomparable types in type list";
wenzelm@40281
   407
      in
wenzelm@40281
   408
        t_of (extract T Ts)
wenzelm@40281
   409
      end;
wenzelm@40281
   410
wenzelm@40282
   411
    fun ex_styp_of_sort super T styps_and_sorts =
wenzelm@40281
   412
      let
wenzelm@40281
   413
        fun adjust T U = if super then (T, U) else (U, T);
wenzelm@40282
   414
        fun styp_test U Ts = forall
wenzelm@40281
   415
          (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
wenzelm@40281
   416
        fun fitting Ts S U = Type.of_sort tsig (t_of U, S) andalso styp_test U Ts
wenzelm@40281
   417
      in
wenzelm@40281
   418
        forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
wenzelm@40281
   419
      end;
wenzelm@40281
   420
wenzelm@40281
   421
    (* computes the tightest possible, correct assignment for 'a::S
wenzelm@40281
   422
       e.g. in the supremum case (sup = true):
wenzelm@40281
   423
               ------- 'a::S---
wenzelm@40281
   424
              /        /    \  \
wenzelm@40281
   425
             /        /      \  \
wenzelm@40281
   426
        'b::C1   'c::C2 ...  T1 T2 ...
wenzelm@40281
   427
wenzelm@40281
   428
       sorts - list of sorts [C1, C2, ...]
wenzelm@40281
   429
       T::Ts - non-empty list of base types [T1, T2, ...]
wenzelm@40281
   430
    *)
wenzelm@40282
   431
    fun tightest sup S styps_and_sorts (T :: Ts) =
wenzelm@40281
   432
      let
wenzelm@40281
   433
        fun restriction T = Type.of_sort tsig (t_of T, S)
wenzelm@40281
   434
          andalso ex_styp_of_sort (not sup) T styps_and_sorts;
wenzelm@40281
   435
        fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
wenzelm@40281
   436
      in
wenzelm@40281
   437
        (case fold candidates Ts (filter restriction (T :: styps sup T)) of
wenzelm@40281
   438
          [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum"))
wenzelm@40281
   439
        | [T] => t_of T
wenzelm@40281
   440
        | Ts => minmax sup Ts)
wenzelm@40281
   441
      end;
wenzelm@40281
   442
wenzelm@40281
   443
    fun build_graph G [] tye_idx = (G, tye_idx)
wenzelm@40282
   444
      | build_graph G ((T, U) :: cs) tye_idx =
wenzelm@40281
   445
        if T = U then build_graph G cs tye_idx
wenzelm@40281
   446
        else
wenzelm@40281
   447
          let
wenzelm@40281
   448
            val G' = maybe_new_typnodes [T, U] G;
wenzelm@40281
   449
            val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
wenzelm@40281
   450
              handle Typ_Graph.CYCLES cycles =>
wenzelm@40281
   451
                let
wenzelm@40281
   452
                  val (tye, idx) = fold unify_list cycles tye_idx
wenzelm@40281
   453
                in
wenzelm@40281
   454
                  (*all cycles collapse to one node,
wenzelm@40281
   455
                    because all of them share at least the nodes x and y*)
wenzelm@40281
   456
                  collapse (tye, idx) (distinct (op =) (flat cycles)) G
wenzelm@40281
   457
                end;
wenzelm@40281
   458
          in
wenzelm@40281
   459
            build_graph G'' cs tye_idx'
wenzelm@40281
   460
          end
wenzelm@40281
   461
    and collapse (tye, idx) nodes G = (*nodes non-empty list*)
wenzelm@40281
   462
      let
wenzelm@40281
   463
        val T = hd nodes;
wenzelm@40281
   464
        val P = new_imm_preds G nodes;
wenzelm@40281
   465
        val S = new_imm_succs G nodes;
wenzelm@40281
   466
        val G' = Typ_Graph.del_nodes (tl nodes) G;
wenzelm@40281
   467
      in
wenzelm@40281
   468
        build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
wenzelm@40281
   469
      end;
wenzelm@40281
   470
wenzelm@40281
   471
    fun assign_bound lower G key (tye_idx as (tye, _)) =
wenzelm@40281
   472
      if is_paramT (deref tye key) then
wenzelm@40281
   473
        let
wenzelm@40281
   474
          val TVar (xi, S) = deref tye key;
wenzelm@40281
   475
          val get_bound = if lower then get_preds else get_succs;
wenzelm@40281
   476
          val raw_bound = get_bound G key;
wenzelm@40281
   477
          val bound = map (deref tye) raw_bound;
wenzelm@40281
   478
          val not_params = filter_out is_paramT bound;
wenzelm@40282
   479
          fun to_fulfil T =
wenzelm@40281
   480
            (case sort_of T of
wenzelm@40281
   481
              NONE => NONE
wenzelm@40282
   482
            | SOME S =>
wenzelm@40281
   483
                SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S));
wenzelm@40281
   484
          val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
wenzelm@40281
   485
          val assignment =
wenzelm@40281
   486
            if null bound orelse null not_params then NONE
wenzelm@40281
   487
            else SOME (tightest lower S styps_and_sorts (map nameT not_params)
wenzelm@40281
   488
                handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
wenzelm@40281
   489
        in
wenzelm@40281
   490
          (case assignment of
wenzelm@40281
   491
            NONE => tye_idx
wenzelm@40281
   492
          | SOME T =>
wenzelm@40281
   493
              if is_paramT T then tye_idx
wenzelm@40281
   494
              else if lower then (*upper bound check*)
wenzelm@40281
   495
                let
wenzelm@40281
   496
                  val other_bound = map (deref tye) (get_succs G key);
wenzelm@40281
   497
                  val s = nameT T;
wenzelm@40281
   498
                in
wenzelm@40281
   499
                  if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
wenzelm@40281
   500
                  then apfst (Vartab.update (xi, T)) tye_idx
wenzelm@40281
   501
                  else err_bound ctxt ("Assigned simple type " ^ s ^
wenzelm@40281
   502
                    " clashes with the upper bound of variable " ^
wenzelm@40281
   503
                    Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key)
wenzelm@40281
   504
                end
wenzelm@40281
   505
              else apfst (Vartab.update (xi, T)) tye_idx)
wenzelm@40281
   506
        end
wenzelm@40281
   507
      else tye_idx;
wenzelm@40281
   508
wenzelm@40281
   509
    val assign_lb = assign_bound true;
wenzelm@40281
   510
    val assign_ub = assign_bound false;
wenzelm@40281
   511
wenzelm@40281
   512
    fun assign_alternating ts' ts G tye_idx =
wenzelm@40281
   513
      if ts' = ts then tye_idx
wenzelm@40281
   514
      else
wenzelm@40281
   515
        let
wenzelm@40281
   516
          val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
wenzelm@40281
   517
            |> fold (assign_ub G) ts;
wenzelm@40281
   518
        in
wenzelm@40281
   519
          assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx'
wenzelm@40281
   520
        end;
wenzelm@40281
   521
wenzelm@40281
   522
    (*Unify all weakly connected components of the constraint forest,
wenzelm@40282
   523
      that contain only params. These are the only WCCs that contain
wenzelm@40281
   524
      params anyway.*)
wenzelm@40281
   525
    fun unify_params G (tye_idx as (tye, _)) =
wenzelm@40281
   526
      let
wenzelm@40281
   527
        val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G);
wenzelm@40281
   528
        val to_unify = map (fn T => T :: get_preds G T) max_params;
wenzelm@40281
   529
      in
wenzelm@40281
   530
        fold unify_list to_unify tye_idx
wenzelm@40281
   531
      end;
wenzelm@40281
   532
wenzelm@40281
   533
    fun solve_constraints G tye_idx = tye_idx
wenzelm@40281
   534
      |> assign_alternating [] (Typ_Graph.keys G) G
wenzelm@40281
   535
      |> unify_params G;
wenzelm@40281
   536
  in
wenzelm@40281
   537
    build_graph Typ_Graph.empty (map fst cs') tye_idx'
wenzelm@40281
   538
      |-> solve_constraints
wenzelm@40281
   539
  end;
wenzelm@40281
   540
wenzelm@40281
   541
wenzelm@40281
   542
wenzelm@40281
   543
(** coercion insertion **)
wenzelm@40281
   544
wenzelm@40281
   545
fun insert_coercions ctxt tye ts =
wenzelm@40281
   546
  let
wenzelm@40281
   547
    fun deep_deref T =
wenzelm@40281
   548
      (case deref tye T of
wenzelm@40281
   549
        Type (a, Ts) => Type (a, map deep_deref Ts)
wenzelm@40281
   550
      | U => U);
wenzelm@40281
   551
wenzelm@40281
   552
    fun gen_coercion ((Type (a, [])), (Type (b, []))) =
wenzelm@40281
   553
          if a = b
wenzelm@40281
   554
          then Abs (Name.uu, Type (a, []), Bound 0)
wenzelm@40281
   555
          else
wenzelm@40281
   556
            (case Symreltab.lookup (coes_of (Context.Proof ctxt)) (a, b) of
wenzelm@40281
   557
              NONE => raise Fail (a ^ " is not a subtype of " ^ b)
wenzelm@40281
   558
            | SOME co => co)
wenzelm@40281
   559
      | gen_coercion ((Type (a, Ts)), (Type (b, Us))) =
wenzelm@40281
   560
          if a <> b
wenzelm@40281
   561
          then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
wenzelm@40281
   562
          else
wenzelm@40281
   563
            let
wenzelm@40282
   564
              fun inst t Ts =
wenzelm@40282
   565
                Term.subst_vars
wenzelm@40281
   566
                  (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
wenzelm@40281
   567
              fun sub_co (COVARIANT, TU) = gen_coercion TU
wenzelm@40281
   568
                | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU);
wenzelm@40281
   569
              fun ts_of [] = []
wenzelm@40282
   570
                | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
wenzelm@40281
   571
            in
wenzelm@40281
   572
              (case Symtab.lookup (tmaps_of (Context.Proof ctxt)) a of
wenzelm@40281
   573
                NONE => raise Fail ("No map function for " ^ a ^ " known")
wenzelm@40281
   574
              | SOME tmap =>
wenzelm@40281
   575
                  let
wenzelm@40281
   576
                    val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
wenzelm@40281
   577
                  in
wenzelm@40281
   578
                    Term.list_comb
wenzelm@40281
   579
                      (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
wenzelm@40281
   580
                  end)
wenzelm@40281
   581
            end
wenzelm@40281
   582
      | gen_coercion (T, U) =
wenzelm@40281
   583
          if Type.could_unify (T, U)
wenzelm@40281
   584
          then Abs (Name.uu, T, Bound 0)
wenzelm@40281
   585
          else raise Fail ("Cannot generate coercion from "
wenzelm@40281
   586
            ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U);
wenzelm@40281
   587
wenzelm@40281
   588
    fun insert _ (Const (c, T)) =
wenzelm@40281
   589
          let val T' = deep_deref T;
wenzelm@40281
   590
          in (Const (c, T'), T') end
wenzelm@40281
   591
      | insert _ (Free (x, T)) =
wenzelm@40281
   592
          let val T' = deep_deref T;
wenzelm@40281
   593
          in (Free (x, T'), T') end
wenzelm@40281
   594
      | insert _ (Var (xi, T)) =
wenzelm@40281
   595
          let val T' = deep_deref T;
wenzelm@40281
   596
          in (Var (xi, T'), T') end
wenzelm@40281
   597
      | insert bs (Bound i) =
wenzelm@40281
   598
          let val T = nth bs i handle Subscript =>
wenzelm@40281
   599
            raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
wenzelm@40281
   600
          in (Bound i, T) end
wenzelm@40281
   601
      | insert bs (Abs (x, T, t)) =
wenzelm@40281
   602
          let
wenzelm@40281
   603
            val T' = deep_deref T;
wenzelm@40282
   604
            val (t', T'') = insert (T' :: bs) t;
wenzelm@40281
   605
          in
wenzelm@40281
   606
            (Abs (x, T', t'), T' --> T'')
wenzelm@40281
   607
          end
wenzelm@40281
   608
      | insert bs (t $ u) =
wenzelm@40281
   609
          let
wenzelm@40281
   610
            val (t', Type ("fun", [U, T])) = insert bs t;
wenzelm@40281
   611
            val (u', U') = insert bs u;
wenzelm@40281
   612
          in
wenzelm@40281
   613
            if U <> U'
wenzelm@40281
   614
            then (t' $ (gen_coercion (U', U) $ u'), T)
wenzelm@40281
   615
            else (t' $ u', T)
wenzelm@40281
   616
          end
wenzelm@40281
   617
  in
wenzelm@40281
   618
    map (fst o insert []) ts
wenzelm@40281
   619
  end;
wenzelm@40281
   620
wenzelm@40281
   621
wenzelm@40281
   622
wenzelm@40281
   623
(** assembling the pipeline **)
wenzelm@40281
   624
wenzelm@40281
   625
fun infer_types ctxt const_type var_type raw_ts =
wenzelm@40281
   626
  let
wenzelm@40281
   627
    val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts;
wenzelm@40281
   628
wenzelm@40281
   629
    fun gen_all t (tye_idx, constraints) =
wenzelm@40281
   630
      let
wenzelm@40281
   631
        val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx
wenzelm@40281
   632
      in (tye_idx', constraints' @ constraints) end;
wenzelm@40281
   633
wenzelm@40281
   634
    val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []);
wenzelm@40281
   635
    val (tye, _) = process_constraints ctxt constraints tye_idx;
wenzelm@40281
   636
    val ts' = insert_coercions ctxt tye ts;
wenzelm@40281
   637
wenzelm@40281
   638
    val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
wenzelm@40281
   639
  in ts'' end;
wenzelm@40281
   640
wenzelm@40281
   641
wenzelm@40281
   642
wenzelm@40281
   643
(** installation **)
wenzelm@40281
   644
wenzelm@40281
   645
fun coercion_infer_types ctxt =
wenzelm@40281
   646
  infer_types ctxt
wenzelm@40281
   647
    (try (Consts.the_constraint (ProofContext.consts_of ctxt)))
wenzelm@40281
   648
    (ProofContext.def_type ctxt);
wenzelm@40281
   649
wenzelm@40281
   650
local
wenzelm@40281
   651
wenzelm@40281
   652
fun add eq what f = Context.>> (what (fn xs => fn ctxt =>
wenzelm@40281
   653
  let val xs' = f ctxt xs in if eq_list eq (xs, xs') then NONE else SOME (xs', ctxt) end));
wenzelm@40281
   654
wenzelm@40281
   655
in
wenzelm@40281
   656
wenzelm@40281
   657
val _ = add (op aconv) (Syntax.add_term_check ~100 "coercions") coercion_infer_types;
wenzelm@40281
   658
wenzelm@40281
   659
end;
wenzelm@40281
   660
wenzelm@40281
   661
wenzelm@40281
   662
(* interface *)
wenzelm@40281
   663
wenzelm@40281
   664
fun add_type_map map_fun context =
wenzelm@40281
   665
  let
wenzelm@40281
   666
    val ctxt = Context.proof_of context;
wenzelm@40281
   667
    val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt map_fun);
wenzelm@40281
   668
wenzelm@40281
   669
    fun err_str () = "\n\nthe general type signature for a map function is" ^
wenzelm@40281
   670
      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^
wenzelm@40281
   671
      "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
wenzelm@40281
   672
wenzelm@40281
   673
    fun gen_arg_var ([], []) = []
wenzelm@40282
   674
      | gen_arg_var ((T, T') :: Ts, (U, U') :: Us) =
wenzelm@40281
   675
          if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
wenzelm@40281
   676
          else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
wenzelm@40281
   677
          else error ("Functions do not apply to arguments correctly:" ^ err_str ())
wenzelm@40281
   678
      | gen_arg_var (_, _) =
wenzelm@40281
   679
          error ("Different numbers of functions and arguments\n" ^ err_str ());
wenzelm@40281
   680
wenzelm@40281
   681
    (* TODO: This function is only needed to introde the fun type map
wenzelm@40281
   682
      function: "% f g h . g o h o f". There must be a better solution. *)
wenzelm@40281
   683
    fun balanced (Type (_, [])) (Type (_, [])) = true
wenzelm@40281
   684
      | balanced (Type (a, Ts)) (Type (b, Us)) =
wenzelm@40281
   685
          a = b andalso forall I (map2 balanced Ts Us)
wenzelm@40281
   686
      | balanced (TFree _) (TFree _) = true
wenzelm@40281
   687
      | balanced (TVar _) (TVar _) = true
wenzelm@40281
   688
      | balanced _ _ = false;
wenzelm@40281
   689
wenzelm@40281
   690
    fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) =
wenzelm@40281
   691
          if balanced T U
wenzelm@40282
   692
          then ((pairs, Ts ~~ Us), C)
wenzelm@40281
   693
          else if C = "fun"
wenzelm@40281
   694
            then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U
wenzelm@40281
   695
            else error ("Not a proper map function:" ^ err_str ())
wenzelm@40281
   696
      | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ());
wenzelm@40281
   697
wenzelm@40281
   698
    val res = check_map_fun ([], []) (fastype_of t);
wenzelm@40281
   699
    val res_av = gen_arg_var (fst res);
wenzelm@40281
   700
  in
wenzelm@40281
   701
    map_tmaps (Symtab.update (snd res, (t, res_av))) context
wenzelm@40281
   702
  end;
wenzelm@40281
   703
wenzelm@40281
   704
fun add_coercion coercion context =
wenzelm@40281
   705
  let
wenzelm@40281
   706
    val ctxt = Context.proof_of context;
wenzelm@40281
   707
    val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt coercion);
wenzelm@40281
   708
wenzelm@40281
   709
    fun err_coercion () = error ("Bad type for coercion " ^
wenzelm@40281
   710
        Syntax.string_of_term ctxt t ^ ":\n" ^
wenzelm@40281
   711
        Syntax.string_of_typ ctxt (fastype_of t));
wenzelm@40281
   712
wenzelm@40281
   713
    val (Type ("fun", [T1, T2])) = fastype_of t
wenzelm@40281
   714
      handle Bind => err_coercion ();
wenzelm@40281
   715
wenzelm@40281
   716
    val a =
wenzelm@40281
   717
      (case T1 of
wenzelm@40281
   718
        Type (x, []) => x
wenzelm@40281
   719
      | _ => err_coercion ());
wenzelm@40281
   720
wenzelm@40281
   721
    val b =
wenzelm@40281
   722
      (case T2 of
wenzelm@40281
   723
        Type (x, []) => x
wenzelm@40281
   724
      | _ => err_coercion ());
wenzelm@40281
   725
wenzelm@40281
   726
    fun coercion_data_update (tab, G) =
wenzelm@40281
   727
      let
wenzelm@40281
   728
        val G' = maybe_new_nodes [a, b] G
wenzelm@40281
   729
        val G'' = Graph.add_edge_trans_acyclic (a, b) G'
wenzelm@40281
   730
          handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^
wenzelm@40281
   731
            "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b);
wenzelm@40281
   732
        val new_edges =
wenzelm@40281
   733
          flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
wenzelm@40281
   734
            if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
wenzelm@40281
   735
        val G_and_new = Graph.add_edge (a, b) G';
wenzelm@40281
   736
wenzelm@40281
   737
        fun complex_coercion tab G (a, b) =
wenzelm@40281
   738
          let
wenzelm@40281
   739
            val path = hd (Graph.irreducible_paths G (a, b))
wenzelm@40281
   740
            val path' = (fst (split_last path)) ~~ tl path
wenzelm@40281
   741
          in Abs (Name.uu, Type (a, []),
wenzelm@40281
   742
              fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
wenzelm@40281
   743
          end;
wenzelm@40281
   744
wenzelm@40281
   745
        val tab' = fold
wenzelm@40281
   746
          (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab)
wenzelm@40281
   747
          (filter (fn pair => pair <> (a, b)) new_edges)
wenzelm@40281
   748
          (Symreltab.update ((a, b), t) tab);
wenzelm@40281
   749
      in
wenzelm@40281
   750
        (tab', G'')
wenzelm@40281
   751
      end;
wenzelm@40281
   752
  in
wenzelm@40281
   753
    map_coes_and_graph coercion_data_update context
wenzelm@40281
   754
  end;
wenzelm@40281
   755
wenzelm@40281
   756
val _ = Context.>> (Context.map_theory
wenzelm@40281
   757
  (Attrib.setup (Binding.name "coercion") (Scan.lift Parse.term >>
wenzelm@40281
   758
    (fn t => fn (context, thm) => (add_coercion t context, thm)))
wenzelm@40281
   759
    "declaration of new coercions" #>
wenzelm@40281
   760
  Attrib.setup (Binding.name "map_function") (Scan.lift Parse.term >>
wenzelm@40281
   761
    (fn t => fn (context, thm) => (add_type_map t context, thm)))
wenzelm@40281
   762
    "declaration of new map functions"));
wenzelm@40281
   763
wenzelm@40281
   764
end;