src/Tools/subtyping.ML
author wenzelm
Mon Apr 18 13:52:23 2011 +0200 (2011-04-18)
changeset 42388 a44b0fdaa6c2
parent 42386 50ea65e84d98
child 42398 919e17c0358e
permissions -rw-r--r--
standardized aliases of operations on tsig;
wenzelm@40281
     1
(*  Title:      Tools/subtyping.ML
wenzelm@40281
     2
    Author:     Dmitriy Traytel, TU Muenchen
wenzelm@40281
     3
wenzelm@40281
     4
Coercive subtyping via subtype constraints.
wenzelm@40281
     5
*)
wenzelm@40281
     6
wenzelm@40281
     7
signature SUBTYPING =
wenzelm@40281
     8
sig
traytel@41353
     9
  datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
wenzelm@40939
    10
  val coercion_enabled: bool Config.T
wenzelm@40281
    11
  val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
wenzelm@40281
    12
    term list -> term list
wenzelm@40284
    13
  val add_type_map: term -> Context.generic -> Context.generic
wenzelm@40284
    14
  val add_coercion: term -> Context.generic -> Context.generic
traytel@40836
    15
  val gen_coercion: Proof.context -> typ Vartab.table -> (typ * typ) -> term
wenzelm@40283
    16
  val setup: theory -> theory
wenzelm@40281
    17
end;
wenzelm@40281
    18
wenzelm@40283
    19
structure Subtyping: SUBTYPING =
wenzelm@40281
    20
struct
wenzelm@40281
    21
wenzelm@40281
    22
(** coercions data **)
wenzelm@40281
    23
traytel@41353
    24
datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
wenzelm@40281
    25
wenzelm@40281
    26
datatype data = Data of
wenzelm@40282
    27
  {coes: term Symreltab.table,  (*coercions table*)
wenzelm@40282
    28
   coes_graph: unit Graph.T,  (*coercions graph*)
wenzelm@40282
    29
   tmaps: (term * variance list) Symtab.table};  (*map functions*)
wenzelm@40281
    30
wenzelm@40281
    31
fun make_data (coes, coes_graph, tmaps) =
wenzelm@40281
    32
  Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
wenzelm@40281
    33
wenzelm@40281
    34
structure Data = Generic_Data
wenzelm@40281
    35
(
wenzelm@40281
    36
  type T = data;
wenzelm@40281
    37
  val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
wenzelm@40281
    38
  val extend = I;
wenzelm@40281
    39
  fun merge
wenzelm@40281
    40
    (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
wenzelm@40281
    41
      Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
wenzelm@40281
    42
    make_data (Symreltab.merge (op aconv) (coes1, coes2),
wenzelm@40281
    43
      Graph.merge (op =) (coes_graph1, coes_graph2),
wenzelm@40281
    44
      Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
wenzelm@40281
    45
);
wenzelm@40281
    46
wenzelm@40281
    47
fun map_data f =
wenzelm@40281
    48
  Data.map (fn Data {coes, coes_graph, tmaps} =>
wenzelm@40281
    49
    make_data (f (coes, coes_graph, tmaps)));
wenzelm@40281
    50
wenzelm@40281
    51
fun map_coes f =
wenzelm@40281
    52
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    53
    (f coes, coes_graph, tmaps));
wenzelm@40281
    54
wenzelm@40281
    55
fun map_coes_graph f =
wenzelm@40281
    56
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    57
    (coes, f coes_graph, tmaps));
wenzelm@40281
    58
wenzelm@40281
    59
fun map_coes_and_graph f =
wenzelm@40281
    60
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    61
    let val (coes', coes_graph') = f (coes, coes_graph);
wenzelm@40281
    62
    in (coes', coes_graph', tmaps) end);
wenzelm@40281
    63
wenzelm@40281
    64
fun map_tmaps f =
wenzelm@40281
    65
  map_data (fn (coes, coes_graph, tmaps) =>
wenzelm@40281
    66
    (coes, coes_graph, f tmaps));
wenzelm@40281
    67
wenzelm@40285
    68
val rep_data = (fn Data args => args) o Data.get o Context.Proof;
wenzelm@40281
    69
wenzelm@40281
    70
val coes_of = #coes o rep_data;
wenzelm@40281
    71
val coes_graph_of = #coes_graph o rep_data;
wenzelm@40281
    72
val tmaps_of = #tmaps o rep_data;
wenzelm@40281
    73
wenzelm@40281
    74
wenzelm@40281
    75
wenzelm@40281
    76
(** utils **)
wenzelm@40281
    77
wenzelm@40281
    78
fun nameT (Type (s, [])) = s;
wenzelm@40281
    79
fun t_of s = Type (s, []);
wenzelm@40286
    80
wenzelm@40281
    81
fun sort_of (TFree (_, S)) = SOME S
wenzelm@40281
    82
  | sort_of (TVar (_, S)) = SOME S
wenzelm@40281
    83
  | sort_of _ = NONE;
wenzelm@40281
    84
wenzelm@40281
    85
val is_typeT = fn (Type _) => true | _ => false;
traytel@41353
    86
val is_stypeT = fn (Type (_, [])) => true | _ => false;
wenzelm@40282
    87
val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
wenzelm@40281
    88
val is_freeT = fn (TFree _) => true | _ => false;
wenzelm@40286
    89
val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
traytel@41353
    90
val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
wenzelm@40281
    91
wenzelm@40281
    92
traytel@40836
    93
(* unification *)
wenzelm@40281
    94
traytel@40836
    95
exception TYPE_INFERENCE_ERROR of unit -> string;
wenzelm@40281
    96
exception NO_UNIFIER of string * typ Vartab.table;
wenzelm@40281
    97
wenzelm@40281
    98
fun unify weak ctxt =
wenzelm@40281
    99
  let
wenzelm@42361
   100
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42386
   101
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   102
wenzelm@40282
   103
wenzelm@40281
   104
    (* adjust sorts of parameters *)
wenzelm@40281
   105
wenzelm@40281
   106
    fun not_of_sort x S' S =
wenzelm@40281
   107
      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
wenzelm@40281
   108
        Syntax.string_of_sort ctxt S;
wenzelm@40281
   109
wenzelm@40281
   110
    fun meet (_, []) tye_idx = tye_idx
wenzelm@40281
   111
      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
