src/Pure/type_infer.ML
changeset 15570 8d8c70b41bab
parent 15531 08c8dad8e399
child 16195 0eb3c15298cd
equal deleted inserted replaced
15569:1b3115d1a8df 15570:8d8c70b41bab
   112   let
   112   let
   113     fun add_parms (ps, TVar (xi as (x, _), S)) =
   113     fun add_parms (ps, TVar (xi as (x, _), S)) =
   114           if is_param xi andalso is_none (assoc (ps, xi))
   114           if is_param xi andalso is_none (assoc (ps, xi))
   115           then (xi, mk_param S) :: ps else ps
   115           then (xi, mk_param S) :: ps else ps
   116       | add_parms (ps, TFree _) = ps
   116       | add_parms (ps, TFree _) = ps
   117       | add_parms (ps, Type (_, Ts)) = foldl add_parms (ps, Ts);
   117       | add_parms (ps, Type (_, Ts)) = Library.foldl add_parms (ps, Ts);
   118 
   118 
   119     val params' = add_parms (params, typ);
   119     val params' = add_parms (params, typ);
   120 
   120 
   121     fun pre_of (TVar (v as (xi, _))) =
   121     fun pre_of (TVar (v as (xi, _))) =
   122           (case assoc (params', xi) of
   122           (case assoc (params', xi) of
   147       | add_vparms (ps, Abs (_, _, t)) = add_vparms (ps, t)
   147       | add_vparms (ps, Abs (_, _, t)) = add_vparms (ps, t)
   148       | add_vparms (ps, t $ u) = add_vparms (add_vparms (ps, t), u)
   148       | add_vparms (ps, t $ u) = add_vparms (add_vparms (ps, t), u)
   149       | add_vparms (ps, _) = ps;
   149       | add_vparms (ps, _) = ps;
   150 
   150 
   151     val vparams' = add_vparms (vparams, tm);
   151     val vparams' = add_vparms (vparams, tm);
   152     fun var_param xi = the (assoc (vparams', xi));
   152     fun var_param xi = valOf (assoc (vparams', xi));
   153 
   153 
   154 
   154 
   155     val preT_of = pretyp_of is_param;
   155     val preT_of = pretyp_of is_param;
   156 
   156 
   157     fun constrain (ps, t) T =
   157     fun constrain (ps, t) T =
   190 
   190 
   191 (** pretyps/terms to typs/terms **)
   191 (** pretyps/terms to typs/terms **)
   192 
   192 
   193 (* add_parms *)
   193 (* add_parms *)
   194 
   194 
   195 fun add_parmsT (rs, PType (_, Ts)) = foldl add_parmsT (rs, Ts)
   195 fun add_parmsT (rs, PType (_, Ts)) = Library.foldl add_parmsT (rs, Ts)
   196   | add_parmsT (rs, Link (r as ref (Param _))) = r ins rs
   196   | add_parmsT (rs, Link (r as ref (Param _))) = r ins rs
   197   | add_parmsT (rs, Link (ref T)) = add_parmsT (rs, T)
   197   | add_parmsT (rs, Link (ref T)) = add_parmsT (rs, T)
   198   | add_parmsT (rs, _) = rs;
   198   | add_parmsT (rs, _) = rs;
   199 
   199 
   200 val add_parms = foldl_pretyps add_parmsT;
   200 val add_parms = foldl_pretyps add_parmsT;
   201 
   201 
   202 
   202 
   203 (* add_names *)
   203 (* add_names *)
   204 
   204 
   205 fun add_namesT (xs, PType (_, Ts)) = foldl add_namesT (xs, Ts)
   205 fun add_namesT (xs, PType (_, Ts)) = Library.foldl add_namesT (xs, Ts)
   206   | add_namesT (xs, PTFree (x, _)) = x ins xs
   206   | add_namesT (xs, PTFree (x, _)) = x ins xs
   207   | add_namesT (xs, PTVar ((x, _), _)) = x ins xs
   207   | add_namesT (xs, PTVar ((x, _), _)) = x ins xs
   208   | add_namesT (xs, Link (ref T)) = add_namesT (xs, T)
   208   | add_namesT (xs, Link (ref T)) = add_namesT (xs, T)
   209   | add_namesT (xs, Param _) = xs;
   209   | add_namesT (xs, Param _) = xs;
   210 
   210 
   235 fun typs_terms_of used mk_var prfx (Ts, ts) =
   235 fun typs_terms_of used mk_var prfx (Ts, ts) =
   236   let
   236   let
   237     fun elim (r as ref (Param S), x) = r := mk_var (x, S)
   237     fun elim (r as ref (Param S), x) = r := mk_var (x, S)
   238       | elim _ = ();
   238       | elim _ = ();
   239 
   239 
   240     val used' = foldl add_names (foldl add_namesT (used, Ts), ts);
   240     val used' = Library.foldl add_names (Library.foldl add_namesT (used, Ts), ts);
   241     val parms = rev (foldl add_parms (foldl add_parmsT ([], Ts), ts));
   241     val parms = rev (Library.foldl add_parms (Library.foldl add_parmsT ([], Ts), ts));
   242     val names = Term.invent_names used' (prfx ^ "'a") (length parms);
   242     val names = Term.invent_names used' (prfx ^ "'a") (length parms);
   243   in
   243   in
   244     seq2 elim (parms, names);
   244     seq2 elim (parms, names);
   245     (map simple_typ_of Ts, map simple_term_of ts)
   245     (map simple_typ_of Ts, map simple_term_of ts)
   246   end;
   246   end;
   283     (* occurs check and assigment *)
   283     (* occurs check and assigment *)
   284 
   284 
   285     fun occurs_check r (Link (r' as ref T)) =
   285     fun occurs_check r (Link (r' as ref T)) =
   286           if r = r' then raise NO_UNIFIER "Occurs check!"
   286           if r = r' then raise NO_UNIFIER "Occurs check!"
   287           else occurs_check r T
   287           else occurs_check r T
   288       | occurs_check r (PType (_, Ts)) = seq (occurs_check r) Ts
   288       | occurs_check r (PType (_, Ts)) = List.app (occurs_check r) Ts
   289       | occurs_check _ _ = ();
   289       | occurs_check _ _ = ();
   290 
   290 
   291     fun assign r T S =
   291     fun assign r T S =
   292       (case deref T of
   292       (case deref T of
   293         T' as Link (r' as ref (Param _)) =>
   293         T' as Link (r' as ref (Param _)) =>
   374     val unif = unify pp classes arities;
   374     val unif = unify pp classes arities;
   375 
   375 
   376     fun inf _ (PConst (_, T)) = T
   376     fun inf _ (PConst (_, T)) = T
   377       | inf _ (PFree (_, T)) = T
   377       | inf _ (PFree (_, T)) = T
   378       | inf _ (PVar (_, T)) = T
   378       | inf _ (PVar (_, T)) = T
   379       | inf bs (PBound i) = snd (nth_elem (i, bs) handle LIST _ => err_loose i)
   379       | inf bs (PBound i) = snd (List.nth (bs, i) handle Subscript => err_loose i)
   380       | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t])
   380       | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t])
   381       | inf bs (PAppl (t, u)) =
   381       | inf bs (PAppl (t, u)) =
   382           let
   382           let
   383             val T = inf bs t;
   383             val T = inf bs t;
   384             val U = inf bs u;
   384             val U = inf bs u;
   403     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
   403     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
   404     val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts);
   404     val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts);
   405 
   405 
   406     (*run type inference*)
   406     (*run type inference*)
   407     val tTs' = ListPair.map Constraint (ts', Ts');
   407     val tTs' = ListPair.map Constraint (ts', Ts');
   408     val _ = seq (fn t => (infer pp classes arities t; ())) tTs';
   408     val _ = List.app (fn t => (infer pp classes arities t; ())) tTs';
   409 
   409 
   410     (*collect result unifier*)
   410     (*collect result unifier*)
   411     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); NONE)
   411     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); NONE)
   412       | ch_var xi_T = SOME xi_T;
   412       | ch_var xi_T = SOME xi_T;
   413     val env = mapfilter ch_var Tps;
   413     val env = List.mapPartial ch_var Tps;
   414 
   414 
   415     (*convert back to terms/typs*)
   415     (*convert back to terms/typs*)
   416     val mk_var =
   416     val mk_var =
   417       if freeze then PTFree
   417       if freeze then PTFree
   418       else (fn (x, S) => PTVar ((x, 0), S));
   418       else (fn (x, S) => PTVar ((x, 0), S));
   470 
   470 
   471 (* decode_types -- transform parse tree into raw term *)
   471 (* decode_types -- transform parse tree into raw term *)
   472 
   472 
   473 fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
   473 fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
   474   let
   474   let
   475     fun get_type xi = if_none (def_type xi) dummyT;
   475     fun get_type xi = getOpt (def_type xi, dummyT);
   476     fun is_free x = is_some (def_type (x, ~1));
   476     fun is_free x = isSome (def_type (x, ~1));
   477     val raw_env = Syntax.raw_term_sorts tm;
   477     val raw_env = Syntax.raw_term_sorts tm;
   478     val sort_of = get_sort tsig def_sort map_sort raw_env;
   478     val sort_of = get_sort tsig def_sort map_sort raw_env;
   479 
   479 
   480     val certT = Type.cert_typ tsig o map_type;
   480     val certT = Type.cert_typ tsig o map_type;
   481     fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
   481     fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
   517 fun infer_types pp tsig const_type def_type def_sort
   517 fun infer_types pp tsig const_type def_type def_sort
   518     map_const map_type map_sort used freeze pat_Ts raw_ts =
   518     map_const map_type map_sort used freeze pat_Ts raw_ts =
   519   let
   519   let
   520     val {classes, arities, ...} = Type.rep_tsig tsig;
   520     val {classes, arities, ...} = Type.rep_tsig tsig;
   521     val pat_Ts' = map (Type.cert_typ tsig) pat_Ts;
   521     val pat_Ts' = map (Type.cert_typ tsig) pat_Ts;
   522     val is_const = is_some o const_type;
   522     val is_const = isSome o const_type;
   523     val raw_ts' =
   523     val raw_ts' =
   524       map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
   524       map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
   525     val (ts, Ts, unifier) = basic_infer_types pp const_type
   525     val (ts, Ts, unifier) = basic_infer_types pp const_type
   526       classes arities used freeze is_param raw_ts' pat_Ts';
   526       classes arities used freeze is_param raw_ts' pat_Ts';
   527   in (ts, unifier) end;
   527   in (ts, unifier) end;