tuned certify_typ;
authorwenzelm
Mon, 21 Jun 2004 16:40:55 +0200
changeset 14989 5a5d076a9863
parent 14988 973ced82812d
child 14990 582b655da757
tuned certify_typ;
src/Pure/type.ML
--- a/src/Pure/type.ML	Mon Jun 21 16:40:44 2004 +0200
+++ b/src/Pure/type.ML	Mon Jun 21 16:40:55 2004 +0200
@@ -11,7 +11,7 @@
   (*type signatures and certified types*)
   datatype decl =
     LogicalType of int |
-    Abbreviation of string list * typ |
+    Abbreviation of string list * typ * bool |
     Nonterminal
   type tsig
   val rep_tsig: tsig ->
@@ -32,9 +32,9 @@
   val cert_class: tsig -> class -> class
   val cert_sort: tsig -> sort -> sort
   val witness_sorts: tsig -> sort list -> sort list -> (typ * sort) list
-  val cert_typ: tsig -> typ -> typ
-  val cert_typ_syntax: tsig -> typ -> typ
-  val cert_typ_raw: tsig -> typ -> typ
+  val cert_typ: tsig -> typ -> typ * int
+  val cert_typ_syntax: tsig -> typ -> typ * int
+  val cert_typ_raw: tsig -> typ -> typ * int
 
   (*special treatment of type vars*)
   val strip_sorts: typ -> typ
@@ -75,7 +75,7 @@
 
 datatype decl =
   LogicalType of int |
-  Abbreviation of string list * typ |
+  Abbreviation of string list * typ * bool |
   Nonterminal;
 
 fun str_of_decl (LogicalType _) = "logical type constructor"
@@ -131,13 +131,9 @@
 fun eq_sort (TSig {classes, ...}) = Sorts.sort_eq classes;
 fun subsort (TSig {classes, ...}) = Sorts.sort_le classes;
 fun of_sort (TSig {classes, arities, ...}) = Sorts.of_sort (classes, arities);
-fun norm_sort (TSig {classes, ...}) = Sorts.norm_sort classes;
 
-fun cert_class (TSig {classes, ...}) c =
-  if can (Graph.get_node classes) c then c
-  else raise TYPE ("Undeclared class: " ^ quote c, [], []);
-
-fun cert_sort tsig = norm_sort tsig o map (cert_class tsig);
+fun cert_class (TSig {classes, ...}) = Sorts.certify_class classes;
+fun cert_sort (TSig {classes, ...}) = Sorts.certify_sort classes;
 
 fun witness_sorts (tsig as TSig {classes, arities, log_types, ...}) =
   Sorts.witness_sorts (classes, arities) log_types;
@@ -150,61 +146,46 @@
 
 local
 
-fun inst_typ tye =
-  let
-    fun inst (var as (v, _)) =
-      (case assoc_string_int (tye, v) of
-        Some U => inst_typ tye U
-      | None => TVar var);
-  in map_type_tvar inst end;
-
-fun norm_typ (tsig as TSig {types, ...}) ty =
-  let
-    val idx = Term.maxidx_of_typ ty + 1;
-
-    fun norm (Type (a, Ts)) =
-          (case Symtab.lookup (types, a) of
-            Some (Abbreviation (vs, U), _) =>
-              norm (inst_typ (map (rpair idx) vs ~~ Ts) (incr_tvar idx U))
-          | _ => Type (a, map norm Ts))
-      | norm (TFree (x, S)) = TFree (x, norm_sort tsig S)
-      | norm (TVar (xi, S)) = TVar (xi, norm_sort tsig S);
-
-    val ty' = norm ty;
-  in if ty = ty' then ty else ty' end;  (*avoid copying of already normal type*)
+fun inst_typ env = Term.map_type_tvar (fn (var as (v, _)) =>
+  (case Library.assoc_string_int (env, v) of
+    Some U => inst_typ env U
+  | None => TVar var));
 
 fun certify_typ normalize syntax tsig ty =
   let
-    val TSig {types, ...} = tsig;
+    val TSig {classes, types, ...} = tsig;
     fun err msg = raise TYPE (msg, [ty], []);
 
