Enabled non fully polymorphic map functions in subtyping
authortraytel
Tue, 21 Dec 2010 01:12:14 +0100
changeset 41353 684003dbda54
parent 41352 87adb55fb0fb
child 41356 04ecd79827f2
Enabled non fully polymorphic map functions in subtyping
src/Tools/subtyping.ML
--- a/src/Tools/subtyping.ML	Tue Dec 21 11:54:35 2010 +0100
+++ b/src/Tools/subtyping.ML	Tue Dec 21 01:12:14 2010 +0100
@@ -6,7 +6,7 @@
 
 signature SUBTYPING =
 sig
-  datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
+  datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
   val coercion_enabled: bool Config.T
   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
     term list -> term list
@@ -21,7 +21,7 @@
 
 (** coercions data **)
 
-datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
+datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
 
 datatype data = Data of
   {coes: term Symreltab.table,  (*coercions table*)
@@ -83,9 +83,11 @@
   | sort_of _ = NONE;
 
 val is_typeT = fn (Type _) => true | _ => false;
+val is_stypeT = fn (Type (_, [])) => true | _ => false;
 val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
 val is_freeT = fn (TFree _) => true | _ => false;
 val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
+val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
 
 
 (* unification *)
@@ -205,10 +207,6 @@
   
 fun unif_failed msg =
   "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
-
-fun subtyping_err_appl_msg ctxt msg tye bs t T u U () =
-  let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
-  in msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n" end;
   
 fun err_appl_msg ctxt msg tye bs t T u U () =
   let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
@@ -264,7 +262,7 @@
             val U = Type_Infer.mk_param idx [];
             val V = Type_Infer.mk_param (idx + 1) [];
             val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
-              handle NO_UNIFIER (msg, tye') => error (gen_msg err msg);
+              handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
             val error_pack = (bs, t $ u, U, V, U');
           in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
   in
@@ -291,12 +289,15 @@
           (case pairself f (fst c) of
             (false, false) => apsnd (cons c) (split_cs f cs)
           | _ => apfst (cons c) (split_cs f cs));
+          
+    fun unify_list (T :: Ts) tye_idx =
+      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;      
 
 
     (* check whether constraint simplification will terminate using weak unification *)
 
-    val _ = fold (fn (TU, error_pack) => fn tye_idx =>
-      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
+    val _ = fold (fn (TU, _) => fn tye_idx =>
+      weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) =>
         error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
 
 
@@ -315,9 +316,14 @@
               (case variance of
                 COVARIANT => (constraint :: cs, tye_idx)
               | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
+              | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
+                  handle NO_UNIFIER (msg, _) => 
+                    err_list ctxt (gen_msg err 
+                      "failed to unify invariant arguments w.r.t. to the known map function") 
+                      (fst tye_idx) Ts)
               | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
-                  handle NO_UNIFIER (msg, tye) => 
-                    error (gen_msg err ("failed to unify invariant arguments\n" ^ msg))));
+                  handle NO_UNIFIER (msg, _) => 
+                    error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
             val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
               (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
             val test_update = is_compT orf is_freeT orf is_fixedvarT;
@@ -343,7 +349,7 @@
             simplify done' ((new, error_pack) :: todo') (tye', idx + n)
           end
         (*TU is a pair of a parameter and a free/fixed variable*)
-        and eliminate TU error_pack done todo tye idx =
+        and eliminate TU done todo tye idx =
           let
             val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
             val [T] = filter_out Type_Infer.is_paramT TU;
@@ -376,7 +382,7 @@
                   if T = U then simplify done todo tye_idx
                   else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
                     exists Type_Infer.is_paramT [T, U]
-                  then eliminate [T, U] error_pack done todo tye idx
+                  then eliminate [T, U] done todo tye idx
                   else if exists (is_freeT orf is_fixedvarT) [T, U]
                   then error (gen_msg err "not eliminated free/fixed variables")
                   else simplify (((T, U), error_pack) :: done) todo tye_idx);
@@ -402,9 +408,6 @@
           cs'
       end;
 
-    fun unify_list (T :: Ts) tye_idx =
-      fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
-
     (*styps stands either for supertypes or for subtypes of a type T
       in terms of the subtype-relation (excluding T itself)*)
     fun styps super T =
@@ -467,7 +470,7 @@
                   val (tye, idx) = 
                     fold 
                       (fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
-                        handle NO_UNIFIER (msg, tye) => 
+                        handle NO_UNIFIER (msg, _) => 
                           err_bound ctxt 
                             (gen_msg err ("constraint cycle not unifiable" ^ msg)) (fst tye_idx)
                             (find_cycle_packs cycle)))
@@ -572,7 +575,7 @@
       in
         fold 
           (fn Ts => fn tye_idx' => unify_list Ts tye_idx'
-            handle NO_UNIFIER (msg, tye) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
+            handle NO_UNIFIER (msg, _) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
           to_unify tye_idx
       end;
 
@@ -605,8 +608,9 @@
             fun inst t Ts =
               Term.subst_vars
                 (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
-            fun sub_co (COVARIANT, TU) = gen_coercion ctxt tye TU
-              | sub_co (CONTRAVARIANT, TU) = gen_coercion ctxt tye (swap TU);
+            fun sub_co (COVARIANT, TU) = SOME (gen_coercion ctxt tye TU)
+              | sub_co (CONTRAVARIANT, TU) = SOME (gen_coercion ctxt tye (swap TU))
+              | sub_co (INVARIANT_TO T, _) = NONE;
             fun ts_of [] = []
               | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
           in
@@ -614,7 +618,7 @@
               NONE => raise Fail ("No map function for " ^ a ^ " known")
             | SOME tmap =>
                 let
-                  val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
+                  val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
                 in
                   Term.list_comb
                     (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
@@ -735,36 +739,39 @@
     val ctxt = Context.proof_of context;
     val t = singleton (Variable.polymorphic ctxt) raw_t;
 
-    fun err_str () = "\n\nthe general type signature for a map function is" ^
-      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^
+    fun err_str t = "\n\nThe provided function has the type\n" ^ 
+      Syntax.string_of_typ ctxt (fastype_of t) ^ 
+      "\n\nThe general type signature of a map function is" ^
+      "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
       "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
-
+      
+    val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
+      handle Empty => error ("Not a proper map function:" ^ err_str t);
+    
     fun gen_arg_var ([], []) = []
       | gen_arg_var ((T, T') :: Ts, (U, U') :: Us) =
-          if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
+          if U = U' then
+            if is_stypeT U then INVARIANT_TO U :: gen_arg_var ((T, T') :: Ts, Us)
+            else error ("Invariant xi and yi should be base types:" ^ err_str t)
+          else if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
           else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
-          else error ("Functions do not apply to arguments correctly:" ^ err_str ())
-      | gen_arg_var (_, _) =
-          error ("Different numbers of functions and arguments\n" ^ err_str ());
+          else error ("Functions do not apply to arguments correctly:" ^ err_str t)
+      | gen_arg_var (_, Ts) = 
+          if forall (op = andf is_stypeT o fst) Ts 
+          then map (INVARIANT_TO o fst) Ts
+          else error ("Different numbers of functions and variant arguments\n" ^ err_str t);
 
-    (* TODO: This function is only needed to introde the fun type map
-      function: "% f g h . g o h o f". There must be a better solution. *)
-    fun balanced (Type (_, [])) (Type (_, [])) = true
-      | balanced (Type (a, Ts)) (Type (b, Us)) =
-          a = b andalso forall I (map2 balanced Ts Us)
-      | balanced (TFree _) (TFree _) = true
-      | balanced (TVar _) (TVar _) = true
-      | balanced _ _ = false;
+    (*retry flag needed to adjust the type lists, when given a map over type constructor fun*)
+    fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) retry =
+          if C1 = C2 andalso not (null fis) andalso forall is_funtype fis
+          then ((map dest_funT fis, Ts ~~ Us), C1)
+          else error ("Not a proper map function:" ^ err_str t)
+      | check_map_fun fis T1 T2 true =
+          let val (fis', T') = split_last fis
+          in check_map_fun fis' T' (T1 --> T2) false end
+      | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t);
 
-    fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) =
-          if balanced T U
-          then ((pairs, Ts ~~ Us), C)
-          else if C = "fun"
-            then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U
-            else error ("Not a proper map function:" ^ err_str ())
-      | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ());
-
-    val res = check_map_fun ([], []) (fastype_of t);
+    val res = check_map_fun fis T1 T2 true;
     val res_av = gen_arg_var (fst res);
   in
     map_tmaps (Symtab.update (snd res, (t, res_av))) context