two-staged architecture for subtyping;
authortraytel
Mon, 29 Nov 2010 16:53:08 +0100
changeset 40836 a81d66d72e70
parent 40835 fc750e794458
child 40837 dedb893dc692
two-staged architecture for subtyping; improved error messages of subtyping (using the new architecture); bugfix: constraint graph consistency check after cycle elimination;
src/Tools/subtyping.ML
--- a/src/Tools/subtyping.ML	Tue Nov 30 20:02:01 2010 -0800
+++ b/src/Tools/subtyping.ML	Mon Nov 29 16:53:08 2010 +0100
@@ -11,6 +11,7 @@
     term list -> term list
   val add_type_map: term -> Context.generic -> Context.generic
   val add_coercion: term -> Context.generic -> Context.generic
+  val gen_coercion: Proof.context -> typ Vartab.table -> (typ * typ) -> term
   val setup: theory -> theory
 end;
 
@@ -86,8 +87,9 @@
 val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
 
 
-(* unification *)  (* TODO dup? needed for weak unification *)
+(* unification *)
 
+exception TYPE_INFERENCE_ERROR of unit -> string;
 exception NO_UNIFIER of string * typ Vartab.table;
 
 fun unify weak ctxt =
@@ -185,6 +187,10 @@
 
 (** error messages **)
 
+fun gen_msg err msg = 
+  err () ^ "\nNow trying to infer coercions:\n\nCoercion inference failed" ^ 
+  (if msg = "" then "" else ": " ^ msg) ^ "\n";
+
 fun prep_output ctxt tye bs ts Ts =
   let
     val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
@@ -195,23 +201,23 @@
   in (map prep ts', Ts') end;
 
 fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
-
-fun inf_failed msg =
-  "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
+  
+fun unif_failed msg =
+  "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
 
-fun err_appl ctxt msg tye bs t T u U =
+fun subtyping_err_appl_msg ctxt msg tye bs t T u U () =
   let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
-  in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end;
-
-fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') =
-  err_appl ctxt msg tye bs t (U --> V) u U';
+  in msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n" end;
+  
+fun err_appl_msg ctxt msg tye bs t T u U () =
+  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
+  in unif_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n" end;
 
 fun err_list ctxt msg tye Ts =
   let
     val (_, Ts') = prep_output ctxt tye [] [] Ts;
-    val text = cat_lines ([inf_failed msg,
-      "Cannot unify a list of types that should be the same,",
-      "according to suptype dependencies:",
+    val text = cat_lines ([msg,
+      "Cannot unify a list of types that should be the same:",
       (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]);
   in
     error text
@@ -222,15 +228,15 @@
     val pp = Syntax.pp ctxt;
     val (ts, Ts) = fold
       (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
-        let val (t', T') = prep_output ctxt tye bs [t, u] [U, U']
+        let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
         in (t' :: ts, T' :: Ts) end)
       packs ([], []);
-    val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @
+    val text = cat_lines ([msg, "Cannot fulfil subtype constraints:"] @
         (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
           Pretty.block [
             Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U,
             Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2,
-            Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]]))
+            Pretty.block [Pretty.term pp (t $ u)]]))
         ts Ts))
   in
     error text
@@ -240,7 +246,7 @@
 
 (** constraint generation **)
 
-fun generate_constraints ctxt =
+fun generate_constraints ctxt err =
   let
     fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
       | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
@@ -257,7 +263,7 @@
             val U = Type_Infer.mk_param idx [];
             val V = Type_Infer.mk_param (idx + 1) [];
             val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
