--- a/src/Pure/Tools/adhoc_overloading.ML Tue Jan 28 11:20:53 2025 +0100
+++ b/src/Pure/Tools/adhoc_overloading.ML Tue Jan 28 11:29:42 2025 +0100
@@ -60,7 +60,7 @@
type T =
{variants : (term * typ) list Symtab.table,
oconsts : string Termtab.table};
- val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
+ val empty : T = {variants = Symtab.empty, oconsts = Termtab.empty};
fun merge
({variants = vtab1, oconsts = otab1},
@@ -75,9 +75,10 @@
end;
);
-fun map_tables f g =
- Data.map (fn {variants = vtab, oconsts = otab} =>
- {variants = f vtab, oconsts = g otab});
+fun map_data f =
+ Data.map (fn {variants, oconsts} =>
+ let val (variants', oconsts') = f (variants, oconsts)
+ in {variants = variants', oconsts = oconsts'} end);
val is_overloaded = Symtab.defined o #variants o Data.get;
val get_variants = Symtab.lookup o #variants o Data.get;
@@ -85,7 +86,7 @@
fun generic_add_overloaded oconst context =
if is_overloaded context oconst then context
- else map_tables (Symtab.update (oconst, [])) I context;
+ else (map_data o apfst) (Symtab.update (oconst, [])) context;
(*If the list of variants is empty at the end of "generic_remove_variant", then
"generic_remove_overloaded" is called implicitly.*)
@@ -97,7 +98,10 @@
(case get_variants context oconst of
NONE => I
| SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
- in map_tables (Symtab.delete_safe oconst) remove_variants context end;
+ in
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.delete_safe oconst variants, remove_variants oconsts))
+ end;
in
if is_overloaded context oconst then remove_oconst_and_variants context oconst
else err_not_overloaded oconst
@@ -117,20 +121,22 @@
NONE => ()
| SOME oconst' => err_duplicate_variant oconst');
in
- map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.cons_list (oconst, (t', T)) variants, Termtab.update (t', oconst) oconsts))
end
else
let
val _ =
if member variants_eq (the (get_variants context oconst)) (t', T) then ()
else err_not_a_variant oconst;
+ val context' =
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.map_entry oconst (remove1 variants_eq (t', T)) variants,
+ Termtab.delete_safe t' oconsts));
in
- map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
- (Termtab.delete_safe t') context
- |> (fn context =>
- (case get_variants context oconst of
- SOME [] => generic_remove_overloaded oconst context
- | _ => context))
+ (case get_variants context' oconst of
+ SOME [] => generic_remove_overloaded oconst context'
+ | _ => context')
end
end;
in