src/Pure/type_infer.ML
changeset 2979 db6941221197
parent 2957 d35fca99b3be
child 2980 98ad57d99427
equal deleted inserted replaced
2978:83a4c4f79dcd 2979:db6941221197
     5 Type inference.
     5 Type inference.
     6 *)
     6 *)
     7 
     7 
     8 signature TYPE_INFER =
     8 signature TYPE_INFER =
     9 sig
     9 sig
    10   val infer_types: (string -> typ option) -> Sorts.classrel -> Sorts.arities
    10   val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
       
    11     -> (string -> typ option) -> Sorts.classrel -> Sorts.arities
    11     -> string list -> bool -> (indexname -> bool) -> term list -> typ list
    12     -> string list -> bool -> (indexname -> bool) -> term list -> typ list
    12     -> term list * typ list * (indexname * typ) list
    13     -> term list * typ list * (indexname * typ) list
    13 end;
    14 end;
    14 
    15 
    15 structure TypeInfer: TYPE_INFER =
    16 structure TypeInfer: TYPE_INFER =
   255 
   256 
   256     (* adjust sorts of parameters *)
   257     (* adjust sorts of parameters *)
   257 
   258 
   258     fun not_in_sort x S' S =
   259     fun not_in_sort x S' S =
   259       "Type variable " ^ x ^ "::" ^ Sorts.str_of_sort S' ^ " not in sort " ^
   260       "Type variable " ^ x ^ "::" ^ Sorts.str_of_sort S' ^ " not in sort " ^
   260         Sorts.str_of_sort S;
   261         Sorts.str_of_sort S ^ ".";
   261 
   262 
   262     fun meet _ [] = ()
   263     fun meet _ [] = ()
   263       | meet (Link (r as (ref (Param S')))) S =
   264       | meet (Link (r as (ref (Param S')))) S =
   264           if Sorts.sort_le classrel (S', S) then ()
   265           if Sorts.sort_le classrel (S', S) then ()
   265           else r := mk_param (Sorts.inter_sort classrel (S', S))
   266           else r := mk_param (Sorts.inter_sort classrel (S', S))
   296     fun unif (Link (r as ref (Param S))) T = assign r T S
   297     fun unif (Link (r as ref (Param S))) T = assign r T S
   297       | unif T (Link (r as ref (Param S))) = assign r T S
   298       | unif T (Link (r as ref (Param S))) = assign r T S
   298       | unif (Link (ref T)) U = unif T U
   299       | unif (Link (ref T)) U = unif T U
   299       | unif T (Link (ref U)) = unif T U
   300       | unif T (Link (ref U)) = unif T U
   300       | unif (PType (a, Ts)) (PType (b, Us)) =
   301       | unif (PType (a, Ts)) (PType (b, Us)) =
   301           if a <> b then raise NO_UNIFIER ("Clash of " ^ a ^ ", " ^ b ^ "!")
   302           if a <> b then
       
   303             raise NO_UNIFIER ("Clash of types " ^ quote a ^ " and " ^ quote b ^ ".")
   302           else seq2 unif Ts Us
   304           else seq2 unif Ts Us
   303       | unif T U = if T = U then () else raise NO_UNIFIER "Unification failed!";
   305       | unif T U = if T = U then () else raise NO_UNIFIER "";
   304 
   306 
   305   in unif end;
   307   in unif end;
   306 
   308 
   307 
   309 
   308 
   310 
   309 (** type inference **)
   311 (** type inference **)
   310 
   312 
   311 (* infer *)                                     (*DESTRUCTIVE*)
   313 (* infer *)                                     (*DESTRUCTIVE*)
   312 
   314 
   313 fun infer classrel arities =
   315 fun infer prt prT classrel arities =
   314   let
   316   let
   315     val unif = unify classrel arities;
   317     (* errors *)
   316 
   318 
   317     fun err msg1 msg2 bs ts Ts =
   319     fun unif_failed msg =
       
   320       "Type unification failed" ^ (if msg = "" then "." else ": " ^ msg) ^ "\n";
       
   321 
       
   322     val str_of = Pretty.string_of;
       
   323 
       
   324     fun prep_output bs ts Ts =
   318       let
   325       let
   319         val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts);
   326         val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts);
   320         val len = length Ts;
   327         val len = length Ts;
   321         val Ts' = take (len, Ts_bTs');
   328         val Ts' = take (len, Ts_bTs');
   322         val xs = map Free (map fst bs ~~ drop (len, Ts_bTs'));
   329         val xs = map Free (map fst bs ~~ drop (len, Ts_bTs'));
   323         val ts'' = map (fn t => subst_bounds (xs, t)) ts';
   330         val ts'' = map (fn t => subst_bounds (xs, t)) ts';
   324       in
   331       in (ts'', Ts') end;
   325         raise_type (msg1 ^ " " ^ msg2) Ts' ts''
   332 
   326       end;
   333     fun err_loose i =
       
   334       raise_type ("Loose bound variable: B." ^ string_of_int i) [] [];
       
   335 
       
   336     fun err_appl msg bs t T U_to_V u U =
       
   337       let
       
   338         val ([t', u'], [T', U_to_V', U']) = prep_output bs [t, u] [T, U_to_V, U];
       
   339         val text = cat_lines
       
   340          [unif_failed msg,
       
   341           "Type error in application:",
       
   342           "",
       
   343           str_of (Pretty.block [Pretty.str "operator:     ", Pretty.brk 1, prt t',
       
   344             Pretty.str " :: ", prT T']),
       
   345           str_of (Pretty.block [Pretty.str "expected type:", Pretty.brk 1, prT U_to_V']),
       
   346           "",
       
   347           str_of (Pretty.block [Pretty.str "operand:      ", Pretty.brk 1, prt u',
       
   348             Pretty.str " :: ", prT U']), ""];
       
   349       in raise_type text [T', U_to_V', U'] [t', u'] end;
       
   350 
       
   351     fun err_constraint msg bs t T U =
       
   352       let
       
   353         val ([t'], [T', U']) = prep_output bs [t] [T, U];
       
   354         val text = cat_lines
       
   355          [unif_failed msg,
       
   356           "Cannot meet type constraint:",
       
   357           "",
       
   358           str_of (Pretty.block [Pretty.str "term:          ", Pretty.brk 1, prt t',
       
   359             Pretty.str " :: ", prT T']),
       
   360           str_of (Pretty.block [Pretty.str "expected type: ", Pretty.brk 1, prT U']), ""];
       
   361       in raise_type text [T', U'] [t'] end;
       
   362 
       
   363 
       
   364     (* main *)
       
   365 
       
   366     val unif = unify classrel arities;
   327 
   367 
   328     fun inf _ (PConst (_, T)) = T
   368     fun inf _ (PConst (_, T)) = T
   329       | inf _ (PFree (_, T)) = T
   369       | inf _ (PFree (_, T)) = T
   330       | inf _ (PVar (_, T)) = T
   370       | inf _ (PVar (_, T)) = T
   331       | inf bs (PBound i) = snd (nth_elem (i, bs)
   371       | inf bs (PBound i) = snd (nth_elem (i, bs) handle LIST _ => err_loose i)
   332           handle LIST _ => raise_type "Loose bound variable" [] [Bound i])
       
   333       | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t])
   372       | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t])
   334       | inf bs (PAppl (t, u)) =
   373       | inf bs (PAppl (t, u)) =
   335           let
   374           let
   336             val T = inf bs t;
   375             val T = inf bs t;
   337             val U = inf bs u;
   376             val U = inf bs u;
   338             val V = mk_param [];
   377             val V = mk_param [];
   339             val U_to_V = PType ("fun", [U, V]);
   378             val U_to_V = PType ("fun", [U, V]);
   340             val _ = unif U_to_V T handle NO_UNIFIER msg =>
   379             val _ = unif U_to_V T handle NO_UNIFIER msg =>
   341               err msg "Bad function application." bs [PAppl (t, u)] [U_to_V, U];
   380               err_appl msg bs t T U_to_V u U;
   342           in V end
   381           in V end
   343       | inf bs (Constraint (t, U)) =
   382       | inf bs (Constraint (t, U)) =
   344           let val T = inf bs t in
   383           let val T = inf bs t in
   345             unif T U handle NO_UNIFIER msg =>
   384             unif T U handle NO_UNIFIER msg => err_constraint msg bs t T U;
   346               err msg "Cannot meet type constraint." bs [t] [T, U];
       
   347             T
   385             T
   348           end;
   386           end;
   349 
   387 
   350   in inf [] end;
   388   in inf [] end;
   351 
   389 
   352 
   390 
   353 (* infer_types *)
   391 (* infer_types *)
   354 
   392 
   355 fun infer_types const_type classrel arities used freeze is_param ts Ts =
   393 fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts =
   356   let
   394   let
   357     (*convert to preterms/typs*)
   395     (*convert to preterms/typs*)
   358     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
   396     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
   359     val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts);
   397     val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts);
   360 
   398 
   361     (*run type inference*)
   399     (*run type inference*)
   362     val tTs' = ListPair.map Constraint (ts', Ts');
   400     val tTs' = ListPair.map Constraint (ts', Ts');
   363     val _ = seq (fn t => (infer classrel arities t; ())) tTs';
   401     val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs';
   364 
   402 
   365     (*collect result unifier*)
   403     (*collect result unifier*)
   366     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None)
   404     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None)
   367       | ch_var xi_T = Some xi_T;
   405       | ch_var xi_T = Some xi_T;
   368     val env = mapfilter ch_var Tps;
   406     val env = mapfilter ch_var Tps;