src/Tools/subtyping.ML
changeset 51319 4a92178011e7
parent 51248 029de23bb5e8
child 51327 62c033d7f3d8
--- a/src/Tools/subtyping.ML	Thu Feb 28 21:11:07 2013 +0100
+++ b/src/Tools/subtyping.ML	Fri Mar 01 22:15:31 2013 +0100
@@ -20,6 +20,7 @@
 (** coercions data **)
 
 datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
+datatype coerce_arg = PERMIT | FORBID
 
 datatype data = Data of
   {coes: (term * ((typ list * typ list) * term list)) Symreltab.table,  (*coercions table*)
@@ -27,10 +28,12 @@
    full_graph: int Graph.T,
    (*coercions graph restricted to base types - for efficiency reasons strored in the context*)
    coes_graph: int Graph.T,
-   tmaps: (term * variance list) Symtab.table};  (*map functions*)
+   tmaps: (term * variance list) Symtab.table,  (*map functions*)
+   coerce_args: coerce_arg option list Symtab.table  (*special constants with non-coercible arguments*)};
 
-fun make_data (coes, full_graph, coes_graph, tmaps) =
-  Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph, tmaps = tmaps};
+fun make_data (coes, full_graph, coes_graph, tmaps, coerce_args) =
+  Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph,
+    tmaps = tmaps, coerce_args = coerce_args};
 
 fun merge_error_coes (a, b) =
   error ("Cannot merge coercion tables.\nConflicting declarations for coercions from " ^
@@ -40,49 +43,62 @@
   error ("Cannot merge coercion map tables.\nConflicting declarations for the constructor " ^
     quote C ^ ".");
 
+fun merge_error_coerce_args C =
+  error ("Cannot merge tables for constants with coercion-invariant arguments.\n"
+    ^ "Conflicting declarations for the constant " ^ quote C ^ ".");
+
 structure Data = Generic_Data
 (
   type T = data;
-  val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty);
+  val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty, Symtab.empty);
   val extend = I;
   fun merge
-    (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1, tmaps = tmaps1},
-      Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2, tmaps = tmaps2}) =
+    (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1,
+      tmaps = tmaps1, coerce_args = coerce_args1},
+      Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2,
+        tmaps = tmaps2, coerce_args = coerce_args2}) =
     make_data (Symreltab.merge (eq_pair (op aconv)
         (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv))))
         (coes1, coes2) handle Symreltab.DUP key => merge_error_coes key,
       Graph.merge (op =) (full_graph1, full_graph2),
       Graph.merge (op =) (coes_graph1, coes_graph2),
       Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2)
-        handle Symtab.DUP key => merge_error_tmaps key);
+        handle Symtab.DUP key => merge_error_tmaps key,
+      Symtab.merge (eq_list (op =)) (coerce_args1, coerce_args2)
+        handle Symtab.DUP key => merge_error_coerce_args key);
 );
 
 fun map_data f =
-  Data.map (fn Data {coes, full_graph, coes_graph, tmaps} =>
-    make_data (f (coes, full_graph, coes_graph, tmaps)));
+  Data.map (fn Data {coes, full_graph, coes_graph, tmaps, coerce_args} =>
+    make_data (f (coes, full_graph, coes_graph, tmaps, coerce_args)));
 
 fun map_coes f =
-  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
-    (f coes, full_graph, coes_graph, tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
+    (f coes, full_graph, coes_graph, tmaps, coerce_args));
 
 fun map_coes_graph f =
-  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
-    (coes, full_graph, f coes_graph, tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
+    (coes, full_graph, f coes_graph, tmaps, coerce_args));
 
 fun map_coes_and_graphs f =
-  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
     let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph);
