wenzelm@40281: (* Title: Tools/subtyping.ML wenzelm@40281: Author: Dmitriy Traytel, TU Muenchen wenzelm@40281: wenzelm@40281: Coercive subtyping via subtype constraints. wenzelm@40281: *) wenzelm@40281: wenzelm@40281: signature SUBTYPING = wenzelm@40281: sig wenzelm@40281: datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT wenzelm@40281: val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) -> wenzelm@40281: term list -> term list wenzelm@40281: end; wenzelm@40281: wenzelm@40281: structure Subtyping = wenzelm@40281: struct wenzelm@40281: wenzelm@40281: (** coercions data **) wenzelm@40281: wenzelm@40281: datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT wenzelm@40281: wenzelm@40281: datatype data = Data of wenzelm@40282: {coes: term Symreltab.table, (*coercions table*) wenzelm@40282: coes_graph: unit Graph.T, (*coercions graph*) wenzelm@40282: tmaps: (term * variance list) Symtab.table}; (*map functions*) wenzelm@40281: wenzelm@40281: fun make_data (coes, coes_graph, tmaps) = wenzelm@40281: Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps}; wenzelm@40281: wenzelm@40281: structure Data = Generic_Data wenzelm@40281: ( wenzelm@40281: type T = data; wenzelm@40281: val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty); wenzelm@40281: val extend = I; wenzelm@40281: fun merge wenzelm@40281: (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1}, wenzelm@40281: Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) = wenzelm@40281: make_data (Symreltab.merge (op aconv) (coes1, coes2), wenzelm@40281: Graph.merge (op =) (coes_graph1, coes_graph2), wenzelm@40281: Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2)); wenzelm@40281: ); wenzelm@40281: wenzelm@40281: fun map_data f = wenzelm@40281: Data.map (fn Data {coes, coes_graph, tmaps} => wenzelm@40281: make_data (f (coes, coes_graph, tmaps))); wenzelm@40281: wenzelm@40281: fun map_coes f = wenzelm@40281: map_data (fn (coes, coes_graph, tmaps) => wenzelm@40281: (f coes, coes_graph, tmaps)); wenzelm@40281: wenzelm@40281: fun map_coes_graph f = wenzelm@40281: map_data (fn (coes, coes_graph, tmaps) => wenzelm@40281: (coes, f coes_graph, tmaps)); wenzelm@40281: wenzelm@40281: fun map_coes_and_graph f = wenzelm@40281: map_data (fn (coes, coes_graph, tmaps) => wenzelm@40281: let val (coes', coes_graph') = f (coes, coes_graph); wenzelm@40281: in (coes', coes_graph', tmaps) end); wenzelm@40281: wenzelm@40281: fun map_tmaps f = wenzelm@40281: map_data (fn (coes, coes_graph, tmaps) => wenzelm@40281: (coes, coes_graph, f tmaps)); wenzelm@40281: wenzelm@40281: fun rep_data context = Data.get context |> (fn Data args => args); wenzelm@40281: wenzelm@40281: val coes_of = #coes o rep_data; wenzelm@40281: val coes_graph_of = #coes_graph o rep_data; wenzelm@40281: val tmaps_of = #tmaps o rep_data; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** utils **) wenzelm@40281: wenzelm@40281: val is_param = Type_Infer.is_param wenzelm@40281: val is_paramT = Type_Infer.is_paramT wenzelm@40281: val deref = Type_Infer.deref wenzelm@40281: fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *) wenzelm@40281: wenzelm@40281: fun nameT (Type (s, [])) = s; wenzelm@40281: fun t_of s = Type (s, []); wenzelm@40281: fun sort_of (TFree (_, S)) = SOME S wenzelm@40281: | sort_of (TVar (_, S)) = SOME S wenzelm@40281: | sort_of _ = NONE; wenzelm@40281: wenzelm@40281: val is_typeT = fn (Type _) => true | _ => false; wenzelm@40282: val is_compT = fn (Type (_, _ :: _)) => true | _ => false; wenzelm@40281: val is_freeT = fn (TFree _) => true | _ => false; wenzelm@40281: val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false; wenzelm@40281: wenzelm@40281: wenzelm@40282: (* unification *) (* TODO dup? needed for weak unification *) wenzelm@40281: wenzelm@40281: exception NO_UNIFIER of string * typ Vartab.table; wenzelm@40281: wenzelm@40281: fun unify weak ctxt = wenzelm@40281: let wenzelm@40281: val thy = ProofContext.theory_of ctxt; wenzelm@40281: val pp = Syntax.pp ctxt; wenzelm@40281: val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy); wenzelm@40281: wenzelm@40282: wenzelm@40281: (* adjust sorts of parameters *) wenzelm@40281: wenzelm@40281: fun not_of_sort x S' S = wenzelm@40281: "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^ wenzelm@40281: Syntax.string_of_sort ctxt S; wenzelm@40281: wenzelm@40281: fun meet (_, []) tye_idx = tye_idx wenzelm@40281: | meet (Type (a, Ts), S) (tye_idx as (tye, _)) = wenzelm@40281: meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx wenzelm@40281: | meet (TFree (x, S'), S) (tye_idx as (tye, _)) = wenzelm@40281: if Sign.subsort thy (S', S) then tye_idx wenzelm@40281: else raise NO_UNIFIER (not_of_sort x S' S, tye) wenzelm@40281: | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = wenzelm@40281: if Sign.subsort thy (S', S) then tye_idx wenzelm@40281: else if Type_Infer.is_param xi then wenzelm@40281: (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) wenzelm@40281: else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) wenzelm@40281: and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = wenzelm@40281: meets (Ts, Ss) (meet (deref tye T, S) tye_idx) wenzelm@40281: | meets _ tye_idx = tye_idx; wenzelm@40281: wenzelm@40281: val weak_meet = if weak then fn _ => I else meet wenzelm@40281: wenzelm@40281: wenzelm@40281: (* occurs check and assignment *) wenzelm@40281: wenzelm@40281: fun occurs_check tye xi (TVar (xi', _)) = wenzelm@40281: if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye) wenzelm@40281: else wenzelm@40281: (case Vartab.lookup tye xi' of wenzelm@40281: NONE => () wenzelm@40281: | SOME T => occurs_check tye xi T) wenzelm@40281: | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts wenzelm@40281: | occurs_check _ _ _ = (); wenzelm@40281: wenzelm@40281: fun assign xi (T as TVar (xi', _)) S env = wenzelm@40281: if xi = xi' then env wenzelm@40281: else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T) wenzelm@40281: | assign xi T S (env as (tye, _)) = wenzelm@40281: (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)); wenzelm@40281: wenzelm@40281: wenzelm@40281: (* unification *) wenzelm@40281: wenzelm@40281: fun show_tycon (a, Ts) = wenzelm@40281: quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); wenzelm@40281: wenzelm@40281: fun unif (T1, T2) (env as (tye, _)) = wenzelm@40281: (case pairself (`is_paramT o deref tye) (T1, T2) of wenzelm@40281: ((true, TVar (xi, S)), (_, T)) => assign xi T S env wenzelm@40281: | ((_, T), (true, TVar (xi, S))) => assign xi T S env wenzelm@40281: | ((_, Type (a, Ts)), (_, Type (b, Us))) => wenzelm@40281: if weak andalso null Ts andalso null Us then env wenzelm@40281: else if a <> b then wenzelm@40281: raise NO_UNIFIER wenzelm@40281: ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye) wenzelm@40281: else fold unif (Ts ~~ Us) env wenzelm@40281: | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye)); wenzelm@40281: wenzelm@40281: in unif end; wenzelm@40281: wenzelm@40281: val weak_unify = unify true; wenzelm@40281: val strong_unify = unify false; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* Typ_Graph shortcuts *) wenzelm@40281: wenzelm@40281: val add_edge = Typ_Graph.add_edge_acyclic; wenzelm@40281: fun get_preds G T = Typ_Graph.all_preds G [T]; wenzelm@40281: fun get_succs G T = Typ_Graph.all_succs G [T]; wenzelm@40281: fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G; wenzelm@40281: fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G; wenzelm@40282: fun new_imm_preds G Ts = wenzelm@40281: subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts)); wenzelm@40282: fun new_imm_succs G Ts = wenzelm@40281: subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts)); wenzelm@40281: wenzelm@40281: wenzelm@40281: (* Graph shortcuts *) wenzelm@40281: wenzelm@40281: fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G wenzelm@40281: fun maybe_new_nodes ss G = fold maybe_new_node ss G wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** error messages **) wenzelm@40281: wenzelm@40281: fun prep_output ctxt tye bs ts Ts = wenzelm@40281: let wenzelm@40281: val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts); wenzelm@40281: val (Ts', Ts'') = chop (length Ts) Ts_bTs'; wenzelm@40281: fun prep t = wenzelm@40281: let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) wenzelm@40281: in Term.subst_bounds (map Syntax.mark_boundT xs, t) end; wenzelm@40281: in (map prep ts', Ts') end; wenzelm@40281: wenzelm@40281: fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i); wenzelm@40281: wenzelm@40281: fun inf_failed msg = wenzelm@40281: "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n"; wenzelm@40281: wenzelm@40281: fun err_appl ctxt msg tye bs t T u U = wenzelm@40281: let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U] wenzelm@40281: in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end; wenzelm@40281: wenzelm@40281: fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') = wenzelm@40281: err_appl ctxt msg tye bs t (U --> V) u U'; wenzelm@40281: wenzelm@40281: fun err_list ctxt msg tye Ts = wenzelm@40281: let wenzelm@40281: val (_, Ts') = prep_output ctxt tye [] [] Ts; wenzelm@40281: val text = cat_lines ([inf_failed msg, wenzelm@40281: "Cannot unify a list of types that should be the same,", wenzelm@40281: "according to suptype dependencies:", wenzelm@40281: (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]); wenzelm@40281: in wenzelm@40281: error text wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun err_bound ctxt msg tye packs = wenzelm@40281: let wenzelm@40281: val pp = Syntax.pp ctxt; wenzelm@40281: val (ts, Ts) = fold wenzelm@40281: (fn (bs, t $ u, U, _, U') => fn (ts, Ts) => wenzelm@40281: let val (t', T') = prep_output ctxt tye bs [t, u] [U, U'] wenzelm@40282: in (t' :: ts, T' :: Ts) end) wenzelm@40281: packs ([], []); wenzelm@40281: val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @ wenzelm@40281: (map2 (fn [t, u] => fn [T, U] => Pretty.string_of ( wenzelm@40281: Pretty.block [ wenzelm@40281: Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U, wenzelm@40281: Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2, wenzelm@40281: Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]])) wenzelm@40281: ts Ts)) wenzelm@40281: in wenzelm@40281: error text wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** constraint generation **) wenzelm@40281: wenzelm@40281: fun generate_constraints ctxt = wenzelm@40281: let wenzelm@40281: fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs) wenzelm@40281: | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs) wenzelm@40281: | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs) wenzelm@40281: | gen cs bs (Bound i) tye_idx = wenzelm@40281: (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs) wenzelm@40281: | gen cs bs (Abs (x, T, t)) tye_idx = wenzelm@40281: let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx wenzelm@40281: in (T --> U, tye_idx', cs') end wenzelm@40281: | gen cs bs (t $ u) tye_idx = wenzelm@40281: let wenzelm@40281: val (T, tye_idx', cs') = gen cs bs t tye_idx; wenzelm@40281: val (U', (tye, idx), cs'') = gen cs' bs u tye_idx'; wenzelm@40281: val U = mk_param idx []; wenzelm@40281: val V = mk_param (idx + 1) []; wenzelm@40281: val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2) wenzelm@40281: handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U; wenzelm@40281: val error_pack = (bs, t $ u, U, V, U'); wenzelm@40281: in (V, tye_idx'', ((U', U), error_pack) :: cs'') end; wenzelm@40281: in wenzelm@40281: gen [] [] wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** constraint resolution **) wenzelm@40281: wenzelm@40281: exception BOUND_ERROR of string; wenzelm@40281: wenzelm@40281: fun process_constraints ctxt cs tye_idx = wenzelm@40281: let wenzelm@40281: val coes_graph = coes_graph_of (Context.Proof ctxt); wenzelm@40281: val tmaps = tmaps_of (Context.Proof ctxt); wenzelm@40281: val tsig = Sign.tsig_of (ProofContext.theory_of ctxt); wenzelm@40281: val pp = Syntax.pp ctxt; wenzelm@40281: val arity_sorts = Type.arity_sorts pp tsig; wenzelm@40281: val subsort = Type.subsort tsig; wenzelm@40281: wenzelm@40281: fun split_cs _ [] = ([], []) wenzelm@40282: | split_cs f (c :: cs) = wenzelm@40281: (case pairself f (fst c) of wenzelm@40281: (false, false) => apsnd (cons c) (split_cs f cs) wenzelm@40281: | _ => apfst (cons c) (split_cs f cs)); wenzelm@40281: wenzelm@40282: wenzelm@40281: (* check whether constraint simplification will terminate using weak unification *) wenzelm@40282: wenzelm@40281: val _ = fold (fn (TU, error_pack) => fn tye_idx => wenzelm@40281: (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) => wenzelm@40281: err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg) wenzelm@40281: tye error_pack)) cs tye_idx; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* simplify constraints *) wenzelm@40282: wenzelm@40281: fun simplify_constraints cs tye_idx = wenzelm@40281: let wenzelm@40281: fun contract a Ts Us error_pack done todo tye idx = wenzelm@40281: let wenzelm@40281: val arg_var = wenzelm@40281: (case Symtab.lookup tmaps a of wenzelm@40281: (*everything is invariant for unknown constructors*) wenzelm@40281: NONE => replicate (length Ts) INVARIANT wenzelm@40281: | SOME av => snd av); wenzelm@40281: fun new_constraints (variance, constraint) (cs, tye_idx) = wenzelm@40281: (case variance of wenzelm@40281: COVARIANT => (constraint :: cs, tye_idx) wenzelm@40281: | CONTRAVARIANT => (swap constraint :: cs, tye_idx) wenzelm@40281: | INVARIANT => (cs, strong_unify ctxt constraint tye_idx wenzelm@40281: handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack)); wenzelm@40281: val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack)) wenzelm@40281: (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx))); wenzelm@40281: val test_update = is_compT orf is_freeT orf is_fixedvarT; wenzelm@40281: val (ch, done') = wenzelm@40281: if not (null new) then ([], done) wenzelm@40281: else split_cs (test_update o deref tye') done; wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: in wenzelm@40281: simplify done' (new @ todo') (tye', idx') wenzelm@40281: end wenzelm@40281: (*xi is definitely a parameter*) wenzelm@40281: and expand varleq xi S a Ts error_pack done todo tye idx = wenzelm@40281: let wenzelm@40281: val n = length Ts; wenzelm@40281: val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S); wenzelm@40281: val tye' = Vartab.update_new (xi, Type(a, args)) tye; wenzelm@40281: val (ch, done') = split_cs (is_compT o deref tye') done; wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: val new = wenzelm@40281: if varleq then (Type(a, args), Type (a, Ts)) wenzelm@40281: else (Type (a, Ts), Type(a, args)); wenzelm@40281: in wenzelm@40281: simplify done' ((new, error_pack) :: todo') (tye', idx + n) wenzelm@40281: end wenzelm@40281: (*TU is a pair of a parameter and a free/fixed variable*) wenzelm@40281: and eliminate TU error_pack done todo tye idx = wenzelm@40281: let wenzelm@40281: val [TVar (xi, S)] = filter is_paramT TU; wenzelm@40281: val [T] = filter_out is_paramT TU; wenzelm@40281: val SOME S' = sort_of T; wenzelm@40281: val test_update = if is_freeT T then is_freeT else is_fixedvarT; wenzelm@40281: val tye' = Vartab.update_new (xi, T) tye; wenzelm@40281: val (ch, done') = split_cs (test_update o deref tye') done; wenzelm@40281: val todo' = ch @ todo; wenzelm@40281: in wenzelm@40281: if subsort (S', S) (*TODO check this*) wenzelm@40281: then simplify done' todo' (tye', idx) wenzelm@40281: else err_subtype ctxt "Sort mismatch" tye error_pack wenzelm@40281: end wenzelm@40281: and simplify done [] tye_idx = (done, tye_idx) wenzelm@40281: | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) = wenzelm@40281: (case (deref tye T, deref tye U) of wenzelm@40281: (Type (a, []), Type (b, [])) => wenzelm@40281: if a = b then simplify done todo tye_idx wenzelm@40281: else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx wenzelm@40281: else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack wenzelm@40281: | (Type (a, Ts), Type (b, Us)) => wenzelm@40281: if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack wenzelm@40281: else contract a Ts Us error_pack done todo tye idx wenzelm@40282: | (TVar (xi, S), Type (a, Ts as (_ :: _))) => wenzelm@40281: expand true xi S a Ts error_pack done todo tye idx wenzelm@40282: | (Type (a, Ts as (_ :: _)), TVar (xi, S)) => wenzelm@40281: expand false xi S a Ts error_pack done todo tye idx wenzelm@40281: | (T, U) => wenzelm@40281: if T = U then simplify done todo tye_idx wenzelm@40282: else if exists (is_freeT orf is_fixedvarT) [T, U] andalso wenzelm@40281: exists is_paramT [T, U] wenzelm@40281: then eliminate [T, U] error_pack done todo tye idx wenzelm@40281: else if exists (is_freeT orf is_fixedvarT) [T, U] wenzelm@40281: then err_subtype ctxt "Not eliminated free/fixed variables" wenzelm@40281: (fst tye_idx) error_pack wenzelm@40282: else simplify (((T, U), error_pack) :: done) todo tye_idx); wenzelm@40281: in wenzelm@40281: simplify [] cs tye_idx wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* do simplification *) wenzelm@40282: wenzelm@40281: val (cs', tye_idx') = simplify_constraints cs tye_idx; wenzelm@40281: wenzelm@40281: fun find_error_pack lower T' = wenzelm@40281: map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs'); wenzelm@40281: wenzelm@40282: fun unify_list (T :: Ts) tye_idx = wenzelm@40281: fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx wenzelm@40282: handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T :: Ts)) wenzelm@40281: Ts tye_idx; wenzelm@40281: wenzelm@40281: (*styps stands either for supertypes or for subtypes of a type T wenzelm@40281: in terms of the subtype-relation (excluding T itself)*) wenzelm@40282: fun styps super T = wenzelm@40281: (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T wenzelm@40281: handle Graph.UNDEF _ => []; wenzelm@40281: wenzelm@40282: fun minmax sup (T :: Ts) = wenzelm@40281: let wenzelm@40281: fun adjust T U = if sup then (T, U) else (U, T); wenzelm@40281: fun extract T [] = T wenzelm@40282: | extract T (U :: Us) = wenzelm@40281: if Graph.is_edge coes_graph (adjust T U) then extract T Us wenzelm@40281: else if Graph.is_edge coes_graph (adjust U T) then extract U Us wenzelm@40281: else raise BOUND_ERROR "Uncomparable types in type list"; wenzelm@40281: in wenzelm@40281: t_of (extract T Ts) wenzelm@40281: end; wenzelm@40281: wenzelm@40282: fun ex_styp_of_sort super T styps_and_sorts = wenzelm@40281: let wenzelm@40281: fun adjust T U = if super then (T, U) else (U, T); wenzelm@40282: fun styp_test U Ts = forall wenzelm@40281: (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts; wenzelm@40281: fun fitting Ts S U = Type.of_sort tsig (t_of U, S) andalso styp_test U Ts wenzelm@40281: in wenzelm@40281: forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts wenzelm@40281: end; wenzelm@40281: wenzelm@40281: (* computes the tightest possible, correct assignment for 'a::S wenzelm@40281: e.g. in the supremum case (sup = true): wenzelm@40281: ------- 'a::S--- wenzelm@40281: / / \ \ wenzelm@40281: / / \ \ wenzelm@40281: 'b::C1 'c::C2 ... T1 T2 ... wenzelm@40281: wenzelm@40281: sorts - list of sorts [C1, C2, ...] wenzelm@40281: T::Ts - non-empty list of base types [T1, T2, ...] wenzelm@40281: *) wenzelm@40282: fun tightest sup S styps_and_sorts (T :: Ts) = wenzelm@40281: let wenzelm@40281: fun restriction T = Type.of_sort tsig (t_of T, S) wenzelm@40281: andalso ex_styp_of_sort (not sup) T styps_and_sorts; wenzelm@40281: fun candidates T = inter (op =) (filter restriction (T :: styps sup T)); wenzelm@40281: in wenzelm@40281: (case fold candidates Ts (filter restriction (T :: styps sup T)) of wenzelm@40281: [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum")) wenzelm@40281: | [T] => t_of T wenzelm@40281: | Ts => minmax sup Ts) wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun build_graph G [] tye_idx = (G, tye_idx) wenzelm@40282: | build_graph G ((T, U) :: cs) tye_idx = wenzelm@40281: if T = U then build_graph G cs tye_idx wenzelm@40281: else wenzelm@40281: let wenzelm@40281: val G' = maybe_new_typnodes [T, U] G; wenzelm@40281: val (G'', tye_idx') = (add_edge (T, U) G', tye_idx) wenzelm@40281: handle Typ_Graph.CYCLES cycles => wenzelm@40281: let wenzelm@40281: val (tye, idx) = fold unify_list cycles tye_idx wenzelm@40281: in wenzelm@40281: (*all cycles collapse to one node, wenzelm@40281: because all of them share at least the nodes x and y*) wenzelm@40281: collapse (tye, idx) (distinct (op =) (flat cycles)) G wenzelm@40281: end; wenzelm@40281: in wenzelm@40281: build_graph G'' cs tye_idx' wenzelm@40281: end wenzelm@40281: and collapse (tye, idx) nodes G = (*nodes non-empty list*) wenzelm@40281: let wenzelm@40281: val T = hd nodes; wenzelm@40281: val P = new_imm_preds G nodes; wenzelm@40281: val S = new_imm_succs G nodes; wenzelm@40281: val G' = Typ_Graph.del_nodes (tl nodes) G; wenzelm@40281: in wenzelm@40281: build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx) wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun assign_bound lower G key (tye_idx as (tye, _)) = wenzelm@40281: if is_paramT (deref tye key) then wenzelm@40281: let wenzelm@40281: val TVar (xi, S) = deref tye key; wenzelm@40281: val get_bound = if lower then get_preds else get_succs; wenzelm@40281: val raw_bound = get_bound G key; wenzelm@40281: val bound = map (deref tye) raw_bound; wenzelm@40281: val not_params = filter_out is_paramT bound; wenzelm@40282: fun to_fulfil T = wenzelm@40281: (case sort_of T of wenzelm@40281: NONE => NONE wenzelm@40282: | SOME S => wenzelm@40281: SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S)); wenzelm@40281: val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound); wenzelm@40281: val assignment = wenzelm@40281: if null bound orelse null not_params then NONE wenzelm@40281: else SOME (tightest lower S styps_and_sorts (map nameT not_params) wenzelm@40281: handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key)) wenzelm@40281: in wenzelm@40281: (case assignment of wenzelm@40281: NONE => tye_idx wenzelm@40281: | SOME T => wenzelm@40281: if is_paramT T then tye_idx wenzelm@40281: else if lower then (*upper bound check*) wenzelm@40281: let wenzelm@40281: val other_bound = map (deref tye) (get_succs G key); wenzelm@40281: val s = nameT T; wenzelm@40281: in wenzelm@40281: if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s) wenzelm@40281: then apfst (Vartab.update (xi, T)) tye_idx wenzelm@40281: else err_bound ctxt ("Assigned simple type " ^ s ^ wenzelm@40281: " clashes with the upper bound of variable " ^ wenzelm@40281: Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key) wenzelm@40281: end wenzelm@40281: else apfst (Vartab.update (xi, T)) tye_idx) wenzelm@40281: end wenzelm@40281: else tye_idx; wenzelm@40281: wenzelm@40281: val assign_lb = assign_bound true; wenzelm@40281: val assign_ub = assign_bound false; wenzelm@40281: wenzelm@40281: fun assign_alternating ts' ts G tye_idx = wenzelm@40281: if ts' = ts then tye_idx wenzelm@40281: else wenzelm@40281: let wenzelm@40281: val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx wenzelm@40281: |> fold (assign_ub G) ts; wenzelm@40281: in wenzelm@40281: assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx' wenzelm@40281: end; wenzelm@40281: wenzelm@40281: (*Unify all weakly connected components of the constraint forest, wenzelm@40282: that contain only params. These are the only WCCs that contain wenzelm@40281: params anyway.*) wenzelm@40281: fun unify_params G (tye_idx as (tye, _)) = wenzelm@40281: let wenzelm@40281: val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G); wenzelm@40281: val to_unify = map (fn T => T :: get_preds G T) max_params; wenzelm@40281: in wenzelm@40281: fold unify_list to_unify tye_idx wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun solve_constraints G tye_idx = tye_idx wenzelm@40281: |> assign_alternating [] (Typ_Graph.keys G) G wenzelm@40281: |> unify_params G; wenzelm@40281: in wenzelm@40281: build_graph Typ_Graph.empty (map fst cs') tye_idx' wenzelm@40281: |-> solve_constraints wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** coercion insertion **) wenzelm@40281: wenzelm@40281: fun insert_coercions ctxt tye ts = wenzelm@40281: let wenzelm@40281: fun deep_deref T = wenzelm@40281: (case deref tye T of wenzelm@40281: Type (a, Ts) => Type (a, map deep_deref Ts) wenzelm@40281: | U => U); wenzelm@40281: wenzelm@40281: fun gen_coercion ((Type (a, [])), (Type (b, []))) = wenzelm@40281: if a = b wenzelm@40281: then Abs (Name.uu, Type (a, []), Bound 0) wenzelm@40281: else wenzelm@40281: (case Symreltab.lookup (coes_of (Context.Proof ctxt)) (a, b) of wenzelm@40281: NONE => raise Fail (a ^ " is not a subtype of " ^ b) wenzelm@40281: | SOME co => co) wenzelm@40281: | gen_coercion ((Type (a, Ts)), (Type (b, Us))) = wenzelm@40281: if a <> b wenzelm@40281: then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b) wenzelm@40281: else wenzelm@40281: let wenzelm@40282: fun inst t Ts = wenzelm@40282: Term.subst_vars wenzelm@40281: (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t; wenzelm@40281: fun sub_co (COVARIANT, TU) = gen_coercion TU wenzelm@40281: | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU); wenzelm@40281: fun ts_of [] = [] wenzelm@40282: | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs); wenzelm@40281: in wenzelm@40281: (case Symtab.lookup (tmaps_of (Context.Proof ctxt)) a of wenzelm@40281: NONE => raise Fail ("No map function for " ^ a ^ " known") wenzelm@40281: | SOME tmap => wenzelm@40281: let wenzelm@40281: val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us)); wenzelm@40281: in wenzelm@40281: Term.list_comb wenzelm@40281: (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes) wenzelm@40281: end) wenzelm@40281: end wenzelm@40281: | gen_coercion (T, U) = wenzelm@40281: if Type.could_unify (T, U) wenzelm@40281: then Abs (Name.uu, T, Bound 0) wenzelm@40281: else raise Fail ("Cannot generate coercion from " wenzelm@40281: ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U); wenzelm@40281: wenzelm@40281: fun insert _ (Const (c, T)) = wenzelm@40281: let val T' = deep_deref T; wenzelm@40281: in (Const (c, T'), T') end wenzelm@40281: | insert _ (Free (x, T)) = wenzelm@40281: let val T' = deep_deref T; wenzelm@40281: in (Free (x, T'), T') end wenzelm@40281: | insert _ (Var (xi, T)) = wenzelm@40281: let val T' = deep_deref T; wenzelm@40281: in (Var (xi, T'), T') end wenzelm@40281: | insert bs (Bound i) = wenzelm@40281: let val T = nth bs i handle Subscript => wenzelm@40281: raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []); wenzelm@40281: in (Bound i, T) end wenzelm@40281: | insert bs (Abs (x, T, t)) = wenzelm@40281: let wenzelm@40281: val T' = deep_deref T; wenzelm@40282: val (t', T'') = insert (T' :: bs) t; wenzelm@40281: in wenzelm@40281: (Abs (x, T', t'), T' --> T'') wenzelm@40281: end wenzelm@40281: | insert bs (t $ u) = wenzelm@40281: let wenzelm@40281: val (t', Type ("fun", [U, T])) = insert bs t; wenzelm@40281: val (u', U') = insert bs u; wenzelm@40281: in wenzelm@40281: if U <> U' wenzelm@40281: then (t' $ (gen_coercion (U', U) $ u'), T) wenzelm@40281: else (t' $ u', T) wenzelm@40281: end wenzelm@40281: in wenzelm@40281: map (fst o insert []) ts wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** assembling the pipeline **) wenzelm@40281: wenzelm@40281: fun infer_types ctxt const_type var_type raw_ts = wenzelm@40281: let wenzelm@40281: val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts; wenzelm@40281: wenzelm@40281: fun gen_all t (tye_idx, constraints) = wenzelm@40281: let wenzelm@40281: val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx wenzelm@40281: in (tye_idx', constraints' @ constraints) end; wenzelm@40281: wenzelm@40281: val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []); wenzelm@40281: val (tye, _) = process_constraints ctxt constraints tye_idx; wenzelm@40281: val ts' = insert_coercions ctxt tye ts; wenzelm@40281: wenzelm@40281: val (_, ts'') = Type_Infer.finish ctxt tye ([], ts'); wenzelm@40281: in ts'' end; wenzelm@40281: wenzelm@40281: wenzelm@40281: wenzelm@40281: (** installation **) wenzelm@40281: wenzelm@40281: fun coercion_infer_types ctxt = wenzelm@40281: infer_types ctxt wenzelm@40281: (try (Consts.the_constraint (ProofContext.consts_of ctxt))) wenzelm@40281: (ProofContext.def_type ctxt); wenzelm@40281: wenzelm@40281: local wenzelm@40281: wenzelm@40281: fun add eq what f = Context.>> (what (fn xs => fn ctxt => wenzelm@40281: let val xs' = f ctxt xs in if eq_list eq (xs, xs') then NONE else SOME (xs', ctxt) end)); wenzelm@40281: wenzelm@40281: in wenzelm@40281: wenzelm@40281: val _ = add (op aconv) (Syntax.add_term_check ~100 "coercions") coercion_infer_types; wenzelm@40281: wenzelm@40281: end; wenzelm@40281: wenzelm@40281: wenzelm@40281: (* interface *) wenzelm@40281: wenzelm@40281: fun add_type_map map_fun context = wenzelm@40281: let wenzelm@40281: val ctxt = Context.proof_of context; wenzelm@40281: val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt map_fun); wenzelm@40281: wenzelm@40281: fun err_str () = "\n\nthe general type signature for a map function is" ^ wenzelm@40281: "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^ wenzelm@40281: "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)"; wenzelm@40281: wenzelm@40281: fun gen_arg_var ([], []) = [] wenzelm@40282: | gen_arg_var ((T, T') :: Ts, (U, U') :: Us) = wenzelm@40281: if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us) wenzelm@40281: else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us) wenzelm@40281: else error ("Functions do not apply to arguments correctly:" ^ err_str ()) wenzelm@40281: | gen_arg_var (_, _) = wenzelm@40281: error ("Different numbers of functions and arguments\n" ^ err_str ()); wenzelm@40281: wenzelm@40281: (* TODO: This function is only needed to introde the fun type map wenzelm@40281: function: "% f g h . g o h o f". There must be a better solution. *) wenzelm@40281: fun balanced (Type (_, [])) (Type (_, [])) = true wenzelm@40281: | balanced (Type (a, Ts)) (Type (b, Us)) = wenzelm@40281: a = b andalso forall I (map2 balanced Ts Us) wenzelm@40281: | balanced (TFree _) (TFree _) = true wenzelm@40281: | balanced (TVar _) (TVar _) = true wenzelm@40281: | balanced _ _ = false; wenzelm@40281: wenzelm@40281: fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) = wenzelm@40281: if balanced T U wenzelm@40282: then ((pairs, Ts ~~ Us), C) wenzelm@40281: else if C = "fun" wenzelm@40281: then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U wenzelm@40281: else error ("Not a proper map function:" ^ err_str ()) wenzelm@40281: | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ()); wenzelm@40281: wenzelm@40281: val res = check_map_fun ([], []) (fastype_of t); wenzelm@40281: val res_av = gen_arg_var (fst res); wenzelm@40281: in wenzelm@40281: map_tmaps (Symtab.update (snd res, (t, res_av))) context wenzelm@40281: end; wenzelm@40281: wenzelm@40281: fun add_coercion coercion context = wenzelm@40281: let wenzelm@40281: val ctxt = Context.proof_of context; wenzelm@40281: val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt coercion); wenzelm@40281: wenzelm@40281: fun err_coercion () = error ("Bad type for coercion " ^ wenzelm@40281: Syntax.string_of_term ctxt t ^ ":\n" ^ wenzelm@40281: Syntax.string_of_typ ctxt (fastype_of t)); wenzelm@40281: wenzelm@40281: val (Type ("fun", [T1, T2])) = fastype_of t wenzelm@40281: handle Bind => err_coercion (); wenzelm@40281: wenzelm@40281: val a = wenzelm@40281: (case T1 of wenzelm@40281: Type (x, []) => x wenzelm@40281: | _ => err_coercion ()); wenzelm@40281: wenzelm@40281: val b = wenzelm@40281: (case T2 of wenzelm@40281: Type (x, []) => x wenzelm@40281: | _ => err_coercion ()); wenzelm@40281: wenzelm@40281: fun coercion_data_update (tab, G) = wenzelm@40281: let wenzelm@40281: val G' = maybe_new_nodes [a, b] G wenzelm@40281: val G'' = Graph.add_edge_trans_acyclic (a, b) G' wenzelm@40281: handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^ wenzelm@40281: "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b); wenzelm@40281: val new_edges = wenzelm@40281: flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y => wenzelm@40281: if Graph.is_edge G' (x, y) then NONE else SOME (x, y)))); wenzelm@40281: val G_and_new = Graph.add_edge (a, b) G'; wenzelm@40281: wenzelm@40281: fun complex_coercion tab G (a, b) = wenzelm@40281: let wenzelm@40281: val path = hd (Graph.irreducible_paths G (a, b)) wenzelm@40281: val path' = (fst (split_last path)) ~~ tl path wenzelm@40281: in Abs (Name.uu, Type (a, []), wenzelm@40281: fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0)) wenzelm@40281: end; wenzelm@40281: wenzelm@40281: val tab' = fold wenzelm@40281: (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab) wenzelm@40281: (filter (fn pair => pair <> (a, b)) new_edges) wenzelm@40281: (Symreltab.update ((a, b), t) tab); wenzelm@40281: in wenzelm@40281: (tab', G'') wenzelm@40281: end; wenzelm@40281: in wenzelm@40281: map_coes_and_graph coercion_data_update context wenzelm@40281: end; wenzelm@40281: wenzelm@40281: val _ = Context.>> (Context.map_theory wenzelm@40281: (Attrib.setup (Binding.name "coercion") (Scan.lift Parse.term >> wenzelm@40281: (fn t => fn (context, thm) => (add_coercion t context, thm))) wenzelm@40281: "declaration of new coercions" #> wenzelm@40281: Attrib.setup (Binding.name "map_function") (Scan.lift Parse.term >> wenzelm@40281: (fn t => fn (context, thm) => (add_type_map t context, thm))) wenzelm@40281: "declaration of new map functions")); wenzelm@40281: wenzelm@40281: end;