improved type check error messages;
authorwenzelm
Thu, 17 Apr 1997 18:46:58 +0200
changeset 2979 db6941221197
parent 2978 83a4c4f79dcd
child 2980 98ad57d99427
improved type check error messages;
src/Pure/sign.ML
src/Pure/theory.ML
src/Pure/thm.ML
src/Pure/type.ML
src/Pure/type_infer.ML
--- a/src/Pure/sign.ML	Thu Apr 17 18:45:43 1997 +0200
+++ b/src/Pure/sign.ML	Thu Apr 17 18:46:58 1997 +0200
@@ -36,7 +36,6 @@
   val certify_typ: sg -> typ -> typ
   val certify_term: sg -> term -> term * typ * int
   val read_typ: sg * (indexname -> sort option) -> string -> typ
-  val exn_type_msg: sg -> string * typ list * term list -> string
   val infer_types: sg -> (indexname -> typ option) ->
     (indexname -> sort option) -> string list -> bool
     -> term list * typ -> int * term * (indexname * typ) list
@@ -77,6 +76,7 @@
 structure Sign : SIGN =
 struct
 
+
 (** datatype sg **)
 
 (*the "ref" in stamps ensures that no two signatures are identical -- it is
@@ -93,9 +93,8 @@
 val tsig_of = #tsig o rep_sg;
 
 
-(* stamps *)
+(* inclusion and equality *)
 
-(*inclusion, equality*)
 local
   (*avoiding polymorphic equality: factor 10 speedup*)
   fun mem_stamp (_:string ref, []) = false
@@ -136,8 +135,7 @@
 
 (* consts *)
 
-fun const_type (Sg {const_tab, ...}) c =
-  Symtab.lookup (const_tab, c);
+fun const_type (Sg {const_tab, ...}) c = Symtab.lookup (const_tab, c);
 
 
 (* classes and sorts *)
@@ -148,6 +146,7 @@
 val norm_sort = Type.norm_sort o tsig_of;
 val nonempty_sort = Type.nonempty_sort o tsig_of;
 
+(* FIXME move to Sorts? *)
 fun pretty_sort [c] = Pretty.str c
   | pretty_sort cs = Pretty.str_list "{" "}" cs;
 
@@ -193,15 +192,14 @@
   in
     Pretty.writeln (Pretty.strs ("stamps:" :: stamp_names stamps));
     Pretty.writeln (Pretty.strs ("classes:" :: classes));
-    Pretty.writeln (Pretty.big_list "classrel:"
-                      (map pretty_classrel classrel));
+    Pretty.writeln (Pretty.big_list "classrel:" (map pretty_classrel classrel));
     Pretty.writeln (pretty_default default);
     Pretty.writeln (Pretty.big_list "types:" (map pretty_ty tycons));
     Pretty.writeln (Pretty.big_list "abbrs:" (map (pretty_abbr syn) abbrs));
     Pretty.writeln (Pretty.big_list "arities:"
-                      (List.concat (map pretty_arities arities)));
+      (List.concat (map pretty_arities arities)));
     Pretty.writeln (Pretty.big_list "consts:"
-                      (map (pretty_const syn) (Symtab.dest const_tab)))
+      (map (pretty_const syn) (Symtab.dest const_tab)))
   end;
 
 
@@ -251,44 +249,51 @@
 
 fun certify_typ (Sg {tsig, ...}) ty = Type.cert_typ tsig ty;
 
