src/Tools/subtyping.ML
author wenzelm
Wed Nov 26 20:05:34 2014 +0100 (2014-11-26)
changeset 59058 a78612c67ec0
parent 58893 9e0ecb66d6a7
child 59840 0ab8750c9342
permissions -rw-r--r--
renamed "pairself" to "apply2", in accordance to @{apply 2};
wenzelm@40281
     1
(*  Title:      Tools/subtyping.ML
wenzelm@40281
     2
    Author:     Dmitriy Traytel, TU Muenchen
wenzelm@40281
     3
wenzelm@40281
     4
Coercive subtyping via subtype constraints.
wenzelm@40281
     5
*)
wenzelm@40281
     6
wenzelm@40281
     7
signature SUBTYPING =
wenzelm@40281
     8
sig
wenzelm@40939
     9
  val coercion_enabled: bool Config.T
wenzelm@40284
    10
  val add_type_map: term -> Context.generic -> Context.generic
wenzelm@40284
    11
  val add_coercion: term -> Context.generic -> Context.generic
traytel@45059
    12
  val print_coercions: Proof.context -> unit
wenzelm@40281
    13
end;
wenzelm@40281
    14
wenzelm@40283
    15
structure Subtyping: SUBTYPING =
wenzelm@40281
    16
struct
wenzelm@40281
    17
wenzelm@40281
    18
(** coercions data **)
wenzelm@40281
    19
traytel@41353
    20
datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
traytel@51327
    21
datatype coerce_arg = PERMIT | FORBID | LEAVE
wenzelm@40281
    22
wenzelm@40281
    23
datatype data = Data of
traytel@45060
    24
  {coes: (term * ((typ list * typ list) * term list)) Symreltab.table,  (*coercions table*)
traytel@45060
    25
   (*full coercions graph - only used at coercion declaration/deletion*)
traytel@45060
    26
   full_graph: int Graph.T,
wenzelm@52432
    27
   (*coercions graph restricted to base types - for efficiency reasons stored in the context*)
traytel@45060
    28
   coes_graph: int Graph.T,
traytel@51319
    29
   tmaps: (term * variance list) Symtab.table,  (*map functions*)
traytel@51327
    30
   coerce_args: coerce_arg list Symtab.table  (*special constants with non-coercible arguments*)};
wenzelm@40281
    31
traytel@51319
    32
fun make_data (coes, full_graph, coes_graph, tmaps, coerce_args) =
traytel@51319
    33
  Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph,
traytel@51319
    34
    tmaps = tmaps, coerce_args = coerce_args};
wenzelm@40281
    35
traytel@45935
    36
fun merge_error_coes (a, b) =
traytel@45935
    37
  error ("Cannot merge coercion tables.\nConflicting declarations for coercions from " ^
traytel@45935
    38
    quote a ^ " to " ^ quote b ^ ".");
traytel@45935
    39
traytel@45935
    40
fun merge_error_tmaps C =
traytel@45935
    41
  error ("Cannot merge coercion map tables.\nConflicting declarations for the constructor " ^
traytel@45935
    42
    quote C ^ ".");
traytel@45935
    43
traytel@51319
    44
fun merge_error_coerce_args C =
wenzelm@55303
    45
  error ("Cannot merge tables for constants with coercion-invariant arguments.\n" ^
wenzelm@55303
    46
    "Conflicting declarations for the constant " ^ quote C ^ ".");
traytel@51319
    47
wenzelm@40281
    48
structure Data = Generic_Data
wenzelm@40281
    49
(
wenzelm@40281
    50
  type T = data;
traytel@51319
    51
  val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty, Symtab.empty);
wenzelm@40281
    52
  val extend = I;
wenzelm@40281
    53
  fun merge
traytel@51319
    54
    (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1,
traytel@51319
    55
      tmaps = tmaps1, coerce_args = coerce_args1},
traytel@51319
    56
      Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2,
traytel@51319
    57
        tmaps = tmaps2, coerce_args = coerce_args2}) =
traytel@45060
    58
    make_data (Symreltab.merge (eq_pair (op aconv)
traytel@45060
    59
        (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv))))
traytel@45935
    60
        (coes1, coes2) handle Symreltab.DUP key => merge_error_coes key,
traytel@45060
    61
      Graph.merge (op =) (full_graph1, full_graph2),
wenzelm@40281
    62
      Graph.merge (op =) (coes_graph1, coes_graph2),
traytel@45935
    63
      Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2)
traytel@51319
    64
        handle Symtab.DUP key => merge_error_tmaps key,
traytel@51319
    65
      Symtab.merge (eq_list (op =)) (coerce_args1, coerce_args2)
traytel@51319
    66
        handle Symtab.DUP key => merge_error_coerce_args key);
wenzelm@40281
    67
);
wenzelm@40281
    68
wenzelm@40281
    69
fun map_data f =
traytel@51319
    70
  Data.map (fn Data {coes, full_graph, coes_graph, tmaps, coerce_args} =>
traytel@51319
    71
    make_data (f (coes, full_graph, coes_graph, tmaps, coerce_args)));
wenzelm@40281
    72
traytel@45060
    73
fun map_coes_and_graphs f =
traytel@51319
    74
  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
traytel@45060
    75
    let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph);
traytel@51319
    76
    in (coes', full_graph', coes_graph', tmaps, coerce_args) end);
wenzelm@40281
    77
wenzelm@40281
    78
fun map_tmaps f =
traytel@51319
    79
  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
traytel@51319
    80
    (coes, full_graph, coes_graph, f tmaps, coerce_args));
traytel@51319
    81
traytel@51319
    82
fun map_coerce_args f =
traytel@51319
    83
  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
traytel@51319
    84
    (coes, full_graph, coes_graph, tmaps, f coerce_args));
wenzelm@40281
    85
wenzelm@40285
    86
val rep_data = (fn Data args => args) o Data.get o Context.Proof;
wenzelm@40281
    87
wenzelm@40281
    88
val coes_of = #coes o rep_data;
wenzelm@40281
    89
val coes_graph_of = #coes_graph o rep_data;
wenzelm@40281
    90
val tmaps_of = #tmaps o rep_data;
traytel@51319
    91
val coerce_args_of = #coerce_args o rep_data;
wenzelm@40281
    92
wenzelm@40281
    93
wenzelm@40281
    94
wenzelm@40281
    95
(** utils **)
wenzelm@40281
    96
wenzelm@46614
    97
fun restrict_graph G = Graph.restrict (fn x => Graph.get_node G x = 0) G;
traytel@45060
    98
wenzelm@40281
    99
fun nameT (Type (s, [])) = s;
wenzelm@40281
   100
fun t_of s = Type (s, []);
wenzelm@40286
   101
wenzelm@40281
   102
fun sort_of (TFree (_, S)) = SOME S
wenzelm@40281
   103
  | sort_of (TVar (_, S)) = SOME S
wenzelm@40281
   104
  | sort_of _ = NONE;
wenzelm@40281
   105
wenzelm@40281
   106
val is_typeT = fn (Type _) => true | _ => false;
wenzelm@40282
   107
val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
wenzelm@40281
   108
val is_freeT = fn (TFree _) => true | _ => false;
wenzelm@40286
   109
val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
traytel@41353
   110
val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
traytel@51335
   111
traytel@51335
   112
fun mk_identity T = Abs (Name.uu, T, Bound 0);
traytel@43591
   113
val is_identity = fn (Abs (_, _, Bound 0)) => true | _ => false;
wenzelm@40281
   114
traytel@45060
   115
