src/Tools/subtyping.ML
changeset 40286 b928e3960446
parent 40285 80136c4240cc
child 40297 c753e3f8b4d6
equal deleted inserted replaced
40285:80136c4240cc 40286:b928e3960446
    71 
    71 
    72 
    72 
    73 
    73 
    74 (** utils **)
    74 (** utils **)
    75 
    75 
    76 val is_param = Type_Infer.is_param
       
    77 val is_paramT = Type_Infer.is_paramT
       
    78 val deref = Type_Infer.deref
       
    79 fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *)
       
    80 
       
    81 fun nameT (Type (s, [])) = s;
    76 fun nameT (Type (s, [])) = s;
    82 fun t_of s = Type (s, []);
    77 fun t_of s = Type (s, []);
       
    78 
    83 fun sort_of (TFree (_, S)) = SOME S
    79 fun sort_of (TFree (_, S)) = SOME S
    84   | sort_of (TVar (_, S)) = SOME S
    80   | sort_of (TVar (_, S)) = SOME S
    85   | sort_of _ = NONE;
    81   | sort_of _ = NONE;
    86 
    82 
    87 val is_typeT = fn (Type _) => true | _ => false;
    83 val is_typeT = fn (Type _) => true | _ => false;
    88 val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
    84 val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
    89 val is_freeT = fn (TFree _) => true | _ => false;
    85 val is_freeT = fn (TFree _) => true | _ => false;
    90 val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false;
    86 val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
    91 
    87 
    92 
    88 
    93 (* unification *)  (* TODO dup? needed for weak unification *)
    89 (* unification *)  (* TODO dup? needed for weak unification *)
    94 
    90 
    95 exception NO_UNIFIER of string * typ Vartab.table;
    91 exception NO_UNIFIER of string * typ Vartab.table;
   114           if Sign.subsort thy (S', S) then tye_idx
   110           if Sign.subsort thy (S', S) then tye_idx
   115           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   111           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   116       | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
   112       | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
   117           if Sign.subsort thy (S', S) then tye_idx
   113           if Sign.subsort thy (S', S) then tye_idx
   118           else if Type_Infer.is_param xi then
   114           else if Type_Infer.is_param xi then
   119             (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
   115             (Vartab.update_new
       
   116               (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
   120           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   117           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   121     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   118     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   122           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
   119           meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
   123       | meets _ tye_idx = tye_idx;
   120       | meets _ tye_idx = tye_idx;
   124 
   121 
   125     val weak_meet = if weak then fn _ => I else meet
   122     val weak_meet = if weak then fn _ => I else meet
   126 
   123 
   127 
   124 
   147 
   144 
   148     fun show_tycon (a, Ts) =
   145     fun show_tycon (a, Ts) =
   149       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   146       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   150 
   147 
   151     fun unif (T1, T2) (env as (tye, _)) =
   148     fun unif (T1, T2) (env as (tye, _)) =
   152       (case pairself (`is_paramT o deref tye) (T1, T2) of
   149       (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
   153         ((true, TVar (xi, S)), (_, T)) => assign xi T S env
   150         ((true, TVar (xi, S)), (_, T)) => assign xi T S env
   154       | ((_, T), (true, TVar (xi, S))) => assign xi T S env
   151       | ((_, T), (true, TVar (xi, S))) => assign xi T S env
   155       | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
   152       | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
   156           if weak andalso null Ts andalso null Us then env
   153           if weak andalso null Ts andalso null Us then env
   157           else if a <> b then
   154           else if a <> b then
   255           in (T --> U, tye_idx', cs') end
   252           in (T --> U, tye_idx', cs') end
   256       | gen cs bs (t $ u) tye_idx =
   253       | gen cs bs (t $ u) tye_idx =
   257           let
   254           let
   258             val (T, tye_idx', cs') = gen cs bs t tye_idx;
   255             val (T, tye_idx', cs') = gen cs bs t tye_idx;
   259             val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
   256             val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
   260             val U = mk_param idx [];
   257             val U = Type_Infer.mk_param idx [];
   261             val V = mk_param (idx + 1) [];
   258             val V = Type_Infer.mk_param (idx + 1) [];
   262             val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
   259             val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
   263               handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
   260               handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
   264             val error_pack = (bs, t $ u, U, V, U');
   261             val error_pack = (bs, t $ u, U, V, U');
   265           in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
   262           in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
   266   in
   263   in
   316                   handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
   313                   handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
   317             val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
   314             val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
   318               (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
   315               (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
   319             val test_update = is_compT orf is_freeT orf is_fixedvarT;
   316             val test_update = is_compT orf is_freeT orf is_fixedvarT;
   320             val (ch, done') =
   317             val (ch, done') =
   321               if not (null new) then ([],   done)
   318               if not (null new) then ([], done)
   322               else split_cs (test_update o deref tye') done;
   319               else split_cs (test_update o Type_Infer.deref tye') done;
   323             val todo' = ch @ todo;
   320             val todo' = ch @ todo;
   324           in
   321           in
   325             simplify done' (new @ todo') (tye', idx')
   322             simplify done' (new @ todo') (tye', idx')
   326           end
   323           end
   327         (*xi is definitely a parameter*)
   324         (*xi is definitely a parameter*)
   328         and expand varleq xi S a Ts error_pack done todo tye idx =
   325         and expand varleq xi S a Ts error_pack done todo tye idx =
   329           let
   326           let
   330             val n = length Ts;
   327             val n = length Ts;
   331             val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S);
   328             val args = map2 Type_Infer.mk_param (idx upto idx + n - 1) (arity_sorts a S);
   332             val tye' = Vartab.update_new (xi, Type(a, args)) tye;
   329             val tye' = Vartab.update_new (xi, Type(a, args)) tye;
   333             val (ch, done') = split_cs (is_compT o deref tye') done;
   330             val (ch, done') = split_cs (is_compT o Type_Infer.deref tye') done;
   334             val todo' = ch @ todo;
   331             val todo' = ch @ todo;
   335             val new =
   332             val new =
   336               if varleq then (Type(a, args), Type (a, Ts))
   333               if varleq then (Type(a, args), Type (a, Ts))
   337               else (Type (a, Ts), Type(a, args));
   334               else (Type (a, Ts), Type (a, args));
   338           in
   335           in
   339             simplify done' ((new, error_pack) :: todo') (tye', idx + n)
   336             simplify done' ((new, error_pack) :: todo') (tye', idx + n)
   340           end
   337           end
   341         (*TU is a pair of a parameter and a free/fixed variable*)
   338         (*TU is a pair of a parameter and a free/fixed variable*)
   342         and eliminate TU error_pack done todo tye idx =
   339         and eliminate TU error_pack done todo tye idx =
   343           let
   340           let
   344             val [TVar (xi, S)] = filter is_paramT TU;
   341             val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
   345             val [T] = filter_out is_paramT TU;
   342             val [T] = filter_out Type_Infer.is_paramT TU;
   346             val SOME S' = sort_of T;
   343             val SOME S' = sort_of T;
   347             val test_update = if is_freeT T then is_freeT else is_fixedvarT;
   344             val test_update = if is_freeT T then is_freeT else is_fixedvarT;
   348             val tye' = Vartab.update_new (xi, T) tye;
   345             val tye' = Vartab.update_new (xi, T) tye;
   349             val (ch, done') = split_cs (test_update o deref tye') done;
   346             val (ch, done') = split_cs (test_update o Type_Infer.deref tye') done;
   350             val todo' = ch @ todo;
   347             val todo' = ch @ todo;
   351           in
   348           in
   352             if subsort (S', S) (*TODO check this*)
   349             if subsort (S', S) (*TODO check this*)
   353             then simplify done' todo' (tye', idx)
   350             then simplify done' todo' (tye', idx)
   354             else err_subtype ctxt "Sort mismatch" tye error_pack
   351             else err_subtype ctxt "Sort mismatch" tye error_pack
   355           end
   352           end
   356         and simplify done [] tye_idx = (done, tye_idx)
   353         and simplify done [] tye_idx = (done, tye_idx)
   357           | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
   354           | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
   358               (case (deref tye T, deref tye U) of
   355               (case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
   359                 (Type (a, []), Type (b, [])) =>
   356                 (Type (a, []), Type (b, [])) =>
   360                   if a = b then simplify done todo tye_idx
   357                   if a = b then simplify done todo tye_idx
   361                   else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
   358                   else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
   362                   else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack
   359                   else err_subtype ctxt (a ^ " is not a subtype of " ^ b) (fst tye_idx) error_pack
   363               | (Type (a, Ts), Type (b, Us)) =>
   360               | (Type (a, Ts), Type (b, Us)) =>
   364                   if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
   361                   if a <> b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
   365                   else contract a Ts Us error_pack done todo tye idx
   362                   else contract a Ts Us error_pack done todo tye idx
   366               | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
   363               | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
   367                   expand true xi S a Ts error_pack done todo tye idx
   364                   expand true xi S a Ts error_pack done todo tye idx
   368               | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
   365               | (Type (a, Ts as (_ :: _)), TVar (xi, S)) =>
   369                   expand false xi S a Ts error_pack done todo tye idx
   366                   expand false xi S a Ts error_pack done todo tye idx
   370               | (T, U) =>
   367               | (T, U) =>
   371                   if T = U then simplify done todo tye_idx
   368                   if T = U then simplify done todo tye_idx
   372                   else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
   369                   else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
   373                     exists is_paramT [T, U]
   370                     exists Type_Infer.is_paramT [T, U]
   374                   then eliminate [T, U] error_pack done todo tye idx
   371                   then eliminate [T, U] error_pack done todo tye idx
   375                   else if exists (is_freeT orf is_fixedvarT) [T, U]
   372                   else if exists (is_freeT orf is_fixedvarT) [T, U]
   376                   then err_subtype ctxt "Not eliminated free/fixed variables"
   373                   then err_subtype ctxt "Not eliminated free/fixed variables"
   377                         (fst tye_idx) error_pack
   374                         (fst tye_idx) error_pack
   378                   else simplify (((T, U), error_pack) :: done) todo tye_idx);
   375                   else simplify (((T, U), error_pack) :: done) todo tye_idx);
   470       in
   467       in
   471         build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
   468         build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
   472       end;
   469       end;
   473 
   470 
   474     fun assign_bound lower G key (tye_idx as (tye, _)) =
   471     fun assign_bound lower G key (tye_idx as (tye, _)) =
   475       if is_paramT (deref tye key) then
   472       if Type_Infer.is_paramT (Type_Infer.deref tye key) then
   476         let
   473         let
   477           val TVar (xi, S) = deref tye key;
   474           val TVar (xi, S) = Type_Infer.deref tye key;
   478           val get_bound = if lower then get_preds else get_succs;
   475           val get_bound = if lower then get_preds else get_succs;
   479           val raw_bound = get_bound G key;
   476           val raw_bound = get_bound G key;
   480           val bound = map (deref tye) raw_bound;
   477           val bound = map (Type_Infer.deref tye) raw_bound;
   481           val not_params = filter_out is_paramT bound;
   478           val not_params = filter_out Type_Infer.is_paramT bound;
   482           fun to_fulfil T =
   479           fun to_fulfil T =
   483             (case sort_of T of
   480             (case sort_of T of
   484               NONE => NONE
   481               NONE => NONE
   485             | SOME S =>
   482             | SOME S =>
   486                 SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S));
   483                 SOME
       
   484                   (map nameT
       
   485                     (filter_out Type_Infer.is_paramT (map (Type_Infer.deref tye) (get_bound G T))),
       
   486                       S));
   487           val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
   487           val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
   488           val assignment =
   488           val assignment =
   489             if null bound orelse null not_params then NONE
   489             if null bound orelse null not_params then NONE
   490             else SOME (tightest lower S styps_and_sorts (map nameT not_params)
   490             else SOME (tightest lower S styps_and_sorts (map nameT not_params)
   491                 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
   491                 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
   492         in
   492         in
   493           (case assignment of
   493           (case assignment of
   494             NONE => tye_idx
   494             NONE => tye_idx
   495           | SOME T =>
   495           | SOME T =>
   496               if is_paramT T then tye_idx
   496               if Type_Infer.is_paramT T then tye_idx
   497               else if lower then (*upper bound check*)
   497               else if lower then (*upper bound check*)
   498                 let
   498                 let
   499                   val other_bound = map (deref tye) (get_succs G key);
   499                   val other_bound = map (Type_Infer.deref tye) (get_succs G key);
   500                   val s = nameT T;
   500                   val s = nameT T;
   501                 in
   501                 in
   502                   if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
   502                   if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
   503                   then apfst (Vartab.update (xi, T)) tye_idx
   503                   then apfst (Vartab.update (xi, T)) tye_idx
   504                   else err_bound ctxt ("Assigned simple type " ^ s ^
   504                   else err_bound ctxt ("Assigned simple type " ^ s ^
   517       else
   517       else
   518         let
   518         let
   519           val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
   519           val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
   520             |> fold (assign_ub G) ts;
   520             |> fold (assign_ub G) ts;
   521         in
   521         in
   522           assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx'
   522           assign_alternating ts (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
   523         end;
   523         end;
   524 
   524 
   525     (*Unify all weakly connected components of the constraint forest,
   525     (*Unify all weakly connected components of the constraint forest,
   526       that contain only params. These are the only WCCs that contain
   526       that contain only params. These are the only WCCs that contain
   527       params anyway.*)
   527       params anyway.*)
   528     fun unify_params G (tye_idx as (tye, _)) =
   528     fun unify_params G (tye_idx as (tye, _)) =
   529       let
   529       let
   530         val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G);
   530         val max_params =
       
   531           filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G);
   531         val to_unify = map (fn T => T :: get_preds G T) max_params;
   532         val to_unify = map (fn T => T :: get_preds G T) max_params;
   532       in
   533       in
   533         fold unify_list to_unify tye_idx
   534         fold unify_list to_unify tye_idx
   534       end;
   535       end;
   535 
   536 
   546 (** coercion insertion **)
   547 (** coercion insertion **)
   547 
   548 
   548 fun insert_coercions ctxt tye ts =
   549 fun insert_coercions ctxt tye ts =
   549   let
   550   let
   550     fun deep_deref T =
   551     fun deep_deref T =
   551       (case deref tye T of
   552       (case Type_Infer.deref tye T of
   552         Type (a, Ts) => Type (a, map deep_deref Ts)
   553         Type (a, Ts) => Type (a, map deep_deref Ts)
   553       | U => U);
   554       | U => U);
   554 
   555 
   555     fun gen_coercion ((Type (a, [])), (Type (b, []))) =
   556     fun gen_coercion ((Type (a, [])), (Type (b, []))) =
   556           if a = b
   557           if a = b