-(* check for duplicate TVars with distinct sorts *)
-fun nodup_TVars(tvars,T) = (case T of
-      Type(_,Ts) => nodup_TVars_list (tvars,Ts)
-    | TFree _ => tvars
-    | TVar(v as (a,S)) =>
-        (case assoc_string_int(tvars,a) of
-           Some(S') => if S=S' then tvars
-                       else raise_type
-                            ("Type variable "^Syntax.string_of_vname a^
-                             " has two distinct sorts") [TVar(a,S'),T] []
-         | None => v::tvars))
-and (*equivalent to foldl nodup_TVars_list, but 3X faster under Poly/ML*)
-    nodup_TVars_list (tvars,[]) = tvars
-  | nodup_TVars_list (tvars,T::Ts) = nodup_TVars_list(nodup_TVars(tvars,T), 
-						      Ts);
+(*check for duplicate TVars with distinct sorts*)
+fun nodup_TVars (tvars, T) =
+  (case T of
+    Type (_, Ts) => nodup_TVars_list (tvars, Ts)
+  | TFree _ => tvars
+  | TVar (v as (a, S)) =>
+      (case assoc_string_int (tvars, a) of
+        Some S' =>
+          if S = S' then tvars
+          else raise_type ("Type variable " ^ Syntax.string_of_vname a ^
+            " has two distinct sorts") [TVar (a, S'), T] []
+      | None => v :: tvars))
+(*equivalent to foldl nodup_TVars_list, but 3X faster under Poly/ML*)
+and nodup_TVars_list (tvars, []) = tvars
+  | nodup_TVars_list (tvars, T :: Ts) =
+      nodup_TVars_list (nodup_TVars (tvars, T), Ts);
 
-(* check for duplicate Vars with distinct types *)
+(*check for duplicate Vars with distinct types*)
 fun nodup_Vars tm =
-let fun nodups vars tvars tm = (case tm of
-          Const(c,T) => (vars, nodup_TVars(tvars,T))
-        | Free(a,T) => (vars, nodup_TVars(tvars,T))
-        | Var(v as (ixn,T)) =>
-            (case assoc_string_int(vars,ixn) of
-               Some(T') => if T=T' then (vars,nodup_TVars(tvars,T))
-                           else raise_type
-                             ("Variable "^Syntax.string_of_vname ixn^
-                              " has two distinct types") [T',T] []
-             | None => (v::vars,tvars))
-        | Bound _ => (vars,tvars)
-        | Abs(_,T,t) => nodups vars (nodup_TVars(tvars,T)) t
-        | s$t => let val (vars',tvars') = nodups vars tvars s
-                 in nodups vars' tvars' t end);
-in nodups [] [] tm; () end;
+  let
+    fun nodups vars tvars tm =
+      (case tm of
+        Const (c, T) => (vars, nodup_TVars (tvars, T))
+      | Free (a, T) => (vars, nodup_TVars (tvars, T))
+      | Var (v as (ixn, T)) =>
+          (case assoc_string_int (vars, ixn) of
+            Some T' =>
+              if T = T' then (vars, nodup_TVars (tvars, T))
+              else raise_type ("Variable " ^ Syntax.string_of_vname ixn ^
+                " has two distinct types") [T', T] []
+          | None => (v :: vars, tvars))
+      | Bound _ => (vars, tvars)
+      | Abs(_, T, t) => nodups vars (nodup_TVars (tvars, T)) t
+      | s $ t =>
+          let val (vars',tvars') = nodups vars tvars s in
+            nodups vars' tvars' t
+          end);
+  in nodups [] [] tm; () end;
+
 
 fun mapfilt_atoms f (Abs (_, _, t)) = mapfilt_atoms f t
   | mapfilt_atoms f (t $ u) = mapfilt_atoms f t @ mapfilt_atoms f u
   | mapfilt_atoms f a = (case f a of Some y => [y] | None => []);
 
+
 fun certify_term (sg as Sg {tsig, ...}) tm =
   let
     fun valid_const a T =
@@ -316,82 +321,67 @@
   end;
 
 
-(*package error messages from type checking*)
-fun exn_type_msg sg (msg, Ts, ts) =
-  let
-    val show_typ = string_of_typ sg;
-    val show_term = set_ap Syntax.show_brackets true (string_of_term sg);
-
-    fun term_err [] = ""
-      | term_err [t] = "\n\nInvolving this term:\n" ^ show_term t
-      | term_err ts =
-          "\n\nInvolving these terms:\n" ^ cat_lines (map show_term ts);
-  in
-    "\nType checking error: " ^ msg ^ "\n" ^
-      cat_lines (map show_typ Ts) ^ term_err ts ^ "\n"
-  end;
-
-
 
 (** infer_types **)         (*exception ERROR*)
 
-(*ts is the list of alternative parses; only one is hoped to be type-correct.
-  T is the expected type for the correct term.
-  Other standard arguments:
-    types is a partial map from indexnames to types (constrains Free, Var).
-    sorts is a partial map from indexnames to sorts (constrains TFree, TVar).
-    used is the list of already used type variables.
-    If freeze then internal TVars are turned into TFrees, else TVars.*)
-fun infer_types sg types sorts used freeze (ts, T) =
+(*
+  ts: list of alternative parses (hopefully only one is type-correct)
+  T: expected type
+
+  def_type: partial map from indexnames to types (constrains Frees, Vars)
+  def_sort: partial map from indexnames to sorts (constrains TFrees, TVars)
+  used: list of already used type variables
+  freeze: if true then generated parameters are turned into TFrees, else TVars
+*)
+
+fun infer_types sg def_type def_sort used freeze (ts, T) =
   let
     val Sg {tsig, ...} = sg;
+    val prt = setmp Syntax.show_brackets true (pretty_term sg);
+    val prT = pretty_typ sg;
+    val infer = Type.infer_types prt prT tsig (const_type sg)
+      def_type def_sort used freeze;
 
-    val T' = certify_typ sg T handle TYPE arg => error (exn_type_msg sg arg);
-
-    val ct = const_type sg;
+    val T' = certify_typ sg T handle TYPE (msg, _, _) => error msg;
 
-    fun warn() =
-      if length ts > 1 andalso length ts <= !Syntax.ambiguity_level
+    fun warn () =
+      if length ts > 1 andalso length ts <= ! Syntax.ambiguity_level
       then (*no warning shown yet*)
-           warning "Currently parsed input \
-                   \produces more than one parse tree.\n\
-                   \For more information lower Syntax.ambiguity_level."
+        warning "Currently parsed input \
+          \produces more than one parse tree.\n\
+          \For more information lower Syntax.ambiguity_level."
       else ();
 
-    datatype result = One of int * term * (indexname * typ) list
-                    | Errs of (string * typ list * term list)list
-                    | Ambigs of term list;
-
-    fun process_term(res,(t,i)) =
-       let val ([u],tye) = 
-	       Type.infer_types(tsig,ct,types,sorts,used,freeze,[T'],[t])
-       in case res of
-            One(_,t0,_) => Ambigs([u,t0])
-          | Errs _ => One(i,u,tye)
-          | Ambigs(us) => Ambigs(u::us)
-       end
-       handle TYPE arg => (case res of Errs(errs) => Errs(arg::errs)
-                                     | _ => res);
+    datatype result =
+      One of int * term * (indexname * typ) list |
+      Errs of string list |
+      Ambigs of term list;
 
-  in case foldl process_term (Errs[], ts ~~ (0 upto (length ts - 1))) of
-       One(res) =>
-         (if length ts > !Syntax.ambiguity_level
-          then writeln "Fortunately, only one parse tree is type correct.\n\
+    fun process_term (res, (t, i)) =
+      let val ([u], tye) = infer [T'] [t] in
+        (case res of
+          One (_, t0, _) => Ambigs ([u, t0])
+        | Errs _ => One (i, u, tye)
+        | Ambigs us => Ambigs (u :: us))
+      end handle TYPE (msg, _, _) =>
+        (case res of
+          Errs errs => Errs (msg :: errs)
+        | _ => res);
+  in
+    (case foldl process_term (Errs [], ts ~~ (0 upto (length ts - 1))) of
+      One res =>
+       (if length ts > ! Syntax.ambiguity_level then
+          writeln "Fortunately, only one parse tree is type correct.\n\
             \It helps (speed!) if you disambiguate your grammar or your input."
-          else ();
-          res)
-     | Errs(errs) => (warn(); error(cat_lines(map (exn_type_msg sg) errs)))
-     | Ambigs(us) =>
-         (warn();
-          let val old_show_brackets = !show_brackets
-              val dummy = show_brackets := true;
-              val errs = cat_lines(map (string_of_term sg) us)
-          in show_brackets := old_show_brackets;
-             error("Error: More than one term is type correct:\n" ^ errs)
-          end)
+        else (); res)
+    | Errs errs => (warn (); error (cat_lines errs))
+    | Ambigs us =>
+        (warn (); error ("Error: More than one term is type correct:\n" ^
+          (cat_lines (map (Pretty.string_of o prt) us)))))
   end;
 
 
+
 (** extend signature **)    (*exception ERROR*)
 
 (** signature extension functions **)  (*exception ERROR*)
@@ -461,7 +451,7 @@
 
 fun ext_cnsts rd_const syn_only prmode (syn, tsig, ctab) raw_consts =
   let
-    fun prep_const (c, ty, mx) = 
+    fun prep_const (c, ty, mx) =
      (c, compress_type (Type.varifyT (Type.cert_typ tsig (Type.no_tvars ty))), mx)
        handle TYPE (msg, _, _)
          => (writeln msg; err_in_const (Syntax.const_name c mx));
--- a/src/Pure/theory.ML	Thu Apr 17 18:45:43 1997 +0200
+++ b/src/Pure/theory.ML	Thu Apr 17 18:46:58 1997 +0200
@@ -169,7 +169,7 @@
 fun cert_axm sg (name, raw_tm) =
   let
     val (t, T, _) = Sign.certify_term sg raw_tm
-      handle TYPE arg => error (Sign.exn_type_msg sg arg)
+      handle TYPE (msg, _, _) => error msg
 	   | TERM (msg, _) => error msg;
   in
     assert (T = propT) "Term not of type prop";
--- a/src/Pure/thm.ML	Thu Apr 17 18:45:43 1997 +0200
+++ b/src/Pure/thm.ML	Thu Apr 17 18:46:58 1997 +0200
@@ -260,7 +260,7 @@
     val (_, t', tye) =
           Sign.infer_types sign types sorts used freeze (ts, T');
     val ct = cterm_of sign t'
-      handle TYPE arg => error (Sign.exn_type_msg sign arg)
+      handle TYPE (msg, _, _) => error msg
            | TERM (msg, _) => error msg;
   in (ct, tye) end;
 
@@ -271,18 +271,20 @@
   not practical.*)
 fun read_cterms sign (bs, Ts) =
   let
-    val {tsig, syn, ...} = Sign.rep_sg sign
-    fun read (b,T) =
-        case Syntax.read syn T b of
-            [t] => t
-          | _   => error("Error or ambiguity in parsing of " ^ b)
-    val (us,_) = Type.infer_types(tsig, Sign.const_type sign, 
-                                  K None, K None, 
-                                  [], true, 
-                                  map (Sign.certify_typ sign) Ts, 
-                                  ListPair.map read (bs,Ts))
-  in  map (cterm_of sign) us  end
-  handle TYPE arg => error (Sign.exn_type_msg sign arg)
+    val {tsig, syn, ...} = Sign.rep_sg sign;
+    fun read (b, T) =
+      (case Syntax.read syn T b of
+        [t] => t
+      | _  => error ("Error or ambiguity in parsing of " ^ b));
+
+    val prt = setmp Syntax.show_brackets true (Sign.pretty_term sign);
+    val prT = Sign.pretty_typ sign;
+    val (us, _) =
+      Type.infer_types prt prT tsig (Sign.const_type sign)
+        (K None) (K None) [] true (map (Sign.certify_typ sign) Ts)
+        (ListPair.map read (bs, Ts));
+  in map (cterm_of sign) us end
+  handle TYPE (msg, _, _) => error msg
        | TERM (msg, _) => error msg;
 
 
--- a/src/Pure/type.ML	Thu Apr 17 18:45:43 1997 +0200
+++ b/src/Pure/type.ML	Thu Apr 17 18:46:58 1997 +0200
@@ -65,8 +65,9 @@
   val get_sort: type_sig -> (indexname -> sort option) -> (indexname * sort) list
     -> indexname -> sort
   val constrain: term -> typ -> term
-  val infer_types: type_sig * (string -> typ option) * (indexname -> typ option)
-    * (indexname -> sort option) * string list * bool * typ list * term list
+  val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
+    -> type_sig -> (string -> typ option) -> (indexname -> typ option)
+    -> (indexname -> sort option) -> string list -> bool -> typ list -> term list
     -> term list * (indexname * typ) list
 end;
 
@@ -871,14 +872,14 @@
     "?" :: _ => true
   | _ => false);
 
-fun infer_types (tsig, const_type, def_type, def_sort, used, freeze, pat_Ts, raw_ts) =
+fun infer_types prt prT tsig const_type def_type def_sort used freeze pat_Ts raw_ts =
   let
     val TySg {classrel, arities, ...} = tsig;
     val pat_Ts' = map (cert_typ tsig) pat_Ts;
     val raw_ts' =
       map (decode_types tsig (is_some o const_type) def_type def_sort) raw_ts;
     val (ts, Ts, unifier) =
-      TypeInfer.infer_types const_type classrel arities used freeze
+      TypeInfer.infer_types prt prT const_type classrel arities used freeze
         q_is_param raw_ts' pat_Ts';
   in
     (ts, unifier)
--- a/src/Pure/type_infer.ML	Thu Apr 17 18:45:43 1997 +0200
+++ b/src/Pure/type_infer.ML	Thu Apr 17 18:46:58 1997 +0200
@@ -7,7 +7,8 @@
 
 signature TYPE_INFER =
 sig
-  val infer_types: (string -> typ option) -> Sorts.classrel -> Sorts.arities
+  val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
+    -> (string -> typ option) -> Sorts.classrel -> Sorts.arities
     -> string list -> bool -> (indexname -> bool) -> term list -> typ list
     -> term list * typ list * (indexname * typ) list
 end;
@@ -257,7 +258,7 @@
 
     fun not_in_sort x S' S =
       "Type variable " ^ x ^ "::" ^ Sorts.str_of_sort S' ^ " not in sort " ^
-        Sorts.str_of_sort S;
+        Sorts.str_of_sort S ^ ".";
 
     fun meet _ [] = ()
       | meet (Link (r as (ref (Param S')))) S =
@@ -298,9 +299,10 @@
       | unif (Link (ref T)) U = unif T U
       | unif T (Link (ref U)) = unif T U
       | unif (PType (a, Ts)) (PType (b, Us)) =
-          if a <> b then raise NO_UNIFIER ("Clash of " ^ a ^ ", " ^ b ^ "!")
+          if a <> b then
+            raise NO_UNIFIER ("Clash of types " ^ quote a ^ " and " ^ quote b ^ ".")
           else seq2 unif Ts Us
-      | unif T U = if T = U then () else raise NO_UNIFIER "Unification failed!";
+      | unif T U = if T = U then () else raise NO_UNIFIER "";
 
   in unif end;
 
@@ -310,26 +312,63 @@
 
 (* infer *)                                     (*DESTRUCTIVE*)
 
-fun infer classrel arities =
+fun infer prt prT classrel arities =
   let
-    val unif = unify classrel arities;
+    (* errors *)
 
-    fun err msg1 msg2 bs ts Ts =
+    fun unif_failed msg =
+      "Type unification failed" ^ (if msg = "" then "." else ": " ^ msg) ^ "\n";
+
+    val str_of = Pretty.string_of;
+
+    fun prep_output bs ts Ts =
       let
         val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts);
         val len = length Ts;
         val Ts' = take (len, Ts_bTs');
         val xs = map Free (map fst bs ~~ drop (len, Ts_bTs'));
         val ts'' = map (fn t => subst_bounds (xs, t)) ts';
-      in
-        raise_type (msg1 ^ " " ^ msg2) Ts' ts''
-      end;
+      in (ts'', Ts') end;
+
+    fun err_loose i =
+      raise_type ("Loose bound variable: B." ^ string_of_int i) [] [];
+
+    fun err_appl msg bs t T U_to_V u U =
+      let
+        val ([t', u'], [T', U_to_V', U']) = prep_output bs [t, u] [T, U_to_V, U];
+        val text = cat_lines
+         [unif_failed msg,
+          "Type error in application:",
+          "",
+          str_of (Pretty.block [Pretty.str "operator:     ", Pretty.brk 1, prt t',
+            Pretty.str " :: ", prT T']),
+          str_of (Pretty.block [Pretty.str "expected type:", Pretty.brk 1, prT U_to_V']),
+          "",
+          str_of (Pretty.block [Pretty.str "operand:      ", Pretty.brk 1, prt u',
+            Pretty.str " :: ", prT U']), ""];
+      in raise_type text [T', U_to_V', U'] [t', u'] end;
+
+    fun err_constraint msg bs t T U =
+      let
+        val ([t'], [T', U']) = prep_output bs [t] [T, U];
+        val text = cat_lines
+         [unif_failed msg,
+          "Cannot meet type constraint:",
+          "",
+          str_of (Pretty.block [Pretty.str "term:          ", Pretty.brk 1, prt t',
+            Pretty.str " :: ", prT T']),
+          str_of (Pretty.block [Pretty.str "expected type: ", Pretty.brk 1, prT U']), ""];
+      in raise_type text [T', U'] [t'] end;
+
+
+    (* main *)
+
+    val unif = unify classrel arities;
 
     fun inf _ (PConst (_, T)) = T
       | inf _ (PFree (_, T)) = T
       | inf _ (PVar (_, T)) = T
-      | inf bs (PBound i) = snd (nth_elem (i, bs)
-          handle LIST _ => raise_type "Loose bound variable" [] [Bound i])
+      | inf bs (PBound i) = snd (nth_elem (i, bs) handle LIST _ => err_loose i)
       | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t])
       | inf bs (PAppl (t, u)) =
           let
@@ -338,12 +377,11 @@
             val V = mk_param [];
             val U_to_V = PType ("fun", [U, V]);
             val _ = unif U_to_V T handle NO_UNIFIER msg =>
-              err msg "Bad function application." bs [PAppl (t, u)] [U_to_V, U];
+              err_appl msg bs t T U_to_V u U;
           in V end
       | inf bs (Constraint (t, U)) =
           let val T = inf bs t in
-            unif T U handle NO_UNIFIER msg =>
-              err msg "Cannot meet type constraint." bs [t] [T, U];
+            unif T U handle NO_UNIFIER msg => err_constraint msg bs t T U;
             T
           end;
 
@@ -352,7 +390,7 @@
 
 (* infer_types *)
 
-fun infer_types const_type classrel arities used freeze is_param ts Ts =
+fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts =
   let
     (*convert to preterms/typs*)
     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
@@ -360,7 +398,7 @@
 
     (*run type inference*)
     val tTs' = ListPair.map Constraint (ts', Ts');
-    val _ = seq (fn t => (infer classrel arities t; ())) tTs';
+    val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs';
 
     (*collect result unifier*)
     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None)