fun instantiate t Ts = Term.subst_TVars
traytel@45060
   116
  ((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts) t;
traytel@45060
   117
wenzelm@55303
   118
exception COERCION_GEN_ERROR of unit -> string * Buffer.T;
traytel@54584
   119
wenzelm@55303
   120
infixr ++>  (*composition with deferred error message*)
wenzelm@55303
   121
fun (err : unit -> string * Buffer.T) ++> s =
wenzelm@55303
   122
  err #> apsnd (Buffer.add s);
traytel@54584
   123
wenzelm@55303
   124
fun eval_err err =
wenzelm@55303
   125
  let val (s, buf) = err ()
wenzelm@55303
   126
  in s ^ Markup.markup Markup.text_fold (Buffer.content buf) end;
traytel@54584
   127
wenzelm@55303
   128
fun eval_error err = error (eval_err err);
traytel@45060
   129
traytel@45060
   130
fun inst_collect tye err T U =
traytel@45060
   131
  (case (T, Type_Infer.deref tye U) of
traytel@54584
   132
    (TVar (xi, _), U) => [(xi, U)]
traytel@45060
   133
  | (Type (a, Ts), Type (b, Us)) =>
wenzelm@55303
   134
      if a <> b then eval_error err else inst_collects tye err Ts Us
wenzelm@55303
   135
  | (_, U') => if T <> U' then eval_error err else [])
traytel@45060
   136
and inst_collects tye err Ts Us =
traytel@45060
   137
  fold2 (fn T => fn U => fn is => inst_collect tye err T U @ is) Ts Us [];
traytel@45060
   138
wenzelm@40281
   139
traytel@40836
   140
(* unification *)
wenzelm@40281
   141
wenzelm@40281
   142
exception NO_UNIFIER of string * typ Vartab.table;
wenzelm@40281
   143
wenzelm@40281
   144
fun unify weak ctxt =
wenzelm@40281
   145
  let
wenzelm@42361
   146
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42386
   147
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   148
wenzelm@40282
   149
wenzelm@40281
   150
    (* adjust sorts of parameters *)
wenzelm@40281
   151
wenzelm@40281
   152
    fun not_of_sort x S' S =
wenzelm@40281
   153
      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
wenzelm@40281
   154
        Syntax.string_of_sort ctxt S;
wenzelm@40281
   155
wenzelm@40281
   156
    fun meet (_, []) tye_idx = tye_idx
wenzelm@40281
   157
      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
wenzelm@40281
   158
          meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
wenzelm@40281
   159
      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
wenzelm@40281
   160
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   161
          else raise NO_UNIFIER (not_of_sort x S' S, tye)
wenzelm@40281
   162
      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
wenzelm@40281
   163
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   164
          else if Type_Infer.is_param xi then
wenzelm@40286
   165
            (Vartab.update_new
wenzelm@40286
   166
              (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
wenzelm@40281
   167
          else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
wenzelm@40281
   168
    and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
wenzelm@40286
   169
          meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
wenzelm@40281
   170
      | meets _ tye_idx = tye_idx;
wenzelm@40281
   171
wenzelm@55301
   172
    val weak_meet = if weak then fn _ => I else meet;
wenzelm@40281
   173
wenzelm@40281
   174
wenzelm@40281
   175
    (* occurs check and assignment *)
wenzelm@40281
   176
wenzelm@40281
   177
    fun occurs_check tye xi (TVar (xi', _)) =
wenzelm@40281
   178
          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
wenzelm@40281
   179
          else
wenzelm@40281
   180
            (case Vartab.lookup tye xi' of
wenzelm@40281
   181
              NONE => ()
wenzelm@40281
   182
            | SOME T => occurs_check tye xi T)
wenzelm@40281
   183
      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
wenzelm@40281
   184
      | occurs_check _ _ _ = ();
wenzelm@40281
   185
wenzelm@40281
   186
    fun assign xi (T as TVar (xi', _)) S env =
wenzelm@40281
   187
          if xi = xi' then env
wenzelm@40281
   188
          else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
wenzelm@40281
   189
      | assign xi T S (env as (tye, _)) =
wenzelm@40281
   190
          (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
wenzelm@40281
   191
wenzelm@40281
   192
wenzelm@40281
   193
    (* unification *)
wenzelm@40281
   194
wenzelm@40281
   195
    fun show_tycon (a, Ts) =
wenzelm@40281
   196
      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
wenzelm@40281
   197
wenzelm@40281
   198
    fun unif (T1, T2) (env as (tye, _)) =
wenzelm@59058
   199
      (case apply2 (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
wenzelm@40281
   200
        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
wenzelm@40281
   201
      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
wenzelm@40281
   202
      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
wenzelm@40281
   203
          if weak andalso null Ts andalso null Us then env
wenzelm@40281
   204
          else if a <> b then
wenzelm@40281
   205
            raise NO_UNIFIER
wenzelm@40281
   206
              ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
wenzelm@40281
   207
          else fold unif (Ts ~~ Us) env
wenzelm@40281
   208
      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
wenzelm@40281
   209
wenzelm@40281
   210
  in unif end;
wenzelm@40281
   211
wenzelm@40281
   212
val weak_unify = unify true;
wenzelm@40281
   213
val strong_unify = unify false;
wenzelm@40281
   214
wenzelm@40281
   215
wenzelm@40281
   216
(* Typ_Graph shortcuts *)
wenzelm@40281
   217
wenzelm@40281
   218
fun get_preds G T = Typ_Graph.all_preds G [T];
wenzelm@40281
   219
fun get_succs G T = Typ_Graph.all_succs G [T];
wenzelm@40281
   220
fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
wenzelm@40281
   221
fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
wenzelm@44338
   222
fun new_imm_preds G Ts =  (* FIXME inefficient *)
wenzelm@44338
   223
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_preds G) Ts));
wenzelm@44338
   224
fun new_imm_succs G Ts =  (* FIXME inefficient *)
wenzelm@44338
   225
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_succs G) Ts));
wenzelm@40281
   226
wenzelm@40281
   227
wenzelm@40281
   228
(* Graph shortcuts *)
wenzelm@40281
   229
wenzelm@55301
   230
fun maybe_new_node s G = perhaps (try (Graph.new_node s)) G;
wenzelm@55301
   231
fun maybe_new_nodes ss G = fold maybe_new_node ss G;
wenzelm@40281
   232
wenzelm@40281
   233
wenzelm@40281
   234
wenzelm@40281
   235
(** error messages **)
wenzelm@40281
   236
traytel@54584
   237
fun gen_err err msg =
wenzelm@55303
   238
  err ++> ("\nNow trying to infer coercions globally.\n\nCoercion inference failed" ^
wenzelm@55303
   239
    (if msg = "" then "" else ":\n" ^ msg) ^ "\n");
traytel@45060
   240
traytel@54584
   241
val gen_msg = eval_err oo gen_err
traytel@40836
   242
wenzelm@40281
   243
fun prep_output ctxt tye bs ts Ts =
wenzelm@40281
   244
  let
wenzelm@40281
   245
    val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
wenzelm@40281
   246
    val (Ts', Ts'') = chop (length Ts) Ts_bTs';
wenzelm@40281
   247
    fun prep t =
wenzelm@40281
   248
      let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
wenzelm@49660
   249
      in Term.subst_bounds (map Syntax_Trans.mark_bound_abs xs, t) end;
wenzelm@40281
   250
  in (map prep ts', Ts') end;
wenzelm@40281
   251
wenzelm@40281
   252
fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
wenzelm@42383
   253
traytel@40836
   254
fun unif_failed msg =
traytel@40836
   255
  "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
wenzelm@42383
   256
traytel@40836
   257
fun err_appl_msg ctxt msg tye bs t T u U () =
wenzelm@55301
   258
  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U] in
wenzelm@55301
   259
    (unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n\n",
wenzelm@55303
   260
      Buffer.empty |> Buffer.add "Coercion Inference:\n\n")
wenzelm@55301
   261
  end;
wenzelm@40281
   262
traytel@54584
   263
fun err_list ctxt err tye Ts =
wenzelm@55303
   264
  let val (_, Ts') = prep_output ctxt tye [] [] Ts in
wenzelm@55303
   265
    eval_error (err ++>
wenzelm@55303
   266
      ("\nCannot unify a list of types that should be the same:\n" ^
wenzelm@55303
   267
        Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts'))))
wenzelm@40281
   268
  end;
wenzelm@40281
   269
traytel@54584
   270
fun err_bound ctxt err tye packs =
wenzelm@40281
   271
  let
wenzelm@40281
   272
    val (ts, Ts) = fold
wenzelm@40281
   273
      (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
traytel@40836
   274
        let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
wenzelm@40282
   275
        in (t' :: ts, T' :: Ts) end)
wenzelm@40281
   276
      packs ([], []);
wenzelm@55303
   277
    val msg =
wenzelm@55303
   278
      Pretty.string_of (Pretty.big_list "Cannot fulfil subtype constraints:"
traytel@45060
   279
        (map2 (fn [t, u] => fn [T, U] =>
wenzelm@40281
   280
          Pretty.block [
wenzelm@42383
   281
            Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2,
wenzelm@42383
   282
            Syntax.pretty_typ ctxt U, Pretty.brk 3,
wenzelm@42383
   283
            Pretty.str "from function application", Pretty.brk 2,
traytel@45060
   284
            Pretty.block [Syntax.pretty_term ctxt (t $ u)]])
wenzelm@55303
   285
          ts Ts));
wenzelm@55303
   286
  in eval_error (err ++> ("\n" ^ msg)) end;
wenzelm@40281
   287
wenzelm@40281
   288
wenzelm@40281
   289
wenzelm@40281
   290
(** constraint generation **)
wenzelm@40281
   291
traytel@51319
   292
fun update_coerce_arg ctxt old t =
traytel@51319
   293
  let
traytel@51319
   294
    val mk_coerce_args = the_default [] o Symtab.lookup (coerce_args_of ctxt);
traytel@51319
   295
    fun update _ [] = old
traytel@51327
   296
      | update 0 (coerce :: _) = (case coerce of LEAVE => old | PERMIT => true | FORBID => false)
traytel@51319
   297
      | update n (_ :: cs) = update (n - 1) cs;
traytel@51319
   298
    val (f, n) = Term.strip_comb (Type.strip_constraints t) ||> length;
traytel@51319
   299
  in
traytel@51319
   300
    update n (case f of Const (name, _) => mk_coerce_args name | _ => [])
traytel@51319
   301
  end;
traytel@51319
   302
traytel@40836
   303
fun generate_constraints ctxt err =
wenzelm@40281
   304
  let
traytel@51319
   305
    fun gen _ cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
traytel@51319
   306
      | gen _ cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
traytel@51319
   307
      | gen _ cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
traytel@51319
   308
      | gen _ cs bs (Bound i) tye_idx =
wenzelm@43278
   309
          (snd (nth bs i handle General.Subscript => err_loose i), tye_idx, cs)
traytel@51319
   310
      | gen coerce cs bs (Abs (x, T, t)) tye_idx =
traytel@51319
   311
          let val (U, tye_idx', cs') = gen coerce cs ((x, T) :: bs) t tye_idx
wenzelm@40281
   312
          in (T --> U, tye_idx', cs') end
traytel@51319
   313
      | gen coerce cs bs (t $ u) tye_idx =
wenzelm@40281
   314
          let
traytel@51319
   315
            val (T, tye_idx', cs') = gen coerce cs bs t tye_idx;
traytel@51319
   316
            val coerce' = update_coerce_arg ctxt coerce t;
traytel@51319
   317
            val (U', (tye, idx), cs'') = gen coerce' cs' bs u tye_idx';
wenzelm@40286
   318
            val U = Type_Infer.mk_param idx [];
wenzelm@40286
   319
            val V = Type_Infer.mk_param (idx + 1) [];
traytel@45060
   320
            val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2)
traytel@41353
   321
              handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
wenzelm@40281
   322
            val error_pack = (bs, t $ u, U, V, U');
wenzelm@52432
   323
          in
traytel@51319
   324
            if coerce'
traytel@51319
   325
            then (V, tye_idx'', ((U', U), error_pack) :: cs'')
traytel@51319
   326
            else (V,
traytel@51319
   327
              strong_unify ctxt (U, U') tye_idx''
traytel@51319
   328
                handle NO_UNIFIER (msg, _) => error (gen_msg err msg),
traytel@51319
   329
              cs'')
traytel@51319
   330
          end;
wenzelm@40281
   331
  in
traytel@51319
   332
    gen true [] []
wenzelm@40281
   333
  end;
wenzelm@40281
   334
wenzelm@40281
   335
wenzelm@40281
   336
wenzelm@40281
   337
(** constraint resolution **)
wenzelm@40281
   338
wenzelm@40281
   339
exception BOUND_ERROR of string;
wenzelm@40281
   340
traytel@40836
   341
fun process_constraints ctxt err cs tye_idx =
wenzelm@40281
   342
  let
wenzelm@42388
   343
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42388
   344
wenzelm@40285
   345
    val coes_graph = coes_graph_of ctxt;
wenzelm@40285
   346
    val tmaps = tmaps_of ctxt;
wenzelm@42388
   347
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   348
wenzelm@40281
   349
    fun split_cs _ [] = ([], [])
wenzelm@40282
   350
      | split_cs f (c :: cs) =
wenzelm@59058
   351
          (case apply2 f (fst c) of
wenzelm@40281
   352
            (false, false) => apsnd (cons c) (split_cs f cs)
wenzelm@40281
   353
          | _ => apfst (cons c) (split_cs f cs));
wenzelm@42383
   354
traytel@41353
   355
    fun unify_list (T :: Ts) tye_idx =
wenzelm@42383
   356
      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
wenzelm@40281
   357
wenzelm@40282
   358
wenzelm@40281
   359
    (* check whether constraint simplification will terminate using weak unification *)
wenzelm@40282
   360
traytel@41353
   361
    val _ = fold (fn (TU, _) => fn tye_idx =>
traytel@41353
   362
      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) =>
traytel@40836
   363
        error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
wenzelm@40281
   364
wenzelm@40281
   365
wenzelm@40281
   366
    (* simplify constraints *)
wenzelm@40282
   367
wenzelm@40281
   368
    fun simplify_constraints cs tye_idx =
wenzelm@40281
   369
      let
wenzelm@40281
   370
        fun contract a Ts Us error_pack done todo tye idx =
wenzelm@40281
   371
          let
wenzelm@40281
   372
            val arg_var =
wenzelm@40281
   373
              (case Symtab.lookup tmaps a of
wenzelm@40281
   374
                (*everything is invariant for unknown constructors*)
wenzelm@40281
   375
                NONE => replicate (length Ts) INVARIANT
wenzelm@40281
   376
              | SOME av => snd av);
wenzelm@40281
   377
            fun new_constraints (variance, constraint) (cs, tye_idx) =
wenzelm@40281
   378
              (case variance of
wenzelm@40281
   379
                COVARIANT => (constraint :: cs, tye_idx)
wenzelm@40281
   380
              | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
traytel@41353
   381
              | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
wenzelm@42383
   382
                  handle NO_UNIFIER (msg, _) =>
traytel@54584
   383
                    err_list ctxt (gen_err err
traytel@54584
   384
                      ("failed to unify invariant arguments w.r.t. to the known map function\n" ^
traytel@54584
   385
                        msg))
traytel@45060
   386
                      (fst tye_idx) (T :: Ts))
wenzelm@40281
   387
              | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
wenzelm@42383
   388
                  handle NO_UNIFIER (msg, _) =>
traytel@51248
   389
                    error (gen_msg err ("failed to unify invariant arguments\n" ^ msg))));
wenzelm@40281
   390
            val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
wenzelm@40281
   391
              (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
traytel@49142
   392
            val test_update = is_typeT orf is_freeT orf is_fixedvarT;
wenzelm@40281
   393
            val (ch, done') =
traytel@51246
   394
              done
wenzelm@59058
   395
              |> map (apfst (apply2 (Type_Infer.deref tye')))
traytel@51246
   396
              |> (if not (null new) then rpair []  else split_cs test_update);
wenzelm@40281
   397
            val todo' = ch @ todo;
wenzelm@40281
   398
          in
wenzelm@40281
   399
            simplify done' (new @ todo') (tye', idx')
wenzelm@40281
   400
          end
wenzelm@40281
   401
        (*xi is definitely a parameter*)
wenzelm@40281
   402
        and expand varleq xi S a Ts error_pack done todo tye idx =
wenzelm@40281
   403
          let
wenzelm@40281
   404
            val n = length Ts;
wenzelm@40286
   405
            val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S);
wenzelm@40281
   406
            val tye' = Vartab.update_new (xi, Type(a, args)) tye;
wenzelm@40286
   407
            val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done;
wenzelm@40281
   408
            val todo' = ch @ todo;
wenzelm@40281
   409
            val new =
wenzelm@40281
   410
              if varleq then (Type(a, args), Type (a, Ts))
wenzelm@40286
   411
              else (Type (a, Ts), Type (a, args));
wenzelm@40281
   412
          in
wenzelm@40281
   413
            simplify done' ((new, error_pack) :: todo') (tye', idx + n)
wenzelm@40281
   414
          end
wenzelm@40281
   415
        (*TU is a pair of a parameter and a free/fixed variable*)
traytel@41353
   416
        and eliminate TU done todo tye idx =
wenzelm@40281
   417
          let
wenzelm@40286
   418
            val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
wenzelm@40286
   419
            val [T] = filter_out Type_Infer.is_paramT TU;
wenzelm@40281
   420
            val SOME S' = sort_of T;
wenzelm@40281
   421
            val test_update = if is_freeT T then is_freeT else is_fixedvarT;
wenzelm@40281
   422
            val tye' = Vartab.update_new (xi, T) tye;
wenzelm@40286
   423
            val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done;
wenzelm@40281
   424
            val todo' = ch @ todo;
wenzelm@40281
   425
          in
wenzelm@42388
   426
            if Sign.subsort thy (S', S) (*TODO check this*)
wenzelm@40281
   427
            then simplify done' todo' (tye', idx)
traytel@40836
   428
            else error (gen_msg err "sort mismatch")
wenzelm@40281
   429
          end
wenzelm@40281
   430
        and simplify done [] tye_idx = (done, tye_idx)
wenzelm@40281
   431
          | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
wenzelm@40286
   432
              (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
traytel@45060
   433
                (T1 as Type (a, []), T2 as Type (b, [])) =>
wenzelm@40281
   434
                  if a = b then simplify done todo tye_idx
wenzelm@40281
   435
                  else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
wenzelm@55303
   436
                  else
wenzelm@55303
   437
                    error (gen_msg err (quote (Syntax.string_of_typ ctxt T1) ^
wenzelm@55303
   438
                      " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2)))
wenzelm@40281
   439
              | (Type (a, Ts), Type (b, Us)) =>
wenzelm@55303
   440
                  if a <> b then
wenzelm@55303
   441
                    error (gen_msg err "different constructors") (fst tye_idx) error_pack
wenzelm@40281
   442
                  else contract a Ts Us error_pack done todo tye idx
wenzelm@40282
   443
              | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
wenzelm@40281
   444
                  expand true xi S a Ts error_pack done todo tye idx
wenzelm@40282
   445
              | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
wenzelm@40281
   446
                  expand false xi S a Ts error_pack done todo tye idx
wenzelm@40281
   447
              | (T, U) =>
wenzelm@40281
   448
                  if T = U then simplify done todo tye_idx
wenzelm@40282
   449
                  else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
wenzelm@40286
   450
                    exists Type_Infer.is_paramT [T, U]
traytel@41353
   451
                  then eliminate [T, U] done todo tye idx
wenzelm@40281
   452
                  else if exists (is_freeT orf is_fixedvarT) [T, U]
traytel@40836
   453
                  then error (gen_msg err "not eliminated free/fixed variables")
wenzelm@40282
   454
                  else simplify (((T, U), error_pack) :: done) todo tye_idx);
wenzelm@40281
   455
      in
wenzelm@40281
   456
        simplify [] cs tye_idx
wenzelm@40281
   457
      end;
wenzelm@40281
   458
wenzelm@40281
   459
wenzelm@40281
   460
    (* do simplification *)
wenzelm@40282
   461
wenzelm@40281
   462
    val (cs', tye_idx') = simplify_constraints cs tye_idx;
wenzelm@42383
   463
wenzelm@42383
   464
    fun find_error_pack lower T' = map_filter
traytel@40836
   465
      (fn ((T, U), pack) => if if lower then T' = U else T' = T then SOME pack else NONE) cs';
wenzelm@42383
   466
wenzelm@42383
   467
    fun find_cycle_packs nodes =
traytel@40836
   468
      let
traytel@40836
   469
        val (but_last, last) = split_last nodes
traytel@40836
   470
        val pairs = (last, hd nodes) :: (but_last ~~ tl nodes);
traytel@40836
   471
      in
traytel@40836
   472
        map_filter
wenzelm@40838
   473
          (fn (TU, pack) => if member (op =) pairs TU then SOME pack else NONE)
traytel@40836
   474
          cs'
traytel@40836
   475
      end;
wenzelm@40281
   476
wenzelm@40281
   477
    (*styps stands either for supertypes or for subtypes of a type T
wenzelm@40281
   478
      in terms of the subtype-relation (excluding T itself)*)
wenzelm@40282
   479
    fun styps super T =
wenzelm@44338
   480
      (if super then Graph.immediate_succs else Graph.immediate_preds) coes_graph T
wenzelm@40281
   481
        handle Graph.UNDEF _ => [];
wenzelm@40281
   482
wenzelm@40282
   483
    fun minmax sup (T :: Ts) =
wenzelm@40281
   484
      let
wenzelm@40281
   485
        fun adjust T U = if sup then (T, U) else (U, T);
wenzelm@40281
   486
        fun extract T [] = T
wenzelm@40282
   487
          | extract T (U :: Us) =
wenzelm@40281
   488
              if Graph.is_edge coes_graph (adjust T U) then extract T Us
wenzelm@40281
   489
              else if Graph.is_edge coes_graph (adjust U T) then extract U Us
traytel@40836
   490
              else raise BOUND_ERROR "uncomparable types in type list";
wenzelm@40281
   491
      in
wenzelm@40281
   492
        t_of (extract T Ts)
wenzelm@40281
   493
      end;
wenzelm@40281
   494
wenzelm@40282
   495
    fun ex_styp_of_sort super T styps_and_sorts =
wenzelm@40281
   496
      let
wenzelm@40281
   497
        fun adjust T U = if super then (T, U) else (U, T);
wenzelm@40282
   498
        fun styp_test U Ts = forall
wenzelm@40281
   499
          (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
wenzelm@55301
   500
        fun fitting Ts S U = Sign.of_sort thy (t_of U, S) andalso styp_test U Ts;
wenzelm@40281
   501
      in
wenzelm@40281
   502
        forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
wenzelm@40281
   503
      end;
wenzelm@40281
   504
wenzelm@40281
   505
    (* computes the tightest possible, correct assignment for 'a::S
wenzelm@40281
   506
       e.g. in the supremum case (sup = true):
wenzelm@40281
   507
               ------- 'a::S---
wenzelm@40281
   508
              /        /    \  \
wenzelm@40281
   509
             /        /      \  \
wenzelm@40281
   510
        'b::C1   'c::C2 ...  T1 T2 ...
wenzelm@40281
   511
wenzelm@40281
   512
       sorts - list of sorts [C1, C2, ...]
wenzelm@40281
   513
       T::Ts - non-empty list of base types [T1, T2, ...]
wenzelm@40281
   514
    *)
wenzelm@40282
   515
    fun tightest sup S styps_and_sorts (T :: Ts) =
wenzelm@40281
   516
      let
wenzelm@42388
   517
        fun restriction T = Sign.of_sort thy (t_of T, S)
wenzelm@40281
   518
          andalso ex_styp_of_sort (not sup) T styps_and_sorts;
wenzelm@40281
   519
        fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
wenzelm@40281
   520
      in
wenzelm@40281
   521
        (case fold candidates Ts (filter restriction (T :: styps sup T)) of
traytel@40836
   522
          [] => raise BOUND_ERROR ("no " ^ (if sup then "supremum" else "infimum"))
wenzelm@40281
   523
        | [T] => t_of T
wenzelm@40281
   524
        | Ts => minmax sup Ts)
wenzelm@40281
   525
      end;
wenzelm@40281
   526
wenzelm@40281
   527
    fun build_graph G [] tye_idx = (G, tye_idx)
wenzelm@40282
   528
      | build_graph G ((T, U) :: cs) tye_idx =
wenzelm@40281
   529
        if T = U then build_graph G cs tye_idx
wenzelm@40281
   530
        else
wenzelm@40281
   531
          let
wenzelm@40281
   532
            val G' = maybe_new_typnodes [T, U] G;
traytel@45059
   533
            val (G'', tye_idx') = (Typ_Graph.add_edge_acyclic (T, U) G', tye_idx)
wenzelm@40281
   534
              handle Typ_Graph.CYCLES cycles =>
wenzelm@40281
   535
                let
wenzelm@42383
   536
                  val (tye, idx) =
wenzelm@42383
   537
                    fold
traytel@40836
   538
                      (fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
wenzelm@42383
   539
                        handle NO_UNIFIER (msg, _) =>
wenzelm@42383
   540
                          err_bound ctxt
traytel@54584
   541
                            (gen_err err ("constraint cycle not unifiable\n" ^ msg)) (fst tye_idx)
traytel@40836
   542
                            (find_cycle_packs cycle)))
traytel@40836
   543
                      cycles tye_idx
wenzelm@40281
   544
                in
traytel@40836
   545
                  collapse (tye, idx) cycles G
traytel@40836
   546
                end
wenzelm@40281
   547
          in
wenzelm@40281
   548
            build_graph G'' cs tye_idx'
wenzelm@40281
   549
          end
traytel@40836
   550
    and collapse (tye, idx) cycles G = (*nodes non-empty list*)
wenzelm@40281
   551
      let
traytel@40836
   552
        (*all cycles collapse to one node,
traytel@40836
   553
          because all of them share at least the nodes x and y*)
traytel@40836
   554
        val nodes = (distinct (op =) (flat cycles));
traytel@40836
   555
        val T = Type_Infer.deref tye (hd nodes);
wenzelm@40281
   556
        val P = new_imm_preds G nodes;
wenzelm@40281
   557
        val S = new_imm_succs G nodes;
wenzelm@46665
   558
        val G' = fold Typ_Graph.del_node (tl nodes) G;
traytel@40836
   559
        fun check_and_gen super T' =
traytel@40836
   560
          let val U = Type_Infer.deref tye T';
traytel@40836
   561
          in
traytel@40836
   562
            if not (is_typeT T) orelse not (is_typeT U) orelse T = U
traytel@40836
   563
            then if super then (hd nodes, T') else (T', hd nodes)
wenzelm@42383
   564
            else
wenzelm@42383
   565
              if super andalso
traytel@40836
   566
                Graph.is_edge coes_graph (nameT T, nameT U) then (hd nodes, T')
wenzelm@42383
   567
              else if not super andalso
traytel@40836
   568
                Graph.is_edge coes_graph (nameT U, nameT T) then (T', hd nodes)
wenzelm@55303
   569
              else
wenzelm@55303
   570
                err_bound ctxt (gen_err err "cycle elimination produces inconsistent graph")
wenzelm@55303
   571
                  (fst tye_idx)
wenzelm@55303
   572
                  (maps find_cycle_packs cycles @ find_error_pack super T')
traytel@40836
   573
          end;
wenzelm@40281
   574
      in
traytel@40836
   575
        build_graph G' (map (check_and_gen false) P @ map (check_and_gen true) S) (tye, idx)
wenzelm@40281
   576
      end;
wenzelm@40281
   577
wenzelm@40281
   578
    fun assign_bound lower G key (tye_idx as (tye, _)) =
wenzelm@40286
   579
      if Type_Infer.is_paramT (Type_Infer.deref tye key) then
wenzelm@40281
   580
        let
wenzelm@40286
   581
          val TVar (xi, S) = Type_Infer.deref tye key;
wenzelm@40281
   582
          val get_bound = if lower then get_preds else get_succs;
wenzelm@40281
   583
          val raw_bound = get_bound G key;
wenzelm@40286
   584
          val bound = map (Type_Infer.deref tye) raw_bound;
wenzelm@40286
   585
          val not_params = filter_out Type_Infer.is_paramT bound;
wenzelm@40282
   586
          fun to_fulfil T =
wenzelm@40281
   587
            (case sort_of T of
wenzelm@40281
   588
              NONE => NONE
wenzelm@40282
   589
            | SOME S =>
wenzelm@40286
   590
                SOME
wenzelm@40286
   591
                  (map nameT
wenzelm@42405
   592
                    (filter_out Type_Infer.is_paramT
wenzelm@42405
   593
                      (map (Type_Infer.deref tye) (get_bound G T))), S));
wenzelm@40281
   594
          val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
wenzelm@40281
   595
          val assignment =
wenzelm@40281
   596
            if null bound orelse null not_params then NONE
wenzelm@40281
   597
            else SOME (tightest lower S styps_and_sorts (map nameT not_params)
traytel@54584
   598
                handle BOUND_ERROR msg => err_bound ctxt (gen_err err msg) tye
wenzelm@55301
   599
                  (maps (find_error_pack (not lower)) raw_bound));
wenzelm@40281
   600
        in
wenzelm@40281
   601
          (case assignment of
wenzelm@40281
   602
            NONE => tye_idx
wenzelm@40281
   603
          | SOME T =>
wenzelm@40286
   604
              if Type_Infer.is_paramT T then tye_idx
wenzelm@40281
   605
              else if lower then (*upper bound check*)
wenzelm@40281
   606
                let
wenzelm@40286
   607
                  val other_bound = map (Type_Infer.deref tye) (get_succs G key);
wenzelm@40281
   608
                  val s = nameT T;
wenzelm@40281
   609
                in
wenzelm@40281
   610
                  if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
wenzelm@40281
   611
                  then apfst (Vartab.update (xi, T)) tye_idx
wenzelm@55303
   612
                  else
wenzelm@55303
   613
                    err_bound ctxt
wenzelm@55303
   614
                      (gen_err err
wenzelm@55303
   615
                        (Pretty.string_of (Pretty.block
wenzelm@55303
   616
                          [Pretty.str "assigned base type", Pretty.brk 1,
wenzelm@55303
   617
                            Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1,
wenzelm@55303
   618
                            Pretty.str "clashes with the upper bound of variable", Pretty.brk 1,
wenzelm@55303
   619
                            Syntax.pretty_typ ctxt (TVar (xi, S))])))
wenzelm@55303
   620
                      tye
wenzelm@55303
   621
                      (maps (find_error_pack lower) other_bound)
wenzelm@40281
   622
                end
wenzelm@40281
   623
              else apfst (Vartab.update (xi, T)) tye_idx)
wenzelm@40281
   624
        end
wenzelm@40281
   625
      else tye_idx;
wenzelm@40281
   626
wenzelm@40281
   627
    val assign_lb = assign_bound true;
wenzelm@40281
   628
    val assign_ub = assign_bound false;
wenzelm@40281
   629
wenzelm@40281
   630
    fun assign_alternating ts' ts G tye_idx =
wenzelm@40281
   631
      if ts' = ts then tye_idx
wenzelm@40281
   632
      else
wenzelm@40281
   633
        let
wenzelm@40281
   634
          val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
wenzelm@40281
   635
            |> fold (assign_ub G) ts;
wenzelm@40281
   636
        in
wenzelm@42383
   637
          assign_alternating ts
traytel@40836
   638
            (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
wenzelm@40281
   639
        end;
wenzelm@40281
   640
wenzelm@40281
   641
    (*Unify all weakly connected components of the constraint forest,
wenzelm@40282
   642
      that contain only params. These are the only WCCs that contain
wenzelm@40281
   643
      params anyway.*)
wenzelm@40281
   644
    fun unify_params G (tye_idx as (tye, _)) =
wenzelm@40281
   645
      let
wenzelm@40286
   646
        val max_params =
wenzelm@40286
   647
          filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G);
wenzelm@40281
   648
        val to_unify = map (fn T => T :: get_preds G T) max_params;
wenzelm@40281
   649
      in
wenzelm@42383
   650
        fold
traytel@40836
   651
          (fn Ts => fn tye_idx' => unify_list Ts tye_idx'
traytel@54584
   652
            handle NO_UNIFIER (msg, _) => err_list ctxt (gen_err err msg) (fst tye_idx) Ts)
traytel@40836
   653
          to_unify tye_idx
wenzelm@40281
   654
      end;
wenzelm@40281
   655
wenzelm@40281
   656
    fun solve_constraints G tye_idx = tye_idx
wenzelm@40281
   657
      |> assign_alternating [] (Typ_Graph.keys G) G
wenzelm@40281
   658
      |> unify_params G;
wenzelm@40281
   659
  in
wenzelm@40281
   660
    build_graph Typ_Graph.empty (map fst cs') tye_idx'
wenzelm@40281
   661
      |-> solve_constraints
wenzelm@40281
   662
  end;
wenzelm@40281
   663
wenzelm@40281
   664
wenzelm@40281
   665
wenzelm@40281
   666
(** coercion insertion **)
wenzelm@40281
   667
traytel@45060
   668
fun gen_coercion ctxt err tye TU =
traytel@45060
   669
  let
wenzelm@59058
   670
    fun gen (T1, T2) =
wenzelm@59058
   671
      (case apply2 (Type_Infer.deref tye) (T1, T2) of
traytel@45060
   672
        (T1 as (Type (a, [])), T2 as (Type (b, []))) =>
traytel@45060
   673
            if a = b
traytel@51335
   674
            then mk_identity T1
traytel@45060
   675
            else
traytel@45060
   676
              (case Symreltab.lookup (coes_of ctxt) (a, b) of
wenzelm@55303
   677
                NONE =>
wenzelm@55303
   678
                  raise COERCION_GEN_ERROR (err ++>
wenzelm@55303
   679
                    (Pretty.string_of o Pretty.block)
wenzelm@55303
   680
                      [Pretty.quote (Syntax.pretty_typ ctxt T1), Pretty.brk 1,
wenzelm@55303
   681
                        Pretty.str "is not a subtype of", Pretty.brk 1,
wenzelm@55303
   682
                        Pretty.quote (Syntax.pretty_typ ctxt T2)])
traytel@45060
   683
              | SOME (co, _) => co)
traytel@45102
   684
      | (T1 as Type (a, Ts), T2 as Type (b, Us)) =>
traytel@45060
   685
            if a <> b
traytel@45060
   686
            then
traytel@45060
   687
              (case Symreltab.lookup (coes_of ctxt) (a, b) of
traytel@45060
   688
                (*immediate error - cannot fix complex coercion with the global algorithm*)
wenzelm@55303
   689
                NONE =>
wenzelm@55303
   690
                  eval_error (err ++>
wenzelm@55304
   691
                    ("No coercion known for type constructors: " ^
wenzelm@55304
   692
                      quote (Proof_Context.markup_type ctxt a) ^ " and " ^
wenzelm@55304
   693
                      quote (Proof_Context.markup_type ctxt b)))
traytel@45060
   694
              | SOME (co, ((Ts', Us'), _)) =>
traytel@45060
   695
                  let
traytel@45060
   696
                    val co_before = gen (T1, Type (a, Ts'));
traytel@45060
   697
                    val coT = range_type (fastype_of co_before);
wenzelm@55303
   698
                    val insts =
wenzelm@55303
   699
                      inst_collect tye (err ++> "Could not insert complex coercion")
wenzelm@55303
   700
                        (domain_type (fastype_of co)) coT;
traytel@45060
   701
                    val co' = Term.subst_TVars insts co;
traytel@45060
   702
                    val co_after = gen (Type (b, (map (typ_subst_TVars insts) Us')), T2);
traytel@45060
   703
                  in
traytel@45060
   704
                    Abs (Name.uu, T1, Library.foldr (op $)
traytel@45060
   705
                      (filter (not o is_identity) [co_after, co', co_before], Bound 0))
traytel@45060
   706
                  end)
traytel@45060
   707
            else
traytel@45060
   708
              let
traytel@51335
   709
                fun sub_co (COVARIANT, TU) = (SOME (gen TU), NONE)
traytel@51335
   710
                  | sub_co (CONTRAVARIANT, TU) = (SOME (gen (swap TU)), NONE)
traytel@51335
   711
                  | sub_co (INVARIANT, (T, _)) = (NONE, SOME T)
traytel@51335
   712
                  | sub_co (INVARIANT_TO T, _) = (NONE, NONE);
traytel@45060
   713
                fun ts_of [] = []
traytel@45060
   714
                  | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
traytel@45060
   715
              in
traytel@45060
   716
                (case Symtab.lookup (tmaps_of ctxt) a of
traytel@45102
   717
                  NONE =>
traytel@45102
   718
                    if Type.could_unify (T1, T2)
traytel@51335
   719
                    then mk_identity T1
wenzelm@55303
   720
                    else
wenzelm@55303
   721
                      raise COERCION_GEN_ERROR
wenzelm@55304
   722
                        (err ++>
wenzelm@55304
   723
                          ("No map function for " ^ quote (Proof_Context.markup_type ctxt a)
wenzelm@55304
   724
                            ^ " known"))
traytel@51335
   725
                | SOME (tmap, variances) =>
traytel@45060
   726
                    let
traytel@51335
   727
                      val (used_coes, invarTs) =
traytel@51335
   728
                        map_split sub_co (variances ~~ (Ts ~~ Us))
traytel@51335
   729
                        |>> map_filter I
traytel@51335
   730
                        ||> map_filter I;
traytel@51335
   731
                      val Tinsts = ts_of (map fastype_of used_coes) @ invarTs;
traytel@45060
   732
                    in
traytel@45060
   733
                      if null (filter (not o is_identity) used_coes)
traytel@51335
   734
                      then mk_identity (Type (a, Ts))
traytel@51335
   735
                      else Term.list_comb (instantiate tmap Tinsts, used_coes)
traytel@45060
   736
                    end)
traytel@45060
   737
              end
traytel@45060
   738
      | (T, U) =>
traytel@45060
   739
            if Type.could_unify (T, U)
traytel@51335
   740
            then mk_identity T
wenzelm@55303
   741
            else raise COERCION_GEN_ERROR (err ++>
wenzelm@55303
   742
              (Pretty.string_of o Pretty.block)
wenzelm@55303
   743
               [Pretty.str "Cannot generate coercion from", Pretty.brk 1,
wenzelm@55303
   744
                Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1,
wenzelm@55303
   745
                Pretty.str "to", Pretty.brk 1,
wenzelm@55303
   746
                Pretty.quote (Syntax.pretty_typ ctxt U)]));
traytel@45060
   747
  in
traytel@45060
   748
    gen TU
traytel@45060
   749
  end;
traytel@40836
   750
traytel@45060
   751
fun function_of ctxt err tye T =
traytel@45060
   752
  (case Type_Infer.deref tye T of
traytel@45060
   753
    Type (C, Ts) =>
traytel@45060
   754
      (case Symreltab.lookup (coes_of ctxt) (C, "fun") of
wenzelm@55304
   755
        NONE =>
wenzelm@55304
   756
          eval_error (err ++> ("No complex coercion from " ^
wenzelm@55304
   757
            quote (Proof_Context.markup_type ctxt C) ^ " to " ^
wenzelm@55304
   758
            quote (Proof_Context.markup_type ctxt "fun")))
traytel@45060
   759
      | SOME (co, ((Ts', _), _)) =>
traytel@45060
   760
        let
traytel@45060
   761
          val co_before = gen_coercion ctxt err tye (Type (C, Ts), Type (C, Ts'));
traytel@45060
   762
          val coT = range_type (fastype_of co_before);
wenzelm@55303
   763
          val insts =
wenzelm@55303
   764
            inst_collect tye (err ++> "Could not insert complex coercion")
wenzelm@55303
   765
              (domain_type (fastype_of co)) coT;
traytel@45060
   766
          val co' = Term.subst_TVars insts co;
traytel@45060
   767
        in
traytel@45060
   768
          Abs (Name.uu, Type (C, Ts), Library.foldr (op $)
traytel@45060
   769
            (filter (not o is_identity) [co', co_before], Bound 0))
traytel@45060
   770
        end)
wenzelm@55303
   771
  | T' =>
wenzelm@55303
   772
      eval_error (err ++>
wenzelm@55303
   773
        (Pretty.string_of o Pretty.block)
wenzelm@55303
   774
         [Pretty.str "No complex coercion from", Pretty.brk 1,
wenzelm@55304
   775
          Pretty.quote (Syntax.pretty_typ ctxt T'), Pretty.brk 1,
wenzelm@55304
   776
          Pretty.str "to", Pretty.brk 1, Proof_Context.pretty_type ctxt "fun"]));
traytel@45060
   777
traytel@45060
   778
fun insert_coercions ctxt (tye, idx) ts =
wenzelm@40281
   779
  let
traytel@45060
   780
    fun insert _ (Const (c, T)) = (Const (c, T), T)
traytel@45060
   781
      | insert _ (Free (x, T)) = (Free (x, T), T)
traytel@45060
   782
      | insert _ (Var (xi, T)) = (Var (xi, T), T)
wenzelm@40281
   783
      | insert bs (Bound i) =
wenzelm@43278
   784
          let val T = nth bs i handle General.Subscript => err_loose i;
wenzelm@40281
   785
          in (Bound i, T) end
wenzelm@40281
   786
      | insert bs (Abs (x, T, t)) =
traytel@45060
   787
          let val (t', T') = insert (T :: bs) t;
traytel@45060
   788
          in (Abs (x, T, t'), T --> T') end
wenzelm@40281
   789
      | insert bs (t $ u) =
wenzelm@40281
   790
          let
traytel@45060
   791
            val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
wenzelm@40281
   792
            val (u', U') = insert bs u;
wenzelm@40281
   793
          in
traytel@40836
   794
            if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
traytel@40836
   795
            then (t' $ u', T)
wenzelm@55303
   796
            else (t' $ (gen_coercion ctxt (K ("", Buffer.empty)) tye (U', U) $ u'), T)
wenzelm@40281
   797
          end
wenzelm@40281
   798
  in
wenzelm@40281
   799
    map (fst o insert []) ts
wenzelm@40281
   800
  end;
wenzelm@40281
   801
wenzelm@40281
   802
wenzelm@40281
   803
wenzelm@40281
   804
(** assembling the pipeline **)
wenzelm@40281
   805
wenzelm@42398
   806
fun coercion_infer_types ctxt raw_ts =
wenzelm@40281
   807
  let
wenzelm@42405
   808
    val (idx, ts) = Type_Infer_Context.prepare ctxt raw_ts;
wenzelm@40281
   809
traytel@51319
   810
    fun inf _ _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
traytel@51319
   811
      | inf _ _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
traytel@51319
   812
      | inf _ _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
traytel@51319
   813
      | inf _ bs (t as (Bound i)) tye_idx =
wenzelm@43278
   814
          (t, snd (nth bs i handle General.Subscript => err_loose i), tye_idx)
traytel@51319
   815
      | inf coerce bs (Abs (x, T, t)) tye_idx =
traytel@51319
   816
          let val (t', U, tye_idx') = inf coerce ((x, T) :: bs) t tye_idx
traytel@40836
   817
          in (Abs (x, T, t'), T --> U, tye_idx') end
traytel@51319
   818
      | inf coerce bs (t $ u) tye_idx =
traytel@40836
   819
          let
traytel@51319
   820
            val (t', T, tye_idx') = inf coerce bs t tye_idx;
traytel@51319
   821
            val coerce' = update_coerce_arg ctxt coerce t;
traytel@51319
   822
            val (u', U, (tye, idx)) = inf coerce' bs u tye_idx';
traytel@40836
   823
            val V = Type_Infer.mk_param idx [];
traytel@40836
   824
            val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
wenzelm@42383
   825
              handle NO_UNIFIER (msg, tye') =>
traytel@45060
   826
                let
traytel@45060
   827
                  val err = err_appl_msg ctxt msg tye' bs t T u U;
traytel@45060
   828
                  val W = Type_Infer.mk_param (idx + 1) [];
traytel@45060
   829
                  val (t'', (tye', idx')) =
traytel@45060
   830
                    (t', strong_unify ctxt (W --> V, T) (tye, idx + 2))
traytel@45060
   831
                      handle NO_UNIFIER _ =>
traytel@45060
   832
                        let
wenzelm@55303
   833
                          val err' = err ++> "Local coercion insertion on the operator failed:\n";
traytel@45060
   834
                          val co = function_of ctxt err' tye T;
traytel@51319
   835
                          val (t'', T'', tye_idx'') = inf coerce bs (co $ t') (tye, idx + 2);
traytel@45060
   836
                        in
traytel@45060
   837
                          (t'', strong_unify ctxt (W --> V, T'') tye_idx''
wenzelm@55303
   838
                            handle NO_UNIFIER (msg, _) => eval_error (err' ++> msg))
traytel@45060
   839
                        end;
wenzelm@55303
   840
                  val err' = err ++>
traytel@54584
   841
                    ((if t' aconv t'' then ""
wenzelm@55303
   842
                      else "Successfully coerced the operator to a function of type:\n" ^
wenzelm@55303
   843
                        Syntax.string_of_typ ctxt
wenzelm@55303
   844
                          (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^
wenzelm@55303
   845
                     (if coerce' then "Local coercion insertion on the operand failed:\n"
wenzelm@55303
   846
                      else "Local coercion insertion on the operand disallowed:\n"));
traytel@45060
   847
                  val (u'', U', tye_idx') =
wenzelm@52432
   848
                    if coerce' then
traytel@51319
   849
                      let val co = gen_coercion ctxt err' tye' (U, W);
traytel@51319
   850
                      in inf coerce' bs (if is_identity co then u else co $ u) (tye', idx') end
traytel@51319
   851
                    else (u, U, (tye', idx'));
traytel@45060
   852
                in
traytel@45060
   853
                  (t'' $ u'', strong_unify ctxt (U', W) tye_idx'
wenzelm@55303
   854
                    handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg))
traytel@45060
   855
                end;
traytel@40836
   856
          in (tu, V, tye_idx'') end;
wenzelm@40281
   857
wenzelm@42383
   858
    fun infer_single t tye_idx =
traytel@51319
   859
      let val (t, _, tye_idx') = inf true [] t tye_idx
traytel@40938
   860
      in (t, tye_idx') end;
wenzelm@42383
   861
traytel@40938
   862
    val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
traytel@45060
   863
      handle COERCION_GEN_ERROR err =>
traytel@40836
   864
        let
traytel@40836
   865
          fun gen_single t (tye_idx, constraints) =
traytel@45060
   866
            let val (_, tye_idx', constraints') =
wenzelm@55303
   867
              generate_constraints ctxt (err ++> "\n") t tye_idx
traytel@40836
   868
            in (tye_idx', constraints' @ constraints) end;
wenzelm@42383
   869
traytel@40836
   870
          val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
wenzelm@55303
   871
          val (tye, idx) = process_constraints ctxt (err ++> "\n") constraints tye_idx;
wenzelm@42383
   872
        in
traytel@45060
   873
          (insert_coercions ctxt (tye, idx) ts, (tye, idx))
traytel@40836
   874
        end);
wenzelm@40281
   875
wenzelm@40281
   876
    val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
wenzelm@40281
   877
  in ts'' end;
wenzelm@40281
   878
wenzelm@40281
   879
wenzelm@40281
   880
wenzelm@40281
   881
(** installation **)
wenzelm@40281
   882
wenzelm@40283
   883
(* term check *)
wenzelm@40283
   884
wenzelm@42616
   885
val coercion_enabled = Attrib.setup_config_bool @{binding coercion_enabled} (K false);
wenzelm@40939
   886
wenzelm@58826
   887
val _ =
wenzelm@58826
   888
  Theory.setup
wenzelm@58826
   889
    (Context.theory_map
wenzelm@58826
   890
      (Syntax_Phases.term_check ~100 "coercions"
wenzelm@58826
   891
        (fn ctxt => Config.get ctxt coercion_enabled ? coercion_infer_types ctxt)));
wenzelm@40281
   892
wenzelm@40281
   893
wenzelm@40283
   894
(* declarations *)
wenzelm@40281
   895
wenzelm@40284
   896
fun add_type_map raw_t context =
wenzelm@40281
   897
  let
wenzelm@40281
   898
    val ctxt = Context.proof_of context;
wenzelm@40284
   899
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   900
traytel@45059
   901
    fun err_str t = "\n\nThe provided function has the type:\n" ^
wenzelm@42383
   902
      Syntax.string_of_typ ctxt (fastype_of t) ^
traytel@45059
   903
      "\n\nThe general type signature of a map function is:" ^
traytel@41353
   904
      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
traytel@45059
   905
      "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi).";
wenzelm@42383
   906
traytel@41353
   907
    val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
wenzelm@47060
   908
      handle List.Empty => error ("Not a proper map function:" ^ err_str t);
wenzelm@42383
   909
wenzelm@40281
   910
    fun gen_arg_var ([], []) = []
traytel@51335
   911
      | gen_arg_var (Ts, (U, U') :: Us) =
traytel@41353
   912
          if U = U' then
traytel@51335
   913
            if null (Term.add_tvarsT U []) then INVARIANT_TO U :: gen_arg_var (Ts, Us)
traytel@51335
   914
            else if Term.is_TVar U then INVARIANT :: gen_arg_var (Ts, Us)
traytel@51335
   915
            else error ("Invariant xi and yi should be variables or variable-free:" ^ err_str t)
traytel@51335
   916
          else
traytel@51335
   917
            (case Ts of
traytel@51335
   918
              [] => error ("Different numbers of functions and variant arguments\n" ^ err_str t)
traytel@51335
   919
            | (T, T') :: Ts =>
traytel@51335
   920
              if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
traytel@51335
   921
              else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
traytel@51335
   922
              else error ("Functions do not apply to arguments correctly:" ^ err_str t));
wenzelm@40281
   923
traytel@41353
   924
    (*retry flag needed to adjust the type lists, when given a map over type constructor fun*)
wenzelm@55305
   925
    fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) _ =
traytel@41353
   926
          if C1 = C2 andalso not (null fis) andalso forall is_funtype fis
traytel@41353
   927
          then ((map dest_funT fis, Ts ~~ Us), C1)
traytel@41353
   928
          else error ("Not a proper map function:" ^ err_str t)
traytel@41353
   929
      | check_map_fun fis T1 T2 true =
traytel@41353
   930
          let val (fis', T') = split_last fis
traytel@41353
   931
          in check_map_fun fis' T' (T1 --> T2) false end
traytel@41353
   932
      | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t);
wenzelm@40281
   933
traytel@41353
   934
    val res = check_map_fun fis T1 T2 true;
wenzelm@40281
   935
    val res_av = gen_arg_var (fst res);
wenzelm@40281
   936
  in
wenzelm@40281
   937
    map_tmaps (Symtab.update (snd res, (t, res_av))) context
wenzelm@40281
   938
  end;
wenzelm@40281
   939
traytel@45060
   940
fun transitive_coercion ctxt tab G (a, b) =
traytel@45059
   941
  let
traytel@45060
   942
    fun safe_app t (Abs (x, T', u)) =
traytel@45060
   943
      let
traytel@45060
   944
        val t' = map_types Type_Infer.paramify_vars t;
traytel@45060
   945
      in
traytel@45060
   946
        singleton (coercion_infer_types ctxt) (Abs(x, T', (t' $ u)))
traytel@45060
   947
      end;
traytel@45059
   948
    val path = hd (Graph.irreducible_paths G (a, b));
traytel@45059
   949
    val path' = fst (split_last path) ~~ tl path;
traytel@45059
   950
    val coercions = map (fst o the o Symreltab.lookup tab) path';
traytel@45060
   951
    val trans_co = singleton (Variable.polymorphic ctxt)
traytel@51335
   952
      (fold safe_app coercions (mk_identity dummyT));
wenzelm@59058
   953
    val (Ts, Us) = apply2 (snd o Term.dest_Type) (Term.dest_funT (type_of trans_co));
traytel@45060
   954
  in
traytel@45060
   955
    (trans_co, ((Ts, Us), coercions))
traytel@45059
   956
  end;
traytel@45059
   957
wenzelm@40284
   958
fun add_coercion raw_t context =
wenzelm@40281
   959
  let
wenzelm@40281
   960
    val ctxt = Context.proof_of context;
wenzelm@40284
   961
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   962
wenzelm@55303
   963
    fun err_coercion () =
wenzelm@55303
   964
      error ("Bad type for a coercion:\n" ^
traytel@45059
   965
        Syntax.string_of_term ctxt t ^ " :: " ^
wenzelm@40281
   966
        Syntax.string_of_typ ctxt (fastype_of t));
wenzelm@40281
   967
wenzelm@40840
   968
    val (T1, T2) = Term.dest_funT (fastype_of t)
wenzelm@40840
   969
      handle TYPE _ => err_coercion ();
wenzelm@40281
   970
traytel@45060
   971
    val (a, Ts) = Term.dest_Type T1
traytel@45060
   972
      handle TYPE _ => err_coercion ();
wenzelm@40281
   973
traytel@45060
   974
    val (b, Us) = Term.dest_Type T2
traytel@45060
   975
      handle TYPE _ => err_coercion ();
wenzelm@40281
   976
traytel@45060
   977
    fun coercion_data_update (tab, G, _) =
wenzelm@40281
   978
      let
traytel@45060
   979
        val G' = maybe_new_nodes [(a, length Ts), (b, length Us)] G
wenzelm@40281
   980
        val G'' = Graph.add_edge_trans_acyclic (a, b) G'
traytel@45059
   981
          handle Graph.CYCLES _ => error (
traytel@45060
   982
            Syntax.string_of_typ ctxt T2 ^ " is already a subtype of " ^
traytel@45060
   983
            Syntax.string_of_typ ctxt T1 ^ "!\n\nCannot add coercion of type: " ^
traytel@45059
   984
            Syntax.string_of_typ ctxt (T1 --> T2));
wenzelm@40281
   985
        val new_edges =
wenzelm@49560
   986
          flat (Graph.dest G'' |> map (fn ((x, _), ys) => ys |> map_filter (fn y =>
wenzelm@40281
   987
            if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
wenzelm@40281
   988
        val G_and_new = Graph.add_edge (a, b) G';
wenzelm@40281
   989
wenzelm@40281
   990
        val tab' = fold
traytel@45059
   991
          (fn pair => fn tab =>
traytel@45060
   992
            Symreltab.update (pair, transitive_coercion ctxt tab G_and_new pair) tab)
wenzelm@40281
   993
          (filter (fn pair => pair <> (a, b)) new_edges)
traytel@45060
   994
          (Symreltab.update ((a, b), (t, ((Ts, Us), []))) tab);
wenzelm@40281
   995
      in
traytel@45060
   996
        (tab', G'', restrict_graph G'')
wenzelm@40281
   997
      end;
wenzelm@40281
   998
  in
traytel@45060
   999
    map_coes_and_graphs coercion_data_update context
wenzelm@40281
  1000
  end;
wenzelm@40281
  1001
traytel@45059
  1002
fun delete_coercion raw_t context =
traytel@45059
  1003
  let
traytel@45059
  1004
    val ctxt = Context.proof_of context;
traytel@45059
  1005
    val t = singleton (Variable.polymorphic ctxt) raw_t;
traytel@45059
  1006
wenzelm@55303
  1007
    fun err_coercion the =
wenzelm@55303
  1008
      error ("Not" ^
traytel@45059
  1009
        (if the then " the defined " else  " a ") ^ "coercion:\n" ^
traytel@45059
  1010
        Syntax.string_of_term ctxt t ^ " :: " ^
traytel@45059
  1011
        Syntax.string_of_typ ctxt (fastype_of t));
traytel@45059
  1012
traytel@45059
  1013
    val (T1, T2) = Term.dest_funT (fastype_of t)
traytel@45059
  1014
      handle TYPE _ => err_coercion false;
traytel@45059
  1015
traytel@54584
  1016
    val (a, _) = dest_Type T1
traytel@45060
  1017
      handle TYPE _ => err_coercion false;
traytel@45059
  1018
traytel@54584
  1019
    val (b, _) = dest_Type T2
traytel@45060
  1020
      handle TYPE _ => err_coercion false;
traytel@45059
  1021
traytel@45059
  1022
    fun delete_and_insert tab G =
traytel@45059
  1023
      let
traytel@45059
  1024
        val pairs =
traytel@45060
  1025
          Symreltab.fold (fn ((a, b), (_, (_, ts))) => fn pairs =>
traytel@45059
  1026
            if member (op aconv) ts t then (a, b) :: pairs else pairs) tab [(a, b)];
traytel@45059
  1027
        fun delete pair (G, tab) = (Graph.del_edge pair G, Symreltab.delete_safe pair tab);
traytel@45059
  1028
        val (G', tab') = fold delete pairs (G, tab);
wenzelm@49564
  1029
        fun reinsert pair (G, xs) =
wenzelm@49564
  1030
          (case Graph.irreducible_paths G pair of
wenzelm@49564
  1031
            [] => (G, xs)
wenzelm@49564
  1032
          | _ => (Graph.add_edge pair G, (pair, transitive_coercion ctxt tab' G' pair) :: xs));
traytel@45059
  1033
        val (G'', ins) = fold reinsert pairs (G', []);
traytel@45059
  1034
      in
traytel@45060
  1035
        (fold Symreltab.update ins tab', G'', restrict_graph G'')
traytel@45059
  1036
      end
traytel@45059
  1037
wenzelm@55303
  1038
    fun show_term t =
wenzelm@55303
  1039
      Pretty.block [Syntax.pretty_term ctxt t,
wenzelm@55303
  1040
        Pretty.str " :: ", Syntax.pretty_typ ctxt (fastype_of t)];
traytel@45059
  1041
traytel@45060
  1042
    fun coercion_data_update (tab, G, _) =
wenzelm@55303
  1043
      (case Symreltab.lookup tab (a, b) of
wenzelm@55303
  1044
        NONE => err_coercion false
wenzelm@55303
  1045
      | SOME (t', (_, [])) =>
wenzelm@55303
  1046
          if t aconv t'
wenzelm@55303
  1047
          then delete_and_insert tab G
wenzelm@55303
  1048
          else err_coercion true
wenzelm@55303
  1049
      | SOME (t', (_, ts)) =>
wenzelm@55303
  1050
          if t aconv t' then
wenzelm@55303
  1051
            error ("Cannot delete the automatically derived coercion:\n" ^
traytel@45059
  1052
              Syntax.string_of_term ctxt t ^ " :: " ^
wenzelm@55303
  1053
              Syntax.string_of_typ ctxt (fastype_of t) ^ "\n\n" ^
wenzelm@55303
  1054
              Pretty.string_of
wenzelm@55303
  1055
                (Pretty.big_list "Deleting one of the coercions:" (map show_term ts)) ^
traytel@45059
  1056
              "\nwill also remove the transitive coercion.")
wenzelm@55303
  1057
          else err_coercion true);
traytel@45059
  1058
  in
traytel@45060
  1059
    map_coes_and_graphs coercion_data_update context
traytel@45059
  1060
  end;
traytel@45059
  1061
traytel@45059
  1062
fun print_coercions ctxt =
traytel@45059
  1063
  let
traytel@45060
  1064
    fun separate _ [] = ([], [])
wenzelm@52432
  1065
      | separate P (x :: xs) = (if P x then apfst else apsnd) (cons x) (separate P xs);
traytel@45060
  1066
    val (simple, complex) =
traytel@45060
  1067
      separate (fn (_, (_, ((Ts, Us), _))) => null Ts andalso null Us)
traytel@45060
  1068
        (Symreltab.dest (coes_of ctxt));
wenzelm@52432
  1069
    fun show_coercion ((a, b), (t, ((Ts, Us), _))) =
wenzelm@52432
  1070
      Pretty.item [Pretty.block
wenzelm@52432
  1071
       [Syntax.pretty_typ ctxt (Type (a, Ts)), Pretty.brk 1,
wenzelm@52432
  1072
        Pretty.str "<:", Pretty.brk 1,
wenzelm@52432
  1073
        Syntax.pretty_typ ctxt (Type (b, Us)), Pretty.brk 3,
wenzelm@52432
  1074
        Pretty.block
wenzelm@55763
  1075
         [Pretty.keyword2 "using", Pretty.brk 1,
wenzelm@52432
  1076
          Pretty.quote (Syntax.pretty_term ctxt t)]]];
wenzelm@52432
  1077
wenzelm@52432
  1078
    val type_space = Proof_Context.type_space ctxt;
wenzelm@52432
  1079
    val tmaps =
wenzelm@59058
  1080
      sort (Name_Space.extern_ord ctxt type_space o apply2 #1)
wenzelm@52432
  1081
        (Symtab.dest (tmaps_of ctxt));
wenzelm@53539
  1082
    fun show_map (c, (t, _)) =
wenzelm@52432
  1083
      Pretty.block
wenzelm@53539
  1084
       [Name_Space.pretty ctxt type_space c, Pretty.str ":",
wenzelm@52432
  1085
        Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt t)];
traytel@45059
  1086
  in
wenzelm@52432
  1087
   [Pretty.big_list "coercions between base types:" (map show_coercion simple),
wenzelm@52432
  1088
    Pretty.big_list "other coercions:" (map show_coercion complex),
wenzelm@52432
  1089
    Pretty.big_list "coercion maps:" (map show_map tmaps)]
wenzelm@56334
  1090
  end |> Pretty.writeln_chunks;
traytel@45059
  1091
traytel@45059
  1092
wenzelm@58826
  1093
(* attribute setup *)
wenzelm@40283
  1094
traytel@51319
  1095
val parse_coerce_args =
traytel@51327
  1096
  Args.$$$ "+" >> K PERMIT || Args.$$$ "-" >> K FORBID || Args.$$$ "0" >> K LEAVE
wenzelm@40283
  1097
wenzelm@58826
  1098
val _ =
wenzelm@58826
  1099
  Theory.setup
wenzelm@58826
  1100
   (Attrib.setup @{binding coercion}
wenzelm@58826
  1101
      (Args.term >> (fn t => Thm.declaration_attribute (K (add_coercion t))))
wenzelm@58826
  1102
      "declaration of new coercions" #>
wenzelm@58826
  1103
    Attrib.setup @{binding coercion_delete}
wenzelm@58826
  1104
      (Args.term >> (fn t => Thm.declaration_attribute (K (delete_coercion t))))
wenzelm@58826
  1105
      "deletion of coercions" #>
wenzelm@58826
  1106
    Attrib.setup @{binding coercion_map}
wenzelm@58826
  1107
      (Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t))))
wenzelm@58826
  1108
      "declaration of new map functions" #>
wenzelm@58826
  1109
    Attrib.setup @{binding coercion_args}
wenzelm@58826
  1110
      (Args.const {proper = false, strict = false} -- Scan.lift (Scan.repeat1 parse_coerce_args) >>
wenzelm@58826
  1111
        (fn spec => Thm.declaration_attribute (K (map_coerce_args (Symtab.update spec)))))
wenzelm@58826
  1112
      "declaration of new constants with coercion-invariant arguments");
wenzelm@40281
  1113
traytel@45059
  1114
traytel@45059
  1115
(* outer syntax commands *)
traytel@45059
  1116
traytel@45059
  1117
val _ =
wenzelm@58893
  1118
  Outer_Syntax.command @{command_spec "print_coercions"}
wenzelm@52432
  1119
    "print information about coercions"
wenzelm@52432
  1120
    (Scan.succeed (Toplevel.keep (print_coercions o Toplevel.context_of)));
traytel@45059
  1121
wenzelm@40281
  1122
end;