-              handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
+              handle NO_UNIFIER (msg, tye') => error (gen_msg err msg);
             val error_pack = (bs, t $ u, U, V, U');
           in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
   in
@@ -270,7 +276,7 @@
 
 exception BOUND_ERROR of string;
 
-fun process_constraints ctxt cs tye_idx =
+fun process_constraints ctxt err cs tye_idx =
   let
     val coes_graph = coes_graph_of ctxt;
     val tmaps = tmaps_of ctxt;
@@ -289,9 +295,8 @@
     (* check whether constraint simplification will terminate using weak unification *)
 
     val _ = fold (fn (TU, error_pack) => fn tye_idx =>
-      (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
-        err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg)
-          tye error_pack)) cs tye_idx;
+      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
+        error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
 
 
     (* simplify constraints *)
@@ -310,7 +315,8 @@
                 COVARIANT => (constraint :: cs, tye_idx)
               | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
               | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
-                  handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
+                  handle NO_UNIFIER (msg, tye) => 
+                    error (gen_msg err ("failed to unify invariant arguments\n" ^ msg))));
             val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
               (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
             val test_update = is_compT orf is_freeT orf is_fixedvarT;
@@ -348,7 +354,7 @@
           in
             if subsort (S', S) (*TODO check this*)
             then simplify done' todo' (tye', idx)
-            else err_subtype ctxt "Sort mismatch" tye error_pack
+            else error (gen_msg err "sort mismatch")
           end
         and simplify done [] tye_idx = (done, tye_idx)
           | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
@@ -356,9 +362,10 @@
                 (Type (a, []), Type (b, [])) =>
                   if a = b then simplify done todo tye_idx
                   else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
-                  else err_subtype ctxt (a ^ " is not a subtype of " ^ b) (fst tye_idx) error_pack
+                  else error (gen_msg err (a ^ " is not a subtype of " ^ b))
               | (Type (a, Ts), Type (b, Us)) =>
-                  if a <> b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
+                  if a <> b then error (gen_msg err "different constructors")
+                    (fst tye_idx) error_pack
                   else contract a Ts Us error_pack done todo tye idx
               | (TVar (xi, S), Type (a, Ts as (_ :: _))) =>
                   expand true xi S a Ts error_pack done todo tye idx
@@ -370,8 +377,7 @@
                     exists Type_Infer.is_paramT [T, U]
                   then eliminate [T, U] error_pack done todo tye idx
                   else if exists (is_freeT orf is_fixedvarT) [T, U]
-                  then err_subtype ctxt "Not eliminated free/fixed variables"
-                        (fst tye_idx) error_pack
+                  then error (gen_msg err "not eliminated free/fixed variables")
                   else simplify (((T, U), error_pack) :: done) todo tye_idx);
       in
         simplify [] cs tye_idx
@@ -381,14 +387,22 @@
     (* do simplification *)
 
     val (cs', tye_idx') = simplify_constraints cs tye_idx;
-
-    fun find_error_pack lower T' =
-      map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs');
+    
+    fun find_error_pack lower T' = map_filter 
+      (fn ((T, U), pack) => if if lower then T' = U else T' = T then SOME pack else NONE) cs';
+      
+    fun find_cycle_packs nodes = 
+      let
+        val (but_last, last) = split_last nodes
+        val pairs = (last, hd nodes) :: (but_last ~~ tl nodes);
+      in
+        map_filter
+          (fn (TU, pack) => if member (eq_pair (op =) (op =)) pairs TU then SOME pack else NONE) 
+          cs'
+      end;
 
     fun unify_list (T :: Ts) tye_idx =
-      fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx
-        handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T :: Ts))
-      Ts tye_idx;
+      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
 
     (*styps stands either for supertypes or for subtypes of a type T
       in terms of the subtype-relation (excluding T itself)*)
@@ -403,7 +417,7 @@
           | extract T (U :: Us) =
               if Graph.is_edge coes_graph (adjust T U) then extract T Us
               else if Graph.is_edge coes_graph (adjust U T) then extract U Us
-              else raise BOUND_ERROR "Uncomparable types in type list";
+              else raise BOUND_ERROR "uncomparable types in type list";
       in
         t_of (extract T Ts)
       end;
@@ -435,7 +449,7 @@
         fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
       in
         (case fold candidates Ts (filter restriction (T :: styps sup T)) of
-          [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum"))
+          [] => raise BOUND_ERROR ("no " ^ (if sup then "supremum" else "infimum"))
         | [T] => t_of T
         | Ts => minmax sup Ts)
       end;
@@ -449,23 +463,45 @@
             val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
               handle Typ_Graph.CYCLES cycles =>
                 let
-                  val (tye, idx) = fold unify_list cycles tye_idx
+                  val (tye, idx) = 
+                    fold 
+                      (fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
+                        handle NO_UNIFIER (msg, tye) => 
+                          err_bound ctxt 
+                            (gen_msg err ("constraint cycle not unifiable" ^ msg)) (fst tye_idx)
+                            (find_cycle_packs cycle)))
+                      cycles tye_idx
                 in
-                  (*all cycles collapse to one node,
-                    because all of them share at least the nodes x and y*)
-                  collapse (tye, idx) (distinct (op =) (flat cycles)) G
-                end;
+                  collapse (tye, idx) cycles G
+                end
           in
             build_graph G'' cs tye_idx'
           end
-    and collapse (tye, idx) nodes G = (*nodes non-empty list*)
+    and collapse (tye, idx) cycles G = (*nodes non-empty list*)
       let
-        val T = hd nodes;
+        (*all cycles collapse to one node,
+          because all of them share at least the nodes x and y*)
+        val nodes = (distinct (op =) (flat cycles));
+        val T = Type_Infer.deref tye (hd nodes);
         val P = new_imm_preds G nodes;
         val S = new_imm_succs G nodes;
         val G' = Typ_Graph.del_nodes (tl nodes) G;
