src/Tools/subtyping.ML
author wenzelm
Wed Nov 09 20:47:11 2011 +0100 (2011-11-09)
changeset 45429 fd58cbf8cae3
parent 45102 7bb89635eb51
child 45935 32f769f94ea4
permissions -rw-r--r--
tuned signature;
tuned;
wenzelm@40281
     1
(*  Title:      Tools/subtyping.ML
wenzelm@40281
     2
    Author:     Dmitriy Traytel, TU Muenchen
wenzelm@40281
     3
wenzelm@40281
     4
Coercive subtyping via subtype constraints.
wenzelm@40281
     5
*)
wenzelm@40281
     6
wenzelm@40281
     7
signature SUBTYPING =
wenzelm@40281
     8
sig
wenzelm@40939
     9
  val coercion_enabled: bool Config.T
wenzelm@40284
    10
  val add_type_map: term -> Context.generic -> Context.generic
wenzelm@40284
    11
  val add_coercion: term -> Context.generic -> Context.generic
traytel@45059
    12
  val print_coercions: Proof.context -> unit
traytel@45059
    13
  val print_coercion_maps: Proof.context -> unit
wenzelm@40283
    14
  val setup: theory -> theory
wenzelm@40281
    15
end;
wenzelm@40281
    16
wenzelm@40283
    17
structure Subtyping: SUBTYPING =
wenzelm@40281
    18
struct
wenzelm@40281
    19
wenzelm@40281
    20
(** coercions data **)
wenzelm@40281
    21
traytel@41353
    22
datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
wenzelm@40281
    23
wenzelm@40281
    24
datatype data = Data of
traytel@45060
    25
  {coes: (term * ((typ list * typ list) * term list)) Symreltab.table,  (*coercions table*)
traytel@45060
    26
   (*full coercions graph - only used at coercion declaration/deletion*)
traytel@45060
    27
   full_graph: int Graph.T,
traytel@45060
    28
   (*coercions graph restricted to base types - for efficiency reasons strored in the context*)
traytel@45060
    29
   coes_graph: int Graph.T,
wenzelm@40282
    30
   tmaps: (term * variance list) Symtab.table};  (*map functions*)
wenzelm@40281
    31
traytel@45060
    32
fun make_data (coes, full_graph, coes_graph, tmaps) =
traytel@45060
    33
  Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph, tmaps = tmaps};
wenzelm@40281
    34
wenzelm@40281
    35
structure Data = Generic_Data
wenzelm@40281
    36
(
wenzelm@40281
    37
  type T = data;
traytel@45060
    38
  val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty);
wenzelm@40281
    39
  val extend = I;
wenzelm@40281
    40
  fun merge
traytel@45060
    41
    (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1, tmaps = tmaps1},
traytel@45060
    42
      Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2, tmaps = tmaps2}) =
traytel@45060
    43
    make_data (Symreltab.merge (eq_pair (op aconv)
traytel@45060
    44
        (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv))))
traytel@45060
    45
        (coes1, coes2),
traytel@45060
    46
      Graph.merge (op =) (full_graph1, full_graph2),
wenzelm@40281
    47
      Graph.merge (op =) (coes_graph1, coes_graph2),
wenzelm@40281
    48
      Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
wenzelm@40281
    49
);
wenzelm@40281
    50
wenzelm@40281
    51
fun map_data f =
traytel@45060
    52
  Data.map (fn Data {coes, full_graph, coes_graph, tmaps} =>
traytel@45060
    53
    make_data (f (coes, full_graph, coes_graph, tmaps)));
wenzelm@40281
    54
wenzelm@40281
    55
fun map_coes f =
traytel@45060
    56
  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
traytel@45060
    57
    (f coes, full_graph, coes_graph, tmaps));
wenzelm@40281
    58
wenzelm@40281
    59
fun map_coes_graph f =
traytel@45060
    60
  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
traytel@45060
    61
    (coes, full_graph, f coes_graph, tmaps));
wenzelm@40281
    62
traytel@45060
    63
fun map_coes_and_graphs f =
traytel@45060
    64
  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
traytel@45060
    65
    let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph);
traytel@45060
    66
    in (coes', full_graph', coes_graph', tmaps) end);
wenzelm@40281
    67
wenzelm@40281
    68
fun map_tmaps f =
traytel@45060
    69
  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
traytel@45060
    70
    (coes, full_graph, coes_graph, f tmaps));
wenzelm@40281
    71
wenzelm@40285
    72
val rep_data = (fn Data args => args) o Data.get o Context.Proof;
wenzelm@40281
    73
wenzelm@40281
    74
val coes_of = #coes o rep_data;
wenzelm@40281
    75
val coes_graph_of = #coes_graph o rep_data;
wenzelm@40281
    76
val tmaps_of = #tmaps o rep_data;
wenzelm@40281
    77
wenzelm@40281
    78
wenzelm@40281
    79
wenzelm@40281
    80
(** utils **)
wenzelm@40281
    81
traytel@45060
    82
fun restrict_graph G =
traytel@45060
    83
  Graph.subgraph (fn key => if Graph.get_node G key = 0 then true else false) G;
traytel@45060
    84
wenzelm@40281
    85
fun nameT (Type (s, [])) = s;
wenzelm@40281
    86
fun t_of s = Type (s, []);
wenzelm@40286
    87
wenzelm@40281
    88
fun sort_of (TFree (_, S)) = SOME S
wenzelm@40281
    89
  | sort_of (TVar (_, S)) = SOME S
wenzelm@40281
    90
  | sort_of _ = NONE;
wenzelm@40281
    91
wenzelm@40281
    92
val is_typeT = fn (Type _) => true | _ => false;
traytel@41353
    93
val is_stypeT = fn (Type (_, [])) => true | _ => false;
wenzelm@40282
    94
val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
wenzelm@40281
    95
val is_freeT = fn (TFree _) => true | _ => false;
wenzelm@40286
    96
val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
traytel@41353
    97
val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
traytel@43591
    98
val is_identity = fn (Abs (_, _, Bound 0)) => true | _ => false;
wenzelm@40281
    99
traytel@45060
   100
