71 |
71 |
72 |
72 |
73 |
73 |
74 (** utils **) |
74 (** utils **) |
75 |
75 |
76 val is_param = Type_Infer.is_param |
|
77 val is_paramT = Type_Infer.is_paramT |
|
78 val deref = Type_Infer.deref |
|
79 fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *) |
|
80 |
|
81 fun nameT (Type (s, [])) = s; |
76 fun nameT (Type (s, [])) = s; |
82 fun t_of s = Type (s, []); |
77 fun t_of s = Type (s, []); |
|
78 |
83 fun sort_of (TFree (_, S)) = SOME S |
79 fun sort_of (TFree (_, S)) = SOME S |
84 | sort_of (TVar (_, S)) = SOME S |
80 | sort_of (TVar (_, S)) = SOME S |
85 | sort_of _ = NONE; |
81 | sort_of _ = NONE; |
86 |
82 |
87 val is_typeT = fn (Type _) => true | _ => false; |
83 val is_typeT = fn (Type _) => true | _ => false; |
88 val is_compT = fn (Type (_, _ :: _)) => true | _ => false; |
84 val is_compT = fn (Type (_, _ :: _)) => true | _ => false; |
89 val is_freeT = fn (TFree _) => true | _ => false; |
85 val is_freeT = fn (TFree _) => true | _ => false; |
90 val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false; |
86 val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false; |
91 |
87 |
92 |
88 |
93 (* unification *) (* TODO dup? needed for weak unification *) |
89 (* unification *) (* TODO dup? needed for weak unification *) |
94 |
90 |
95 exception NO_UNIFIER of string * typ Vartab.table; |
91 exception NO_UNIFIER of string * typ Vartab.table; |
114 if Sign.subsort thy (S', S) then tye_idx |
110 if Sign.subsort thy (S', S) then tye_idx |
115 else raise NO_UNIFIER (not_of_sort x S' S, tye) |
111 else raise NO_UNIFIER (not_of_sort x S' S, tye) |
116 | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = |
112 | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = |
117 if Sign.subsort thy (S', S) then tye_idx |
113 if Sign.subsort thy (S', S) then tye_idx |
118 else if Type_Infer.is_param xi then |
114 else if Type_Infer.is_param xi then |
119 (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) |
115 (Vartab.update_new |
|
116 (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) |
120 else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) |
117 else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) |
121 and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = |
118 and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = |
122 meets (Ts, Ss) (meet (deref tye T, S) tye_idx) |
119 meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx) |
123 | meets _ tye_idx = tye_idx; |
120 | meets _ tye_idx = tye_idx; |
124 |
121 |
125 val weak_meet = if weak then fn _ => I else meet |
122 val weak_meet = if weak then fn _ => I else meet |
126 |
123 |
127 |
124 |
147 |
144 |
148 fun show_tycon (a, Ts) = |
145 fun show_tycon (a, Ts) = |
149 quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); |
146 quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); |
150 |
147 |
151 fun unif (T1, T2) (env as (tye, _)) = |
148 fun unif (T1, T2) (env as (tye, _)) = |
152 (case pairself (`is_paramT o deref tye) (T1, T2) of |
149 (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of |
153 ((true, TVar (xi, S)), (_, T)) => assign xi T S env |
150 ((true, TVar (xi, S)), (_, T)) => assign xi T S env |
154 | ((_, T), (true, TVar (xi, S))) => assign xi T S env |
151 | ((_, T), (true, TVar (xi, S))) => assign xi T S env |
155 | ((_, Type (a, Ts)), (_, Type (b, Us))) => |
152 | ((_, Type (a, Ts)), (_, Type (b, Us))) => |
156 if weak andalso null Ts andalso null Us then env |
153 if weak andalso null Ts andalso null Us then env |
157 else if a <> b then |
154 else if a <> b then |
255 in (T --> U, tye_idx', cs') end |
252 in (T --> U, tye_idx', cs') end |
256 | gen cs bs (t $ u) tye_idx = |
253 | gen cs bs (t $ u) tye_idx = |
257 let |
254 let |
258 val (T, tye_idx', cs') = gen cs bs t tye_idx; |
255 val (T, tye_idx', cs') = gen cs bs t tye_idx; |
259 val (U', (tye, idx), cs'') = gen cs' bs u tye_idx'; |
256 val (U', (tye, idx), cs'') = gen cs' bs u tye_idx'; |
260 val U = mk_param idx []; |
257 val U = Type_Infer.mk_param idx []; |
261 val V = mk_param (idx + 1) []; |
258 val V = Type_Infer.mk_param (idx + 1) []; |
262 val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2) |
259 val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2) |
263 handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U; |
260 handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U; |
264 val error_pack = (bs, t $ u, U, V, U'); |
261 val error_pack = (bs, t $ u, U, V, U'); |
265 in (V, tye_idx'', ((U', U), error_pack) :: cs'') end; |
262 in (V, tye_idx'', ((U', U), error_pack) :: cs'') end; |
266 in |
263 in |
316 handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack)); |
313 handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack)); |
317 val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack)) |
314 val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack)) |
318 (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx))); |
315 (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx))); |
319 val test_update = is_compT orf is_freeT orf is_fixedvarT; |
316 val test_update = is_compT orf is_freeT orf is_fixedvarT; |
320 val (ch, done') = |
317 val (ch, done') = |
321 if not (null new) then ([], done) |
318 if not (null new) then ([], done) |
322 else split_cs (test_update o deref tye') done; |
319 else split_cs (test_update o Type_Infer.deref tye') done; |
323 val todo' = ch @ todo; |
320 val todo' = ch @ todo; |
324 in |
321 in |
325 simplify done' (new @ todo') (tye', idx') |
322 simplify done' (new @ todo') (tye', idx') |
326 end |
323 end |
327 (*xi is definitely a parameter*) |
324 (*xi is definitely a parameter*) |
328 and expand varleq xi S a Ts error_pack done todo tye idx = |
325 and expand varleq xi S a Ts error_pack done todo tye idx = |
329 let |
326 let |
330 val n = length Ts; |
327 val n = length Ts; |
331 val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S); |
328 val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S); |
332 val tye' = Vartab.update_new (xi, Type(a, args)) tye; |
329 val tye' = Vartab.update_new (xi, Type(a, args)) tye; |
333 val (ch, done') = split_cs (is_compT o deref tye') done; |
330 val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done; |
334 val todo' = ch @ todo; |
331 val todo' = ch @ todo; |
335 val new = |
332 val new = |
336 if varleq then (Type(a, args), Type (a, Ts)) |
333 if varleq then (Type(a, args), Type (a, Ts)) |
337 else (Type (a, Ts), Type(a, args)); |
334 else (Type (a, Ts), Type (a, args)); |
338 in |
335 in |
339 simplify done' ((new, error_pack) :: todo') (tye', idx + n) |
336 simplify done' ((new, error_pack) :: todo') (tye', idx + n) |
340 end |
337 end |
341 (*TU is a pair of a parameter and a free/fixed variable*) |
338 (*TU is a pair of a parameter and a free/fixed variable*) |
342 and eliminate TU error_pack done todo tye idx = |
339 and eliminate TU error_pack done todo tye idx = |
343 let |
340 let |
344 val [TVar (xi, S)] = filter is_paramT TU; |
341 val [TVar (xi, S)] = filter Type_Infer.is_paramT TU; |
345 val [T] = filter_out is_paramT TU; |
342 val [T] = filter_out Type_Infer.is_paramT TU; |
346 val SOME S' = sort_of T; |
343 val SOME S' = sort_of T; |
347 val test_update = if is_freeT T then is_freeT else is_fixedvarT; |
344 val test_update = if is_freeT T then is_freeT else is_fixedvarT; |
348 val tye' = Vartab.update_new (xi, T) tye; |
345 val tye' = Vartab.update_new (xi, T) tye; |
349 val (ch, done') = split_cs (test_update o deref tye') done; |
346 val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done; |
350 val todo' = ch @ todo; |
347 val todo' = ch @ todo; |
351 in |
348 in |
352 if subsort (S', S) (*TODO check this*) |
349 if subsort (S', S) (*TODO check this*) |
353 then simplify done' todo' (tye', idx) |
350 then simplify done' todo' (tye', idx) |
354 else err_subtype ctxt "Sort mismatch" tye error_pack |
351 else err_subtype ctxt "Sort mismatch" tye error_pack |
355 end |
352 end |
356 and simplify done [] tye_idx = (done, tye_idx) |
353 and simplify done [] tye_idx = (done, tye_idx) |
357 | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) = |
354 | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) = |
358 (case (deref tye T, deref tye U) of |
355 (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of |
359 (Type (a, []), Type (b, [])) => |
356 (Type (a, []), Type (b, [])) => |
360 if a = b then simplify done todo tye_idx |
357 if a = b then simplify done todo tye_idx |
361 else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx |
358 else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx |
362 else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack |
359 else err_subtype ctxt (a ^ " is not a subtype of " ^ b) (fst tye_idx) error_pack |
363 | (Type (a, Ts), Type (b, Us)) => |
360 | (Type (a, Ts), Type (b, Us)) => |
364 if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack |
361 if a <> b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack |
365 else contract a Ts Us error_pack done todo tye idx |
362 else contract a Ts Us error_pack done todo tye idx |
366 | (TVar (xi, S), Type (a, Ts as (_ :: _))) => |
363 | (TVar (xi, S), Type (a, Ts as (_ :: _))) => |
367 expand true xi S a Ts error_pack done todo tye idx |
364 expand true xi S a Ts error_pack done todo tye idx |
368 | (Type (a, Ts as (_ :: _)), TVar (xi, S)) => |
365 | (Type (a, Ts as (_ :: _)), TVar (xi, S)) => |
369 expand false xi S a Ts error_pack done todo tye idx |
366 expand false xi S a Ts error_pack done todo tye idx |
370 | (T, U) => |
367 | (T, U) => |
371 if T = U then simplify done todo tye_idx |
368 if T = U then simplify done todo tye_idx |
372 else if exists (is_freeT orf is_fixedvarT) [T, U] andalso |
369 else if exists (is_freeT orf is_fixedvarT) [T, U] andalso |
373 exists is_paramT [T, U] |
370 exists Type_Infer.is_paramT [T, U] |
374 then eliminate [T, U] error_pack done todo tye idx |
371 then eliminate [T, U] error_pack done todo tye idx |
375 else if exists (is_freeT orf is_fixedvarT) [T, U] |
372 else if exists (is_freeT orf is_fixedvarT) [T, U] |
376 then err_subtype ctxt "Not eliminated free/fixed variables" |
373 then err_subtype ctxt "Not eliminated free/fixed variables" |
377 (fst tye_idx) error_pack |
374 (fst tye_idx) error_pack |
378 else simplify (((T, U), error_pack) :: done) todo tye_idx); |
375 else simplify (((T, U), error_pack) :: done) todo tye_idx); |
470 in |
467 in |
471 build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx) |
468 build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx) |
472 end; |
469 end; |
473 |
470 |
474 fun assign_bound lower G key (tye_idx as (tye, _)) = |
471 fun assign_bound lower G key (tye_idx as (tye, _)) = |
475 if is_paramT (deref tye key) then |
472 if Type_Infer.is_paramT (Type_Infer.deref tye key) then |
476 let |
473 let |
477 val TVar (xi, S) = deref tye key; |
474 val TVar (xi, S) = Type_Infer.deref tye key; |
478 val get_bound = if lower then get_preds else get_succs; |
475 val get_bound = if lower then get_preds else get_succs; |
479 val raw_bound = get_bound G key; |
476 val raw_bound = get_bound G key; |
480 val bound = map (deref tye) raw_bound; |
477 val bound = map (Type_Infer.deref tye) raw_bound; |
481 val not_params = filter_out is_paramT bound; |
478 val not_params = filter_out Type_Infer.is_paramT bound; |
482 fun to_fulfil T = |
479 fun to_fulfil T = |
483 (case sort_of T of |
480 (case sort_of T of |
484 NONE => NONE |
481 NONE => NONE |
485 | SOME S => |
482 | SOME S => |
486 SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S)); |
483 SOME |
|
484 (map nameT |
|
485 (filter_out Type_Infer.is_paramT (map (Type_Infer.deref tye) (get_bound G T))), |
|
486 S)); |
487 val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound); |
487 val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound); |
488 val assignment = |
488 val assignment = |
489 if null bound orelse null not_params then NONE |
489 if null bound orelse null not_params then NONE |
490 else SOME (tightest lower S styps_and_sorts (map nameT not_params) |
490 else SOME (tightest lower S styps_and_sorts (map nameT not_params) |
491 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key)) |
491 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key)) |
492 in |
492 in |
493 (case assignment of |
493 (case assignment of |
494 NONE => tye_idx |
494 NONE => tye_idx |
495 | SOME T => |
495 | SOME T => |
496 if is_paramT T then tye_idx |
496 if Type_Infer.is_paramT T then tye_idx |
497 else if lower then (*upper bound check*) |
497 else if lower then (*upper bound check*) |
498 let |
498 let |
499 val other_bound = map (deref tye) (get_succs G key); |
499 val other_bound = map (Type_Infer.deref tye) (get_succs G key); |
500 val s = nameT T; |
500 val s = nameT T; |
501 in |
501 in |
502 if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s) |
502 if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s) |
503 then apfst (Vartab.update (xi, T)) tye_idx |
503 then apfst (Vartab.update (xi, T)) tye_idx |
504 else err_bound ctxt ("Assigned simple type " ^ s ^ |
504 else err_bound ctxt ("Assigned simple type " ^ s ^ |