+        fun check_and_gen super T' =
+          let val U = Type_Infer.deref tye T';
+          in
+            if not (is_typeT T) orelse not (is_typeT U) orelse T = U
+            then if super then (hd nodes, T') else (T', hd nodes)
+            else 
+              if super andalso 
+                Graph.is_edge coes_graph (nameT T, nameT U) then (hd nodes, T')
+              else if not super andalso 
+                Graph.is_edge coes_graph (nameT U, nameT T) then (T', hd nodes)
+              else err_bound ctxt (gen_msg err "cycle elimination produces inconsistent graph")
+                    (fst tye_idx) 
+                    (maps find_cycle_packs cycles @ find_error_pack super T')
+          end;
       in
-        build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
+        build_graph G' (map (check_and_gen false) P @ map (check_and_gen true) S) (tye, idx)
       end;
 
     fun assign_bound lower G key (tye_idx as (tye, _)) =
@@ -488,7 +524,8 @@
           val assignment =
             if null bound orelse null not_params then NONE
             else SOME (tightest lower S styps_and_sorts (map nameT not_params)
-                handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
+                handle BOUND_ERROR msg => 
+                  err_bound ctxt (gen_msg err msg) tye (find_error_pack lower key))
         in
           (case assignment of
             NONE => tye_idx
@@ -501,9 +538,9 @@
                 in
                   if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
                   then apfst (Vartab.update (xi, T)) tye_idx
-                  else err_bound ctxt ("Assigned simple type " ^ s ^
+                  else err_bound ctxt (gen_msg err ("assigned simple type " ^ s ^
                     " clashes with the upper bound of variable " ^
-                    Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key)
+                    Syntax.string_of_typ ctxt (TVar(xi, S)))) tye (find_error_pack (not lower) key)
                 end
               else apfst (Vartab.update (xi, T)) tye_idx)
         end
@@ -519,7 +556,8 @@
           val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
             |> fold (assign_ub G) ts;
         in
-          assign_alternating ts (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
+          assign_alternating ts 
+            (filter (Type_Infer.is_paramT o Type_Infer.deref tye) ts) G tye_idx'
         end;
 
     (*Unify all weakly connected components of the constraint forest,
@@ -531,7 +569,10 @@
           filter (Type_Infer.is_paramT o Type_Infer.deref tye) (Typ_Graph.maximals G);
         val to_unify = map (fn T => T :: get_preds G T) max_params;
       in
-        fold unify_list to_unify tye_idx
+        fold 
+          (fn Ts => fn tye_idx' => unify_list Ts tye_idx'
+            handle NO_UNIFIER (msg, tye) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
+          to_unify tye_idx
       end;
 
     fun solve_constraints G tye_idx = tye_idx
@@ -546,77 +587,73 @@
 
 (** coercion insertion **)
 
+fun gen_coercion ctxt tye (T1, T2) =
+  (case pairself (Type_Infer.deref tye) (T1, T2) of
+    ((Type (a, [])), (Type (b, []))) =>
+        if a = b
+        then Abs (Name.uu, Type (a, []), Bound 0)
+        else
+          (case Symreltab.lookup (coes_of ctxt) (a, b) of
+            NONE => raise Fail (a ^ " is not a subtype of " ^ b)
+          | SOME co => co)
+  | ((Type (a, Ts)), (Type (b, Us))) =>
+        if a <> b
+        then raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
+        else
+          let
+            fun inst t Ts =
+              Term.subst_vars
+                (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
+            fun sub_co (COVARIANT, TU) = gen_coercion ctxt tye TU
+              | sub_co (CONTRAVARIANT, TU) = gen_coercion ctxt tye (swap TU);
+            fun ts_of [] = []
+              | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
+          in
+            (case Symtab.lookup (tmaps_of ctxt) a of
+              NONE => raise Fail ("No map function for " ^ a ^ " known")
+            | SOME tmap =>
+                let
+                  val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
+                in
+                  Term.list_comb
+                    (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
+                end)
+          end
+  | (T, U) =>
+        if Type.could_unify (T, U)
+        then Abs (Name.uu, T, Bound 0)
+        else raise Fail ("Cannot generate coercion from "
+          ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U));
+
 fun insert_coercions ctxt tye ts =
   let
-    fun deep_deref T =
-      (case Type_Infer.deref tye T of
-        Type (a, Ts) => Type (a, map deep_deref Ts)
-      | U => U);
-
-    fun gen_coercion ((Type (a, [])), (Type (b, []))) =
-          if a = b
-          then Abs (Name.uu, Type (a, []), Bound 0)
-          else
-            (case Symreltab.lookup (coes_of ctxt) (a, b) of
-              NONE => raise Fail (a ^ " is not a subtype of " ^ b)
-            | SOME co => co)
-      | gen_coercion ((Type (a, Ts)), (Type (b, Us))) =
-          if a <> b
-          then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
-          else
-            let
-              fun inst t Ts =
-                Term.subst_vars
-                  (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
-              fun sub_co (COVARIANT, TU) = gen_coercion TU
-                | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU);
-              fun ts_of [] = []
-                | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
-            in
-              (case Symtab.lookup (tmaps_of ctxt) a of
-                NONE => raise Fail ("No map function for " ^ a ^ " known")
-              | SOME tmap =>
-                  let
-                    val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
-                  in
-                    Term.list_comb
-                      (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
-                  end)
-            end
-      | gen_coercion (T, U) =
-          if Type.could_unify (T, U)
-          then Abs (Name.uu, T, Bound 0)
-          else raise Fail ("Cannot generate coercion from "
-            ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U);
-
     fun insert _ (Const (c, T)) =
-          let val T' = deep_deref T;
+          let val T' = T;
           in (Const (c, T'), T') end
       | insert _ (Free (x, T)) =
-          let val T' = deep_deref T;
+          let val T' = T;
           in (Free (x, T'), T') end
       | insert _ (Var (xi, T)) =
-          let val T' = deep_deref T;
+          let val T' = T;
           in (Var (xi, T'), T') end
       | insert bs (Bound i) =
-          let val T = nth bs i handle Subscript =>
-            raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
+          let val T = nth bs i handle Subscript => err_loose i;
           in (Bound i, T) end
       | insert bs (Abs (x, T, t)) =
           let
-            val T' = deep_deref T;
+            val T' = T;
             val (t', T'') = insert (T' :: bs) t;
           in
             (Abs (x, T', t'), T' --> T'')
           end
       | insert bs (t $ u) =
           let
-            val (t', Type ("fun", [U, T])) = insert bs t;
+            val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
             val (u', U') = insert bs u;
           in
-            if U <> U'
-            then (t' $ (gen_coercion (U', U) $ u'), T)
-            else (t' $ u', T)
+            if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
+            then (t' $ u', T)
+            else (t' $ (gen_coercion ctxt tye (U', U) $ u'), T)
           end
   in
     map (fst o insert []) ts
@@ -630,14 +667,40 @@
   let
     val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts;
 
-    fun gen_all t (tye_idx, constraints) =
-      let
-        val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx
-      in (tye_idx', constraints' @ constraints) end;
+    fun inf _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
+      | inf _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
+      | inf _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
+      | inf bs (t as (Bound i)) tye_idx =
+          (t, snd (nth bs i handle Subscript => err_loose i), tye_idx)
+      | inf bs (Abs (x, T, t)) tye_idx =
+          let val (t', U, tye_idx') = inf ((x, T) :: bs) t tye_idx
+          in (Abs (x, T, t'), T --> U, tye_idx') end
+      | inf bs (t $ u) tye_idx =
+          let
+            val (t', T, tye_idx') = inf bs t tye_idx;
+            val (u', U, (tye, idx)) = inf bs u tye_idx';
+            val V = Type_Infer.mk_param idx [];
+            val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
+              handle NO_UNIFIER (msg, tye') => 
+                raise TYPE_INFERENCE_ERROR (err_appl_msg ctxt msg tye' bs t T u U);
+          in (tu, V, tye_idx'') end;
 
-    val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []);
-    val (tye, _) = process_constraints ctxt constraints tye_idx;
-    val ts' = insert_coercions ctxt tye ts;
+    fun infer_single t (ts, tye_idx) = 
+      let val (t, _, tye_idx') = inf [] t tye_idx;
+      in (ts @ [t], tye_idx') end;
+      
+    val (ts', (tye, _)) = (fold infer_single ts ([], (Vartab.empty, idx))
+      handle TYPE_INFERENCE_ERROR err =>     
+        let
+          fun gen_single t (tye_idx, constraints) =
+            let val (_, tye_idx', constraints') = generate_constraints ctxt err t tye_idx
+            in (tye_idx', constraints' @ constraints) end;
+      
+          val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
+          val (tye, idx) = process_constraints ctxt err constraints tye_idx;
+        in 
+          (insert_coercions ctxt tye ts, (tye, idx))
+        end);
 
     val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
   in ts'' end;
@@ -738,7 +801,7 @@
         fun complex_coercion tab G (a, b) =
           let
             val path = hd (Graph.irreducible_paths G (a, b))
-            val path' = (fst (split_last path)) ~~ tl path
+            val path' = fst (split_last path) ~~ tl path
           in Abs (Name.uu, Type (a, []),
               fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
           end;