-    in (coes', full_graph', coes_graph', tmaps) end);
+    in (coes', full_graph', coes_graph', tmaps, coerce_args) end);
 
 fun map_tmaps f =
-  map_data (fn (coes, full_graph, coes_graph, tmaps) =>
-    (coes, full_graph, coes_graph, f tmaps));
+  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
+    (coes, full_graph, coes_graph, f tmaps, coerce_args));
+
+fun map_coerce_args f =
+  map_data (fn (coes, full_graph, coes_graph, tmaps, coerce_args) =>
+    (coes, full_graph, coes_graph, tmaps, f coerce_args));
 
 val rep_data = (fn Data args => args) o Data.get o Context.Proof;
 
 val coes_of = #coes o rep_data;
 val coes_graph_of = #coes_graph o rep_data;
 val tmaps_of = #tmaps o rep_data;
+val coerce_args_of = #coerce_args o rep_data;
 
 
 
@@ -277,29 +293,48 @@
 
 (** constraint generation **)
 
+fun update_coerce_arg ctxt old t =
+  let
+    val mk_coerce_args = the_default [] o Symtab.lookup (coerce_args_of ctxt);
+    fun update _ [] = old
+      | update 0 (coerce :: _) =
+        (case coerce of NONE => old | SOME PERMIT => true | SOME FORBID => false)
+      | update n (_ :: cs) = update (n - 1) cs;
+    val (f, n) = Term.strip_comb (Type.strip_constraints t) ||> length;
+  in
+    update n (case f of Const (name, _) => mk_coerce_args name | _ => [])
+  end;
+
 fun generate_constraints ctxt err =
   let
-    fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
-      | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
-      | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
-      | gen cs bs (Bound i) tye_idx =
+    fun gen _ cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
+      | gen _ cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
+      | gen _ cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
+      | gen _ cs bs (Bound i) tye_idx =
           (snd (nth bs i handle General.Subscript => err_loose i), tye_idx, cs)
-      | gen cs bs (Abs (x, T, t)) tye_idx =
-          let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
+      | gen coerce cs bs (Abs (x, T, t)) tye_idx =
+          let val (U, tye_idx', cs') = gen coerce cs ((x, T) :: bs) t tye_idx
           in (T --> U, tye_idx', cs') end
-      | gen cs bs (t $ u) tye_idx =
+      | gen coerce cs bs (t $ u) tye_idx =
           let
-            val (T, tye_idx', cs') = gen cs bs t tye_idx;
-            val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
+            val (T, tye_idx', cs') = gen coerce cs bs t tye_idx;
+            val coerce' = update_coerce_arg ctxt coerce t;
+            val (U', (tye, idx), cs'') = gen coerce' cs' bs u tye_idx';
             val U = Type_Infer.mk_param idx [];
             val V = Type_Infer.mk_param (idx + 1) [];
             val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2)
               handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
             val error_pack = (bs, t $ u, U, V, U');
-          in (V, (tye_idx''),
-            ((U', U), error_pack) :: cs'') end;
+          in 
+            if coerce'
+            then (V, tye_idx'', ((U', U), error_pack) :: cs'')
+            else (V,
+              strong_unify ctxt (U, U') tye_idx''
+                handle NO_UNIFIER (msg, _) => error (gen_msg err msg),
+              cs'')
+          end;
   in
-    gen [] []
+    gen true [] []
   end;
 
 
@@ -741,18 +776,19 @@
   let
     val (idx, ts) = Type_Infer_Context.prepare ctxt raw_ts;
 
-    fun inf _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
-      | inf _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
-      | inf _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
-      | inf bs (t as (Bound i)) tye_idx =
+    fun inf _ _ (t as (Const (_, T))) tye_idx = (t, T, tye_idx)
+      | inf _ _ (t as (Free (_, T))) tye_idx = (t, T, tye_idx)
+      | inf _ _ (t as (Var (_, T))) tye_idx = (t, T, tye_idx)
+      | inf _ bs (t as (Bound i)) tye_idx =
           (t, snd (nth bs i handle General.Subscript => err_loose i), tye_idx)
-      | inf bs (Abs (x, T, t)) tye_idx =
-          let val (t', U, tye_idx') = inf ((x, T) :: bs) t tye_idx
+      | inf coerce bs (Abs (x, T, t)) tye_idx =
+          let val (t', U, tye_idx') = inf coerce ((x, T) :: bs) t tye_idx
           in (Abs (x, T, t'), T --> U, tye_idx') end
-      | inf bs (t $ u) tye_idx =
+      | inf coerce bs (t $ u) tye_idx =
           let
-            val (t', T, tye_idx') = inf bs t tye_idx;
-            val (u', U, (tye, idx)) = inf bs u tye_idx';
+            val (t', T, tye_idx') = inf coerce bs t tye_idx;
+            val coerce' = update_coerce_arg ctxt coerce t;
+            val (u', U, (tye, idx)) = inf coerce' bs u tye_idx';
             val V = Type_Infer.mk_param idx [];
             val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
               handle NO_UNIFIER (msg, tye') =>
@@ -766,19 +802,23 @@
                           val err' =
                             err ++> "\nLocal coercion insertion on the operator failed:\n";
                           val co = function_of ctxt err' tye T;
-                          val (t'', T'', tye_idx'') = inf bs (co $ t') (tye, idx + 2);
+                          val (t'', T'', tye_idx'') = inf coerce bs (co $ t') (tye, idx + 2);
                         in
                           (t'', strong_unify ctxt (W --> V, T'') tye_idx''
                              handle NO_UNIFIER (msg, _) => error (err' () ^ msg))
                         end;
-                  val err' = err ++> (if t' aconv t'' then ""
-                    else "\nSuccessfully coerced the operand to a function of type:\n" ^
+                  val err' = err ++>
+                    (if t' aconv t'' then ""
+                    else "\nSuccessfully coerced the operator to a function of type:\n" ^
                       Syntax.string_of_typ ctxt
                         (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^
-                      "\nLocal coercion insertion on the operand failed:\n";
-                  val co = gen_coercion ctxt err' tye' (U, W);
+                    (if coerce' then "\nLocal coercion insertion on the operand failed:\n"
+                    else "\nLocal coercion insertion on the operand disallowed:\n");
                   val (u'', U', tye_idx') =
-                    inf bs (if is_identity co then u else co $ u) (tye', idx');
+                    if coerce' then 
+                      let val co = gen_coercion ctxt err' tye' (U, W);
+                      in inf coerce' bs (if is_identity co then u else co $ u) (tye', idx') end
+                    else (u, U, (tye', idx'));
                 in
                   (t'' $ u'', strong_unify ctxt (U', W) tye_idx'
                     handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg))
@@ -786,7 +826,7 @@
           in (tu, V, tye_idx'') end;
 
     fun infer_single t tye_idx =
-      let val (t, _, tye_idx') = inf [] t tye_idx
+      let val (t, _, tye_idx') = inf true [] t tye_idx
       in (t, tye_idx') end;
 
     val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
@@ -1011,8 +1051,12 @@
     |> Pretty.writeln
   end;
 
+(* theory setup *)
 
-(* theory setup *)
+val parse_coerce_args =
+  Args.$$$ "+" >> K (SOME PERMIT) ||
+  Args.$$$ "-" >> K (SOME FORBID) ||
+  Args.$$$ "0" >> K NONE
 
 val setup =
   Context.theory_map add_term_check #>
@@ -1024,7 +1068,11 @@
     "deletion of coercions" #>
   Attrib.setup @{binding coercion_map}
     (Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t))))
-    "declaration of new map functions";
+    "declaration of new map functions" #>
+  Attrib.setup @{binding coercion_args}
+    (Args.const false -- Scan.lift (Scan.repeat1 parse_coerce_args) >>
+      (fn spec => Thm.declaration_attribute (K (map_coerce_args (Symtab.update spec)))))
+    "declaration of new constants with coercion-invariant arguments";
 
 
 (* outer syntax commands *)