-    fun check_sort S = (map (cert_class tsig) S; ());
+    val maxidx = Term.maxidx_of_typ ty;
+    val idx = maxidx + 1;
 
-    fun check_typ (Type (c, Ts)) =
-          let fun nargs n = if length Ts <> n then err (bad_nargs c) else () in
+    val check_syntax =
+      if syntax then K ()
+      else fn c => err ("Illegal occurrence of syntactic type: " ^ quote c);
+
+    fun cert (T as Type (c, Ts)) =
+          let
+            val Ts' = map cert Ts;
+            fun nargs n = if length Ts <> n then err (bad_nargs c) else ();
+          in
             (case Symtab.lookup (types, c) of
-              Some (LogicalType n, _) => nargs n
-            | Some (Abbreviation (vs, _), _) => nargs (length vs)
-            | Some (Nonterminal, _) => nargs 0
-            | None => err (undecl_type c));
-            seq check_typ Ts
+              Some (LogicalType n, _) => (nargs n; Type (c, Ts'))
+            | Some (Abbreviation (vs, U, syn), _) => (nargs (length vs);
+                if syn then check_syntax c else ();
+                if normalize then
+                  inst_typ (map (rpair idx) vs ~~ Ts') (Term.incr_tvar idx U)
+                else Type (c, Ts'))
+            | Some (Nonterminal, _) => (nargs 0; check_syntax c; T)
+            | None => err (undecl_type c))
           end
-    | check_typ (TFree (_, S)) = check_sort S
-    | check_typ (TVar ((x, i), S)) =
-        if i < 0 then err ("Malformed type variable: " ^ quote (Term.string_of_vname (x, i)))
-        else check_sort S;
+      | cert (TFree (x, S)) = TFree (x, Sorts.certify_sort classes S)
+      | cert (TVar (xi as (_, i), S)) =
+          if i < 0 then err ("Malformed type variable: " ^ quote (Term.string_of_vname xi))
+          else TVar (xi, Sorts.certify_sort classes S);
 
-    fun no_syntax (Type (c, Ts)) =
-          (case Symtab.lookup (types, c) of
-            Some (Nonterminal, _) =>
-              err ("Illegal occurrence of syntactic type: " ^ quote c)
-          | _ => seq no_syntax Ts)
-      | no_syntax _ = ();
-
-    val _ = check_typ ty;
-    val ty' = if normalize orelse not syntax then norm_typ tsig ty else ty;
-    val _ = if not syntax then no_syntax ty' else ();
-  in ty' end;
+    val ty' = cert ty;
+    val ty' = if ty = ty' then ty else ty';  (*avoid copying of already normal type*)
+  in (ty', maxidx) end;  
 
 in
 
@@ -562,11 +543,16 @@
 fun change_types f = change_tsig (fn (classes, default, types, arities) =>
   (classes, default, f types, arities));
 
+fun syntactic types (Type (c, Ts)) =
+      (case Symtab.lookup (types, c) of Some (Nonterminal, _) => true | _ => false)
+        orelse exists (syntactic types) Ts
+  | syntactic _ _ = false;
+
 fun add_abbr (a, vs, rhs) tsig = tsig |> change_types (fn types =>
   let
     fun err msg =
       error (msg ^ "\nThe error(s) above occurred in type abbreviation: " ^ quote a);
-    val rhs' = strip_sorts (varifyT (no_tvars (cert_typ_syntax tsig rhs)))
+    val rhs' = strip_sorts (varifyT (no_tvars (#1 (cert_typ_syntax tsig rhs))))
       handle TYPE (msg, _, _) => err msg;
   in
     (case duplicates vs of
@@ -575,7 +561,7 @@
     (case gen_rems (op =) (map (#1 o #1) (typ_tvars rhs'), vs) of
       [] => []
     | extras => err ("Extra variables on rhs: " ^ commas_quote extras));
-    types |> new_decl (a, Abbreviation (vs, rhs'))
+    types |> new_decl (a, Abbreviation (vs, rhs', syntactic types rhs'))
   end);
 
 in