--- a/src/Pure/Tools/adhoc_overloading.ML Tue Jan 28 13:02:42 2025 +0100
+++ b/src/Pure/Tools/adhoc_overloading.ML Tue Jan 28 14:53:36 2025 +0100
@@ -55,83 +55,89 @@
fun variants_eq ((v1, T1), (v2, T2)) =
Term.aconv_untyped (v1, v2) andalso Type.raw_equiv (T1, T2);
-structure Overload_Data = Generic_Data
+structure Data = Generic_Data
(
type T =
{variants : (term * typ) list Symtab.table,
oconsts : string Termtab.table};
- val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
+ val empty : T = {variants = Symtab.empty, oconsts = Termtab.empty};
fun merge
({variants = vtab1, oconsts = otab1},
{variants = vtab2, oconsts = otab2}) : T =
let
- fun merge_oconsts _ (oconst1, oconst2) =
+ fun join (oconst1, oconst2) =
if oconst1 = oconst2 then oconst1
else err_duplicate_variant oconst1;
in
{variants = Symtab.merge_list variants_eq (vtab1, vtab2),
- oconsts = Termtab.join merge_oconsts (otab1, otab2)}
+ oconsts = Termtab.join (K join) (otab1, otab2)}
end;
);
-fun map_tables f g =
- Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
- {variants = f vtab, oconsts = g otab});
+fun map_data f =
+ Data.map (fn {variants, oconsts} =>
+ let val (variants', oconsts') = f (variants, oconsts)
+ in {variants = variants', oconsts = oconsts'} end);
-val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
-val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
-val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
+val no_variants = Symtab.is_empty o #variants o Data.get;
+val has_variants = Symtab.defined o #variants o Data.get;
+val get_variants = Symtab.lookup o #variants o Data.get;
+val get_overloaded = Termtab.lookup o #oconsts o Data.get;
fun generic_add_overloaded oconst context =
- if is_overloaded (Context.proof_of context) oconst then context
- else map_tables (Symtab.update (oconst, [])) I context;
+ if has_variants context oconst then context
+ else (map_data o apfst) (Symtab.update (oconst, [])) context;
(*If the list of variants is empty at the end of "generic_remove_variant", then
-"generic_remove_overloaded" is called implicitly.*)
+ "generic_remove_overloaded" is called implicitly.*)
fun generic_remove_overloaded oconst context =
let
fun remove_oconst_and_variants context oconst =
let
val remove_variants =
- (case get_variants (Context.proof_of context) oconst of
+ (case get_variants context oconst of
NONE => I
| SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
- in map_tables (Symtab.delete_safe oconst) remove_variants context end;
+ in
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.delete_safe oconst variants, remove_variants oconsts))
+ end;
in
- if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
+ if has_variants context oconst then remove_oconst_and_variants context oconst
else err_not_overloaded oconst
end;
local
fun generic_variant add oconst t context =
let
- val ctxt = Context.proof_of context;
- val _ = if is_overloaded ctxt oconst then () else err_not_overloaded oconst;
- val T = t |> fastype_of;
+ val _ = if has_variants context oconst then () else err_not_overloaded oconst;
+ val T = fastype_of t;
val t' = Term.map_types (K dummyT) t;
in
if add then
let
val _ =
- (case get_overloaded ctxt t' of
+ (case get_overloaded context t' of
NONE => ()
| SOME oconst' => err_duplicate_variant oconst');
in
- map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.cons_list (oconst, (t', T)) variants, Termtab.update (t', oconst) oconsts))
end
else
let
val _ =
- if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
+ if member variants_eq (the (get_variants context oconst)) (t', T) then ()
else err_not_a_variant oconst;
+ val context' =
+ context |> map_data (fn (variants, oconsts) =>
+ (Symtab.map_entry oconst (remove1 variants_eq (t', T)) variants,
+ Termtab.delete_safe t' oconsts));
in
- map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
- (Termtab.delete_safe t') context
- |> (fn context =>
- (case get_variants (Context.proof_of context) oconst of
- SOME [] => generic_remove_overloaded oconst context
- | _ => context))
+ (case get_variants context' oconst of
+ SOME [] => generic_remove_overloaded oconst context'
+ | _ => context')
end
end;
in
@@ -142,54 +148,72 @@
(* check / uncheck *)
-fun unifiable_with thy T1 T2 =
+local
+
+fun unifiable_types ctxt (T1, T2) =
let
+ val thy = Proof_Context.theory_of ctxt;
val maxidx1 = Term.maxidx_of_typ T1;
val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
val maxidx2 = Term.maxidx_typ T2' maxidx1;
in can (Sign.typ_unify thy (T1, T2')) (Vartab.empty, maxidx2) end;
fun get_candidates ctxt (c, T) =
- get_variants ctxt c
+ get_variants (Context.Proof ctxt) c
|> Option.map (map_filter (fn (t, T') =>
- if unifiable_with (Proof_Context.theory_of ctxt) T T'
+ if unifiable_types ctxt (T, T')
(*keep the type constraint for the type-inference check phase*)
then SOME (Type.constraint T t)
else NONE));
-fun insert_variants ctxt t (oconst as Const (c, T)) =
- (case get_candidates ctxt (c, T) of
- SOME [] => err_unresolved_overloading ctxt (c, T) t []
+val the_candidates = the oo get_candidates;
+
+fun insert_variants_same ctxt t : term Same.operation =
+ (fn Const const =>
+ (case get_candidates ctxt const of
+ SOME [] => err_unresolved_overloading ctxt const t []
| SOME [variant] => variant
- | _ => oconst)
- | insert_variants _ _ oconst = oconst;
+ | _ => raise Same.SAME)
+ | _ => raise Same.SAME);
fun insert_overloaded ctxt =
let
+ val thy = Proof_Context.theory_of ctxt;
fun proc t =
Term.map_types (K dummyT) t
- |> get_overloaded ctxt
+ |> get_overloaded (Context.Proof ctxt)
|> Option.map (Const o rpair (Term.type_of t));
in
- Pattern.rewrite_term_yoyo (Proof_Context.theory_of ctxt) [] [proc]
+ Pattern.rewrite_term_yoyo thy [] [proc]
end;
+fun overloaded_term_consts ctxt =
+ let
+ val context = Context.Proof ctxt;
+ val overloaded = has_variants context;
+ val add = fn Const (c, T) => if overloaded c then insert (op =) (c, T) else I | _ => I;
+ in fn t => if no_variants context then [] else fold_aterms add t [] end;
+
+in
+
fun check ctxt =
- map (fn t => Term.map_aterms (insert_variants ctxt t) t);
+ if no_variants (Context.Proof ctxt) then I
+ else map (fn t => t |> Term.map_aterms (insert_variants_same ctxt t));
fun uncheck ctxt ts =
- if Config.get ctxt show_variants orelse exists (is_none o try Term.type_of) ts then ts
+ if Config.get ctxt show_variants orelse exists (not o can Term.type_of) ts then ts
else map (insert_overloaded ctxt) ts;
fun reject_unresolved ctxt =
let
- val the_candidates = the o get_candidates ctxt;
fun check_unresolved t =
- (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
+ (case overloaded_term_consts ctxt t of
[] => t
- | (cT :: _) => err_unresolved_overloading ctxt cT t (the_candidates cT));
+ | const :: _ => err_unresolved_overloading ctxt const t (the_candidates ctxt const));
in map check_unresolved end;
+end;
+
(* setup *)