wenzelm@40281
   112
          meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
wenzelm@40281
   113
      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
wenzelm@40281
   114
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   115
          else raise NO_UNIFIER (not_of_sort x S' S, tye)
wenzelm@40281
   116
      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
wenzelm@40281
   117
          if Sign.subsort thy (S', S) then tye_idx
wenzelm@40281
   118
          else if Type_Infer.is_param xi then
wenzelm@40286
   119
            (Vartab.update_new
wenzelm@40286
   120
              (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
wenzelm@40281
   121
          else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
wenzelm@40281
   122
    and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
wenzelm@40286
   123
          meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
wenzelm@40281
   124
      | meets _ tye_idx = tye_idx;
wenzelm@40281
   125
wenzelm@40281
   126
    val weak_meet = if weak then fn _ => I else meet
wenzelm@40281
   127
wenzelm@40281
   128
wenzelm@40281
   129
    (* occurs check and assignment *)
wenzelm@40281
   130
wenzelm@40281
   131
    fun occurs_check tye xi (TVar (xi', _)) =
wenzelm@40281
   132
          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
wenzelm@40281
   133
          else
wenzelm@40281
   134
            (case Vartab.lookup tye xi' of
wenzelm@40281
   135
              NONE => ()
wenzelm@40281
   136
            | SOME T => occurs_check tye xi T)
wenzelm@40281
   137
      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
wenzelm@40281
   138
      | occurs_check _ _ _ = ();
wenzelm@40281
   139
wenzelm@40281
   140
    fun assign xi (T as TVar (xi', _)) S env =
wenzelm@40281
   141
          if xi = xi' then env
wenzelm@40281
   142
          else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
wenzelm@40281
   143
      | assign xi T S (env as (tye, _)) =
wenzelm@40281
   144
          (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
wenzelm@40281
   145
wenzelm@40281
   146
wenzelm@40281
   147
    (* unification *)
wenzelm@40281
   148
wenzelm@40281
   149
    fun show_tycon (a, Ts) =
wenzelm@40281
   150
      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
wenzelm@40281
   151
wenzelm@40281
   152
    fun unif (T1, T2) (env as (tye, _)) =
wenzelm@40286
   153
      (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
wenzelm@40281
   154
        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
wenzelm@40281
   155
      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
wenzelm@40281
   156
      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
wenzelm@40281
   157
          if weak andalso null Ts andalso null Us then env
wenzelm@40281
   158
          else if a <> b then
wenzelm@40281
   159
            raise NO_UNIFIER
wenzelm@40281
   160
              ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
wenzelm@40281
   161
          else fold unif (Ts ~~ Us) env
wenzelm@40281
   162
      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
wenzelm@40281
   163
wenzelm@40281
   164
  in unif end;
wenzelm@40281
   165
wenzelm@40281
   166
val weak_unify = unify true;
wenzelm@40281
   167
val strong_unify = unify false;
wenzelm@40281
   168
wenzelm@40281
   169
wenzelm@40281
   170
(* Typ_Graph shortcuts *)
wenzelm@40281
   171
wenzelm@40281
   172
val add_edge = Typ_Graph.add_edge_acyclic;
wenzelm@40281
   173
fun get_preds G T = Typ_Graph.all_preds G [T];
wenzelm@40281
   174
fun get_succs G T = Typ_Graph.all_succs G [T];
wenzelm@40281
   175
fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
wenzelm@40281
   176
fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
wenzelm@40282
   177
fun new_imm_preds G Ts =
wenzelm@40281
   178
  subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts));
wenzelm@40282
   179
fun new_imm_succs G Ts =
wenzelm@40281
   180
  subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts));
wenzelm@40281
   181
wenzelm@40281
   182
wenzelm@40281
   183
(* Graph shortcuts *)
wenzelm@40281
   184
wenzelm@40281
   185
fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
wenzelm@40281
   186
fun maybe_new_nodes ss G = fold maybe_new_node ss G
wenzelm@40281
   187
wenzelm@40281
   188
wenzelm@40281
   189
wenzelm@40281
   190
(** error messages **)
wenzelm@40281
   191
wenzelm@42383
   192
fun gen_msg err msg =
wenzelm@42383
   193
  err () ^ "\nNow trying to infer coercions:\n\nCoercion inference failed" ^
traytel@40836
   194
  (if msg = "" then "" else ": " ^ msg) ^ "\n";
traytel@40836
   195
wenzelm@40281
   196
fun prep_output ctxt tye bs ts Ts =
wenzelm@40281
   197
  let
wenzelm@40281
   198
    val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
wenzelm@40281
   199
    val (Ts', Ts'') = chop (length Ts) Ts_bTs';
wenzelm@40281
   200
    fun prep t =
wenzelm@40281
   201
      let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
wenzelm@42284
   202
      in Term.subst_bounds (map Syntax_Trans.mark_boundT xs, t) end;
wenzelm@40281
   203
  in (map prep ts', Ts') end;
wenzelm@40281
   204
wenzelm@40281
   205
fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
wenzelm@42383
   206
traytel@40836
   207
fun unif_failed msg =
traytel@40836
   208
  "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
wenzelm@42383
   209
traytel@40836
   210
fun err_appl_msg ctxt msg tye bs t T u U () =
traytel@40836
   211
  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
wenzelm@42383
   212
  in unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n" end;
wenzelm@40281
   213
wenzelm@40281
   214
fun err_list ctxt msg tye Ts =
wenzelm@40281
   215
  let
wenzelm@40281
   216
    val (_, Ts') = prep_output ctxt tye [] [] Ts;
wenzelm@42383
   217
    val text =
wenzelm@42383
   218
      msg ^ "\n" ^ "Cannot unify a list of types that should be the same:" ^ "\n" ^
wenzelm@42383
   219
        Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts'));
wenzelm@40281
   220
  in
wenzelm@40281
   221
    error text
wenzelm@40281
   222
  end;
wenzelm@40281
   223
wenzelm@40281
   224
fun err_bound ctxt msg tye packs =
wenzelm@40281
   225
  let
wenzelm@40281
   226
    val (ts, Ts) = fold
wenzelm@40281
   227
      (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
traytel@40836
   228
        let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
wenzelm@40282
   229
        in (t' :: ts, T' :: Ts) end)
wenzelm@40281
   230
      packs ([], []);
traytel@40836
   231
    val text = cat_lines ([msg, "Cannot fulfil subtype constraints:"] @
wenzelm@40281
   232
        (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
wenzelm@40281
   233
          Pretty.block [
wenzelm@42383
   234
            Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2,
wenzelm@42383
   235
            Syntax.pretty_typ ctxt U, Pretty.brk 3,
wenzelm@42383
   236
            Pretty.str "from function application", Pretty.brk 2,
wenzelm@42383
   237
            Pretty.block [Syntax.pretty_term ctxt (t $ u)]]))
wenzelm@40281
   238
        ts Ts))
wenzelm@40281
   239
  in
wenzelm@40281
   240
    error text
wenzelm@40281
   241
  end;
wenzelm@40281
   242
wenzelm@40281
   243
wenzelm@40281
   244
wenzelm@40281
   245
(** constraint generation **)
wenzelm@40281
   246
traytel@40836
   247
fun generate_constraints ctxt err =
wenzelm@40281
   248
  let
wenzelm@40281
   249
    fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   250
      | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   251
      | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
wenzelm@40281
   252
      | gen cs bs (Bound i) tye_idx =
wenzelm@40281
   253
          (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs)
wenzelm@40281
   254
      | gen cs bs (Abs (x, T, t)) tye_idx =
wenzelm@40281
   255
          let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
wenzelm@40281
   256
          in (T --> U, tye_idx', cs') end
wenzelm@40281
   257
      | gen cs bs (t $ u) tye_idx =
wenzelm@40281
   258
          let
wenzelm@40281
   259
            val (T, tye_idx', cs') = gen cs bs t tye_idx;
wenzelm@40281
   260
            val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
wenzelm@40286
   261
            val U = Type_Infer.mk_param idx [];
wenzelm@40286
   262
            val V = Type_Infer.mk_param (idx + 1) [];
wenzelm@40281
   263
            val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
traytel@41353
   264
              handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
wenzelm@40281
   265
            val error_pack = (bs, t $ u, U, V, U');
wenzelm@40281
   266
          in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
wenzelm@40281
   267
  in
wenzelm@40281
   268
    gen [] []
wenzelm@40281
   269
  end;
wenzelm@40281
   270
wenzelm@40281
   271
wenzelm@40281
   272
wenzelm@40281
   273
(** constraint resolution **)
wenzelm@40281
   274
wenzelm@40281
   275
exception BOUND_ERROR of string;
wenzelm@40281
   276
traytel@40836
   277
fun process_constraints ctxt err cs tye_idx =
wenzelm@40281
   278
  let
wenzelm@42388
   279
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42388
   280
wenzelm@40285
   281
    val coes_graph = coes_graph_of ctxt;
wenzelm@40285
   282
    val tmaps = tmaps_of ctxt;
wenzelm@42388
   283
    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
wenzelm@40281
   284
wenzelm@40281
   285
    fun split_cs _ [] = ([], [])
wenzelm@40282
   286
      | split_cs f (c :: cs) =
wenzelm@40281
   287
          (case pairself f (fst c) of
wenzelm@40281
   288
            (false, false) => apsnd (cons c) (split_cs f cs)
wenzelm@40281
   289
          | _ => apfst (cons c) (split_cs f cs));
wenzelm@42383
   290
traytel@41353
   291
    fun unify_list (T :: Ts) tye_idx =
wenzelm@42383
   292
      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
wenzelm@40281
   293
wenzelm@40282
   294
wenzelm@40281
   295
    (* check whether constraint simplification will terminate using weak unification *)
wenzelm@40282
   296
traytel@41353
   297
    val _ = fold (fn (TU, _) => fn tye_idx =>
traytel@41353
   298
      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) =>
traytel@40836
   299
        error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
wenzelm@40281
   300
wenzelm@40281
   301
wenzelm@40281
   302
    (* simplify constraints *)
wenzelm@40282
   303
wenzelm@40281
   304
    fun simplify_constraints cs tye_idx =
wenzelm@40281
   305
      let
wenzelm@40281
   306
        fun contract a Ts Us error_pack done todo tye idx =
wenzelm@40281
   307
          let
wenzelm@40281
   308
            val arg_var =
wenzelm@40281
   309
              (case Symtab.lookup tmaps a of
wenzelm@40281
   310
                (*everything is invariant for unknown constructors*)
wenzelm@40281
   311
                NONE => replicate (length Ts) INVARIANT
wenzelm@40281
   312
              | SOME av => snd av);
wenzelm@40281
   313
            fun new_constraints (variance, constraint) (cs, tye_idx) =
wenzelm@40281
   314
              (case variance of
wenzelm@40281
   315
                COVARIANT => (constraint :: cs, tye_idx)
wenzelm@40281
   316
              | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
traytel@41353
   317
              | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
wenzelm@42383
   318
                  handle NO_UNIFIER (msg, _) =>
wenzelm@42383
   319
                    err_list ctxt (gen_msg err
wenzelm@42383
   320
                      "failed to unify invariant arguments w.r.t. to the known map function")
traytel@41353
   321
                      (fst tye_idx) Ts)
wenzelm@40281
   322
              | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
wenzelm@42383
   323
                  handle NO_UNIFIER (msg, _) =>
traytel@41353
   324
                    error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
wenzelm@40281
   325
            val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
wenzelm@40281
   326
              (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
wenzelm@40281
   327
            val test_update = is_compT orf is_freeT orf is_fixedvarT;
wenzelm@40281
   328
            val (ch, done') =
wenzelm@40286
   329
              if not (null new) then ([], done)
wenzelm@40286
   330
              else split_cs (test_update o Type_Infer.deref tye') done;
wenzelm@40281
   331
            val todo' = ch @ todo;
wenzelm@40281
   332
          in
wenzelm@40281
   333
            simplify done' (new @ todo') (tye', idx')
wenzelm@40281
   334
          end
wenzelm@40281
   335
        (*xi is definitely a parameter*)
wenzelm@40281
   336
        and expand varleq xi S a Ts error_pack done todo tye idx =
wenzelm@40281
   337
          let
wenzelm@40281
   338
            val n = length Ts;
wenzelm@40286
   339
            val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S);
wenzelm@40281
   340
            val tye' = Vartab.update_new (xi, Type(a, args)) tye;
wenzelm@40286
   341
            val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done;
wenzelm@40281
   342
            val todo' = ch @ todo;
wenzelm@40281
   343
            val new =
wenzelm@40281
   344
              if varleq then (Type(a, args), Type (a, Ts))
wenzelm@40286
   345
              else (Type (a, Ts), Type (a, args));
wenzelm@40281
   346
          in
wenzelm@40281
   347
            simplify done' ((new, error_pack) :: todo') (tye', idx + n)
wenzelm@40281
   348
          end
wenzelm@40281
   349
        (*TU is a pair of a parameter and a free/fixed variable*)
traytel@41353
   350
        and eliminate TU done todo tye idx =
wenzelm@40281
   351
          let
wenzelm@40286
   352
            val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
wenzelm@40286
   353
            val [T] = filter_out Type_Infer.is_paramT TU;
wenzelm@40281
   354
            val SOME S' = sort_of T;
wenzelm@40281
   355
            val test_update = if is_freeT T then is_freeT else is_fixedvarT;
wenzelm@40281
   356
            val tye' = Vartab.update_new (xi, T) tye;
wenzelm@40286
   357
            val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done;
wenzelm@40281
   358
            val todo' = ch @ todo;
wenzelm@40281
   359
          in
wenzelm@42388
   360
            if Sign.subsort thy (S', S) (*TODO check this*)
wenzelm@40281
   361
            then simplify done' todo' (tye', idx)
traytel@40836
   362
            else error (gen_msg err "sort mismatch")
wenzelm@40281
   363
          end
wenzelm@40281
   364
        and simplify done [] tye_idx = (done, tye_idx)
wenzelm@40281
   365
          | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
wenzelm@40286
   366
              (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
wenzelm@40281
   367
                (Type (a, []), Type (b, [])) =>
wenzelm@40281
   368
                  if a = b then simplify done todo tye_idx
wenzelm@40281
   369
                  else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
traytel@40836
   370
                  else error (gen_msg err (a ^ " is not a subtype of " ^ b))
wenzelm@40281
   371
              | (Type (a, Ts), Type (b, Us)) =>
traytel@40836
   372
                  if a <> b then error (gen_msg err "different constructors")
traytel@40836
   373
                    (fst tye_idx) error_pack
wenzelm@40281
   374
                  else contract a Ts Us error_pack done todo tye idx
wenzelm@40282
   375
              | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
wenzelm@40281
   376
                  expand true xi S a Ts error_pack done todo tye idx
wenzelm@40282
   377
              | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
wenzelm@40281
   378
                  expand false xi S a Ts error_pack done todo tye idx
wenzelm@40281
   379
              | (T, U) =>
wenzelm@40281
   380
                  if T = U then simplify done todo tye_idx
wenzelm@40282
   381
                  else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
wenzelm@40286
   382
                    exists Type_Infer.is_paramT [T, U]
traytel@41353
   383
                  then eliminate [T, U] done todo tye idx
wenzelm@40281
   384
                  else if exists (is_freeT orf is_fixedvarT) [T, U]
traytel@40836
   385
                  then error (gen_msg err "not eliminated free/fixed variables")
wenzelm@40282
   386
                  else simplify (((T, U), error_pack) :: done) todo tye_idx);
wenzelm@40281
   387
      in
wenzelm@40281
   388
        simplify [] cs tye_idx
wenzelm@40281
   389
      end;
wenzelm@40281
   390
wenzelm@40281
   391
wenzelm@40281
   392
    (* do simplification *)
wenzelm@40282
   393
wenzelm@40281
   394
    val (cs', tye_idx') = simplify_constraints cs tye_idx;
wenzelm@42383
   395
wenzelm@42383
   396
    fun find_error_pack lower T' = map_filter
traytel@40836
   397
      (fn ((T, U), pack) => if if lower then T' = U else T' = T then SOME pack else NONE) cs';
wenzelm@42383
   398
wenzelm@42383
   399
    fun find_cycle_packs nodes =
traytel@40836
   400
      let
traytel@40836
   401
        val (but_last, last) = split_last nodes
traytel@40836
   402
        val pairs = (last, hd nodes) :: (but_last ~~ tl nodes);
traytel@40836
   403
      in
traytel@40836
   404
        map_filter
wenzelm@40838
   405
          (fn (TU, pack) => if member (op =) pairs TU then SOME pack else NONE)
traytel@40836
   406
          cs'
traytel@40836
   407
      end;
wenzelm@40281
   408
wenzelm@40281
   409
    (*styps stands either for supertypes or for subtypes of a type T
wenzelm@40281
   410
      in terms of the subtype-relation (excluding T itself)*)
wenzelm@40282
   411
    fun styps super T =
wenzelm@40281
   412
      (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T
wenzelm@40281
   413
        handle Graph.UNDEF _ => [];
wenzelm@40281
   414
wenzelm@40282
   415
    fun minmax sup (T :: Ts) =
wenzelm@40281
   416
      let
wenzelm@40281
   417
        fun adjust T U = if sup then (T, U) else (U, T);
wenzelm@40281
   418
        fun extract T [] = T
wenzelm@40282
   419
          | extract T (U :: Us) =
wenzelm@40281
   420
              if Graph.is_edge coes_graph (adjust T U) then extract T Us
wenzelm@40281
   421
              else if Graph.is_edge coes_graph (adjust U T) then extract U Us
traytel@40836
   422
              else raise BOUND_ERROR "uncomparable types in type list";
wenzelm@40281
   423
      in
wenzelm@40281
   424
        t_of (extract T Ts)
wenzelm@40281
   425
      end;
wenzelm@40281
   426
wenzelm@40282
   427
    fun ex_styp_of_sort super T styps_and_sorts =
wenzelm@40281
   428
      let
wenzelm@40281
   429
        fun adjust T U = if super then (T, U) else (U, T);
wenzelm@40282
   430
        fun styp_test U Ts = forall
wenzelm@40281
   431
          (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
wenzelm@42388
   432
        fun fitting Ts S U = Sign.of_sort thy (t_of U, S) andalso styp_test U Ts
wenzelm@40281
   433
      in
wenzelm@40281
   434
        forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
wenzelm@40281
   435
      end;
wenzelm@40281
   436
wenzelm@40281
   437
    (* computes the tightest possible, correct assignment for 'a::S
wenzelm@40281
   438
       e.g. in the supremum case (sup = true):
wenzelm@40281
   439
               ------- 'a::S---
wenzelm@40281
   440
              /        /    \  \
wenzelm@40281
   441
             /        /      \  \
wenzelm@40281
   442
        'b::C1   'c::C2 ...  T1 T2 ...
wenzelm@40281
   443
wenzelm@40281
   444
       sorts - list of sorts [C1, C2, ...]
wenzelm@40281
   445
       T::Ts - non-empty list of base types [T1, T2, ...]
wenzelm@40281
   446
    *)
wenzelm@40282
   447
    fun tightest sup S styps_and_sorts (T :: Ts) =
wenzelm@40281
   448
      let
wenzelm@42388
   449
        fun restriction T = Sign.of_sort thy (t_of T, S)
wenzelm@40281
   450
          andalso ex_styp_of_sort (not sup) T styps_and_sorts;
wenzelm@40281
   451
        fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
wenzelm@40281
   452
      in
wenzelm@40281
   453
        (case fold candidates Ts (filter restriction (T :: styps sup T)) of
traytel@40836
   454
          [] => raise BOUND_ERROR ("no " ^ (if sup then "supremum" else "infimum"))
wenzelm@40281
   455
        | [T] => t_of T
wenzelm@40281
   456
        | Ts => minmax sup Ts)
wenzelm@40281
   457
      end;
wenzelm@40281
   458
wenzelm@40281
   459
    fun build_graph G [] tye_idx = (G, tye_idx)
wenzelm@40282
   460
      | build_graph G ((T, U) :: cs) tye_idx =
wenzelm@40281
   461
        if T = U then build_graph G cs tye_idx
wenzelm@40281
   462
        else
wenzelm@40281
   463
          let
wenzelm@40281
   464
            val G' = maybe_new_typnodes [T, U] G;
wenzelm@40281
   465
            val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
wenzelm@40281
   466
              handle Typ_Graph.CYCLES cycles =>
wenzelm@40281
   467
                let
wenzelm@42383
   468
                  val (tye, idx) =
wenzelm@42383
   469
                    fold
traytel@40836
   470
                      (fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
wenzelm@42383
   471
                        handle NO_UNIFIER (msg, _) =>
wenzelm@42383
   472
                          err_bound ctxt
traytel@40836
   473
                            (gen_msg err ("constraint cycle not unifiable" ^ msg)) (fst tye_idx)
traytel@40836
   474
                            (find_cycle_packs cycle)))
traytel@40836
   475
                      cycles tye_idx
wenzelm@40281
   476
                in
traytel@40836
   477
                  collapse (tye, idx) cycles G
traytel@40836
   478
                end
wenzelm@40281
   479
          in
wenzelm@40281
   480
            build_graph G'' cs tye_idx'
wenzelm@40281
   481
          end
traytel@40836
   482
    and collapse (tye, idx) cycles G = (*nodes non-empty list*)
wenzelm@40281
   483
      let
traytel@40836
   484
        (*all cycles collapse to one node,
traytel@40836
   485
          because all of them share at least the nodes x and y*)
traytel@40836
   486
        val nodes = (distinct (op =) (flat cycles));
traytel@40836
   487
        val T = Type_Infer.deref tye (hd nodes);
wenzelm@40281
   488
        val P = new_imm_preds G nodes;
wenzelm@40281
   489
        val S = new_imm_succs G nodes;
wenzelm@40281
   490
        val G' = Typ_Graph.del_nodes (tl nodes) G;
traytel@40836
   491
        fun check_and_gen super T' =
traytel@40836
   492
          let val U = Type_Infer.deref tye T';
traytel@40836
   493
          in
traytel@40836
   494
            if not (is_typeT T) orelse not (is_typeT U) orelse T = U
traytel@40836
   495
            then if super then (hd nodes, T') else (T', hd nodes)
wenzelm@42383
   496
            else
wenzelm@42383
   497
              if super andalso
traytel@40836
   498
                Graph.is_edge coes_graph (nameT T, nameT U) then (hd nodes, T')
wenzelm@42383
   499
              else if not super andalso
traytel@40836
   500
                Graph.is_edge coes_graph (nameT U, nameT T) then (T', hd nodes)
traytel@40836
   501
              else err_bound ctxt (gen_msg err "cycle elimination produces inconsistent graph")
wenzelm@42383
   502
                    (fst tye_idx)
traytel@40836
   503
                    (maps find_cycle_packs cycles @ find_error_pack super T')
traytel@40836
   504
          end;
wenzelm@40281
   505
      in
traytel@40836
   506
        build_graph G' (map (check_and_gen false) P @ map (check_and_gen true) S) (tye, idx)
wenzelm@40281
   507
      end;
wenzelm@40281
   508
wenzelm@40281
   509
    fun assign_bound lower G key (tye_idx as (tye, _)) =
wenzelm@40286
   510
      if Type_Infer.is_paramT (Type_Infer.deref tye key) then
wenzelm@40281
   511
        let
wenzelm@40286
   512
          val TVar (xi, S) = Type_Infer.deref tye key;
wenzelm@40281
   513
          val get_bound = if lower then get_preds else get_succs;
wenzelm@40281
   514
          val raw_bound = get_bound G key;
wenzelm@40286
   515
          val bound = map (Type_Infer.deref tye) raw_bound;
wenzelm@40286
   516
          val not_params = filter_out Type_Infer.is_paramT bound;
wenzelm@40282
   517
          fun to_fulfil T =
wenzelm@40281
   518
            (case sort_of T of
wenzelm@40281
   519
              NONE => NONE
wenzelm@40282
   520
            | SOME S =>
wenzelm@40286
   521
                SOME
wenzelm@40286
   522
                  (map nameT
wenzelm@40286
   523
                    (filter_out Type_Infer.is_paramT (map (Type_Infer.deref tye) (get_bound G T))),
wenzelm@40286
   524
                      S));
wenzelm@40281
   525
          val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
wenzelm@40281
   526
          val assignment =
wenzelm@40281
   527
            if null bound orelse null not_params then NONE
wenzelm@40281
   528
            else SOME (tightest lower S styps_and_sorts (map nameT not_params)
wenzelm@42383
   529
                handle BOUND_ERROR msg =>
traytel@40836
   530
                  err_bound ctxt (gen_msg err msg) tye (find_error_pack lower key))
wenzelm@40281
   531
        in
wenzelm@40281
   532
          (case assignment of
wenzelm@40281
   533
            NONE => tye_idx
wenzelm@40281
   534
          | SOME T =>
wenzelm@40286
   535
              if Type_Infer.is_paramT T then tye_idx
wenzelm@40281
   536
              else if lower then (*upper bound check*)
wenzelm@40281
   537
                let
wenzelm@40286
   538
                  val other_bound = map (Type_Infer.deref tye) (get_succs G key);
wenzelm@40281
   539
                  val s = nameT T;
wenzelm@40281
   540
                in
wenzelm@40281
   541
                  if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
wenzelm@40281
   542
                  then apfst (Vartab.update (xi, T)) tye_idx
traytel@40836
   543
                  else err_bound ctxt (gen_msg err ("assigned simple type " ^ s ^
wenzelm@40281
   544
                    " clashes with the upper bound of variable " ^
traytel@40836
   545
                    Syntax.string_of_typ ctxt (TVar(xi, S)))) tye (find_error_pack (not lower) key)
wenzelm@40281
   546
                end
wenzelm@40281
   547
              else apfst (Vartab.update (xi, T)) tye_idx)
wenzelm@40281
   548
        end
wenzelm@40281
   549
      else tye_idx;
wenzelm@40281
   550
wenzelm@40281
   551
    val assign_lb = assign_bound true;
wenzelm@40281
   552
    val assign_ub = assign_bound false;
wenzelm@40281
   553
wenzelm@40281
   554
    fun assign_alternating ts' ts G tye_idx =
wenzelm@40281
   555
      if ts' = ts then tye_idx
wenzelm@40281
   556
      else
wenzelm@40281
   557
        let
wenzelm@40281
   558
          val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
wenzelm@40281
   559
            |> fold (assign_ub G) ts;
wenzelm@40281
   560
        in
wenzelm@42383
   561
          assign_alternating ts
traytel@40836
   562
            (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
wenzelm@40281
   563
        end;
wenzelm@40281
   564
wenzelm@40281
   565
    (*Unify all weakly connected components of the constraint forest,
wenzelm@40282
   566
      that contain only params. These are the only WCCs that contain
wenzelm@40281
   567
      params anyway.*)
wenzelm@40281
   568
    fun unify_params G (tye_idx as (tye, _)) =
wenzelm@40281
   569
      let
wenzelm@40286
   570
        val max_params =
wenzelm@40286
   571
          filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G);
wenzelm@40281
   572
        val to_unify = map (fn T => T :: get_preds G T) max_params;
wenzelm@40281
   573
      in
wenzelm@42383
   574
        fold
traytel@40836
   575
          (fn Ts => fn tye_idx' => unify_list Ts tye_idx'
traytel@41353
   576
            handle NO_UNIFIER (msg, _) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
traytel@40836
   577
          to_unify tye_idx
wenzelm@40281
   578
      end;
wenzelm@40281
   579
wenzelm@40281
   580
    fun solve_constraints G tye_idx = tye_idx
wenzelm@40281
   581
      |> assign_alternating [] (Typ_Graph.keys G) G
wenzelm@40281
   582
      |> unify_params G;
wenzelm@40281
   583
  in
wenzelm@40281
   584
    build_graph Typ_Graph.empty (map fst cs') tye_idx'
wenzelm@40281
   585
      |-> solve_constraints
wenzelm@40281
   586
  end;
wenzelm@40281
   587
wenzelm@40281
   588
wenzelm@40281
   589
wenzelm@40281
   590
(** coercion insertion **)
wenzelm@40281
   591
traytel@40836
   592
fun gen_coercion ctxt tye (T1, T2) =
traytel@40836
   593
  (case pairself (Type_Infer.deref tye) (T1, T2) of
traytel@40836
   594
    ((Type (a, [])), (Type (b, []))) =>
traytel@40836
   595
        if a = b
traytel@40836
   596
        then Abs (Name.uu, Type (a, []), Bound 0)
traytel@40836
   597
        else
traytel@40836
   598
          (case Symreltab.lookup (coes_of ctxt) (a, b) of
traytel@40836
   599
            NONE => raise Fail (a ^ " is not a subtype of " ^ b)
traytel@40836
   600
          | SOME co => co)
traytel@40836
   601
  | ((Type (a, Ts)), (Type (b, Us))) =>
traytel@40836
   602
        if a <> b
traytel@40836
   603
        then raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
traytel@40836
   604
        else
traytel@40836
   605
          let
traytel@40836
   606
            fun inst t Ts =
traytel@40836
   607
              Term.subst_vars
traytel@40836
   608
                (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
traytel@41353
   609
            fun sub_co (COVARIANT, TU) = SOME (gen_coercion ctxt tye TU)
traytel@41353
   610
              | sub_co (CONTRAVARIANT, TU) = SOME (gen_coercion ctxt tye (swap TU))
traytel@41353
   611
              | sub_co (INVARIANT_TO T, _) = NONE;
traytel@40836
   612
            fun ts_of [] = []
traytel@40836
   613
              | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
traytel@40836
   614
          in
traytel@40836
   615
            (case Symtab.lookup (tmaps_of ctxt) a of
traytel@40836
   616
              NONE => raise Fail ("No map function for " ^ a ^ " known")
traytel@40836
   617
            | SOME tmap =>
traytel@40836
   618
                let
traytel@41353
   619
                  val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
traytel@40836
   620
                in
traytel@40836
   621
                  Term.list_comb
traytel@40836
   622
                    (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
traytel@40836
   623
                end)
traytel@40836
   624
          end
traytel@40836
   625
  | (T, U) =>
traytel@40836
   626
        if Type.could_unify (T, U)
traytel@40836
   627
        then Abs (Name.uu, T, Bound 0)
traytel@40836
   628
        else raise Fail ("Cannot generate coercion from "
traytel@40836
   629
          ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U));
traytel@40836
   630
wenzelm@40281
   631
fun insert_coercions ctxt tye ts =
wenzelm@40281
   632
  let
wenzelm@40281
   633
    fun insert _ (Const (c, T)) =
traytel@40836
   634
          let val T' = T;
wenzelm@40281
   635
          in (Const (c, T'), T') end
wenzelm@40281
   636
      | insert _ (Free (x, T)) =
traytel@40836
   637
          let val T' = T;
wenzelm@40281
   638
          in (Free (x, T'), T') end
wenzelm@40281
   639
      | insert _ (Var (xi, T)) =
traytel@40836
   640
          let val T' = T;
wenzelm@40281
   641
          in (Var (xi, T'), T') end
wenzelm@40281
   642
      | insert bs (Bound i) =
traytel@40836
   643
          let val T = nth bs i handle Subscript => err_loose i;
wenzelm@40281
   644
          in (Bound i, T) end
wenzelm@40281
   645
      | insert bs (Abs (x, T, t)) =
wenzelm@40281
   646
          let
traytel@40836
   647
            val T' = T;
wenzelm@40282
   648
            val (t', T'') = insert (T' :: bs) t;
wenzelm@40281
   649
          in
wenzelm@40281
   650
            (Abs (x, T', t'), T' --> T'')
wenzelm@40281
   651
          end
wenzelm@40281
   652
      | insert bs (t $ u) =
wenzelm@40281
   653
          let
traytel@40836
   654
            val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
wenzelm@40281
   655
            val (u', U') = insert bs u;
wenzelm@40281
   656
          in
traytel@40836
   657
            if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
traytel@40836
   658
            then (t' $ u', T)
traytel@40836
   659
            else (t' $ (gen_coercion ctxt tye (U', U) $ u'), T)
wenzelm@40281
   660
          end
wenzelm@40281
   661
  in
wenzelm@40281
   662
    map (fst o insert []) ts
wenzelm@40281
   663
  end;
wenzelm@40281
   664
wenzelm@40281
   665
wenzelm@40281
   666
wenzelm@40281
   667
(** assembling the pipeline **)
wenzelm@40281
   668
wenzelm@40281
   669
fun infer_types ctxt const_type var_type raw_ts =
wenzelm@40281
   670
  let
wenzelm@40281
   671
    val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts;
wenzelm@40281
   672
traytel@40836
   673
    fun inf _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   674
      | inf _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   675
      | inf _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
traytel@40836
   676
      | inf bs (t as (Bound i)) tye_idx =
traytel@40836
   677
          (t, snd (nth bs i handle Subscript => err_loose i), tye_idx)
traytel@40836
   678
      | inf bs (Abs (x, T, t)) tye_idx =
traytel@40836
   679
          let val (t', U, tye_idx') = inf ((x, T) :: bs) t tye_idx
traytel@40836
   680
          in (Abs (x, T, t'), T --> U, tye_idx') end
traytel@40836
   681
      | inf bs (t $ u) tye_idx =
traytel@40836
   682
          let
traytel@40836
   683
            val (t', T, tye_idx') = inf bs t tye_idx;
traytel@40836
   684
            val (u', U, (tye, idx)) = inf bs u tye_idx';
traytel@40836
   685
            val V = Type_Infer.mk_param idx [];
traytel@40836
   686
            val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
wenzelm@42383
   687
              handle NO_UNIFIER (msg, tye') =>
traytel@40836
   688
                raise TYPE_INFERENCE_ERROR (err_appl_msg ctxt msg tye' bs t T u U);
traytel@40836
   689
          in (tu, V, tye_idx'') end;
wenzelm@40281
   690
wenzelm@42383
   691
    fun infer_single t tye_idx =
traytel@40836
   692
      let val (t, _, tye_idx') = inf [] t tye_idx;
traytel@40938
   693
      in (t, tye_idx') end;
wenzelm@42383
   694
traytel@40938
   695
    val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
wenzelm@42383
   696
      handle TYPE_INFERENCE_ERROR err =>
traytel@40836
   697
        let
traytel@40836
   698
          fun gen_single t (tye_idx, constraints) =
traytel@40836
   699
            let val (_, tye_idx', constraints') = generate_constraints ctxt err t tye_idx
traytel@40836
   700
            in (tye_idx', constraints' @ constraints) end;
wenzelm@42383
   701
traytel@40836
   702
          val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
traytel@40836
   703
          val (tye, idx) = process_constraints ctxt err constraints tye_idx;
wenzelm@42383
   704
        in
traytel@40836
   705
          (insert_coercions ctxt tye ts, (tye, idx))
traytel@40836
   706
        end);
wenzelm@40281
   707
wenzelm@40281
   708
    val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
wenzelm@40281
   709
  in ts'' end;
wenzelm@40281
   710
wenzelm@40281
   711
wenzelm@40281
   712
wenzelm@40281
   713
(** installation **)
wenzelm@40281
   714
wenzelm@40283
   715
(* term check *)
wenzelm@40283
   716
wenzelm@40281
   717
fun coercion_infer_types ctxt =
wenzelm@40281
   718
  infer_types ctxt
wenzelm@42361
   719
    (try (Consts.the_constraint (Proof_Context.consts_of ctxt)))
wenzelm@42361
   720
    (Proof_Context.def_type ctxt);
wenzelm@40281
   721
wenzelm@40939
   722
val (coercion_enabled, coercion_enabled_setup) = Attrib.config_bool "coercion_enabled" (K false);
wenzelm@40939
   723
wenzelm@40283
   724
val add_term_check =
wenzelm@40283
   725
  Syntax.add_term_check ~100 "coercions"
wenzelm@40283
   726
    (fn xs => fn ctxt =>
wenzelm@40939
   727
      if Config.get ctxt coercion_enabled then
wenzelm@40939
   728
        let val xs' = coercion_infer_types ctxt xs
wenzelm@40939
   729
        in if eq_list (op aconv) (xs, xs') then NONE else SOME (xs', ctxt) end
wenzelm@40939
   730
      else NONE);
wenzelm@40281
   731
wenzelm@40281
   732
wenzelm@40283
   733
(* declarations *)
wenzelm@40281
   734
wenzelm@40284
   735
fun add_type_map raw_t context =
wenzelm@40281
   736
  let
wenzelm@40281
   737
    val ctxt = Context.proof_of context;
wenzelm@40284
   738
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   739
wenzelm@42383
   740
    fun err_str t = "\n\nThe provided function has the type\n" ^
wenzelm@42383
   741
      Syntax.string_of_typ ctxt (fastype_of t) ^
traytel@41353
   742
      "\n\nThe general type signature of a map function is" ^
traytel@41353
   743
      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
wenzelm@40281
   744
      "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
wenzelm@42383
   745
traytel@41353
   746
    val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
traytel@41353
   747
      handle Empty => error ("Not a proper map function:" ^ err_str t);
wenzelm@42383
   748
wenzelm@40281
   749
    fun gen_arg_var ([], []) = []
wenzelm@40282
   750
      | gen_arg_var ((T, T') :: Ts, (U, U') :: Us) =
traytel@41353
   751
          if U = U' then
traytel@41353
   752
            if is_stypeT U then INVARIANT_TO U :: gen_arg_var ((T, T') :: Ts, Us)
traytel@41353
   753
            else error ("Invariant xi and yi should be base types:" ^ err_str t)
traytel@41353
   754
          else if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
wenzelm@40281
   755
          else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
traytel@41353
   756
          else error ("Functions do not apply to arguments correctly:" ^ err_str t)
wenzelm@42383
   757
      | gen_arg_var (_, Ts) =
wenzelm@42383
   758
          if forall (op = andf is_stypeT o fst) Ts
traytel@41353
   759
          then map (INVARIANT_TO o fst) Ts
traytel@41353
   760
          else error ("Different numbers of functions and variant arguments\n" ^ err_str t);
wenzelm@40281
   761
traytel@41353
   762
    (*retry flag needed to adjust the type lists, when given a map over type constructor fun*)
traytel@41353
   763
    fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) retry =
traytel@41353
   764
          if C1 = C2 andalso not (null fis) andalso forall is_funtype fis
traytel@41353
   765
          then ((map dest_funT fis, Ts ~~ Us), C1)
traytel@41353
   766
          else error ("Not a proper map function:" ^ err_str t)
traytel@41353
   767
      | check_map_fun fis T1 T2 true =
traytel@41353
   768
          let val (fis', T') = split_last fis
traytel@41353
   769
          in check_map_fun fis' T' (T1 --> T2) false end
traytel@41353
   770
      | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t);
wenzelm@40281
   771
traytel@41353
   772
    val res = check_map_fun fis T1 T2 true;
wenzelm@40281
   773
    val res_av = gen_arg_var (fst res);
wenzelm@40281
   774
  in
wenzelm@40281
   775
    map_tmaps (Symtab.update (snd res, (t, res_av))) context
wenzelm@40281
   776
  end;
wenzelm@40281
   777
wenzelm@40284
   778
fun add_coercion raw_t context =
wenzelm@40281
   779
  let
wenzelm@40281
   780
    val ctxt = Context.proof_of context;
wenzelm@40284
   781
    val t = singleton (Variable.polymorphic ctxt) raw_t;
wenzelm@40281
   782
wenzelm@40281
   783
    fun err_coercion () = error ("Bad type for coercion " ^
wenzelm@40281
   784
        Syntax.string_of_term ctxt t ^ ":\n" ^
wenzelm@40281
   785
        Syntax.string_of_typ ctxt (fastype_of t));
wenzelm@40281
   786
wenzelm@40840
   787
    val (T1, T2) = Term.dest_funT (fastype_of t)
wenzelm@40840
   788
      handle TYPE _ => err_coercion ();
wenzelm@40281
   789
wenzelm@40281
   790
    val a =
wenzelm@40281
   791
      (case T1 of
wenzelm@40281
   792
        Type (x, []) => x
wenzelm@40281
   793
      | _ => err_coercion ());
wenzelm@40281
   794
wenzelm@40281
   795
    val b =
wenzelm@40281
   796
      (case T2 of
wenzelm@40281
   797
        Type (x, []) => x
wenzelm@40281
   798
      | _ => err_coercion ());
wenzelm@40281
   799
wenzelm@40281
   800
    fun coercion_data_update (tab, G) =
wenzelm@40281
   801
      let
wenzelm@40281
   802
        val G' = maybe_new_nodes [a, b] G
wenzelm@40281
   803
        val G'' = Graph.add_edge_trans_acyclic (a, b) G'
wenzelm@40281
   804
          handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^
wenzelm@40281
   805
            "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b);
wenzelm@40281
   806
        val new_edges =
wenzelm@40281
   807
          flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
wenzelm@40281
   808
            if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
wenzelm@40281
   809
        val G_and_new = Graph.add_edge (a, b) G';
wenzelm@40281
   810
wenzelm@40281
   811
        fun complex_coercion tab G (a, b) =
wenzelm@40281
   812
          let
wenzelm@40281
   813
            val path = hd (Graph.irreducible_paths G (a, b))
traytel@40836
   814
            val path' = fst (split_last path) ~~ tl path
wenzelm@40281
   815
          in Abs (Name.uu, Type (a, []),
wenzelm@40281
   816
              fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
wenzelm@40281
   817
          end;
wenzelm@40281
   818
wenzelm@40281
   819
        val tab' = fold
wenzelm@40281
   820
          (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab)
wenzelm@40281
   821
          (filter (fn pair => pair <> (a, b)) new_edges)
wenzelm@40281
   822
          (Symreltab.update ((a, b), t) tab);
wenzelm@40281
   823
      in
wenzelm@40281
   824
        (tab', G'')
wenzelm@40281
   825
      end;
wenzelm@40281
   826
  in
wenzelm@40281
   827
    map_coes_and_graph coercion_data_update context
wenzelm@40281
   828
  end;
wenzelm@40281
   829
wenzelm@40283
   830
wenzelm@40283
   831
(* theory setup *)
wenzelm@40283
   832
wenzelm@40283
   833
val setup =
wenzelm@40939
   834
  coercion_enabled_setup #>
wenzelm@40283
   835
  Context.theory_map add_term_check #>
wenzelm@40284
   836
  Attrib.setup @{binding coercion}
wenzelm@40284
   837
    (Args.term >> (fn t => Thm.declaration_attribute (K (add_coercion t))))
wenzelm@40281
   838
    "declaration of new coercions" #>
traytel@40297
   839
  Attrib.setup @{binding coercion_map}
wenzelm@40284
   840
    (Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t))))
wenzelm@40283
   841
    "declaration of new map functions";
wenzelm@40281
   842
wenzelm@40281
   843
end;