fun instantiate t Ts = Term.subst_TVars
traytel@45060
   101
  ((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts) t;
traytel@45060
   102
traytel@45060
   103
exception COERCION_GEN_ERROR of unit -> string;
traytel@45060
   104
traytel@45060
   105
fun inst_collect tye err T U =
traytel@45060
   106
  (case (T, Type_Infer.deref tye U) of
traytel@45060
   107
    (TVar (xi, S), U) => [(xi, U)]
traytel@45060
   108
  | (Type (a, Ts), Type (b, Us)) =>
traytel@45060
   109
      if a <> b then raise error (err ()) else inst_collects tye err Ts Us
traytel@45060
   110
  | (_, U') => if T <> U' then error (err ()) else [])
traytel@45060
   111
and inst_collects tye err Ts Us =
traytel@45060
   112
  fold2 (fn T => fn U => fn is => inst_collect tye err T U @ is) Ts Us [];
traytel@45060
   113
wenzelm@40281
   114
traytel@40836
   115
(* unification *)
wenzelm@40281
   116
wenzelm@40281
   117
exception NO_UNIFIER of string * typ Vartab.table;
wenzelm@40281
   118
wenzelm@40281
   119
fun unify weak ctxt =
wenzelm@40281
   120
  let
wenzelm@42361
   121
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42386
   122
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   123
wenzelm@40282
   124
wenzelm@40281
   125
    (* adjust sorts of parameters *)
wenzelm@40281
   126
wenzelm@40281
   127
    fun not_of_sort x S' S =
wenzelm@40281
   128
      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
wenzelm@40281
   129
        Syntax.string_of_sort ctxt S;
wenzelm@40281
   130
wenzelm@40281
   131
    fun meet (_, []) tye_idx = tye_idx
wenzelm@40281
   132
      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
wenzelm@40281
   133
          meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
wenzelm@40281
   134
      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
wenzelm@40281
   135
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   136
          else raise NO_UNIFIER (not_of_sort x S' S, tye)
wenzelm@40281
   137
      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
wenzelm@40281
   138
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   139
          else if Type_Infer.is_param xi then
wenzelm@40286
   140
            (Vartab.update_new
wenzelm@40286
   141
              (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
wenzelm@40281
   142
          else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
wenzelm@40281
   143
    and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
wenzelm@40286
   144
          meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
wenzelm@40281
   145
      | meets _ tye_idx = tye_idx;
wenzelm@40281
   146
wenzelm@40281
   147
    val weak_meet = if weak then fn _ => I else meet
wenzelm@40281
   148
wenzelm@40281
   149
wenzelm@40281
   150
    (* occurs check and assignment *)
wenzelm@40281
   151
wenzelm@40281
   152
    fun occurs_check tye xi (TVar (xi', _)) =
wenzelm@40281
   153
          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
wenzelm@40281
   154
          else
wenzelm@40281
   155
            (case Vartab.lookup tye xi' of
wenzelm@40281
   156
              NONE => ()
wenzelm@40281
   157
            | SOME T => occurs_check tye xi T)
wenzelm@40281
   158
      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
wenzelm@40281
   159
      | occurs_check _ _ _ = ();
wenzelm@40281
   160
wenzelm@40281
   161
    fun assign xi (T as TVar (xi', _)) S env =
wenzelm@40281
   162
          if xi = xi' then env
wenzelm@40281
   163
          else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
wenzelm@40281
   164
      | assign xi T S (env as (tye, _)) =
wenzelm@40281
   165
          (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
wenzelm@40281
   166
wenzelm@40281
   167
wenzelm@40281
   168
    (* unification *)
wenzelm@40281
   169
wenzelm@40281
   170
    fun show_tycon (a, Ts) =
wenzelm@40281
   171
      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
wenzelm@40281
   172
wenzelm@40281
   173
    fun unif (T1, T2) (env as (tye, _)) =
wenzelm@40286
   174
      (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
wenzelm@40281
   175
        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
wenzelm@40281
   176
      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
wenzelm@40281
   177
      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
wenzelm@40281
   178
          if weak andalso null Ts andalso null Us then env
wenzelm@40281
   179
          else if a <> b then
wenzelm@40281
   180
            raise NO_UNIFIER
wenzelm@40281
   181
              ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
wenzelm@40281
   182
          else fold unif (Ts ~~ Us) env
wenzelm@40281
   183
      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
wenzelm@40281
   184
wenzelm@40281
   185
  in unif end;
wenzelm@40281
   186
wenzelm@40281
   187
val weak_unify = unify true;
wenzelm@40281
   188
val strong_unify = unify false;
wenzelm@40281
   189
wenzelm@40281
   190
wenzelm@40281
   191
(* Typ_Graph shortcuts *)
wenzelm@40281
   192
wenzelm@40281
   193
fun get_preds G T = Typ_Graph.all_preds G [T];
wenzelm@40281
   194
fun get_succs G T = Typ_Graph.all_succs G [T];
wenzelm@40281
   195
fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
wenzelm@40281
   196
fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
wenzelm@44338
   197
fun new_imm_preds G Ts =  (* FIXME inefficient *)
wenzelm@44338
   198
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_preds G) Ts));
wenzelm@44338
   199
fun new_imm_succs G Ts =  (* FIXME inefficient *)
wenzelm@44338
   200
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_succs G) Ts));
wenzelm@40281
   201
wenzelm@40281
   202
wenzelm@40281
   203
(* Graph shortcuts *)
wenzelm@40281
   204
traytel@45060
   205
fun maybe_new_node s G = perhaps (try (Graph.new_node s)) G
wenzelm@40281
   206
fun maybe_new_nodes ss G = fold maybe_new_node ss G
wenzelm@40281
   207
wenzelm@40281
   208
wenzelm@40281
   209
wenzelm@40281
   210
(** error messages **)
wenzelm@40281
   211
traytel@45060
   212
infixr ++> (* lazy error msg composition *)
traytel@45060
   213
traytel@45060
   214
fun err ++> str = err #> suffix str
traytel@45060
   215
wenzelm@42383
   216
fun gen_msg err msg =
traytel@45060
   217
  err () ^ "\nNow trying to infer coercions globally.\n\nCoercion inference failed" ^
traytel@45060
   218
  (if msg = "" then "" else ":\n" ^ msg) ^ "\n";
traytel@40836
   219
wenzelm@40281
   220
fun prep_output ctxt tye bs ts Ts =
wenzelm@40281
   221
  let
wenzelm@40281
   222
    val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
wenzelm@40281
   223
    val (Ts', Ts'') = chop (length Ts) Ts_bTs';
wenzelm@40281
   224
    fun prep t =
wenzelm@40281
   225
      let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
wenzelm@42284
   226
      in Term.subst_bounds (map Syntax_Trans.mark_boundT xs, t) end;
wenzelm@40281
   227
  in (map prep ts', Ts') end;
wenzelm@40281
   228
wenzelm@40281
   229
fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
wenzelm@42383
   230
traytel@40836
   231
fun unif_failed msg =
traytel@40836
   232
  "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
wenzelm@42383
   233
traytel@40836
   234
fun err_appl_msg ctxt msg tye bs t T u U () =
traytel@40836
   235
  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
wenzelm@42383
   236
  in unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n" end;
wenzelm@40281
   237
wenzelm@40281
   238
fun err_list ctxt msg tye Ts =
wenzelm@40281
   239
  let
wenzelm@40281
   240
    val (_, Ts') = prep_output ctxt tye [] [] Ts;
wenzelm@42383
   241
    val text =
traytel@45060
   242
      msg ^ "\nCannot unify a list of types that should be the same:\n" ^
wenzelm@42383
   243
        Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts'));
wenzelm@40281
   244
  in
wenzelm@40281
   245
    error text
wenzelm@40281
   246
  end;
wenzelm@40281
   247
wenzelm@40281
   248
fun err_bound ctxt msg tye packs =
wenzelm@40281
   249
  let
wenzelm@40281
   250
    val (ts, Ts) = fold
wenzelm@40281
   251
      (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
traytel@40836
   252
        let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
wenzelm@40282
   253
        in (t' :: ts, T' :: Ts) end)
wenzelm@40281
   254
      packs ([], []);
traytel@45060
   255
    val text = msg ^ "\n" ^ Pretty.string_of (
traytel@45060
   256
        Pretty.big_list "Cannot fulfil subtype constraints:"
traytel@45060
   257
        (map2 (fn [t, u] => fn [T, U] =>
wenzelm@40281
   258
          Pretty.block [
wenzelm@42383
   259
            Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2,
wenzelm@42383
   260
            Syntax.pretty_typ ctxt U, Pretty.brk 3,
wenzelm@42383
   261
            Pretty.str "from function application", Pretty.brk 2,
traytel@45060
   262
            Pretty.block [Syntax.pretty_term ctxt (t $ u)]])
wenzelm@40281
   263
        ts Ts))
wenzelm@40281
   264
  in
wenzelm@40281
   265
    error text
wenzelm@40281
   266
  end;
wenzelm@40281
   267
wenzelm@40281
   268
wenzelm@40281
   269
wenzelm@40281
   270
(** constraint generation **)
wenzelm@40281
   271
traytel@40836
   272
fun generate_constraints ctxt err =
wenzelm@40281
   273
  let
wenzelm@40281
   274
    fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   275
      | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   276
      | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   277
      | gen cs bs (Bound i) tye_idx =
wenzelm@43278
   278
          (snd (nth bs i handle General.Subscript => err_loose i), tye_idx, cs)
wenzelm@40281
   279
      | gen cs bs (Abs (x, T, t)) tye_idx =
wenzelm@40281
   280
          let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
wenzelm@40281
   281
          in (T --> U, tye_idx', cs') end
wenzelm@40281
   282
      | gen cs bs (t $ u) tye_idx =
wenzelm@40281
   283
          let
wenzelm@40281
   284
            val (T, tye_idx', cs') = gen cs bs t tye_idx;
wenzelm@40281
   285
            val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
wenzelm@40286
   286
            val U = Type_Infer.mk_param idx [];
wenzelm@40286
   287
            val V = Type_Infer.mk_param (idx + 1) [];
traytel@45060
   288
            val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2)
traytel@41353
   289
              handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
wenzelm@40281
   290
            val error_pack = (bs, t $ u, U, V, U');
traytel@45060
   291
          in (V, (tye_idx''),
traytel@45060
   292
            ((U', U), error_pack) :: cs'') end;
wenzelm@40281
   293
  in
wenzelm@40281
   294
    gen [] []
wenzelm@40281
   295
  end;
wenzelm@40281
   296
wenzelm@40281
   297
wenzelm@40281
   298
wenzelm@40281
   299
(** constraint resolution **)
wenzelm@40281
   300
wenzelm@40281
   301
exception BOUND_ERROR of string;
wenzelm@40281
   302
traytel@40836
   303
fun process_constraints ctxt err cs tye_idx =
wenzelm@40281
   304
  let
wenzelm@42388
   305
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42388
   306
wenzelm@40285
   307
    val coes_graph = coes_graph_of ctxt;
wenzelm@40285
   308
    val tmaps = tmaps_of ctxt;
wenzelm@42388
   309
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   310
wenzelm@40281
   311
    fun split_cs _ [] = ([], [])
wenzelm@40282
   312
      | split_cs f (c :: cs) =
wenzelm@40281
   313
          (case pairself f (fst c) of
wenzelm@40281
   314
            (false, false) => apsnd (cons c) (split_cs f cs)
wenzelm@40281
   315
          | _ => apfst (cons c) (split_cs f cs));
wenzelm@42383
   316
traytel@41353
   317
    fun unify_list (T :: Ts) tye_idx =
wenzelm@42383
   318
      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
wenzelm@40281
   319
wenzelm@40282
   320
wenzelm@40281
   321
    (* check whether constraint simplification will terminate using weak unification *)
wenzelm@40282
   322
traytel@41353
   323
    val _ = fold (fn (TU, _) => fn tye_idx =>
traytel@41353
   324
      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) =>
traytel@40836
   325
        error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
wenzelm@40281
   326
wenzelm@40281
   327
wenzelm@40281
   328
    (* simplify constraints *)
wenzelm@40282
   329
wenzelm@40281
   330
    fun simplify_constraints cs tye_idx =
wenzelm@40281
   331
      let
wenzelm@40281
   332
        fun contract a Ts Us error_pack done todo tye idx =
wenzelm@40281
   333
          let
wenzelm@40281
   334
            val arg_var =
wenzelm@40281
   335
              (case Symtab.lookup tmaps a of
wenzelm@40281
   336
                (*everything is invariant for unknown constructors*)
wenzelm@40281
   337
                NONE => replicate (length Ts) INVARIANT
wenzelm@40281
   338
              | SOME av => snd av);
wenzelm@40281
   339
            fun new_constraints (variance, constraint) (cs, tye_idx) =
wenzelm@40281
   340
              (case variance of
wenzelm@40281
   341
                COVARIANT => (constraint :: cs, tye_idx)
wenzelm@40281
   342
              | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
traytel@41353
   343
              | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
wenzelm@42383
   344
                  handle NO_UNIFIER (msg, _) =>
wenzelm@42383
   345
                    err_list ctxt (gen_msg err
traytel@45059
   346
                      "failed to unify invariant arguments w.r.t. to the known map function" ^ msg)
traytel@45060
   347
                      (fst tye_idx) (T :: Ts))
wenzelm@40281
   348
              | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
wenzelm@42383
   349
                  handle NO_UNIFIER (msg, _) =>
traytel@41353
   350
                    error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
wenzelm@40281
   351
            val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
wenzelm@40281
   352
              (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
wenzelm@40281
   353
            val test_update = is_compT orf is_freeT orf is_fixedvarT;
wenzelm@40281
   354
            val (ch, done') =
wenzelm@40286
   355
              if not (null new) then ([], done)
wenzelm@40286
   356
              else split_cs (test_update o Type_Infer.deref tye') done;
wenzelm@40281
   357
            val todo' = ch @ todo;
wenzelm@40281
   358
          in
wenzelm@40281
   359
            simplify done' (new @ todo') (tye', idx')
wenzelm@40281
   360
          end
wenzelm@40281
   361
        (*xi is definitely a parameter*)
wenzelm@40281
   362
        and expand varleq xi S a Ts error_pack done todo tye idx =
wenzelm@40281
   363
          let
wenzelm@40281
   364
            val n = length Ts;
wenzelm@40286
   365
            val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S);
wenzelm@40281
   366
            val tye' = Vartab.update_new (xi, Type(a, args)) tye;
wenzelm@40286
   367
            val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done;
wenzelm@40281
   368
            val todo' = ch @ todo;
wenzelm@40281
   369
            val new =
wenzelm@40281
   370
              if varleq then (Type(a, args), Type (a, Ts))
wenzelm@40286
   371
              else (Type (a, Ts), Type (a, args));
wenzelm@40281
   372
          in
wenzelm@40281
   373
            simplify done' ((new, error_pack) :: todo') (tye', idx + n)
wenzelm@40281
   374
          end
wenzelm@40281
   375
        (*TU is a pair of a parameter and a free/fixed variable*)
traytel@41353
   376
        and eliminate TU done todo tye idx =
wenzelm@40281
   377
          let
wenzelm@40286
   378
            val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
wenzelm@40286
   379
            val [T] = filter_out Type_Infer.is_paramT TU;
wenzelm@40281
   380
            val SOME S' = sort_of T;
wenzelm@40281
   381
            val test_update = if is_freeT T then is_freeT else is_fixedvarT;
wenzelm@40281
   382
            val tye' = Vartab.update_new (xi, T) tye;
wenzelm@40286
   383
            val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done;
wenzelm@40281
   384
            val todo' = ch @ todo;
wenzelm@40281
   385
          in
wenzelm@42388
   386
            if Sign.subsort thy (S', S) (*TODO check this*)
wenzelm@40281
   387
            then simplify done' todo' (tye', idx)
traytel@40836
   388
            else error (gen_msg err "sort mismatch")
wenzelm@40281
   389
          end
wenzelm@40281
   390
        and simplify done [] tye_idx = (done, tye_idx)
wenzelm@40281
   391
          | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
wenzelm@40286
   392
              (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
traytel@45060
   393
                (T1 as Type (a, []), T2 as Type (b, [])) =>
wenzelm@40281
   394
                  if a = b then simplify done todo tye_idx
wenzelm@40281
   395
                  else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
traytel@45060
   396
                  else error (gen_msg err (quote (Syntax.string_of_typ ctxt T1) ^
traytel@45060
   397
                    " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2)))
wenzelm@40281
   398
              | (Type (a, Ts), Type (b, Us)) =>
traytel@40836
   399
                  if a <> b then error (gen_msg err "different constructors")
traytel@40836
   400
                    (fst tye_idx) error_pack
wenzelm@40281
   401
                  else contract a Ts Us error_pack done todo tye idx
wenzelm@40282
   402
              | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
wenzelm@40281
   403
                  expand true xi S a Ts error_pack done todo tye idx
wenzelm@40282
   404
              | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
wenzelm@40281
   405
                  expand false xi S a Ts error_pack done todo tye idx
wenzelm@40281
   406
              | (T, U) =>
wenzelm@40281
   407
                  if T = U then simplify done todo tye_idx
wenzelm@40282
   408
                  else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
wenzelm@40286
   409
                    exists Type_Infer.is_paramT [T, U]
traytel@41353
   410
                  then eliminate [T, U] done todo tye idx
wenzelm@40281
   411
                  else if exists (is_freeT orf is_fixedvarT) [T, U]
traytel@40836
   412
                  then error (gen_msg err "not eliminated free/fixed variables")
wenzelm@40282
   413
                  else simplify (((T, U), error_pack) :: done) todo tye_idx);
wenzelm@40281
   414
      in
wenzelm@40281
   415
        simplify [] cs tye_idx
wenzelm@40281
   416
      end;
wenzelm@40281
   417
wenzelm@40281
   418
wenzelm@40281
   419
    (* do simplification *)
wenzelm@40282
   420
wenzelm@40281
   421
    val (cs', tye_idx') = simplify_constraints cs tye_idx;
wenzelm@42383
   422
wenzelm@42383
   423
    fun find_error_pack lower T' = map_filter
traytel@40836
   424
      (fn ((T, U), pack) => if if lower then T' = U else T' = T then SOME pack else NONE) cs';
wenzelm@42383
   425
wenzelm@42383
   426
    fun find_cycle_packs nodes =
traytel@40836
   427
      let
traytel@40836
   428
        val (but_last, last) = split_last nodes
traytel@40836
   429
        val pairs = (last, hd nodes) :: (but_last ~~ tl nodes);
traytel@40836
   430
      in
traytel@40836
   431
        map_filter
wenzelm@40838
   432
          (fn (TU, pack) => if member (op =) pairs TU then SOME pack else NONE)
traytel@40836
   433
          cs'
traytel@40836
   434
      end;
wenzelm@40281
   435
wenzelm@40281
   436
    (*styps stands either for supertypes or for subtypes of a type T
wenzelm@40281
   437
      in terms of the subtype-relation (excluding T itself)*)
wenzelm@40282
   438
    fun styps super T =
wenzelm@44338
   439
      (if super then Graph.immediate_succs else Graph.immediate_preds) coes_graph T
wenzelm@40281
   440
        handle Graph.UNDEF _ => [];
wenzelm@40281
   441
wenzelm@40282
   442
    fun minmax sup (T :: Ts) =
wenzelm@40281
   443
      let
wenzelm@40281
   444
        fun adjust T U = if sup then (T, U) else (U, T);
wenzelm@40281
   445
        fun extract T [] = T
wenzelm@40282
   446
          | extract T (U :: Us) =
wenzelm@40281
   447
              if Graph.is_edge coes_graph (adjust T U) then extract T Us
wenzelm@40281
   448
              else if Graph.is_edge coes_graph (adjust U T) then extract U Us
traytel@40836
   449
              else raise BOUND_ERROR "uncomparable types in type list";
wenzelm@40281
   450
      in
wenzelm@40281
   451
        t_of (extract T Ts)
wenzelm@40281
   452
      end;
wenzelm@40281
   453
wenzelm@40282
   454
    fun ex_styp_of_sort super T styps_and_sorts =
wenzelm@40281
   455
      let
wenzelm@40281
   456
        fun adjust T U = if super then (T, U) else (U, T);
wenzelm@40282
   457
        fun styp_test U Ts = forall
wenzelm@40281
   458
          (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
wenzelm@42388
   459
        fun fitting Ts S U = Sign.of_sort thy (t_of U, S) andalso styp_test U Ts
wenzelm@40281
   460
      in
wenzelm@40281
   461
        forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
wenzelm@40281
   462
      end;
wenzelm@40281
   463
wenzelm@40281
   464
    (* computes the tightest possible, correct assignment for 'a::S
wenzelm@40281
   465
       e.g. in the supremum case (sup = true):
wenzelm@40281
   466
               ------- 'a::S---
wenzelm@40281
   467
              /        /    \  \
wenzelm@40281
   468
             /        /      \  \
wenzelm@40281
   469
        'b::C1   'c::C2 ...  T1 T2 ...
wenzelm@40281
   470
wenzelm@40281
   471
       sorts - list of sorts [C1, C2, ...]
wenzelm@40281
   472
       T::Ts - non-empty list of base types [T1, T2, ...]
wenzelm@40281
   473
    *)
wenzelm@40282
   474
    fun tightest sup S styps_and_sorts (T :: Ts) =
wenzelm@40281
   475
      let
wenzelm@42388
   476
        fun restriction T = Sign.of_sort thy (t_of T, S)
wenzelm@40281
   477
          andalso ex_styp_of_sort (not sup) T styps_and_sorts;
wenzelm@40281
   478
        fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
wenzelm@40281
   479
      in
wenzelm@40281
   480
        (case fold candidates Ts (filter restriction (T :: styps sup T)) of
traytel@40836
   481
          [] => raise BOUND_ERROR ("no " ^ (if sup then "supremum" else "infimum"))
wenzelm@40281
   482
        | [T] => t_of T
wenzelm@40281
   483
        | Ts => minmax sup Ts)
wenzelm@40281
   484
      end;
wenzelm@40281
   485
wenzelm@40281
   486
    fun build_graph G [] tye_idx = (G, tye_idx)
wenzelm@40282
   487
      | build_graph G ((T, U) :: cs) tye_idx =
wenzelm@40281
   488
        if T = U then build_graph G cs tye_idx
wenzelm@40281
   489
        else
wenzelm@40281
   490
          let
wenzelm@40281
   491
            val G' = maybe_new_typnodes [T, U] G;
traytel@45059
   492
            val (G'', tye_idx') = (Typ_Graph.add_edge_acyclic (T, U) G', tye_idx)
wenzelm@40281
   493
              handle Typ_Graph.CYCLES cycles =>
wenzelm@40281
   494
                let
wenzelm@42383
   495
                  val (tye, idx) =
wenzelm@42383
   496
                    fold
traytel@40836
   497
                      (fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
wenzelm@42383
   498
                        handle NO_UNIFIER (msg, _) =>
wenzelm@42383
   499
                          err_bound ctxt
traytel@40836
   500
                            (gen_msg err ("constraint cycle not unifiable" ^ msg)) (fst tye_idx)
traytel@40836
   501
                            (find_cycle_packs cycle)))
traytel@40836
   502
                      cycles tye_idx
wenzelm@40281
   503
                in
traytel@40836
   504
                  collapse (tye, idx) cycles G
traytel@40836
   505
                end
wenzelm@40281
   506
          in
wenzelm@40281
   507
            build_graph G'' cs tye_idx'
wenzelm@40281
   508
          end
traytel@40836
   509
    and collapse (tye, idx) cycles G = (*nodes non-empty list*)
wenzelm@40281
   510
      let
traytel@40836
   511
        (*all cycles collapse to one node,
traytel@40836
   512
          because all of them share at least the nodes x and y*)
traytel@40836
   513
        val nodes = (distinct (op =) (flat cycles));
traytel@40836
   514
        val T = Type_Infer.deref tye (hd nodes);
wenzelm@40281
   515
        val P = new_imm_preds G nodes;
wenzelm@40281
   516
        val S = new_imm_succs G nodes;
wenzelm@40281
   517
        val G' = Typ_Graph.del_nodes (tl nodes) G;
traytel@40836
   518
        fun check_and_gen super T' =
traytel@40836
   519
          let val U = Type_Infer.deref tye T';
traytel@40836
   520
          in
traytel@40836
   521
            if not (is_typeT T) orelse not (is_typeT U) orelse T = U
traytel@40836
   522
            then if super then (hd nodes, T') else (T', hd nodes)
wenzelm@42383
   523
            else
wenzelm@42383
   524
              if super andalso
traytel@40836
   525
                Graph.is_edge coes_graph (nameT T, nameT U) then (hd nodes, T')
wenzelm@42383
   526
              else if not super andalso
traytel@40836
   527
                Graph.is_edge coes_graph (nameT U, nameT T) then (T', hd nodes)
traytel@40836
   528
              else err_bound ctxt (gen_msg err "cycle elimination produces inconsistent graph")
wenzelm@42383
   529
                    (fst tye_idx)
traytel@40836
   530
                    (maps find_cycle_packs cycles @ find_error_pack super T')
traytel@40836
   531
          end;
wenzelm@40281
   532
      in
traytel@40836
   533
        build_graph G' (map (check_and_gen false) P @ map (check_and_gen true) S) (tye, idx)
wenzelm@40281
   534
      end;
wenzelm@40281
   535
wenzelm@40281
   536
    fun assign_bound lower G key (tye_idx as (tye, _)) =
wenzelm@40286
   537
      if Type_Infer.is_paramT (Type_Infer.deref tye key) then
wenzelm@40281
   538
        let
wenzelm@40286
   539
          val TVar (xi, S) = Type_Infer.deref tye key;
wenzelm@40281
   540
          val get_bound = if lower then get_preds else get_succs;
wenzelm@40281
   541
          val raw_bound = get_bound G key;
wenzelm@40286
   542
          val bound = map (Type_Infer.deref tye) raw_bound;
wenzelm@40286
   543
          val not_params = filter_out Type_Infer.is_paramT bound;
wenzelm@40282
   544
          fun to_fulfil T =
wenzelm@40281
   545
            (case sort_of T of
wenzelm@40281
   546
              NONE => NONE
wenzelm@40282
   547
            | SOME S =>
wenzelm@40286
   548
                SOME
wenzelm@40286
   549
                  (map nameT
wenzelm@42405
   550
                    (filter_out Type_Infer.is_paramT
wenzelm@42405
   551
                      (map (Type_Infer.deref tye) (get_bound G T))), S));
wenzelm@40281
   552
          val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
wenzelm@40281
   553
          val assignment =
wenzelm@40281
   554
            if null bound orelse null not_params then NONE
wenzelm@40281
   555
            else SOME (tightest lower S styps_and_sorts (map nameT not_params)
wenzelm@42383
   556
                handle BOUND_ERROR msg =>
traytel@40836
   557
                  err_bound ctxt (gen_msg err msg) tye (find_error_pack lower key))
wenzelm@40281
   558
        in
wenzelm@40281
   559
          (case assignment of
wenzelm@40281
   560
            NONE => tye_idx
wenzelm@40281
   561
          | SOME T =>
wenzelm@40286
   562
              if Type_Infer.is_paramT T then tye_idx
wenzelm@40281
   563
              else if lower then (*upper bound check*)
wenzelm@40281
   564
                let
wenzelm@40286
   565
                  val other_bound = map (Type_Infer.deref tye) (get_succs G key);
wenzelm@40281
   566
                  val s = nameT T;
wenzelm@40281
   567
                in
wenzelm@40281
   568
                  if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
wenzelm@40281
   569
                  then apfst (Vartab.update (xi, T)) tye_idx
traytel@45060
   570
                  else err_bound ctxt (gen_msg err ("assigned base type " ^
traytel@45060
   571
                    quote (Syntax.string_of_typ ctxt T) ^
wenzelm@40281
   572
                    " clashes with the upper bound of variable " ^
traytel@40836
   573
                    Syntax.string_of_typ ctxt (TVar(xi, S)))) tye (find_error_pack (not lower) key)
wenzelm@40281
   574
                end
wenzelm@40281
   575
              else apfst (Vartab.update (xi, T)) tye_idx)
wenzelm@40281
   576
        end
wenzelm@40281
   577
      else tye_idx;
wenzelm@40281
   578
wenzelm@40281
   579
    val assign_lb = assign_bound true;
wenzelm@40281
   580
    val assign_ub = assign_bound false;
wenzelm@40281
   581
wenzelm@40281
   582
    fun assign_alternating ts' ts G tye_idx =
wenzelm@40281
   583
      if ts' = ts then tye_idx
wenzelm@40281
   584
      else
wenzelm@40281
   585
        let
wenzelm@40281
   586
          val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
wenzelm@40281
   587
            |> fold (assign_ub G) ts;
wenzelm@40281
   588
        in
wenzelm@42383
   589
          assign_alternating ts
traytel@40836
   590
            (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
wenzelm@40281
   591
        end;
wenzelm@40281
   592
wenzelm@40281
   593
    (*Unify all weakly connected components of the constraint forest,
wenzelm@40282
   594
      that contain only params. These are the only WCCs that contain
wenzelm@40281
   595
      params anyway.*)
wenzelm@40281
   596
    fun unify_params G (tye_idx as (tye, _)) =
wenzelm@40281
   597
      let
wenzelm@40286
   598
        val max_params =
wenzelm@40286
   599
          filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G);
wenzelm@40281
   600
        val to_unify = map (fn T => T :: get_preds G T) max_params;
wenzelm@40281
   601
      in
wenzelm@42383
   602
        fold
traytel@40836
   603
          (fn Ts => fn tye_idx' => unify_list Ts tye_idx'
traytel@41353
   604
            handle NO_UNIFIER (msg, _) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
traytel@40836
   605
          to_unify tye_idx
wenzelm@40281
   606
      end;
wenzelm@40281
   607
wenzelm@40281
   608
    fun solve_constraints G tye_idx = tye_idx
wenzelm@40281
   609
      |> assign_alternating [] (Typ_Graph.keys G) G
wenzelm@40281
   610
      |> unify_params G;
wenzelm@40281
   611
  in
wenzelm@40281
   612
    build_graph Typ_Graph.empty (map fst cs') tye_idx'
wenzelm@40281
   613
      |-> solve_constraints
wenzelm@40281
   614
  end;
wenzelm@40281
   615
wenzelm@40281
   616
wenzelm@40281
   617
wenzelm@40281
   618
(** coercion insertion **)
wenzelm@40281
   619
traytel@45060
   620
fun gen_coercion ctxt err tye TU =
traytel@45060
   621
  let
traytel@45060
   622
    fun gen (T1, T2) = (case pairself (Type_Infer.deref tye) (T1, T2) of
traytel@45060
   623
        (T1 as (Type (a, [])), T2 as (Type (b, []))) =>
traytel@45060
   624
            if a = b
traytel@45060
   625
            then Abs (Name.uu, Type (a, []), Bound 0)
traytel@45060
   626
            else
traytel@45060
   627
              (case Symreltab.lookup (coes_of ctxt) (a, b) of
traytel@45060
   628
                NONE => raise COERCION_GEN_ERROR (err ++> quote (Syntax.string_of_typ ctxt T1) ^
traytel@45060
   629
                  " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2))
traytel@45060
   630
              | SOME (co, _) => co)
traytel@45102
   631
      | (T1 as Type (a, Ts), T2 as Type (b, Us)) =>
traytel@45060
   632
            if a <> b
traytel@45060
   633
            then
traytel@45060
   634
              (case Symreltab.lookup (coes_of ctxt) (a, b) of
traytel@45060
   635
                (*immediate error - cannot fix complex coercion with the global algorithm*)
traytel@45060
   636
                NONE => error (err () ^ "No coercion known for type constructors: " ^
traytel@45060
   637
                  quote a ^ " and " ^ quote b)
traytel@45060
   638
              | SOME (co, ((Ts', Us'), _)) =>
traytel@45060
   639
                  let
traytel@45060
   640
                    val co_before = gen (T1, Type (a, Ts'));
traytel@45060
   641
                    val coT = range_type (fastype_of co_before);
traytel@45060
   642
                    val insts = inst_collect tye (err ++> "Could not insert complex coercion")
traytel@45060
   643
                      (domain_type (fastype_of co)) coT;
traytel@45060
   644
                    val co' = Term.subst_TVars insts co;
traytel@45060
   645
                    val co_after = gen (Type (b, (map (typ_subst_TVars insts) Us')), T2);
traytel@45060
   646
                  in
traytel@45060
   647
                    Abs (Name.uu, T1, Library.foldr (op $)
traytel@45060
   648
                      (filter (not o is_identity) [co_after, co', co_before], Bound 0))
traytel@45060
   649
                  end)
traytel@45060
   650
            else
traytel@45060
   651
              let
traytel@45060
   652
                fun sub_co (COVARIANT, TU) = SOME (gen TU)
traytel@45060
   653
                  | sub_co (CONTRAVARIANT, TU) = SOME (gen (swap TU))
traytel@45060
   654
                  | sub_co (INVARIANT_TO _, _) = NONE;
traytel@45060
   655
                fun ts_of [] = []
traytel@45060
   656
                  | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
traytel@45060
   657
              in
traytel@45060
   658
                (case Symtab.lookup (tmaps_of ctxt) a of
traytel@45102
   659
                  NONE =>
traytel@45102
   660
                    if Type.could_unify (T1, T2)
traytel@45102
   661
                    then Abs (Name.uu, T1, Bound 0)
traytel@45102
   662
                    else raise COERCION_GEN_ERROR
traytel@45102
   663
                      (err ++> "No map function for " ^ quote a ^ " known")
traytel@45060
   664
                | SOME tmap =>
traytel@45060
   665
                    let
traytel@45060
   666
                      val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
traytel@45060
   667
                    in
traytel@45060
   668
                      if null (filter (not o is_identity) used_coes)
traytel@45060
   669
                      then Abs (Name.uu, Type (a, Ts), Bound 0)
traytel@45060
   670
                      else Term.list_comb
traytel@45060
   671
                        (instantiate (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
traytel@45060
   672
                    end)
traytel@45060
   673
              end
traytel@45060
   674
      | (T, U) =>
traytel@45060
   675
            if Type.could_unify (T, U)
traytel@45060
   676
            then Abs (Name.uu, T, Bound 0)
traytel@45060
   677
            else raise COERCION_GEN_ERROR (err ++> "Cannot generate coercion from " ^
traytel@45060
   678
              quote (Syntax.string_of_typ ctxt T) ^ " to " ^
traytel@45060
   679
              quote (Syntax.string_of_typ ctxt U)));
traytel@45060
   680
  in
traytel@45060
   681
    gen TU
traytel@45060
   682
  end;
traytel@40836
   683
traytel@45060
   684
fun function_of ctxt err tye T =
traytel@45060
   685
  (case Type_Infer.deref tye T of
traytel@45060
   686
    Type (C, Ts) =>
traytel@45060
   687
      (case Symreltab.lookup (coes_of ctxt) (C, "fun") of
traytel@45060
   688
        NONE => error (err () ^ "No complex coercion from " ^ quote C ^ " to fun")
traytel@45060
   689
      | SOME (co, ((Ts', _), _)) =>
traytel@45060
   690
        let
traytel@45060
   691
          val co_before = gen_coercion ctxt err tye (Type (C, Ts), Type (C, Ts'));
traytel@45060
   692
          val coT = range_type (fastype_of co_before);
traytel@45060
   693
          val insts = inst_collect tye (err ++> "Could not insert complex coercion")
traytel@45060
   694
            (domain_type (fastype_of co)) coT;
traytel@45060
   695
          val co' = Term.subst_TVars insts co;
traytel@45060
   696
        in
traytel@45060
   697
          Abs (Name.uu, Type (C, Ts), Library.foldr (op $)
traytel@45060
   698
            (filter (not o is_identity) [co', co_before], Bound 0))
traytel@45060
   699
        end)
traytel@45060
   700
  | T' => error (err () ^ "No complex coercion from " ^
traytel@45060
   701
      quote (Syntax.string_of_typ ctxt T') ^ " to fun"));
traytel@45060
   702
traytel@45060
   703
fun insert_coercions ctxt (tye, idx) ts =
wenzelm@40281
   704
  let
traytel@45060
   705
    fun insert _ (Const (c, T)) = (Const (c, T), T)
traytel@45060
   706
      | insert _ (Free (x, T)) = (Free (x, T), T)
traytel@45060
   707
      | insert _ (Var (xi, T)) = (Var (xi, T), T)
wenzelm@40281
   708
      | insert bs (Bound i) =
wenzelm@43278
   709
          let val T = nth bs i handle General.Subscript => err_loose i;
wenzelm@40281
   710
          in (Bound i, T) end
wenzelm@40281
   711
      | insert bs (Abs (x, T, t)) =
traytel@45060
   712
          let val (t', T') = insert (T :: bs) t;
traytel@45060
   713
          in (Abs (x, T, t'), T --> T') end
wenzelm@40281
   714
      | insert bs (t $ u) =
wenzelm@40281
   715
          let
traytel@45060
   716
            val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
wenzelm@40281
   717
            val (u', U') = insert bs u;
wenzelm@40281
   718
          in
traytel@40836
   719
            if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
traytel@40836
   720
            then (t' $ u', T)
traytel@45060
   721
            else (t' $ (gen_coercion ctxt (K "") tye (U', U) $ u'), T)
wenzelm@40281
   722
          end
wenzelm@40281
   723
  in
wenzelm@40281
   724
    map (fst o insert []) ts
wenzelm@40281
   725
  end;
wenzelm@40281
   726
wenzelm@40281
   727
wenzelm@40281
   728
wenzelm@40281
   729
(** assembling the pipeline **)
wenzelm@40281
   730
wenzelm@42398
   731
fun coercion_infer_types ctxt raw_ts =
wenzelm@40281
   732
  let
wenzelm@42405
   733
    val (idx, ts) = Type_Infer_Context.prepare ctxt raw_ts;
wenzelm@40281
   734
traytel@40836
   735
    fun inf _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   736
      | inf _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   737
      | inf _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   738
      | inf bs (t as (Bound i)) tye_idx =
wenzelm@43278
   739
          (t, snd (nth bs i handle General.Subscript => err_loose i), tye_idx)
traytel@40836
   740
      | inf bs (Abs (x, T, t)) tye_idx =
traytel@40836
   741
          let val (t', U, tye_idx') = inf ((x, T) :: bs) t tye_idx
traytel@40836
   742
          in (Abs (x, T, t'), T --> U, tye_idx') end
traytel@40836
   743
      | inf bs (t $ u) tye_idx =
traytel@40836
   744
          let
traytel@40836
   745
            val (t', T, tye_idx') = inf bs t tye_idx;
traytel@40836
   746
            val (u', U, (tye, idx)) = inf bs u tye_idx';
traytel@40836
   747
            val V = Type_Infer.mk_param idx [];
traytel@40836
   748
            val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
wenzelm@42383
   749
              handle NO_UNIFIER (msg, tye') =>
traytel@45060
   750
                let
traytel@45060
   751
                  val err = err_appl_msg ctxt msg tye' bs t T u U;
traytel@45060
   752
                  val W = Type_Infer.mk_param (idx + 1) [];
traytel@45060
   753
                  val (t'', (tye', idx')) =
traytel@45060
   754
                    (t', strong_unify ctxt (W --> V, T) (tye, idx + 2))
traytel@45060
   755
                      handle NO_UNIFIER _ =>
traytel@45060
   756
                        let
traytel@45060
   757
                          val err' =
traytel@45060
   758
                            err ++> "\nLocal coercion insertion on the operator failed:\n";
traytel@45060
   759
                          val co = function_of ctxt err' tye T;
traytel@45060
   760
                          val (t'', T'', tye_idx'') = inf bs (co $ t') (tye, idx + 2);
traytel@45060
   761
                        in
traytel@45060
   762
                          (t'', strong_unify ctxt (W --> V, T'') tye_idx''
traytel@45060
   763
                             handle NO_UNIFIER (msg, _) => error (err' () ^ msg))
traytel@45060
   764
                        end;
traytel@45060
   765
                  val err' = err ++> (if t' aconv t'' then ""
traytel@45060
   766
                    else "\nSuccessfully coerced the operand to a function of type:\n" ^
traytel@45060
   767
                      Syntax.string_of_typ ctxt
traytel@45060
   768
                        (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^
traytel@45060
   769
                      "\nLocal coercion insertion on the operand failed:\n";
traytel@45060
   770
                  val co = gen_coercion ctxt err' tye' (U, W);
traytel@45060
   771
                  val (u'', U', tye_idx') =
traytel@45060
   772
                    inf bs (if is_identity co then u else co $ u) (tye', idx');
traytel@45060
   773
                in
traytel@45060
   774
                  (t'' $ u'', strong_unify ctxt (U', W) tye_idx'
traytel@45060
   775
                    handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg))
traytel@45060
   776
                end;
traytel@40836
   777
          in (tu, V, tye_idx'') end;
wenzelm@40281
   778
wenzelm@42383
   779
    fun infer_single t tye_idx =
traytel@45060
   780
      let val (t, _, tye_idx') = inf [] t tye_idx
traytel@40938
   781
      in (t, tye_idx') end;
wenzelm@42383
   782
traytel@40938
   783
    val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
traytel@45060
   784
      handle COERCION_GEN_ERROR err =>
traytel@40836
   785
        let
traytel@40836
   786
          fun gen_single t (tye_idx, constraints) =
traytel@45060
   787
            let val (_, tye_idx', constraints') =
traytel@45060
   788
              generate_constraints ctxt (err ++> "\n") t tye_idx
traytel@40836
   789
            in (tye_idx', constraints' @ constraints) end;
wenzelm@42383
   790
traytel@40836
   791
          val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
traytel@45060
   792
          val (tye, idx) = process_constraints ctxt (err ++> "\n") constraints tye_idx;
wenzelm@42383
   793
        in
traytel@45060
   794
          (insert_coercions ctxt (tye, idx) ts, (tye, idx))
traytel@40836
   795
        end);
wenzelm@40281
   796
wenzelm@40281
   797
    val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
wenzelm@40281
   798
  in ts'' end;
wenzelm@40281
   799
wenzelm@40281
   800
wenzelm@40281
   801
wenzelm@40281
   802
(** installation **)
wenzelm@40281
   803
wenzelm@40283
   804
(* term check *)
wenzelm@40283
   805
wenzelm@42616
   806
val coercion_enabled = Attrib.setup_config_bool @{binding coercion_enabled} (K false);
wenzelm@40939
   807
wenzelm@40283
   808
val add_term_check =
wenzelm@45429
   809
  Syntax_Phases.term_check ~100 "coercions"
wenzelm@42402
   810
    (fn ctxt => Config.get ctxt coercion_enabled ? coercion_infer_types ctxt);
wenzelm@40281
   811
wenzelm@40281
   812
wenzelm@40283
   813
(* declarations *)
wenzelm@40281
   814
wenzelm@40284
   815
fun add_type_map raw_t context =
wenzelm@40281
   816
  let
wenzelm@40281
   817
    val ctxt = Context.proof_of context;
wenzelm@40284
   818
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   819
traytel@45059
   820
    fun err_str t = "\n\nThe provided function has the type:\n" ^
wenzelm@42383
   821
      Syntax.string_of_typ ctxt (fastype_of t) ^
traytel@45059
   822
      "\n\nThe general type signature of a map function is:" ^
traytel@41353
   823
      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
traytel@45059
   824
      "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi).";
wenzelm@42383
   825
traytel@41353
   826
    val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
traytel@41353
   827
      handle Empty => error ("Not a proper map function:" ^ err_str t);
wenzelm@42383
   828
wenzelm@40281
   829
    fun gen_arg_var ([], []) = []
wenzelm@40282
   830
      | gen_arg_var ((T, T') :: Ts, (U, U') :: Us) =
traytel@41353
   831
          if U = U' then
traytel@41353
   832
            if is_stypeT U then INVARIANT_TO U :: gen_arg_var ((T, T') :: Ts, Us)
traytel@41353
   833
            else error ("Invariant xi and yi should be base types:" ^ err_str t)
traytel@41353
   834
          else if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
wenzelm@40281
   835
          else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
traytel@41353
   836
          else error ("Functions do not apply to arguments correctly:" ^ err_str t)
wenzelm@42383
   837
      | gen_arg_var (_, Ts) =
wenzelm@42383
   838
          if forall (op = andf is_stypeT o fst) Ts
traytel@41353
   839
          then map (INVARIANT_TO o fst) Ts
traytel@41353
   840
          else error ("Different numbers of functions and variant arguments\n" ^ err_str t);
wenzelm@40281
   841
traytel@41353
   842
    (*retry flag needed to adjust the type lists, when given a map over type constructor fun*)
traytel@41353
   843
    fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) retry =
traytel@41353
   844
          if C1 = C2 andalso not (null fis) andalso forall is_funtype fis
traytel@41353
   845
          then ((map dest_funT fis, Ts ~~ Us), C1)
traytel@41353
   846
          else error ("Not a proper map function:" ^ err_str t)
traytel@41353
   847
      | check_map_fun fis T1 T2 true =
traytel@41353
   848
          let val (fis', T') = split_last fis
traytel@41353
   849
          in check_map_fun fis' T' (T1 --> T2) false end
traytel@41353
   850
      | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t);
wenzelm@40281
   851
traytel@41353
   852
    val res = check_map_fun fis T1 T2 true;
wenzelm@40281
   853
    val res_av = gen_arg_var (fst res);
wenzelm@40281
   854
  in
wenzelm@40281
   855
    map_tmaps (Symtab.update (snd res, (t, res_av))) context
wenzelm@40281
   856
  end;
wenzelm@40281
   857
traytel@45060
   858
fun transitive_coercion ctxt tab G (a, b) =
traytel@45059
   859
  let
traytel@45060
   860
    fun safe_app t (Abs (x, T', u)) =
traytel@45060
   861
      let
traytel@45060
   862
        val t' = map_types Type_Infer.paramify_vars t;
traytel@45060
   863
      in
traytel@45060
   864
        singleton (coercion_infer_types ctxt) (Abs(x, T', (t' $ u)))
traytel@45060
   865
      end;
traytel@45059
   866
    val path = hd (Graph.irreducible_paths G (a, b));
traytel@45059
   867
    val path' = fst (split_last path) ~~ tl path;
traytel@45059
   868
    val coercions = map (fst o the o Symreltab.lookup tab) path';
traytel@45060
   869
    val trans_co = singleton (Variable.polymorphic ctxt)
traytel@45060
   870
      (fold safe_app coercions (Abs (Name.uu, dummyT, Bound 0)));
traytel@45060
   871
    val (Ts, Us) = pairself (snd o Term.dest_Type) (Term.dest_funT (type_of trans_co))
traytel@45060
   872
  in
traytel@45060
   873
    (trans_co, ((Ts, Us), coercions))
traytel@45059
   874
  end;
traytel@45059
   875
wenzelm@40284
   876
fun add_coercion raw_t context =
wenzelm@40281
   877
  let
wenzelm@40281
   878
    val ctxt = Context.proof_of context;
wenzelm@40284
   879
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   880
traytel@45059
   881
    fun err_coercion () = error ("Bad type for a coercion:\n" ^
traytel@45059
   882
        Syntax.string_of_term ctxt t ^ " :: " ^
wenzelm@40281
   883
        Syntax.string_of_typ ctxt (fastype_of t));
wenzelm@40281
   884
wenzelm@40840
   885
    val (T1, T2) = Term.dest_funT (fastype_of t)
wenzelm@40840
   886
      handle TYPE _ => err_coercion ();
wenzelm@40281
   887
traytel@45060
   888
    val (a, Ts) = Term.dest_Type T1
traytel@45060
   889
      handle TYPE _ => err_coercion ();
wenzelm@40281
   890
traytel@45060
   891
    val (b, Us) = Term.dest_Type T2
traytel@45060
   892
      handle TYPE _ => err_coercion ();
wenzelm@40281
   893
traytel@45060
   894
    fun coercion_data_update (tab, G, _) =
wenzelm@40281
   895
      let
traytel@45060
   896
        val G' = maybe_new_nodes [(a, length Ts), (b, length Us)] G
wenzelm@40281
   897
        val G'' = Graph.add_edge_trans_acyclic (a, b) G'
traytel@45059
   898
          handle Graph.CYCLES _ => error (
traytel@45060
   899
            Syntax.string_of_typ ctxt T2 ^ " is already a subtype of " ^
traytel@45060
   900
            Syntax.string_of_typ ctxt T1 ^ "!\n\nCannot add coercion of type: " ^
traytel@45059
   901
            Syntax.string_of_typ ctxt (T1 --> T2));
wenzelm@40281
   902
        val new_edges =
wenzelm@40281
   903
          flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
wenzelm@40281
   904
            if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
wenzelm@40281
   905
        val G_and_new = Graph.add_edge (a, b) G';
wenzelm@40281
   906
wenzelm@40281
   907
        val tab' = fold
traytel@45059
   908
          (fn pair => fn tab =>
traytel@45060
   909
            Symreltab.update (pair, transitive_coercion ctxt tab G_and_new pair) tab)
wenzelm@40281
   910
          (filter (fn pair => pair <> (a, b)) new_edges)
traytel@45060
   911
          (Symreltab.update ((a, b), (t, ((Ts, Us), []))) tab);
wenzelm@40281
   912
      in
traytel@45060
   913
        (tab', G'', restrict_graph G'')
wenzelm@40281
   914
      end;
wenzelm@40281
   915
  in
traytel@45060
   916
    map_coes_and_graphs coercion_data_update context
wenzelm@40281
   917
  end;
wenzelm@40281
   918
traytel@45059
   919
fun delete_coercion raw_t context =
traytel@45059
   920
  let
traytel@45059
   921
    val ctxt = Context.proof_of context;
traytel@45059
   922
    val t = singleton (Variable.polymorphic ctxt) raw_t;
traytel@45059
   923
traytel@45059
   924
    fun err_coercion the = error ("Not" ^
traytel@45059
   925
        (if the then " the defined " else  " a ") ^ "coercion:\n" ^
traytel@45059
   926
        Syntax.string_of_term ctxt t ^ " :: " ^
traytel@45059
   927
        Syntax.string_of_typ ctxt (fastype_of t));
traytel@45059
   928
traytel@45059
   929
    val (T1, T2) = Term.dest_funT (fastype_of t)
traytel@45059
   930
      handle TYPE _ => err_coercion false;
traytel@45059
   931
traytel@45060
   932
    val (a, Ts) = dest_Type T1
traytel@45060
   933
      handle TYPE _ => err_coercion false;
traytel@45059
   934
traytel@45060
   935
    val (b, Us) = dest_Type T2
traytel@45060
   936
      handle TYPE _ => err_coercion false;
traytel@45059
   937
traytel@45059
   938
    fun delete_and_insert tab G =
traytel@45059
   939
      let
traytel@45059
   940
        val pairs =
traytel@45060
   941
          Symreltab.fold (fn ((a, b), (_, (_, ts))) => fn pairs =>
traytel@45059
   942
            if member (op aconv) ts t then (a, b) :: pairs else pairs) tab [(a, b)];
traytel@45059
   943
        fun delete pair (G, tab) = (Graph.del_edge pair G, Symreltab.delete_safe pair tab);
traytel@45059
   944
        val (G', tab') = fold delete pairs (G, tab);
traytel@45059
   945
        fun reinsert pair (G, xs) = (case (Graph.irreducible_paths G pair) of
traytel@45059
   946
              [] => (G, xs)
traytel@45060
   947
            | _ => (Graph.add_edge pair G, (pair, transitive_coercion ctxt tab' G' pair) :: xs));
traytel@45059
   948
        val (G'', ins) = fold reinsert pairs (G', []);
traytel@45059
   949
      in
traytel@45060
   950
        (fold Symreltab.update ins tab', G'', restrict_graph G'')
traytel@45059
   951
      end
traytel@45059
   952
traytel@45059
   953
    fun show_term t = Pretty.block [Syntax.pretty_term ctxt t,
traytel@45059
   954
      Pretty.str " :: ", Syntax.pretty_typ ctxt (fastype_of t)]
traytel@45059
   955
traytel@45060
   956
    fun coercion_data_update (tab, G, _) =
traytel@45059
   957
        (case Symreltab.lookup tab (a, b) of
traytel@45059
   958
          NONE => err_coercion false
traytel@45060
   959
        | SOME (t', (_, [])) => if t aconv t'
traytel@45059
   960
            then delete_and_insert tab G
traytel@45059
   961
            else err_coercion true
traytel@45060
   962
        | SOME (t', (_, ts)) => if t aconv t'
traytel@45059
   963
            then error ("Cannot delete the automatically derived coercion:\n" ^
traytel@45059
   964
              Syntax.string_of_term ctxt t ^ " :: " ^
traytel@45059
   965
              Syntax.string_of_typ ctxt (fastype_of t) ^
traytel@45059
   966
              Pretty.string_of (Pretty.big_list "\n\nDeleting one of the coercions:"
traytel@45059
   967
                (map show_term ts)) ^
traytel@45059
   968
              "\nwill also remove the transitive coercion.")
traytel@45059
   969
            else err_coercion true);
traytel@45059
   970
  in
traytel@45060
   971
    map_coes_and_graphs coercion_data_update context
traytel@45059
   972
  end;
traytel@45059
   973
traytel@45059
   974
fun print_coercions ctxt =
traytel@45059
   975
  let
traytel@45060
   976
    fun separate _ [] = ([], [])
traytel@45060
   977
      | separate P (x::xs) = (if P x then apfst else apsnd) (cons x) (separate P xs);
traytel@45060
   978
    val (simple, complex) =
traytel@45060
   979
      separate (fn (_, (_, ((Ts, Us), _))) => null Ts andalso null Us)
traytel@45060
   980
        (Symreltab.dest (coes_of ctxt));
traytel@45060
   981
    fun show_coercion ((a, b), (t, ((Ts, Us), _))) = Pretty.block [
traytel@45060
   982
      Syntax.pretty_typ ctxt (Type (a, Ts)),
traytel@45059
   983
      Pretty.brk 1, Pretty.str "<:", Pretty.brk 1,
traytel@45060
   984
      Syntax.pretty_typ ctxt (Type (b, Us)),
traytel@45059
   985
      Pretty.brk 3, Pretty.block [Pretty.str "using", Pretty.brk 1,
traytel@45059
   986
      Pretty.quote (Syntax.pretty_term ctxt t)]];
traytel@45059
   987
  in
traytel@45060
   988
    Pretty.big_list "Coercions:"
traytel@45060
   989
    [Pretty.big_list "between base types:" (map show_coercion simple),
traytel@45060
   990
     Pretty.big_list "other:" (map show_coercion complex)]
traytel@45059
   991
    |> Pretty.writeln
traytel@45059
   992
  end;
traytel@45059
   993
traytel@45059
   994
fun print_coercion_maps ctxt =
traytel@45059
   995
  let
traytel@45059
   996
    fun show_map (x, (t, _)) = Pretty.block [
traytel@45059
   997
      Pretty.str x, Pretty.str ":", Pretty.brk 1,
traytel@45059
   998
      Pretty.quote (Syntax.pretty_term ctxt t)];
traytel@45059
   999
  in
traytel@45059
  1000
    Pretty.big_list "Coercion maps:" (map show_map (Symtab.dest (tmaps_of ctxt)))
traytel@45059
  1001
    |> Pretty.writeln
traytel@45059
  1002
  end;
traytel@45059
  1003
wenzelm@40283
  1004
wenzelm@40283
  1005
(* theory setup *)
wenzelm@40283
  1006
wenzelm@40283
  1007
val setup =
wenzelm@40283
  1008
  Context.theory_map add_term_check #>
wenzelm@40284
  1009
  Attrib.setup @{binding coercion}
wenzelm@40284
  1010
    (Args.term >> (fn t => Thm.declaration_attribute (K (add_coercion t))))
wenzelm@40281
  1011
    "declaration of new coercions" #>
traytel@45059
  1012
  Attrib.setup @{binding coercion_delete}
traytel@45059
  1013
    (Args.term >> (fn t => Thm.declaration_attribute (K (delete_coercion t))))
traytel@45059
  1014
    "deletion of coercions" #>
traytel@40297
  1015
  Attrib.setup @{binding coercion_map}
wenzelm@40284
  1016
    (Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t))))
wenzelm@40283
  1017
    "declaration of new map functions";
wenzelm@40281
  1018
traytel@45059
  1019
traytel@45059
  1020
(* outer syntax commands *)
traytel@45059
  1021
traytel@45059
  1022
val _ =
traytel@45059
  1023
  Outer_Syntax.improper_command "print_coercions" "print all coercions" Keyword.diag
traytel@45059
  1024
    (Scan.succeed (Toplevel.keep (print_coercions o Toplevel.context_of)))
traytel@45059
  1025
val _ =
traytel@45059
  1026
  Outer_Syntax.improper_command "print_coercion_maps" "print all coercion maps" Keyword.diag
traytel@45059
  1027
    (Scan.succeed (Toplevel.keep (print_coercion_maps o Toplevel.context_of)))
traytel@45059
  1028
wenzelm@40281
  1029
end;