src/Pure/type_infer.ML
changeset 39290 44e4d8dfd6bf
parent 39289 92b50c8bb67b
child 39291 4b632bb847a8
     1.1 --- a/src/Pure/type_infer.ML	Sun Sep 12 19:55:45 2010 +0200
     1.2 +++ b/src/Pure/type_infer.ML	Sun Sep 12 20:47:47 2010 +0200
     1.3 @@ -7,14 +7,12 @@
     1.4  signature TYPE_INFER =
     1.5  sig
     1.6    val anyT: sort -> typ
     1.7 -  val polymorphicT: typ -> typ
     1.8    val is_param: indexname -> bool
     1.9    val param: int -> string * sort -> typ
    1.10    val paramify_vars: typ -> typ
    1.11    val paramify_dummies: typ -> int -> typ * int
    1.12    val fixate_params: Name.context -> term list -> term list
    1.13 -  val infer_types: Pretty.pp -> Type.tsig -> (typ list -> typ list) ->
    1.14 -    (string -> typ option) -> (indexname -> typ option) -> Name.context -> int ->
    1.15 +  val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    1.16      term list -> term list
    1.17  end;
    1.18  
    1.19 @@ -26,9 +24,6 @@
    1.20  
    1.21  fun anyT S = TFree ("'_dummy_", S);
    1.22  
    1.23 -(*indicate polymorphic Vars*)
    1.24 -fun polymorphicT T = Type ("_polymorphic_", [T]);
    1.25 -
    1.26  
    1.27  (* type inference parameters -- may get instantiated *)
    1.28  
    1.29 @@ -235,12 +230,14 @@
    1.30  
    1.31  (* typs_terms_of *)
    1.32  
    1.33 -fun typs_terms_of tye used maxidx (Ts, ts) =
    1.34 +fun typs_terms_of ctxt tye (Ts, ts) =
    1.35    let
    1.36 -    val used' = fold (add_names tye) ts (fold (add_namesT tye) Ts used);
    1.37 +    val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
    1.38      val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
    1.39 -    val names = Name.invents used' ("?" ^ Name.aT) (length parms);
    1.40 +    val names = Name.invents used ("?" ^ Name.aT) (length parms);
    1.41      val tab = Inttab.make (parms ~~ names);
    1.42 +
    1.43 +    val maxidx = Variable.maxidx_of ctxt;
    1.44      fun f i = (the (Inttab.lookup tab i), maxidx + 1);
    1.45    in (map (simple_typ_of tye f) Ts, map (simple_term_of tye f) ts) end;
    1.46  
    1.47 @@ -250,27 +247,31 @@
    1.48  
    1.49  exception NO_UNIFIER of string * pretyp Inttab.table;
    1.50  
    1.51 -fun unify pp tsig =
    1.52 +fun unify ctxt pp =
    1.53    let
    1.54 +    val thy = ProofContext.theory_of ctxt;
    1.55 +    val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
    1.56 +
    1.57 +
    1.58      (* adjust sorts of parameters *)
    1.59  
    1.60      fun not_of_sort x S' S =
    1.61 -      "Variable " ^ x ^ "::" ^ Pretty.string_of_sort pp S' ^ " not of sort " ^
    1.62 -        Pretty.string_of_sort pp S;
    1.63 +      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
    1.64 +        Syntax.string_of_sort ctxt S;
    1.65  
    1.66      fun meet (_, []) tye_idx = tye_idx
    1.67        | meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
    1.68 -          if Type.subsort tsig (S', S) then tye_idx
    1.69 +          if Sign.subsort thy (S', S) then tye_idx
    1.70            else (Inttab.update_new (i,
    1.71 -            Param (idx, Type.inter_sort tsig (S', S))) tye, idx + 1)
    1.72 +            Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
    1.73        | meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
    1.74 -          meets (Ts, Type.arity_sorts pp tsig a S
    1.75 +          meets (Ts, arity_sorts a S
    1.76              handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
    1.77        | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
    1.78 -          if Type.subsort tsig (S', S) then tye_idx
    1.79 +          if Sign.subsort thy (S', S) then tye_idx
    1.80            else raise NO_UNIFIER (not_of_sort x S' S, tye)
    1.81        | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
    1.82 -          if Type.subsort tsig (S', S) then tye_idx
    1.83 +          if Sign.subsort thy (S', S) then tye_idx
    1.84            else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
    1.85      and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
    1.86            meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
    1.87 @@ -298,7 +299,7 @@
    1.88      (* unification *)
    1.89  
    1.90      fun show_tycon (a, Ts) =
    1.91 -      quote (Pretty.string_of_typ pp (Type (a, replicate (length Ts) dummyT)));
    1.92 +      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
    1.93  
    1.94      fun unif (T1, T2) (tye_idx as (tye, idx)) =
    1.95        (case (deref tye T1, deref tye T2) of
    1.96 @@ -319,13 +320,16 @@
    1.97  
    1.98  (* infer *)
    1.99  
   1.100 -fun infer pp tsig =
   1.101 +fun infer ctxt =
   1.102    let
   1.103 +    val pp = Syntax.pp ctxt;
   1.104 +
   1.105 +
   1.106      (* errors *)
   1.107  
   1.108      fun prep_output tye bs ts Ts =
   1.109        let
   1.110 -        val (Ts_bTs', ts') = typs_terms_of tye Name.context ~1 (Ts @ map snd bs, ts);
   1.111 +        val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
   1.112          val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   1.113          fun prep t =
   1.114            let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   1.115 @@ -355,8 +359,6 @@
   1.116  
   1.117      (* main *)
   1.118  
   1.119 -    val unif = unify pp tsig;
   1.120 -
   1.121      fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
   1.122        | inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
   1.123        | inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
   1.124 @@ -371,13 +373,13 @@
   1.125              val (U, (tye, idx)) = inf bs u tye_idx';
   1.126              val V = Param (idx, []);
   1.127              val U_to_V = PType ("fun", [U, V]);
   1.128 -            val tye_idx'' = unif (U_to_V, T) (tye, idx + 1)
   1.129 +            val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
   1.130                handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   1.131            in (V, tye_idx'') end
   1.132        | inf bs (Constraint (t, U)) tye_idx =
   1.133            let val (T, tye_idx') = inf bs t tye_idx in
   1.134              (T,
   1.135 -             unif (T, U) tye_idx'
   1.136 +             unify ctxt pp (T, U) tye_idx'
   1.137                 handle NO_UNIFIER (msg, tye) => err_constraint msg tye bs t T U)
   1.138            end;
   1.139  
   1.140 @@ -386,7 +388,7 @@
   1.141  
   1.142  (* infer_types *)
   1.143  
   1.144 -fun infer_types pp tsig check_typs const_type var_type used maxidx raw_ts =
   1.145 +fun infer_types ctxt const_type var_type raw_ts =
   1.146    let
   1.147      (*constrain vars*)
   1.148      val get_type = the_default dummyT o var_type;
   1.149 @@ -396,13 +398,13 @@
   1.150          | t => t);
   1.151  
   1.152      (*convert to preterms*)
   1.153 -    val ts = burrow_types check_typs raw_ts;
   1.154 +    val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   1.155      val (ts', (_, _, idx)) =
   1.156        fold_map (preterm_of const_type o constrain_vars) ts
   1.157        (Vartab.empty, Vartab.empty, 0);
   1.158  
   1.159      (*do type inference*)
   1.160 -    val (tye, _) = fold (snd oo infer pp tsig) ts' (Inttab.empty, idx);
   1.161 -  in #2 (typs_terms_of tye used maxidx ([], ts')) end;
   1.162 +    val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
   1.163 +  in #2 (typs_terms_of ctxt tye ([], ts')) end;
   1.164  
   1.165  end;