wenzelm@40281: (* Title: Tools/subtyping.ML wenzelm@40281: Author: Dmitriy Traytel, TU Muenchen wenzelm@40281: wenzelm@40281: Coercive subtyping via subtype constraints. wenzelm@40281: *) wenzelm@40281: wenzelm@40281: signature SUBTYPING = wenzelm@40281: sig wenzelm@40939: val coercion_enabled: bool Config.T wenzelm@40284: val add_type_map: term -> Context.generic -> Context.generic wenzelm@40284: val add_coercion: term -> Context.generic -> Context.generic traytel@45059: val print_coercions: Proof.context -> unit wenzelm@40283: val setup: theory -> theory wenzelm@40281: end; wenzelm@40281: wenzelm@40283: structure Subtyping: SUBTYPING = wenzelm@40281: struct wenzelm@40281: wenzelm@40281: (** coercions data **) wenzelm@40281: traytel@41353: datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ; traytel@51327: datatype coerce_arg = PERMIT | FORBID | LEAVE wenzelm@40281: wenzelm@40281: datatype data = Data of traytel@45060: {coes: (term * ((typ list * typ list) * term list)) Symreltab.table, (*coercions table*) traytel@45060: (*full coercions graph - only used at coercion declaration/deletion*) traytel@45060: full_graph: int Graph.T, wenzelm@52432: (*coercions graph restricted to base types - for efficiency reasons stored in the context*) traytel@45060: coes_graph: int Graph.T, traytel@51319: tmaps: (term * variance list) Symtab.table, (*map functions*) traytel@51327: coerce_args: coerce_arg list Symtab.table (*special constants with non-coercible arguments*)}; wenzelm@40281: traytel@51319: fun make_data (coes, full_graph, coes_graph, tmaps, coerce_args) = traytel@51319: Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph, traytel@51319: tmaps = tmaps, coerce_args = coerce_args}; wenzelm@40281: traytel@45935: fun merge_error_coes (a, b) = traytel@45935: error ("Cannot merge coercion tables.\nConflicting declarations for coercions from " ^ traytel@45935: quote a ^ " to " ^ quote b ^ "."); traytel@45935: traytel@45935: fun merge_error_tmaps C = traytel@45935: error ("Cannot merge coercion map tables.\nConflicting declarations for the constructor " ^ traytel@45935: quote C ^ "."); traytel@45935: traytel@51319: fun merge_error_coerce_args C = wenzelm@55303: error ("Cannot merge tables for constants with coercion-invariant arguments.\n" ^ wenzelm@55303: "Conflicting declarations for the constant " ^ quote C ^ "."); traytel@51319: wenzelm@40281: structure Data = Generic_Data wenzelm@40281: ( wenzelm@40281: type T = data; traytel@51319: val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty, Symtab.empty); wenzelm@40281: val extend = I; wenzelm@40281: fun merge traytel@51319: (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1, traytel@51319: tmaps = tmaps1, coerce_args = coerce_args1}, traytel@51319: Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2, traytel@51319: tmaps = tmaps2, coerce_args = coerce_args2}) = traytel@45060: make_data (Symreltab.merge (eq_pair (op aconv) traytel@45060: (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv)))) traytel@45935: (coes1, coes2) handle Symreltab.DUP key => merge_error_coes key, traytel@45060: Graph.merge (op =) (full_graph1, full_graph2), wenzelm@40281: Graph.merge (op =) (coes_graph1, coes_graph2), traytel@45935: Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2) traytel@51319: handle Symtab.DUP key => merge_error_tmaps key, traytel@51319: Symtab.merge (eq_list (op =)) (coerce_args1, coerce_args2) traytel@51319: handle Symtab.DUP key => merge_error_coerce_args key); wenzelm@40281: ); wenzelm@40281: wenzelm@40281: fun map_data f = traytel@51319: Data.map (fn Data {coes, full_graph, coes_graph, tmaps, coerce_args} => traytel@51319: make_data (f (coes, full_graph, coes_graph, tmaps, coerce_args))); wenzelm@40281: wenzelm@40281: fun map_coes f = traytel@51319: map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) => traytel@51319: (f coes, full_graph, coes_graph, tmaps, coerce_args)); wenzelm@40281: wenzelm@40281: fun map_coes_graph f = traytel@51319: map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) => traytel@51319: (coes, full_graph, f coes_graph, tmaps, coerce_args)); wenzelm@40281: traytel@45060: fun map_coes_and_graphs f = traytel@51319: map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) => traytel@45060: let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph); traytel@51319: in (coes', full_graph', coes_graph', tmaps, coerce_args) end); wenzelm@40281: wenzelm@40281: fun map_tmaps f = traytel@51319: map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) => traytel@51319: (coes, full_graph, coes_graph, f tmaps, coerce_args)); traytel@51319: traytel@51319: fun map_coerce_args f = traytel@51319: map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) => traytel@51319: (coes, full_graph, coes_graph, tmaps, f coerce_args)); wenzelm@40281: wenzelm@40285: val rep_data = (fn Data args => args) o Data.get o Context.Proof; wenzelm@40281: wenzelm@40281: val coes_of = #coes o rep_data; wenzelm@40281: val coes_graph_of = #coes_graph o rep_data; wenzelm@40281: val tmaps_of = #tmaps o rep_data; traytel@51319: val coerce_args_of = #coerce_args o rep_data; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** utils **) wenzelm@40281: wenzelm@46614: fun restrict_graph G = Graph.restrict (fn x => Graph.get_node G x = 0) G; traytel@45060: wenzelm@40281: fun nameT (Type (s, [])) = s; wenzelm@40281: fun t_of s = Type (s, []); wenzelm@40286: wenzelm@40281: fun sort_of (TFree (_, S)) = SOME S wenzelm@40281: | sort_of (TVar (_, S)) = SOME S wenzelm@40281: | sort_of _ = NONE; wenzelm@40281: wenzelm@40281: val is_typeT = fn (Type _) => true | _ => false; traytel@41353: val is_stypeT = fn (Type (_, [])) => true | _ => false; wenzelm@40282: val is_compT = fn (Type (_, _ :: _)) => true | _ => false; wenzelm@40281: val is_freeT = fn (TFree _) => true | _ => false; wenzelm@40286: val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false; traytel@41353: val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false; traytel@51335: traytel@51335: fun mk_identity T = Abs (Name.uu, T, Bound 0); traytel@43591: val is_identity = fn (Abs (_, _, Bound 0)) => true | _ => false; wenzelm@40281: traytel@45060: fun instantiate t Ts = Term.subst_TVars traytel@45060: ((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts) t; traytel@45060: wenzelm@55303: exception COERCION_GEN_ERROR of unit -> string * Buffer.T; traytel@54584: wenzelm@55303: infixr ++> (*composition with deferred error message*) wenzelm@55303: fun (err : unit -> string * Buffer.T) ++> s = wenzelm@55303: err #> apsnd (Buffer.add s); traytel@54584: wenzelm@55303: fun eval_err err = wenzelm@55303: let val (s, buf) = err () wenzelm@55303: in s ^ Markup.markup Markup.text_fold (Buffer.content buf) end; traytel@54584: wenzelm@55303: fun eval_error err = error (eval_err err); traytel@45060: traytel@45060: fun inst_collect tye err T U = traytel@45060: (case (T, Type_Infer.deref tye U) of traytel@54584: (TVar (xi, _), U) => [(xi, U)] traytel@45060: | (Type (a, Ts), Type (b, Us)) => wenzelm@55303: if a <> b then eval_error err else inst_collects tye err Ts Us wenzelm@55303: | (_, U') => if T <> U' then eval_error err else []) traytel@45060: and inst_collects tye err Ts Us = traytel@45060: fold2 (fn T => fn U => fn is => inst_collect tye err T U @ is) Ts Us []; traytel@45060: wenzelm@40281: traytel@40836: (* unification *) wenzelm@40281: wenzelm@40281: exception NO_UNIFIER of string * typ Vartab.table; wenzelm@40281: wenzelm@40281: fun unify weak ctxt = wenzelm@40281: let wenzelm@42361: val thy = Proof_Context.theory_of ctxt; wenzelm@42386: val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy); wenzelm@40281: wenzelm@40282: wenzelm@40281: (* adjust sorts of parameters *) wenzelm@40281: wenzelm@40281: fun not_of_sort x S' S = wenzelm@40281: "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^ wenzelm@40281: Syntax.string_of_sort ctxt S; wenzelm@40281: wenzelm@40281: fun meet (_, []) tye_idx = tye_idx wenzelm@40281: | meet (Type (a, Ts), S) (tye_idx as (tye, _)) = wenzelm@40281: meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx wenzelm@40281: | meet (TFree (x, S'), S) (tye_idx as (tye, _)) = wenzelm@40281: if Sign.subsort thy (S', S) then tye_idx wenzelm@40281: else raise NO_UNIFIER (not_of_sort x S' S, tye) wenzelm@40281: | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = wenzelm@40281: if Sign.subsort thy (S', S) then tye_idx wenzelm@40281: else if Type_Infer.is_param xi then wenzelm@40286: (Vartab.update_new wenzelm@40286: (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) wenzelm@40281: else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) wenzelm@40281: and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = wenzelm@40286: meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx) wenzelm@40281: | meets _ tye_idx = tye_idx; wenzelm@40281: wenzelm@55301: val weak_meet = if weak then fn _ => I else meet; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* occurs check and assignment *) wenzelm@40281: wenzelm@40281: fun occurs_check tye xi (TVar (xi', _)) = wenzelm@40281: if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye) wenzelm@40281: else wenzelm@40281: (case Vartab.lookup tye xi' of wenzelm@40281: NONE => () wenzelm@40281: | SOME T => occurs_check tye xi T) wenzelm@40281: | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts wenzelm@40281: | occurs_check _ _ _ = (); wenzelm@40281: wenzelm@40281: fun assign xi (T as TVar (xi', _)) S env = wenzelm@40281: if xi = xi' then env wenzelm@40281: else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T) wenzelm@40281: | assign xi T S (env as (tye, _)) = wenzelm@40281: (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)); wenzelm@40281: wenzelm@40281: wenzelm@40281: (* unification *) wenzelm@40281: wenzelm@40281: fun show_tycon (a, Ts) = wenzelm@40281: quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); wenzelm@40281: wenzelm@40281: fun unif (T1, T2) (env as (tye, _)) = wenzelm@40286: (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of wenzelm@40281: ((true, TVar (xi, S)), (_, T)) => assign xi T S env wenzelm@40281: | ((_, T), (true, TVar (xi, S))) => assign xi T S env wenzelm@40281: | ((_, Type (a, Ts)), (_, Type (b, Us))) => wenzelm@40281: if weak andalso null Ts andalso null Us then env wenzelm@40281: else if a <> b then wenzelm@40281: raise NO_UNIFIER wenzelm@40281: ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye) wenzelm@40281: else fold unif (Ts ~~ Us) env wenzelm@40281: | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye)); wenzelm@40281: wenzelm@40281: in unif end; wenzelm@40281: wenzelm@40281: val weak_unify = unify true; wenzelm@40281: val strong_unify = unify false; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* Typ_Graph shortcuts *) wenzelm@40281: wenzelm@40281: fun get_preds G T = Typ_Graph.all_preds G [T]; wenzelm@40281: fun get_succs G T = Typ_Graph.all_succs G [T]; wenzelm@40281: fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G; wenzelm@40281: fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G; wenzelm@44338: fun new_imm_preds G Ts = (* FIXME inefficient *) wenzelm@44338: subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_preds G) Ts)); wenzelm@44338: fun new_imm_succs G Ts = (* FIXME inefficient *) wenzelm@44338: subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.immediate_succs G) Ts)); wenzelm@40281: wenzelm@40281: wenzelm@40281: (* Graph shortcuts *) wenzelm@40281: wenzelm@55301: fun maybe_new_node s G = perhaps (try (Graph.new_node s)) G; wenzelm@55301: fun maybe_new_nodes ss G = fold maybe_new_node ss G; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** error messages **) wenzelm@40281: traytel@54584: fun gen_err err msg = wenzelm@55303: err ++> ("\nNow trying to infer coercions globally.\n\nCoercion inference failed" ^ wenzelm@55303: (if msg = "" then "" else ":\n" ^ msg) ^ "\n"); traytel@45060: traytel@54584: val gen_msg = eval_err oo gen_err traytel@40836: wenzelm@40281: fun prep_output ctxt tye bs ts Ts = wenzelm@40281: let wenzelm@40281: val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts); wenzelm@40281: val (Ts', Ts'') = chop (length Ts) Ts_bTs'; wenzelm@40281: fun prep t = wenzelm@40281: let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) wenzelm@49660: in Term.subst_bounds (map Syntax_Trans.mark_bound_abs xs, t) end; wenzelm@40281: in (map prep ts', Ts') end; wenzelm@40281: wenzelm@40281: fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i); wenzelm@42383: traytel@40836: fun unif_failed msg = traytel@40836: "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n"; wenzelm@42383: traytel@40836: fun err_appl_msg ctxt msg tye bs t T u U () = wenzelm@55301: let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U] in wenzelm@55301: (unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n\n", wenzelm@55303: Buffer.empty |> Buffer.add "Coercion Inference:\n\n") wenzelm@55301: end; wenzelm@40281: traytel@54584: fun err_list ctxt err tye Ts = wenzelm@55303: let val (_, Ts') = prep_output ctxt tye [] [] Ts in wenzelm@55303: eval_error (err ++> wenzelm@55303: ("\nCannot unify a list of types that should be the same:\n" ^ wenzelm@55303: Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts')))) wenzelm@40281: end; wenzelm@40281: traytel@54584: fun err_bound ctxt err tye packs = wenzelm@40281: let wenzelm@40281: val (ts, Ts) = fold wenzelm@40281: (fn (bs, t $ u, U, _, U') => fn (ts, Ts) => traytel@40836: let val (t', T') = prep_output ctxt tye bs [t, u] [U', U] wenzelm@40282: in (t' :: ts, T' :: Ts) end) wenzelm@40281: packs ([], []); wenzelm@55303: val msg = wenzelm@55303: Pretty.string_of (Pretty.big_list "Cannot fulfil subtype constraints:" traytel@45060: (map2 (fn [t, u] => fn [T, U] => wenzelm@40281: Pretty.block [ wenzelm@42383: Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, wenzelm@42383: Syntax.pretty_typ ctxt U, Pretty.brk 3, wenzelm@42383: Pretty.str "from function application", Pretty.brk 2, traytel@45060: Pretty.block [Syntax.pretty_term ctxt (t $ u)]]) wenzelm@55303: ts Ts)); wenzelm@55303: in eval_error (err ++> ("\n" ^ msg)) end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** constraint generation **) wenzelm@40281: traytel@51319: fun update_coerce_arg ctxt old t = traytel@51319: let traytel@51319: val mk_coerce_args = the_default [] o Symtab.lookup (coerce_args_of ctxt); traytel@51319: fun update _ [] = old traytel@51327: | update 0 (coerce :: _) = (case coerce of LEAVE => old | PERMIT => true | FORBID => false) traytel@51319: | update n (_ :: cs) = update (n - 1) cs; traytel@51319: val (f, n) = Term.strip_comb (Type.strip_constraints t) ||> length; traytel@51319: in traytel@51319: update n (case f of Const (name, _) => mk_coerce_args name | _ => []) traytel@51319: end; traytel@51319: traytel@40836: fun generate_constraints ctxt err = wenzelm@40281: let traytel@51319: fun gen _ cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs) traytel@51319: | gen _ cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs) traytel@51319: | gen _ cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs) traytel@51319: | gen _ cs bs (Bound i) tye_idx = wenzelm@43278: (snd (nth bs i handle General.Subscript => err_loose i), tye_idx, cs) traytel@51319: | gen coerce cs bs (Abs (x, T, t)) tye_idx = traytel@51319: let val (U, tye_idx', cs') = gen coerce cs ((x, T) :: bs) t tye_idx wenzelm@40281: in (T --> U, tye_idx', cs') end traytel@51319: | gen coerce cs bs (t $ u) tye_idx = wenzelm@40281: let traytel@51319: val (T, tye_idx', cs') = gen coerce cs bs t tye_idx; traytel@51319: val coerce' = update_coerce_arg ctxt coerce t; traytel@51319: val (U', (tye, idx), cs'') = gen coerce' cs' bs u tye_idx'; wenzelm@40286: val U = Type_Infer.mk_param idx []; wenzelm@40286: val V = Type_Infer.mk_param (idx + 1) []; traytel@45060: val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2) traytel@41353: handle NO_UNIFIER (msg, _) => error (gen_msg err msg); wenzelm@40281: val error_pack = (bs, t $ u, U, V, U'); wenzelm@52432: in traytel@51319: if coerce' traytel@51319: then (V, tye_idx'', ((U', U), error_pack) :: cs'') traytel@51319: else (V, traytel@51319: strong_unify ctxt (U, U') tye_idx'' traytel@51319: handle NO_UNIFIER (msg, _) => error (gen_msg err msg), traytel@51319: cs'') traytel@51319: end; wenzelm@40281: in traytel@51319: gen true [] [] wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** constraint resolution **) wenzelm@40281: wenzelm@40281: exception BOUND_ERROR of string; wenzelm@40281: traytel@40836: fun process_constraints ctxt err cs tye_idx = wenzelm@40281: let wenzelm@42388: val thy = Proof_Context.theory_of ctxt; wenzelm@42388: wenzelm@40285: val coes_graph = coes_graph_of ctxt; wenzelm@40285: val tmaps = tmaps_of ctxt; wenzelm@42388: val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy); wenzelm@40281: wenzelm@40281: fun split_cs _ [] = ([], []) wenzelm@40282: | split_cs f (c :: cs) = wenzelm@40281: (case pairself f (fst c) of wenzelm@40281: (false, false) => apsnd (cons c) (split_cs f cs) wenzelm@40281: | _ => apfst (cons c) (split_cs f cs)); wenzelm@42383: traytel@41353: fun unify_list (T :: Ts) tye_idx = wenzelm@42383: fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx; wenzelm@40281: wenzelm@40282: wenzelm@40281: (* check whether constraint simplification will terminate using weak unification *) wenzelm@40282: traytel@41353: val _ = fold (fn (TU, _) => fn tye_idx => traytel@41353: weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) => traytel@40836: error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* simplify constraints *) wenzelm@40282: wenzelm@40281: fun simplify_constraints cs tye_idx = wenzelm@40281: let wenzelm@40281: fun contract a Ts Us error_pack done todo tye idx = wenzelm@40281: let wenzelm@40281: val arg_var = wenzelm@40281: (case Symtab.lookup tmaps a of wenzelm@40281: (*everything is invariant for unknown constructors*) wenzelm@40281: NONE => replicate (length Ts) INVARIANT wenzelm@40281: | SOME av => snd av); wenzelm@40281: fun new_constraints (variance, constraint) (cs, tye_idx) = wenzelm@40281: (case variance of wenzelm@40281: COVARIANT => (constraint :: cs, tye_idx) wenzelm@40281: | CONTRAVARIANT => (swap constraint :: cs, tye_idx) traytel@41353: | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx wenzelm@42383: handle NO_UNIFIER (msg, _) => traytel@54584: err_list ctxt (gen_err err traytel@54584: ("failed to unify invariant arguments w.r.t. to the known map function\n" ^ traytel@54584: msg)) traytel@45060: (fst tye_idx) (T :: Ts)) wenzelm@40281: | INVARIANT => (cs, strong_unify ctxt constraint tye_idx wenzelm@42383: handle NO_UNIFIER (msg, _) => traytel@51248: error (gen_msg err ("failed to unify invariant arguments\n" ^ msg)))); wenzelm@40281: val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack)) wenzelm@40281: (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx))); traytel@49142: val test_update = is_typeT orf is_freeT orf is_fixedvarT; wenzelm@40281: val (ch, done') = traytel@51246: done traytel@51246: |> map (apfst (pairself (Type_Infer.deref tye'))) traytel@51246: |> (if not (null new) then rpair [] else split_cs test_update); wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: in wenzelm@40281: simplify done' (new @ todo') (tye', idx') wenzelm@40281: end wenzelm@40281: (*xi is definitely a parameter*) wenzelm@40281: and expand varleq xi S a Ts error_pack done todo tye idx = wenzelm@40281: let wenzelm@40281: val n = length Ts; wenzelm@40286: val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S); wenzelm@40281: val tye' = Vartab.update_new (xi, Type(a, args)) tye; wenzelm@40286: val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done; wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: val new = wenzelm@40281: if varleq then (Type(a, args), Type (a, Ts)) wenzelm@40286: else (Type (a, Ts), Type (a, args)); wenzelm@40281: in wenzelm@40281: simplify done' ((new, error_pack) :: todo') (tye', idx + n) wenzelm@40281: end wenzelm@40281: (*TU is a pair of a parameter and a free/fixed variable*) traytel@41353: and eliminate TU done todo tye idx = wenzelm@40281: let wenzelm@40286: val [TVar (xi, S)] = filter Type_Infer.is_paramT TU; wenzelm@40286: val [T] = filter_out Type_Infer.is_paramT TU; wenzelm@40281: val SOME S' = sort_of T; wenzelm@40281: val test_update = if is_freeT T then is_freeT else is_fixedvarT; wenzelm@40281: val tye' = Vartab.update_new (xi, T) tye; wenzelm@40286: val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done; wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: in wenzelm@42388: if Sign.subsort thy (S', S) (*TODO check this*) wenzelm@40281: then simplify done' todo' (tye', idx) traytel@40836: else error (gen_msg err "sort mismatch") wenzelm@40281: end wenzelm@40281: and simplify done [] tye_idx = (done, tye_idx) wenzelm@40281: | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) = wenzelm@40286: (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of traytel@45060: (T1 as Type (a, []), T2 as Type (b, [])) => wenzelm@40281: if a = b then simplify done todo tye_idx wenzelm@40281: else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx wenzelm@55303: else wenzelm@55303: error (gen_msg err (quote (Syntax.string_of_typ ctxt T1) ^ wenzelm@55303: " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2))) wenzelm@40281: | (Type (a, Ts), Type (b, Us)) => wenzelm@55303: if a <> b then wenzelm@55303: error (gen_msg err "different constructors") (fst tye_idx) error_pack wenzelm@40281: else contract a Ts Us error_pack done todo tye idx wenzelm@40282: | (TVar (xi, S), Type (a, Ts as (_ :: _))) => wenzelm@40281: expand true xi S a Ts error_pack done todo tye idx wenzelm@40282: | (Type (a, Ts as (_ :: _)), TVar (xi, S)) => wenzelm@40281: expand false xi S a Ts error_pack done todo tye idx wenzelm@40281: | (T, U) => wenzelm@40281: if T = U then simplify done todo tye_idx wenzelm@40282: else if exists (is_freeT orf is_fixedvarT) [T, U] andalso wenzelm@40286: exists Type_Infer.is_paramT [T, U] traytel@41353: then eliminate [T, U] done todo tye idx wenzelm@40281: else if exists (is_freeT orf is_fixedvarT) [T, U] traytel@40836: then error (gen_msg err "not eliminated free/fixed variables") wenzelm@40282: else simplify (((T, U), error_pack) :: done) todo tye_idx); wenzelm@40281: in wenzelm@40281: simplify [] cs tye_idx wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* do simplification *) wenzelm@40282: wenzelm@40281: val (cs', tye_idx') = simplify_constraints cs tye_idx; wenzelm@42383: wenzelm@42383: fun find_error_pack lower T' = map_filter traytel@40836: (fn ((T, U), pack) => if if lower then T' = U else T' = T then SOME pack else NONE) cs'; wenzelm@42383: wenzelm@42383: fun find_cycle_packs nodes = traytel@40836: let traytel@40836: val (but_last, last) = split_last nodes traytel@40836: val pairs = (last, hd nodes) :: (but_last ~~ tl nodes); traytel@40836: in traytel@40836: map_filter wenzelm@40838: (fn (TU, pack) => if member (op =) pairs TU then SOME pack else NONE) traytel@40836: cs' traytel@40836: end; wenzelm@40281: wenzelm@40281: (*styps stands either for supertypes or for subtypes of a type T wenzelm@40281: in terms of the subtype-relation (excluding T itself)*) wenzelm@40282: fun styps super T = wenzelm@44338: (if super then Graph.immediate_succs else Graph.immediate_preds) coes_graph T wenzelm@40281: handle Graph.UNDEF _ => []; wenzelm@40281: wenzelm@40282: fun minmax sup (T :: Ts) = wenzelm@40281: let wenzelm@40281: fun adjust T U = if sup then (T, U) else (U, T); wenzelm@40281: fun extract T [] = T wenzelm@40282: | extract T (U :: Us) = wenzelm@40281: if Graph.is_edge coes_graph (adjust T U) then extract T Us wenzelm@40281: else if Graph.is_edge coes_graph (adjust U T) then extract U Us traytel@40836: else raise BOUND_ERROR "uncomparable types in type list"; wenzelm@40281: in wenzelm@40281: t_of (extract T Ts) wenzelm@40281: end; wenzelm@40281: wenzelm@40282: fun ex_styp_of_sort super T styps_and_sorts = wenzelm@40281: let wenzelm@40281: fun adjust T U = if super then (T, U) else (U, T); wenzelm@40282: fun styp_test U Ts = forall wenzelm@40281: (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts; wenzelm@55301: fun fitting Ts S U = Sign.of_sort thy (t_of U, S) andalso styp_test U Ts; wenzelm@40281: in wenzelm@40281: forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts wenzelm@40281: end; wenzelm@40281: wenzelm@40281: (* computes the tightest possible, correct assignment for 'a::S wenzelm@40281: e.g. in the supremum case (sup = true): wenzelm@40281: ------- 'a::S--- wenzelm@40281: / / \ \ wenzelm@40281: / / \ \ wenzelm@40281: 'b::C1 'c::C2 ... T1 T2 ... wenzelm@40281: wenzelm@40281: sorts - list of sorts [C1, C2, ...] wenzelm@40281: T::Ts - non-empty list of base types [T1, T2, ...] wenzelm@40281: *) wenzelm@40282: fun tightest sup S styps_and_sorts (T :: Ts) = wenzelm@40281: let wenzelm@42388: fun restriction T = Sign.of_sort thy (t_of T, S) wenzelm@40281: andalso ex_styp_of_sort (not sup) T styps_and_sorts; wenzelm@40281: fun candidates T = inter (op =) (filter restriction (T :: styps sup T)); wenzelm@40281: in wenzelm@40281: (case fold candidates Ts (filter restriction (T :: styps sup T)) of traytel@40836: [] => raise BOUND_ERROR ("no " ^ (if sup then "supremum" else "infimum")) wenzelm@40281: | [T] => t_of T wenzelm@40281: | Ts => minmax sup Ts) wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun build_graph G [] tye_idx = (G, tye_idx) wenzelm@40282: | build_graph G ((T, U) :: cs) tye_idx = wenzelm@40281: if T = U then build_graph G cs tye_idx wenzelm@40281: else wenzelm@40281: let wenzelm@40281: val G' = maybe_new_typnodes [T, U] G; traytel@45059: val (G'', tye_idx') = (Typ_Graph.add_edge_acyclic (T, U) G', tye_idx) wenzelm@40281: handle Typ_Graph.CYCLES cycles => wenzelm@40281: let wenzelm@42383: val (tye, idx) = wenzelm@42383: fold traytel@40836: (fn cycle => fn tye_idx' => (unify_list cycle tye_idx' wenzelm@42383: handle NO_UNIFIER (msg, _) => wenzelm@42383: err_bound ctxt traytel@54584: (gen_err err ("constraint cycle not unifiable\n" ^ msg)) (fst tye_idx) traytel@40836: (find_cycle_packs cycle))) traytel@40836: cycles tye_idx wenzelm@40281: in traytel@40836: collapse (tye, idx) cycles G traytel@40836: end wenzelm@40281: in wenzelm@40281: build_graph G'' cs tye_idx' wenzelm@40281: end traytel@40836: and collapse (tye, idx) cycles G = (*nodes non-empty list*) wenzelm@40281: let traytel@40836: (*all cycles collapse to one node, traytel@40836: because all of them share at least the nodes x and y*) traytel@40836: val nodes = (distinct (op =) (flat cycles)); traytel@40836: val T = Type_Infer.deref tye (hd nodes); wenzelm@40281: val P = new_imm_preds G nodes; wenzelm@40281: val S = new_imm_succs G nodes; wenzelm@46665: val G' = fold Typ_Graph.del_node (tl nodes) G; traytel@40836: fun check_and_gen super T' = traytel@40836: let val U = Type_Infer.deref tye T'; traytel@40836: in traytel@40836: if not (is_typeT T) orelse not (is_typeT U) orelse T = U traytel@40836: then if super then (hd nodes, T') else (T', hd nodes) wenzelm@42383: else wenzelm@42383: if super andalso traytel@40836: Graph.is_edge coes_graph (nameT T, nameT U) then (hd nodes, T') wenzelm@42383: else if not super andalso traytel@40836: Graph.is_edge coes_graph (nameT U, nameT T) then (T', hd nodes) wenzelm@55303: else wenzelm@55303: err_bound ctxt (gen_err err "cycle elimination produces inconsistent graph") wenzelm@55303: (fst tye_idx) wenzelm@55303: (maps find_cycle_packs cycles @ find_error_pack super T') traytel@40836: end; wenzelm@40281: in traytel@40836: build_graph G' (map (check_and_gen false) P @ map (check_and_gen true) S) (tye, idx) wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun assign_bound lower G key (tye_idx as (tye, _)) = wenzelm@40286: if Type_Infer.is_paramT (Type_Infer.deref tye key) then wenzelm@40281: let wenzelm@40286: val TVar (xi, S) = Type_Infer.deref tye key; wenzelm@40281: val get_bound = if lower then get_preds else get_succs; wenzelm@40281: val raw_bound = get_bound G key; wenzelm@40286: val bound = map (Type_Infer.deref tye) raw_bound; wenzelm@40286: val not_params = filter_out Type_Infer.is_paramT bound; wenzelm@40282: fun to_fulfil T = wenzelm@40281: (case sort_of T of wenzelm@40281: NONE => NONE wenzelm@40282: | SOME S => wenzelm@40286: SOME wenzelm@40286: (map nameT wenzelm@42405: (filter_out Type_Infer.is_paramT wenzelm@42405: (map (Type_Infer.deref tye) (get_bound G T))), S)); wenzelm@40281: val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound); wenzelm@40281: val assignment = wenzelm@40281: if null bound orelse null not_params then NONE wenzelm@40281: else SOME (tightest lower S styps_and_sorts (map nameT not_params) traytel@54584: handle BOUND_ERROR msg => err_bound ctxt (gen_err err msg) tye wenzelm@55301: (maps (find_error_pack (not lower)) raw_bound)); wenzelm@40281: in wenzelm@40281: (case assignment of wenzelm@40281: NONE => tye_idx wenzelm@40281: | SOME T => wenzelm@40286: if Type_Infer.is_paramT T then tye_idx wenzelm@40281: else if lower then (*upper bound check*) wenzelm@40281: let wenzelm@40286: val other_bound = map (Type_Infer.deref tye) (get_succs G key); wenzelm@40281: val s = nameT T; wenzelm@40281: in wenzelm@40281: if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s) wenzelm@40281: then apfst (Vartab.update (xi, T)) tye_idx wenzelm@55303: else wenzelm@55303: err_bound ctxt wenzelm@55303: (gen_err err wenzelm@55303: (Pretty.string_of (Pretty.block wenzelm@55303: [Pretty.str "assigned base type", Pretty.brk 1, wenzelm@55303: Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1, wenzelm@55303: Pretty.str "clashes with the upper bound of variable", Pretty.brk 1, wenzelm@55303: Syntax.pretty_typ ctxt (TVar (xi, S))]))) wenzelm@55303: tye wenzelm@55303: (maps (find_error_pack lower) other_bound) wenzelm@40281: end wenzelm@40281: else apfst (Vartab.update (xi, T)) tye_idx) wenzelm@40281: end wenzelm@40281: else tye_idx; wenzelm@40281: wenzelm@40281: val assign_lb = assign_bound true; wenzelm@40281: val assign_ub = assign_bound false; wenzelm@40281: wenzelm@40281: fun assign_alternating ts' ts G tye_idx = wenzelm@40281: if ts' = ts then tye_idx wenzelm@40281: else wenzelm@40281: let wenzelm@40281: val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx wenzelm@40281: |> fold (assign_ub G) ts; wenzelm@40281: in wenzelm@42383: assign_alternating ts traytel@40836: (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx' wenzelm@40281: end; wenzelm@40281: wenzelm@40281: (*Unify all weakly connected components of the constraint forest, wenzelm@40282: that contain only params. These are the only WCCs that contain wenzelm@40281: params anyway.*) wenzelm@40281: fun unify_params G (tye_idx as (tye, _)) = wenzelm@40281: let wenzelm@40286: val max_params = wenzelm@40286: filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G); wenzelm@40281: val to_unify = map (fn T => T :: get_preds G T) max_params; wenzelm@40281: in wenzelm@42383: fold traytel@40836: (fn Ts => fn tye_idx' => unify_list Ts tye_idx' traytel@54584: handle NO_UNIFIER (msg, _) => err_list ctxt (gen_err err msg) (fst tye_idx) Ts) traytel@40836: to_unify tye_idx wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun solve_constraints G tye_idx = tye_idx wenzelm@40281: |> assign_alternating [] (Typ_Graph.keys G) G wenzelm@40281: |> unify_params G; wenzelm@40281: in wenzelm@40281: build_graph Typ_Graph.empty (map fst cs') tye_idx' wenzelm@40281: |-> solve_constraints wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** coercion insertion **) wenzelm@40281: traytel@45060: fun gen_coercion ctxt err tye TU = traytel@45060: let traytel@45060: fun gen (T1, T2) = (case pairself (Type_Infer.deref tye) (T1, T2) of traytel@45060: (T1 as (Type (a, [])), T2 as (Type (b, []))) => traytel@45060: if a = b traytel@51335: then mk_identity T1 traytel@45060: else traytel@45060: (case Symreltab.lookup (coes_of ctxt) (a, b) of wenzelm@55303: NONE => wenzelm@55303: raise COERCION_GEN_ERROR (err ++> wenzelm@55303: (Pretty.string_of o Pretty.block) wenzelm@55303: [Pretty.quote (Syntax.pretty_typ ctxt T1), Pretty.brk 1, wenzelm@55303: Pretty.str "is not a subtype of", Pretty.brk 1, wenzelm@55303: Pretty.quote (Syntax.pretty_typ ctxt T2)]) traytel@45060: | SOME (co, _) => co) traytel@45102: | (T1 as Type (a, Ts), T2 as Type (b, Us)) => traytel@45060: if a <> b traytel@45060: then traytel@45060: (case Symreltab.lookup (coes_of ctxt) (a, b) of traytel@45060: (*immediate error - cannot fix complex coercion with the global algorithm*) wenzelm@55303: NONE => wenzelm@55303: eval_error (err ++> wenzelm@55304: ("No coercion known for type constructors: " ^ wenzelm@55304: quote (Proof_Context.markup_type ctxt a) ^ " and " ^ wenzelm@55304: quote (Proof_Context.markup_type ctxt b))) traytel@45060: | SOME (co, ((Ts', Us'), _)) => traytel@45060: let traytel@45060: val co_before = gen (T1, Type (a, Ts')); traytel@45060: val coT = range_type (fastype_of co_before); wenzelm@55303: val insts = wenzelm@55303: inst_collect tye (err ++> "Could not insert complex coercion") wenzelm@55303: (domain_type (fastype_of co)) coT; traytel@45060: val co' = Term.subst_TVars insts co; traytel@45060: val co_after = gen (Type (b, (map (typ_subst_TVars insts) Us')), T2); traytel@45060: in traytel@45060: Abs (Name.uu, T1, Library.foldr (op $) traytel@45060: (filter (not o is_identity) [co_after, co', co_before], Bound 0)) traytel@45060: end) traytel@45060: else traytel@45060: let traytel@51335: fun sub_co (COVARIANT, TU) = (SOME (gen TU), NONE) traytel@51335: | sub_co (CONTRAVARIANT, TU) = (SOME (gen (swap TU)), NONE) traytel@51335: | sub_co (INVARIANT, (T, _)) = (NONE, SOME T) traytel@51335: | sub_co (INVARIANT_TO T, _) = (NONE, NONE); traytel@45060: fun ts_of [] = [] traytel@45060: | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs); traytel@45060: in traytel@45060: (case Symtab.lookup (tmaps_of ctxt) a of traytel@45102: NONE => traytel@45102: if Type.could_unify (T1, T2) traytel@51335: then mk_identity T1 wenzelm@55303: else wenzelm@55303: raise COERCION_GEN_ERROR wenzelm@55304: (err ++> wenzelm@55304: ("No map function for " ^ quote (Proof_Context.markup_type ctxt a) wenzelm@55304: ^ " known")) traytel@51335: | SOME (tmap, variances) => traytel@45060: let traytel@51335: val (used_coes, invarTs) = traytel@51335: map_split sub_co (variances ~~ (Ts ~~ Us)) traytel@51335: |>> map_filter I traytel@51335: ||> map_filter I; traytel@51335: val Tinsts = ts_of (map fastype_of used_coes) @ invarTs; traytel@45060: in traytel@45060: if null (filter (not o is_identity) used_coes) traytel@51335: then mk_identity (Type (a, Ts)) traytel@51335: else Term.list_comb (instantiate tmap Tinsts, used_coes) traytel@45060: end) traytel@45060: end traytel@45060: | (T, U) => traytel@45060: if Type.could_unify (T, U) traytel@51335: then mk_identity T wenzelm@55303: else raise COERCION_GEN_ERROR (err ++> wenzelm@55303: (Pretty.string_of o Pretty.block) wenzelm@55303: [Pretty.str "Cannot generate coercion from", Pretty.brk 1, wenzelm@55303: Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1, wenzelm@55303: Pretty.str "to", Pretty.brk 1, wenzelm@55303: Pretty.quote (Syntax.pretty_typ ctxt U)])); traytel@45060: in traytel@45060: gen TU traytel@45060: end; traytel@40836: traytel@45060: fun function_of ctxt err tye T = traytel@45060: (case Type_Infer.deref tye T of traytel@45060: Type (C, Ts) => traytel@45060: (case Symreltab.lookup (coes_of ctxt) (C, "fun") of wenzelm@55304: NONE => wenzelm@55304: eval_error (err ++> ("No complex coercion from " ^ wenzelm@55304: quote (Proof_Context.markup_type ctxt C) ^ " to " ^ wenzelm@55304: quote (Proof_Context.markup_type ctxt "fun"))) traytel@45060: | SOME (co, ((Ts', _), _)) => traytel@45060: let traytel@45060: val co_before = gen_coercion ctxt err tye (Type (C, Ts), Type (C, Ts')); traytel@45060: val coT = range_type (fastype_of co_before); wenzelm@55303: val insts = wenzelm@55303: inst_collect tye (err ++> "Could not insert complex coercion") wenzelm@55303: (domain_type (fastype_of co)) coT; traytel@45060: val co' = Term.subst_TVars insts co; traytel@45060: in traytel@45060: Abs (Name.uu, Type (C, Ts), Library.foldr (op $) traytel@45060: (filter (not o is_identity) [co', co_before], Bound 0)) traytel@45060: end) wenzelm@55303: | T' => wenzelm@55303: eval_error (err ++> wenzelm@55303: (Pretty.string_of o Pretty.block) wenzelm@55303: [Pretty.str "No complex coercion from", Pretty.brk 1, wenzelm@55304: Pretty.quote (Syntax.pretty_typ ctxt T'), Pretty.brk 1, wenzelm@55304: Pretty.str "to", Pretty.brk 1, Proof_Context.pretty_type ctxt "fun"])); traytel@45060: traytel@45060: fun insert_coercions ctxt (tye, idx) ts = wenzelm@40281: let traytel@45060: fun insert _ (Const (c, T)) = (Const (c, T), T) traytel@45060: | insert _ (Free (x, T)) = (Free (x, T), T) traytel@45060: | insert _ (Var (xi, T)) = (Var (xi, T), T) wenzelm@40281: | insert bs (Bound i) = wenzelm@43278: let val T = nth bs i handle General.Subscript => err_loose i; wenzelm@40281: in (Bound i, T) end wenzelm@40281: | insert bs (Abs (x, T, t)) = traytel@45060: let val (t', T') = insert (T :: bs) t; traytel@45060: in (Abs (x, T, t'), T --> T') end wenzelm@40281: | insert bs (t $ u) = wenzelm@40281: let traytel@45060: val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t); wenzelm@40281: val (u', U') = insert bs u; wenzelm@40281: in traytel@40836: if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U') traytel@40836: then (t' $ u', T) wenzelm@55303: else (t' $ (gen_coercion ctxt (K ("", Buffer.empty)) tye (U', U) $ u'), T) wenzelm@40281: end wenzelm@40281: in wenzelm@40281: map (fst o insert []) ts wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** assembling the pipeline **) wenzelm@40281: wenzelm@42398: fun coercion_infer_types ctxt raw_ts = wenzelm@40281: let wenzelm@42405: val (idx, ts) = Type_Infer_Context.prepare ctxt raw_ts; wenzelm@40281: traytel@51319: fun inf _ _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx) traytel@51319: | inf _ _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx) traytel@51319: | inf _ _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx) traytel@51319: | inf _ bs (t as (Bound i)) tye_idx = wenzelm@43278: (t, snd (nth bs i handle General.Subscript => err_loose i), tye_idx) traytel@51319: | inf coerce bs (Abs (x, T, t)) tye_idx = traytel@51319: let val (t', U, tye_idx') = inf coerce ((x, T) :: bs) t tye_idx traytel@40836: in (Abs (x, T, t'), T --> U, tye_idx') end traytel@51319: | inf coerce bs (t $ u) tye_idx = traytel@40836: let traytel@51319: val (t', T, tye_idx') = inf coerce bs t tye_idx; traytel@51319: val coerce' = update_coerce_arg ctxt coerce t; traytel@51319: val (u', U, (tye, idx)) = inf coerce' bs u tye_idx'; traytel@40836: val V = Type_Infer.mk_param idx []; traytel@40836: val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1)) wenzelm@42383: handle NO_UNIFIER (msg, tye') => traytel@45060: let traytel@45060: val err = err_appl_msg ctxt msg tye' bs t T u U; traytel@45060: val W = Type_Infer.mk_param (idx + 1) []; traytel@45060: val (t'', (tye', idx')) = traytel@45060: (t', strong_unify ctxt (W --> V, T) (tye, idx + 2)) traytel@45060: handle NO_UNIFIER _ => traytel@45060: let wenzelm@55303: val err' = err ++> "Local coercion insertion on the operator failed:\n"; traytel@45060: val co = function_of ctxt err' tye T; traytel@51319: val (t'', T'', tye_idx'') = inf coerce bs (co $ t') (tye, idx + 2); traytel@45060: in traytel@45060: (t'', strong_unify ctxt (W --> V, T'') tye_idx'' wenzelm@55303: handle NO_UNIFIER (msg, _) => eval_error (err' ++> msg)) traytel@45060: end; wenzelm@55303: val err' = err ++> traytel@54584: ((if t' aconv t'' then "" wenzelm@55303: else "Successfully coerced the operator to a function of type:\n" ^ wenzelm@55303: Syntax.string_of_typ ctxt wenzelm@55303: (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^ wenzelm@55303: (if coerce' then "Local coercion insertion on the operand failed:\n" wenzelm@55303: else "Local coercion insertion on the operand disallowed:\n")); traytel@45060: val (u'', U', tye_idx') = wenzelm@52432: if coerce' then traytel@51319: let val co = gen_coercion ctxt err' tye' (U, W); traytel@51319: in inf coerce' bs (if is_identity co then u else co $ u) (tye', idx') end traytel@51319: else (u, U, (tye', idx')); traytel@45060: in traytel@45060: (t'' $ u'', strong_unify ctxt (U', W) tye_idx' wenzelm@55303: handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg)) traytel@45060: end; traytel@40836: in (tu, V, tye_idx'') end; wenzelm@40281: wenzelm@42383: fun infer_single t tye_idx = traytel@51319: let val (t, _, tye_idx') = inf true [] t tye_idx traytel@40938: in (t, tye_idx') end; wenzelm@42383: traytel@40938: val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx) traytel@45060: handle COERCION_GEN_ERROR err => traytel@40836: let traytel@40836: fun gen_single t (tye_idx, constraints) = traytel@45060: let val (_, tye_idx', constraints') = wenzelm@55303: generate_constraints ctxt (err ++> "\n") t tye_idx traytel@40836: in (tye_idx', constraints' @ constraints) end; wenzelm@42383: traytel@40836: val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []); wenzelm@55303: val (tye, idx) = process_constraints ctxt (err ++> "\n") constraints tye_idx; wenzelm@42383: in traytel@45060: (insert_coercions ctxt (tye, idx) ts, (tye, idx)) traytel@40836: end); wenzelm@40281: wenzelm@40281: val (_, ts'') = Type_Infer.finish ctxt tye ([], ts'); wenzelm@40281: in ts'' end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** installation **) wenzelm@40281: wenzelm@40283: (* term check *) wenzelm@40283: wenzelm@42616: val coercion_enabled = Attrib.setup_config_bool @{binding coercion_enabled} (K false); wenzelm@40939: wenzelm@40283: val add_term_check = wenzelm@45429: Syntax_Phases.term_check ~100 "coercions" wenzelm@42402: (fn ctxt => Config.get ctxt coercion_enabled ? coercion_infer_types ctxt); wenzelm@40281: wenzelm@40281: wenzelm@40283: (* declarations *) wenzelm@40281: wenzelm@40284: fun add_type_map raw_t context = wenzelm@40281: let wenzelm@40281: val ctxt = Context.proof_of context; wenzelm@40284: val t = singleton (Variable.polymorphic ctxt) raw_t; wenzelm@40281: traytel@45059: fun err_str t = "\n\nThe provided function has the type:\n" ^ wenzelm@42383: Syntax.string_of_typ ctxt (fastype_of t) ^ traytel@45059: "\n\nThe general type signature of a map function is:" ^ traytel@41353: "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^ traytel@45059: "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)."; wenzelm@42383: traytel@41353: val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t)) wenzelm@47060: handle List.Empty => error ("Not a proper map function:" ^ err_str t); wenzelm@42383: wenzelm@40281: fun gen_arg_var ([], []) = [] traytel@51335: | gen_arg_var (Ts, (U, U') :: Us) = traytel@41353: if U = U' then traytel@51335: if null (Term.add_tvarsT U []) then INVARIANT_TO U :: gen_arg_var (Ts, Us) traytel@51335: else if Term.is_TVar U then INVARIANT :: gen_arg_var (Ts, Us) traytel@51335: else error ("Invariant xi and yi should be variables or variable-free:" ^ err_str t) traytel@51335: else traytel@51335: (case Ts of traytel@51335: [] => error ("Different numbers of functions and variant arguments\n" ^ err_str t) traytel@51335: | (T, T') :: Ts => traytel@51335: if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us) traytel@51335: else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us) traytel@51335: else error ("Functions do not apply to arguments correctly:" ^ err_str t)); wenzelm@40281: traytel@41353: (*retry flag needed to adjust the type lists, when given a map over type constructor fun*) traytel@41353: fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) retry = traytel@41353: if C1 = C2 andalso not (null fis) andalso forall is_funtype fis traytel@41353: then ((map dest_funT fis, Ts ~~ Us), C1) traytel@41353: else error ("Not a proper map function:" ^ err_str t) traytel@41353: | check_map_fun fis T1 T2 true = traytel@41353: let val (fis', T') = split_last fis traytel@41353: in check_map_fun fis' T' (T1 --> T2) false end traytel@41353: | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t); wenzelm@40281: traytel@41353: val res = check_map_fun fis T1 T2 true; wenzelm@40281: val res_av = gen_arg_var (fst res); wenzelm@40281: in wenzelm@40281: map_tmaps (Symtab.update (snd res, (t, res_av))) context wenzelm@40281: end; wenzelm@40281: traytel@45060: fun transitive_coercion ctxt tab G (a, b) = traytel@45059: let traytel@45060: fun safe_app t (Abs (x, T', u)) = traytel@45060: let traytel@45060: val t' = map_types Type_Infer.paramify_vars t; traytel@45060: in traytel@45060: singleton (coercion_infer_types ctxt) (Abs(x, T', (t' $ u))) traytel@45060: end; traytel@45059: val path = hd (Graph.irreducible_paths G (a, b)); traytel@45059: val path' = fst (split_last path) ~~ tl path; traytel@45059: val coercions = map (fst o the o Symreltab.lookup tab) path'; traytel@45060: val trans_co = singleton (Variable.polymorphic ctxt) traytel@51335: (fold safe_app coercions (mk_identity dummyT)); traytel@45060: val (Ts, Us) = pairself (snd o Term.dest_Type) (Term.dest_funT (type_of trans_co)) traytel@45060: in traytel@45060: (trans_co, ((Ts, Us), coercions)) traytel@45059: end; traytel@45059: wenzelm@40284: fun add_coercion raw_t context = wenzelm@40281: let wenzelm@40281: val ctxt = Context.proof_of context; wenzelm@40284: val t = singleton (Variable.polymorphic ctxt) raw_t; wenzelm@40281: wenzelm@55303: fun err_coercion () = wenzelm@55303: error ("Bad type for a coercion:\n" ^ traytel@45059: Syntax.string_of_term ctxt t ^ " :: " ^ wenzelm@40281: Syntax.string_of_typ ctxt (fastype_of t)); wenzelm@40281: wenzelm@40840: val (T1, T2) = Term.dest_funT (fastype_of t) wenzelm@40840: handle TYPE _ => err_coercion (); wenzelm@40281: traytel@45060: val (a, Ts) = Term.dest_Type T1 traytel@45060: handle TYPE _ => err_coercion (); wenzelm@40281: traytel@45060: val (b, Us) = Term.dest_Type T2 traytel@45060: handle TYPE _ => err_coercion (); wenzelm@40281: traytel@45060: fun coercion_data_update (tab, G, _) = wenzelm@40281: let traytel@45060: val G' = maybe_new_nodes [(a, length Ts), (b, length Us)] G wenzelm@40281: val G'' = Graph.add_edge_trans_acyclic (a, b) G' traytel@45059: handle Graph.CYCLES _ => error ( traytel@45060: Syntax.string_of_typ ctxt T2 ^ " is already a subtype of " ^ traytel@45060: Syntax.string_of_typ ctxt T1 ^ "!\n\nCannot add coercion of type: " ^ traytel@45059: Syntax.string_of_typ ctxt (T1 --> T2)); wenzelm@40281: val new_edges = wenzelm@49560: flat (Graph.dest G'' |> map (fn ((x, _), ys) => ys |> map_filter (fn y => wenzelm@40281: if Graph.is_edge G' (x, y) then NONE else SOME (x, y)))); wenzelm@40281: val G_and_new = Graph.add_edge (a, b) G'; wenzelm@40281: wenzelm@40281: val tab' = fold traytel@45059: (fn pair => fn tab => traytel@45060: Symreltab.update (pair, transitive_coercion ctxt tab G_and_new pair) tab) wenzelm@40281: (filter (fn pair => pair <> (a, b)) new_edges) traytel@45060: (Symreltab.update ((a, b), (t, ((Ts, Us), []))) tab); wenzelm@40281: in traytel@45060: (tab', G'', restrict_graph G'') wenzelm@40281: end; wenzelm@40281: in traytel@45060: map_coes_and_graphs coercion_data_update context wenzelm@40281: end; wenzelm@40281: traytel@45059: fun delete_coercion raw_t context = traytel@45059: let traytel@45059: val ctxt = Context.proof_of context; traytel@45059: val t = singleton (Variable.polymorphic ctxt) raw_t; traytel@45059: wenzelm@55303: fun err_coercion the = wenzelm@55303: error ("Not" ^ traytel@45059: (if the then " the defined " else " a ") ^ "coercion:\n" ^ traytel@45059: Syntax.string_of_term ctxt t ^ " :: " ^ traytel@45059: Syntax.string_of_typ ctxt (fastype_of t)); traytel@45059: traytel@45059: val (T1, T2) = Term.dest_funT (fastype_of t) traytel@45059: handle TYPE _ => err_coercion false; traytel@45059: traytel@54584: val (a, _) = dest_Type T1 traytel@45060: handle TYPE _ => err_coercion false; traytel@45059: traytel@54584: val (b, _) = dest_Type T2 traytel@45060: handle TYPE _ => err_coercion false; traytel@45059: traytel@45059: fun delete_and_insert tab G = traytel@45059: let traytel@45059: val pairs = traytel@45060: Symreltab.fold (fn ((a, b), (_, (_, ts))) => fn pairs => traytel@45059: if member (op aconv) ts t then (a, b) :: pairs else pairs) tab [(a, b)]; traytel@45059: fun delete pair (G, tab) = (Graph.del_edge pair G, Symreltab.delete_safe pair tab); traytel@45059: val (G', tab') = fold delete pairs (G, tab); wenzelm@49564: fun reinsert pair (G, xs) = wenzelm@49564: (case Graph.irreducible_paths G pair of wenzelm@49564: [] => (G, xs) wenzelm@49564: | _ => (Graph.add_edge pair G, (pair, transitive_coercion ctxt tab' G' pair) :: xs)); traytel@45059: val (G'', ins) = fold reinsert pairs (G', []); traytel@45059: in traytel@45060: (fold Symreltab.update ins tab', G'', restrict_graph G'') traytel@45059: end traytel@45059: wenzelm@55303: fun show_term t = wenzelm@55303: Pretty.block [Syntax.pretty_term ctxt t, wenzelm@55303: Pretty.str " :: ", Syntax.pretty_typ ctxt (fastype_of t)]; traytel@45059: traytel@45060: fun coercion_data_update (tab, G, _) = wenzelm@55303: (case Symreltab.lookup tab (a, b) of wenzelm@55303: NONE => err_coercion false wenzelm@55303: | SOME (t', (_, [])) => wenzelm@55303: if t aconv t' wenzelm@55303: then delete_and_insert tab G wenzelm@55303: else err_coercion true wenzelm@55303: | SOME (t', (_, ts)) => wenzelm@55303: if t aconv t' then wenzelm@55303: error ("Cannot delete the automatically derived coercion:\n" ^ traytel@45059: Syntax.string_of_term ctxt t ^ " :: " ^ wenzelm@55303: Syntax.string_of_typ ctxt (fastype_of t) ^ "\n\n" ^ wenzelm@55303: Pretty.string_of wenzelm@55303: (Pretty.big_list "Deleting one of the coercions:" (map show_term ts)) ^ traytel@45059: "\nwill also remove the transitive coercion.") wenzelm@55303: else err_coercion true); traytel@45059: in traytel@45060: map_coes_and_graphs coercion_data_update context traytel@45059: end; traytel@45059: traytel@45059: fun print_coercions ctxt = traytel@45059: let traytel@45060: fun separate _ [] = ([], []) wenzelm@52432: | separate P (x :: xs) = (if P x then apfst else apsnd) (cons x) (separate P xs); traytel@45060: val (simple, complex) = traytel@45060: separate (fn (_, (_, ((Ts, Us), _))) => null Ts andalso null Us) traytel@45060: (Symreltab.dest (coes_of ctxt)); wenzelm@52432: fun show_coercion ((a, b), (t, ((Ts, Us), _))) = wenzelm@52432: Pretty.item [Pretty.block wenzelm@52432: [Syntax.pretty_typ ctxt (Type (a, Ts)), Pretty.brk 1, wenzelm@52432: Pretty.str "<:", Pretty.brk 1, wenzelm@52432: Syntax.pretty_typ ctxt (Type (b, Us)), Pretty.brk 3, wenzelm@52432: Pretty.block wenzelm@52432: [Pretty.keyword "using", Pretty.brk 1, wenzelm@52432: Pretty.quote (Syntax.pretty_term ctxt t)]]]; wenzelm@52432: wenzelm@52432: val type_space = Proof_Context.type_space ctxt; wenzelm@52432: val tmaps = wenzelm@52432: sort (Name_Space.extern_ord ctxt type_space o pairself #1) wenzelm@52432: (Symtab.dest (tmaps_of ctxt)); wenzelm@53539: fun show_map (c, (t, _)) = wenzelm@52432: Pretty.block wenzelm@53539: [Name_Space.pretty ctxt type_space c, Pretty.str ":", wenzelm@52432: Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt t)]; traytel@45059: in wenzelm@52432: [Pretty.big_list "coercions between base types:" (map show_coercion simple), wenzelm@52432: Pretty.big_list "other coercions:" (map show_coercion complex), wenzelm@52432: Pretty.big_list "coercion maps:" (map show_map tmaps)] wenzelm@52432: end |> Pretty.chunks |> Pretty.writeln; traytel@45059: traytel@45059: traytel@51319: (* theory setup *) wenzelm@40283: traytel@51319: val parse_coerce_args = traytel@51327: Args.$$$ "+" >> K PERMIT || Args.$$$ "-" >> K FORBID || Args.$$$ "0" >> K LEAVE wenzelm@40283: wenzelm@40283: val setup = wenzelm@40283: Context.theory_map add_term_check #> wenzelm@40284: Attrib.setup @{binding coercion} wenzelm@40284: (Args.term >> (fn t => Thm.declaration_attribute (K (add_coercion t)))) wenzelm@40281: "declaration of new coercions" #> traytel@45059: Attrib.setup @{binding coercion_delete} traytel@45059: (Args.term >> (fn t => Thm.declaration_attribute (K (delete_coercion t)))) traytel@45059: "deletion of coercions" #> traytel@40297: Attrib.setup @{binding coercion_map} wenzelm@40284: (Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t)))) traytel@51319: "declaration of new map functions" #> traytel@51319: Attrib.setup @{binding coercion_args} traytel@51319: (Args.const false -- Scan.lift (Scan.repeat1 parse_coerce_args) >> traytel@51319: (fn spec => Thm.declaration_attribute (K (map_coerce_args (Symtab.update spec))))) traytel@51319: "declaration of new constants with coercion-invariant arguments"; wenzelm@40281: traytel@45059: traytel@45059: (* outer syntax commands *) traytel@45059: traytel@45059: val _ = wenzelm@52432: Outer_Syntax.improper_command @{command_spec "print_coercions"} wenzelm@52432: "print information about coercions" wenzelm@52432: (Scan.succeed (Toplevel.keep (print_coercions o Toplevel.context_of))); traytel@45059: wenzelm@40281: end;