merged
authorwenzelm
Wed, 10 Jan 2024 22:25:34 +0100
changeset 79472 27279c76a068
parent 79445 8e3e9e6ca538 (current diff)
parent 79471 593fdddc6d98 (diff)
child 79473 e1b2595d678a
merged
--- a/src/Pure/Concurrent/task_queue.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/Concurrent/task_queue.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -250,7 +250,7 @@
     gs Tasks.empty
   |> Tasks.dest;
 
-fun known_task (Queue {jobs, ...}) task = can (Task_Graph.get_entry jobs) task;
+fun known_task (Queue {jobs, ...}) task = Task_Graph.defined jobs task;
 
 
 (* job status *)
--- a/src/Pure/General/graph.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/General/graph.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -25,6 +25,7 @@
   val keys: 'a T -> key list
   val get_first: (key * ('a * (Keys.T * Keys.T)) -> 'b option) -> 'a T -> 'b option
   val fold: (key * ('a * (Keys.T * Keys.T)) -> 'b -> 'b) -> 'a T -> 'b -> 'b
+  val defined: 'a T -> key -> bool
   val get_entry: 'a T -> key -> key * ('a * (Keys.T * Keys.T))        (*exception UNDEF*)
   val get_node: 'a T -> key -> 'a                                     (*exception UNDEF*)
   val map_node: key -> ('a -> 'a) -> 'a T -> 'a T
@@ -120,6 +121,8 @@
 fun get_first f (Graph tab) = Table.get_first f tab;
 fun fold_graph f (Graph tab) = Table.fold f tab;
 
+fun defined (Graph tab) = Table.defined tab;
+
 fun get_entry (Graph tab) x =
   (case Table.lookup_key tab x of
     SOME entry => entry
--- a/src/Pure/Isar/proof_context.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/Isar/proof_context.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -495,7 +495,7 @@
 
 
 fun cert_typ_mode mode ctxt T =
-  Type.cert_typ_mode mode (tsig_of ctxt) T
+  Type.certify_typ mode (tsig_of ctxt) T
     handle TYPE (msg, _, _) => error msg;
 
 val cert_typ = cert_typ_mode Type.mode_default;
@@ -533,13 +533,10 @@
   else
     let
       val ((d, reports), decl) = Type.check_decl (Context.Proof ctxt) (tsig_of ctxt) (c, pos);
-      fun err () = error ("Bad type name: " ^ quote d ^ Position.here pos);
-      val args =
-        (case decl of
-          Type.LogicalType n => n
-        | Type.Abbreviation (vs, _, _) => if strict then err () else length vs
-        | Type.Nonterminal => if strict then err () else 0);
-    in (Type (d, replicate args dummyT), reports) end;
+      val _ =
+        if strict andalso not (Type.decl_logical decl)
+        then error ("Bad type name: " ^ quote d ^ Position.here pos) else ();
+    in (Type (d, replicate (Type.decl_args decl) dummyT), reports) end;
 
 fun read_type_name flags ctxt text =
   let
@@ -584,7 +581,7 @@
     val _ =
       (case (strict, t) of
         (true, Const (d, _)) =>
-          (ignore (Consts.the_const consts d) handle TYPE (msg, _, _) => err msg)
+          (ignore (Consts.the_const_type consts d) handle TYPE (msg, _, _) => err msg)
       | _ => ());
   in (t, reports) end;
 
@@ -629,8 +626,8 @@
 local
 
 fun certify_consts ctxt =
-  Consts.certify (Context.Proof ctxt) (tsig_of ctxt)
-    (not (abbrev_mode ctxt)) (consts_of ctxt);
+  Consts.certify {normalize = not (abbrev_mode ctxt)}
+    (Context.Proof ctxt) (tsig_of ctxt) (consts_of ctxt);
 
 fun expand_binds ctxt =
   let
@@ -809,17 +806,16 @@
 
 local
 
-fun gen_cert prop ctxt t =
-  t
-  |> expand_abbrevs ctxt
-  |> (fn t' =>
-      #1 (Sign.certify' prop (Context.Proof ctxt) false (consts_of ctxt) (theory_of ctxt) t')
-        handle TYPE (msg, _, _) => error msg | TERM (msg, _) => error msg);
+fun cert_flags flags ctxt t =
+  let val t' = expand_abbrevs ctxt t in
+    #1 (Sign.certify_flags flags (Context.Proof ctxt) (consts_of ctxt) (theory_of ctxt) t')
+      handle TYPE (msg, _, _) => error msg | TERM (msg, _) => error msg
+  end;
 
 in
 
-val cert_term = gen_cert false;
-val cert_prop = gen_cert true;
+val cert_term = cert_flags {prop = false, normalize = false};
+val cert_prop = cert_flags {prop = true, normalize = false};
 
 end;
 
--- a/src/Pure/Isar/proof_display.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/Isar/proof_display.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -94,7 +94,7 @@
       [Pretty.str "default sort:", Pretty.brk 1, prt_sort S];
 
     val tfrees = map (fn v => TFree (v, []));
-    fun pretty_type syn (t, Type.LogicalType n) =
+    fun pretty_type syn (t, Type.Logical_Type n) =
           if syn then NONE
           else SOME (prt_typ (Type (t, tfrees (Name.invent Name.context Name.aT n))))
       | pretty_type syn (t, Type.Abbreviation (vs, U, syn')) =
--- a/src/Pure/ML/ml_antiquotations.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/ML/ml_antiquotations.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -127,7 +127,7 @@
 
 val _ = Theory.setup
  (ML_Antiquotation.inline_embedded \<^binding>\<open>type_name\<close>
-    (type_name "logical type" (fn (c, Type.LogicalType _) => c)) #>
+    (type_name "logical type" (fn (c, Type.Logical_Type _) => c)) #>
   ML_Antiquotation.inline_embedded \<^binding>\<open>type_abbrev\<close>
     (type_name "type abbreviation" (fn (c, Type.Abbreviation _) => c)) #>
   ML_Antiquotation.inline_embedded \<^binding>\<open>nonterminal\<close>
@@ -150,7 +150,7 @@
 
 val _ = Theory.setup
  (ML_Antiquotation.inline_embedded \<^binding>\<open>const_name\<close>
-    (const_name (fn (consts, c) => (Consts.the_const consts c; c))) #>
+    (const_name (fn (consts, c) => (Consts.the_const_type consts c; c))) #>
   ML_Antiquotation.inline_embedded \<^binding>\<open>const_abbrev\<close>
     (const_name (fn (consts, c) => (Consts.the_abbreviation consts c; c))) #>
   ML_Antiquotation.inline_embedded \<^binding>\<open>const_syntax\<close>
--- a/src/Pure/Thy/export_theory.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/Thy/export_theory.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -219,7 +219,7 @@
     val _ =
       export_entities "types" Sign.type_space (Name_Space.dest_table (#types rep_tsig))
         (fn c =>
-          (fn Type.LogicalType n =>
+          (fn Type.Logical_Type n =>
                 SOME (fn () =>
                   encode_type (get_syntax_type thy_ctxt c, Name.invent Name.context Name.aT n, NONE))
             | Type.Abbreviation (args, U, false) =>
--- a/src/Pure/consts.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/consts.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -16,7 +16,8 @@
    {const_space: Name_Space.T,
     constants: (string * (typ * term option)) list,
     constraints: (string * typ) list}
-  val the_const: T -> string -> string * typ                   (*exception TYPE*)
+  val get_const_name: T -> string -> string option
+  val the_const_type: T -> string -> typ                       (*exception TYPE*)
   val the_abbreviation: T -> string -> typ * term              (*exception TYPE*)
   val type_scheme: T -> string -> typ                          (*exception TYPE*)
   val type_arguments: T -> string -> int list list             (*exception TYPE*)
@@ -28,7 +29,7 @@
   val intern: T -> xstring -> string
   val intern_syntax: T -> xstring -> string
   val check_const: Context.generic -> T -> xstring * Position.T list -> term * Position.report list
-  val certify: Context.generic -> Type.tsig -> bool -> T -> term -> term  (*exception TYPE*)
+  val certify: {normalize: bool} -> Context.generic -> Type.tsig -> T -> term -> term  (*exception TYPE*)
   val typargs: T -> string * typ -> typ list
   val instance: T -> string * typ list -> typ
   val dummy_types: T -> term -> term
@@ -49,7 +50,7 @@
 (* datatype T *)
 
 type decl = {T: typ, typargs: int list list};
-type abbrev = {rhs: term, normal_rhs: term, force_expand: bool};
+type abbrev = {rhs: term, normal_rhs: term, internal: bool};
 
 datatype T = Consts of
  {decls: (decl * abbrev option) Name_Space.table,
@@ -108,22 +109,26 @@
 
 (* lookup consts *)
 
-fun the_entry (Consts {decls, ...}) c =
-  (case Name_Space.lookup_key decls c of
+fun get_const_name (Consts {decls, ...}) = Name_Space.lookup_key decls #> Option.map #1;
+
+fun get_entry (Consts {decls, ...}) = Name_Space.lookup decls;
+
+fun the_entry consts c =
+  (case get_entry consts c of
     SOME entry => entry
   | NONE => raise TYPE ("Unknown constant: " ^ quote c, [], []));
 
-fun the_const consts c =
+fun the_const_type consts c =
   (case the_entry consts c of
-    (c', ({T, ...}, NONE)) => (c', T)
+    ({T, ...}, NONE) => T
   | _ => raise TYPE ("Not a logical constant: " ^ quote c, [], []));
 
 fun the_abbreviation consts c =
   (case the_entry consts c of
-    (_, ({T, ...}, SOME {rhs, ...})) => (T, rhs)
+    ({T, ...}, SOME {rhs, ...}) => (T, rhs)
   | _ => raise TYPE ("Not an abbreviated constant: " ^ quote c, [], []));
 
-fun the_decl consts = #1 o #2 o the_entry consts;
+fun the_decl consts = #1 o the_entry consts;
 val type_scheme = #T oo the_decl;
 val type_arguments = #typargs oo the_decl;
 
@@ -164,41 +169,71 @@
 
 (* certify *)
 
-fun certify context tsig do_expand consts =
+fun certify {normalize} context tsig consts =
   let
     fun err msg (c, T) =
       raise TYPE (msg ^ " " ^ quote c ^ " :: " ^
         Syntax.string_of_typ (Syntax.init_pretty context) T, [], []);
-    val certT = Type.cert_typ tsig;
-    fun cert tm =
+
+    fun err_const const = err "Illegal type for constant" const;
+
+    val need_expand =
+      Term.exists_Const (fn (c, _) =>
+        (case get_entry consts c of
+          SOME (_, SOME {internal, ...}) => normalize orelse internal
+        | _ => false));
+
+    val expand_typ = Type.certify_typ Type.mode_default tsig;
+    fun expand_term tm =
       let
         val (head, args) = Term.strip_comb tm;
-        val args' = map cert args;
+        val args' = map expand_term args;
         fun comb head' = Term.list_comb (head', args');
       in
         (case head of
-          Abs (x, T, t) => comb (Abs (x, certT T, cert t))
-        | Const (c, T) =>
+          Const (c, T) =>
             let
-              val T' = certT T;
-              val (_, ({T = U, ...}, abbr)) = the_entry consts c;
+              val T' = expand_typ T;
+              val ({T = U, ...}, abbr) = the_entry consts c;
               fun expand u =
                 Term.betapplys (Envir.expand_atom T' (U, u) handle TYPE _ =>
                   err "Illegal type for abbreviation" (c, T), args');
             in
-              if not (Type.raw_instance (T', U)) then
-                err "Illegal type for constant" (c, T)
+              if not (Type.raw_instance (T', U)) then err_const (c, T)
               else
                 (case abbr of
-                  SOME {rhs, normal_rhs, force_expand} =>
-                    if do_expand then expand normal_rhs
-                    else if force_expand then expand rhs
+                  SOME {rhs, normal_rhs, internal} =>
+                    if normalize then expand normal_rhs
+                    else if internal then expand rhs
                     else comb head
                 | _ => comb head)
             end
+        | Abs (x, T, t) => comb (Abs (x, expand_typ T, expand_term t))
+        | Free (x, T) => comb (Free (x, expand_typ T))
+        | Var (xi, T) => comb (Var (xi, expand_typ T))
         | _ => comb head)
       end;
-  in cert end;
+
+    val typ = Type.certify_typ_same Type.mode_default tsig;
+    fun term (Const (c, T)) =
+          let
+            val (T', same) = Same.commit_id typ T;
+            val U = type_scheme consts c;
+          in
+            if not (Type.raw_instance (T', U)) then err_const (c, T)
+            else if same then raise Same.SAME else Const (c, T')
+          end
+      | term (Free (x, T)) = Free (x, typ T)
+      | term (Var (xi, T)) = Var (xi, typ T)
+      | term (Bound _) = raise Same.SAME
+      | term (Abs (x, T, t)) =
+          (Abs (x, typ T, Same.commit term t)
+            handle Same.SAME => Abs (x, T, term t))
+      | term (t $ u) =
+          (term t $ Same.commit term u
+            handle Same.SAME => t $ term u);
+
+  in fn tm => if need_expand tm then expand_term tm else Same.commit term tm end;
 
 
 (* typargs -- view actual const type as instance of declaration *)
@@ -293,17 +328,14 @@
 
 fun abbreviate context tsig mode (b, raw_rhs) consts =
   let
-    val cert_term = certify context tsig false consts;
-    val expand_term = certify context tsig true consts;
-    val force_expand = mode = Print_Mode.internal;
+    val cert_term = certify {normalize = false} context tsig consts;
+    val expand_term = certify {normalize = true} context tsig consts;
+    val internal = mode = Print_Mode.internal;
 
     val _ = Term.exists_subterm Term.is_Var raw_rhs andalso
       error ("Illegal schematic variables on rhs of abbreviation " ^ Binding.print b);
 
-    val rhs = raw_rhs
-      |> Term.map_types (Type.cert_typ tsig)
-      |> cert_term
-      |> Term.close_schematic_term;
+    val rhs = raw_rhs |> cert_term |> Term.close_schematic_term;
     val normal_rhs = expand_term rhs;
     val T = Term.fastype_of rhs;
     val lhs = Const (Name_Space.full_name (Name_Space.naming_of context) b, T);
@@ -311,7 +343,7 @@
     consts |> map_consts (fn (decls, constraints, rev_abbrevs) =>
       let
         val decl = {T = T, typargs = typargs_of T};
-        val abbr = {rhs = rhs, normal_rhs = normal_rhs, force_expand = force_expand};
+        val abbr = {rhs = rhs, normal_rhs = normal_rhs, internal = internal};
         val _ = Binding.check b;
         val (_, decls') = decls
           |> Name_Space.define context true (b, (decl, SOME abbr));
--- a/src/Pure/sign.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/sign.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -64,8 +64,9 @@
   val certify_sort: theory -> sort -> sort
   val certify_typ: theory -> typ -> typ
   val certify_typ_mode: Type.mode -> theory -> typ -> typ
-  val certify': bool -> Context.generic -> bool -> Consts.T -> theory -> term -> term * typ * int
-  val certify_term: theory -> term -> term * typ * int
+  val certify_flags: {prop: bool, normalize: bool} -> Context.generic -> Consts.T -> theory ->
+    term -> term * typ
+  val certify_term: theory -> term -> term * typ
   val cert_term: theory -> term -> term
   val cert_prop: theory -> term -> term
   val no_frees: Proof.context -> term -> term
@@ -206,7 +207,7 @@
 
 val consts_of = #consts o rep_sg;
 val the_const_constraint = Consts.the_constraint o consts_of;
-val the_const_type = #2 oo (Consts.the_const o consts_of);
+val the_const_type = Consts.the_const_type o consts_of;
 val const_type = try o the_const_type;
 val const_monomorphic = Consts.is_monomorphic o consts_of;
 val const_typargs = Consts.typargs o consts_of;
@@ -259,8 +260,8 @@
 
 val certify_class = Type.cert_class o tsig_of;
 val certify_sort = Type.cert_sort o tsig_of;
-val certify_typ = Type.cert_typ o tsig_of;
-fun certify_typ_mode mode = Type.cert_typ_mode mode o tsig_of;
+fun certify_typ_mode mode = Type.certify_typ mode o tsig_of;
+val certify_typ = certify_typ_mode Type.mode_default;
 
 
 (* certify term/prop *)
@@ -304,21 +305,32 @@
 
 in
 
-fun certify' prop context do_expand consts thy tm =
+fun certify_flags {prop, normalize} context consts thy tm =
   let
-    val _ = check_vars tm;
-    val tm' = Term.map_types (certify_typ thy) tm;
-    val T = type_check context tm';
-    val _ = if prop andalso T <> propT then err "Term not of type prop" else ();
-    val tm'' = tm'
-      |> Consts.certify context (tsig_of thy) do_expand consts
-      |> Soft_Type_System.global_purge thy;
-  in (if tm = tm'' then tm else tm'', T, Term.maxidx_of_term tm'') end;
+    val tsig = tsig_of thy;
+    fun check_term t =
+      let
+        val _ = check_vars t;
+        val t' = Type.certify_types Type.mode_default tsig t;
+        val T = type_check context t';
+        val t'' = Consts.certify {normalize = normalize} context tsig consts t';
+      in if prop andalso T <> propT then err "Term not of type prop" else (t'', T) end;
 
-fun certify_term thy = certify' false (Context.Theory thy) true (consts_of thy) thy;
-fun cert_term_abbrev thy = #1 o certify' false (Context.Theory thy) false (consts_of thy) thy;
+    val (tm1, ty1) = check_term tm;
+    val tm' = Soft_Type_System.global_purge thy tm1;
+    val (tm2, ty2) = if tm1 = tm' then (tm1, ty1) else check_term tm';
+  in if tm = tm2 then (tm, ty2) else (tm2, ty2) end;
+
+fun certify_term thy =
+  certify_flags {prop = false, normalize = true} (Context.Theory thy) (consts_of thy) thy;
+
+fun cert_term_abbrev thy =
+  #1 o certify_flags {prop = false, normalize = false} (Context.Theory thy) (consts_of thy) thy;
+
 val cert_term = #1 oo certify_term;
-fun cert_prop thy = #1 o certify' true (Context.Theory thy) true (consts_of thy) thy;
+
+fun cert_prop thy =
+  #1 o certify_flags {prop = true, normalize = true} (Context.Theory thy) (consts_of thy) thy;
 
 end;
 
--- a/src/Pure/sorts.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/sorts.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -30,6 +30,8 @@
   val arities_of: algebra -> (class * sort list) list Symtab.table
   val all_classes: algebra -> class list
   val super_classes: algebra -> class -> class list
+  val cert_class: algebra -> class -> class
+  val cert_sort: algebra -> sort -> sort
   val class_less: algebra -> class * class -> bool
   val class_le: algebra -> class * class -> bool
   val sort_eq: algebra -> sort * sort -> bool
@@ -135,6 +137,12 @@
 
 val super_classes = Graph.immediate_succs o classes_of;
 
+fun cert_class (Algebra {classes, ...}) c =
+  if Graph.defined classes c then c
+  else raise TYPE ("Undeclared class: " ^ quote c, [], []);
+
+fun cert_sort algebra S = (List.app (ignore o cert_class algebra) S; S);
+
 
 (* class relations *)
 
--- a/src/Pure/term.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/term.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -480,8 +480,8 @@
 
 fun map_types_same f =
   let
-    fun term (Const (a, T)) = Const (a, f T)
-      | term (Free (a, T)) = Free (a, f T)
+    fun term (Const (c, T)) = Const (c, f T)
+      | term (Free (x, T)) = Free (x, f T)
       | term (Var (xi, T)) = Var (xi, f T)
       | term (Bound _) = raise Same.SAME
       | term (Abs (x, T, t)) =
--- a/src/Pure/term_sharing.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/term_sharing.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -23,7 +23,7 @@
 
     val class = perhaps (try (#1 o Graph.get_entry (Sorts.classes_of algebra)));
     val tycon = perhaps (Option.map #1 o Name_Space.lookup_key types);
-    val const = perhaps (try (#1 o Consts.the_const (Sign.consts_of thy)));
+    val const = perhaps (Consts.get_const_name (Sign.consts_of thy));
 
     val typs = Unsynchronized.ref (Typtab.empty: Typtab.set);
     val terms = Unsynchronized.ref (Syntax_Termtab.empty: Syntax_Termtab.set);
--- a/src/Pure/thm.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/thm.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -269,7 +269,8 @@
 
 fun global_cterm_of thy tm =
   let
-    val (t, T, maxidx) = Sign.certify_term thy tm;
+    val (t, T) = Sign.certify_term thy tm;
+    val maxidx = Term.maxidx_of_term t;
     val sorts = Sorts.insert_term t [];
   in Cterm {cert = Context.Certificate thy, t = t, T = T, maxidx = maxidx, sorts = sorts} end;
 
--- a/src/Pure/type.ML	Wed Jan 10 15:30:13 2024 +0100
+++ b/src/Pure/type.ML	Wed Jan 10 22:25:34 2024 +0100
@@ -15,9 +15,11 @@
   val appl_error: Proof.context -> term -> typ -> term -> typ -> string
   (*type signatures and certified types*)
   datatype decl =
-    LogicalType of int |
+    Logical_Type of int |
     Abbreviation of string list * typ * bool |
     Nonterminal
+  val decl_args: decl -> int
+  val decl_logical: decl -> bool
   type tsig
   val eq_tsig: tsig * tsig -> bool
   val rep_tsig: tsig ->
@@ -53,8 +55,9 @@
   val check_decl: Context.generic -> tsig ->
     xstring * Position.T -> (string * Position.report list) * decl
   val the_decl: tsig -> string * Position.T -> decl
-  val cert_typ_mode: mode -> tsig -> typ -> typ
-  val cert_typ: tsig -> typ -> typ
+  val certify_typ_same: mode -> tsig -> typ Same.operation
+  val certify_typ: mode -> tsig -> typ -> typ
+  val certify_types: mode -> tsig -> term -> term
   val arity_number: tsig -> string -> int
   val arity_sorts: Context.generic -> tsig -> string -> sort -> sort list
 
@@ -152,10 +155,17 @@
 (* type declarations *)
 
 datatype decl =
-  LogicalType of int |
+  Logical_Type of int |
   Abbreviation of string list * typ * bool |
   Nonterminal;
 
+fun decl_args (Logical_Type n) = n
+  | decl_args (Abbreviation (vs, _, _)) = length vs
+  | decl_args Nonterminal = 0;
+
+fun decl_logical (Logical_Type _) = true
+  | decl_logical _ = false;
+
 
 (* type tsig *)
 
@@ -187,7 +197,7 @@
 fun build_tsig (classes, default, types) =
   let
     val log_types =
-      Name_Space.fold_table (fn (c, LogicalType n) => cons (c, n) | _ => I) types []
+      Name_Space.fold_table (fn (c, Logical_Type n) => cons (c, n) | _ => I) types []
       |> Library.sort (int_ord o apply2 snd) |> map fst;
   in make_tsig (classes, default, types, log_types) end;
 
@@ -211,11 +221,8 @@
 fun of_sort (TSig {classes, ...}) = Sorts.of_sort (#2 classes);
 fun inter_sort (TSig {classes, ...}) = Sorts.inter_sort (#2 classes);
 
-fun cert_class (TSig {classes = (_, algebra), ...}) c =
-  if can (Graph.get_entry (Sorts.classes_of algebra)) c then c
-  else raise TYPE ("Undeclared class: " ^ quote c, [], []);
-
-val cert_sort = map o cert_class;
+fun cert_class (TSig {classes, ...}) = Sorts.cert_class (#2 classes);
+fun cert_sort (TSig {classes, ...}) = Sorts.cert_sort (#2 classes);
 
 fun minimize_sort (TSig {classes, ...}) = Sorts.minimize_sort (#2 classes);
 
@@ -263,59 +270,46 @@
   | SOME decl => decl);
 
 
-(* certified types *)
+(* certify types *)
 
 fun bad_nargs t = "Bad number of arguments for type constructor: " ^ quote t;
 
-local
-
-fun inst_typ env (Type (c, Ts)) = Type (c, map (inst_typ env) Ts)
-  | inst_typ env (T as TFree (x, _)) = the_default T (AList.lookup (op =) env x)
-  | inst_typ _ T = T;
-
-in
-
-fun cert_typ_mode (Mode {normalize, logical}) tsig ty =
+fun certify_typ_same (Mode {normalize, logical}) tsig =
   let
-    fun err msg = raise TYPE (msg, [ty], []);
-
-    val check_logical =
-      if logical then fn c => err ("Illegal occurrence of syntactic type: " ^ quote c)
-      else fn _ => ();
+    fun err T msg = raise TYPE (msg, [T], []);
+    fun err_syntactic T c = err T ("Illegal occurrence of syntactic type: " ^ quote c);
 
-    fun cert (T as Type (c, Ts)) =
-          let
-            val Ts' = map cert Ts;
-            fun nargs n = if length Ts <> n then err (bad_nargs c) else ();
-          in
-            (case the_decl tsig (c, Position.none) of
-              LogicalType n => (nargs n; Type (c, Ts'))
-            | Abbreviation (vs, U, syn) =>
-               (nargs (length vs);
-                if syn then check_logical c else ();
-                if normalize then inst_typ (vs ~~ Ts') U
-                else Type (c, Ts'))
-            | Nonterminal => (nargs 0; check_logical c; T))
+    fun sort S = (cert_sort tsig S; raise Same.SAME);
+    fun typ (T as Type (c, args)) =
+          let val decl = the_decl tsig (c, Position.none) in
+            if length args <> decl_args decl then err T (bad_nargs c)
+            else
+              (case decl of
+                Logical_Type _ => Type (c, Same.map typ args)
+              | Abbreviation (vs, U, syntactic) =>
+                  if syntactic andalso logical then err_syntactic T c
+                  else if normalize then inst_typ vs args U
+                  else Type (c, Same.map typ args)
+              | Nonterminal => if logical then err_syntactic T c else raise Same.SAME)
           end
-      | cert (TFree (x, S)) = TFree (x, cert_sort tsig S)
-      | cert (TVar (xi as (_, i), S)) =
-          if i < 0 then
-            err ("Malformed type variable: " ^ quote (Term.string_of_vname xi))
-          else TVar (xi, cert_sort tsig S);
+      | typ (TFree (_, S)) = sort S
+      | typ (T as TVar ((x, i), S)) =
+          if i < 0 then err T ("Malformed type variable: " ^ quote (Term.string_of_vname (x, i)))
+          else sort S
+    and inst_typ vs args =
+      Term_Subst.instantiateT_frees
+        (TFrees.build (fold2 (fn v => fn T => TFrees.add ((v, []), Same.commit typ T)) vs args));
+  in typ end;
 
-    val ty' = cert ty;
-  in if ty = ty' then ty else ty' end;  (*avoid copying of already normal type*)
-
-val cert_typ = cert_typ_mode mode_default;
-
-end;
+val certify_typ = Same.commit oo certify_typ_same;
+val certify_types = Term.map_types oo certify_typ_same;
 
 
 (* type arities *)
 
 fun arity_number tsig a =
   (case lookup_type tsig a of
-    SOME (LogicalType n) => n
+    SOME (Logical_Type n) => n
   | _ => error (undecl_type a));
 
 fun arity_sorts _ tsig a [] = replicate (arity_number tsig a) []
@@ -337,9 +331,11 @@
 (* no_tvars *)
 
 fun no_tvars T =
-  (case Term.add_tvarsT T [] of [] => T
-  | vs => raise TYPE ("Illegal schematic type variable(s): " ^
-      commas_quote (map (Term.string_of_vname o #1) (rev vs)), [T], []));
+  (case Term.add_tvarsT T [] of
+    [] => T
+  | vs =>
+      raise TYPE ("Illegal schematic type variable(s): " ^
+        commas_quote (map (Term.string_of_vname o #1) (rev vs)), [T], []));
 
 
 (* varify_global *)
@@ -622,8 +618,7 @@
 fun add_class context (c, cs) tsig =
   tsig |> map_tsig (fn ((space, classes), default, types) =>
     let
-      val cs' = map (cert_class tsig) cs
-        handle TYPE (msg, _, _) => error msg;
+      val cs' = cert_sort tsig cs handle TYPE (msg, _, _) => error msg;
       val _ = Binding.check c;
       val (c', space') = space |> Name_Space.declare context true c;
       val classes' = classes |> Sorts.add_class context (c', cs');
@@ -639,7 +634,7 @@
   let
     val _ =
       (case lookup_type tsig t of
-        SOME (LogicalType n) => if length Ss <> n then error (bad_nargs t) else ()
+        SOME (Logical_Type n) => if length Ss <> n then error (bad_nargs t) else ()
       | SOME _ => error ("Logical type constructor expected: " ^ quote t)
       | NONE => error (undecl_type t));
     val (Ss', S') = (map (cert_sort tsig) Ss, cert_sort tsig S)
@@ -690,13 +685,13 @@
 
 fun add_type context (c, n) =
   if n < 0 then error ("Bad type constructor declaration " ^ Binding.print c)
-  else map_types (new_decl context (c, LogicalType n));
+  else map_types (new_decl context (c, Logical_Type n));
 
 fun add_abbrev context (a, vs, rhs) tsig = tsig |> map_types (fn types =>
   let
     fun err msg =
       cat_error msg ("The error(s) above occurred in type abbreviation " ^ Binding.print a);
-    val rhs' = Term.strip_sortsT (no_tvars (cert_typ_mode mode_syntax tsig rhs))
+    val rhs' = Term.strip_sortsT (no_tvars (certify_typ mode_syntax tsig rhs))
       handle TYPE (msg, _, _) => err msg;
     val _ =
       (case duplicates (op =) vs of