merged
authorwenzelm
Tue, 28 Jan 2025 14:53:36 +0100
changeset 82009 e04cdf27fdae
parent 81999 513f8fa74c82 (current diff)
parent 82008 7301923ad1e9 (diff)
child 82010 dfde9a8296f5
merged
--- a/src/Pure/Tools/adhoc_overloading.ML	Tue Jan 28 13:02:42 2025 +0100
+++ b/src/Pure/Tools/adhoc_overloading.ML	Tue Jan 28 14:53:36 2025 +0100
@@ -55,83 +55,89 @@
 fun variants_eq ((v1, T1), (v2, T2)) =
   Term.aconv_untyped (v1, v2) andalso Type.raw_equiv (T1, T2);
 
-structure Overload_Data = Generic_Data
+structure Data = Generic_Data
 (
   type T =
     {variants : (term * typ) list Symtab.table,
      oconsts : string Termtab.table};
-  val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
+  val empty : T = {variants = Symtab.empty, oconsts = Termtab.empty};
 
   fun merge
     ({variants = vtab1, oconsts = otab1},
      {variants = vtab2, oconsts = otab2}) : T =
     let
-      fun merge_oconsts _ (oconst1, oconst2) =
+      fun join (oconst1, oconst2) =
         if oconst1 = oconst2 then oconst1
         else err_duplicate_variant oconst1;
     in
       {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
-       oconsts = Termtab.join merge_oconsts (otab1, otab2)}
+       oconsts = Termtab.join (K join) (otab1, otab2)}
     end;
 );
 
-fun map_tables f g =
-  Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
-    {variants = f vtab, oconsts = g otab});
+fun map_data f =
+  Data.map (fn {variants, oconsts} =>
+    let val (variants', oconsts') = f (variants, oconsts)
+    in {variants = variants', oconsts = oconsts'} end);
 
-val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
-val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
-val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
+val no_variants = Symtab.is_empty o #variants o Data.get;
+val has_variants = Symtab.defined o #variants o Data.get;
+val get_variants = Symtab.lookup o #variants o Data.get;
+val get_overloaded = Termtab.lookup o #oconsts o Data.get;
 
 fun generic_add_overloaded oconst context =
-  if is_overloaded (Context.proof_of context) oconst then context
-  else map_tables (Symtab.update (oconst, [])) I context;
+  if has_variants context oconst then context
+  else (map_data o apfst) (Symtab.update (oconst, [])) context;
 
 (*If the list of variants is empty at the end of "generic_remove_variant", then
-"generic_remove_overloaded" is called implicitly.*)
+  "generic_remove_overloaded" is called implicitly.*)
 fun generic_remove_overloaded oconst context =
   let
     fun remove_oconst_and_variants context oconst =
       let
         val remove_variants =
-          (case get_variants (Context.proof_of context) oconst of
+          (case get_variants context oconst of
             NONE => I
           | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
-      in map_tables (Symtab.delete_safe oconst) remove_variants context end;
+      in
+        context |> map_data (fn (variants, oconsts) =>
+          (Symtab.delete_safe oconst variants, remove_variants oconsts))
+      end;
   in
-    if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
+    if has_variants context oconst then remove_oconst_and_variants context oconst
     else err_not_overloaded oconst
   end;
 
 local
   fun generic_variant add oconst t context =
     let
-      val ctxt = Context.proof_of context;
-      val _ = if is_overloaded ctxt oconst then () else err_not_overloaded oconst;
-      val T = t |> fastype_of;
+      val _ = if has_variants context oconst then () else err_not_overloaded oconst;
+      val T = fastype_of t;
       val t' = Term.map_types (K dummyT) t;
     in
       if add then
         let
           val _ =
-            (case get_overloaded ctxt t' of
+            (case get_overloaded context t' of
               NONE => ()
             | SOME oconst' => err_duplicate_variant oconst');
         in
-          map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
+          context |> map_data (fn (variants, oconsts) =>
+            (Symtab.cons_list (oconst, (t', T)) variants, Termtab.update (t', oconst) oconsts))
         end
       else
         let
           val _ =
-            if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
+            if member variants_eq (the (get_variants context oconst)) (t', T) then ()
             else err_not_a_variant oconst;
+          val context' =
+            context |> map_data (fn (variants, oconsts) =>
+              (Symtab.map_entry oconst (remove1 variants_eq (t', T)) variants,
+               Termtab.delete_safe t' oconsts));
         in
-          map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
-            (Termtab.delete_safe t') context
-          |> (fn context =>
-            (case get_variants (Context.proof_of context) oconst of
-              SOME [] => generic_remove_overloaded oconst context
-            | _ => context))
+          (case get_variants context' oconst of
+            SOME [] => generic_remove_overloaded oconst context'
+          | _ => context')
         end
     end;
 in
@@ -142,54 +148,72 @@
 
 (* check / uncheck *)
 
-fun unifiable_with thy T1 T2 =
+local
+
+fun unifiable_types ctxt (T1, T2) =
   let
+    val thy = Proof_Context.theory_of ctxt;
     val maxidx1 = Term.maxidx_of_typ T1;
     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
     val maxidx2 = Term.maxidx_typ T2' maxidx1;
   in can (Sign.typ_unify thy (T1, T2')) (Vartab.empty, maxidx2) end;
 
 fun get_candidates ctxt (c, T) =
-  get_variants ctxt c
+  get_variants (Context.Proof ctxt) c
   |> Option.map (map_filter (fn (t, T') =>
-    if unifiable_with (Proof_Context.theory_of ctxt) T T'
+    if unifiable_types ctxt (T, T')
     (*keep the type constraint for the type-inference check phase*)
     then SOME (Type.constraint T t)
     else NONE));
 
-fun insert_variants ctxt t (oconst as Const (c, T)) =
-      (case get_candidates ctxt (c, T) of
-        SOME [] => err_unresolved_overloading ctxt (c, T) t []
+val the_candidates = the oo get_candidates;
+
+fun insert_variants_same ctxt t : term Same.operation =
+  (fn Const const =>
+      (case get_candidates ctxt const of
+        SOME [] => err_unresolved_overloading ctxt const t []
       | SOME [variant] => variant
-      | _ => oconst)
-  | insert_variants _ _ oconst = oconst;
+      | _ => raise Same.SAME)
+    | _ => raise Same.SAME);
 
 fun insert_overloaded ctxt =
   let
+    val thy = Proof_Context.theory_of ctxt;
     fun proc t =
       Term.map_types (K dummyT) t
-      |> get_overloaded ctxt
+      |> get_overloaded (Context.Proof ctxt)
       |> Option.map (Const o rpair (Term.type_of t));
   in
-    Pattern.rewrite_term_yoyo (Proof_Context.theory_of ctxt) [] [proc]
+    Pattern.rewrite_term_yoyo thy [] [proc]
   end;
 
+fun overloaded_term_consts ctxt =
+  let
+    val context = Context.Proof ctxt;
+    val overloaded = has_variants context;
+    val add = fn Const (c, T) => if overloaded c then insert (op =) (c, T) else I | _ => I;
+  in fn t => if no_variants context then [] else fold_aterms add t [] end;
+
+in
+
 fun check ctxt =
-  map (fn t => Term.map_aterms (insert_variants ctxt t) t);
+  if no_variants (Context.Proof ctxt) then I
+  else map (fn t => t |> Term.map_aterms (insert_variants_same ctxt t));
 
 fun uncheck ctxt ts =
-  if Config.get ctxt show_variants orelse exists (is_none o try Term.type_of) ts then ts
+  if Config.get ctxt show_variants orelse exists (not o can Term.type_of) ts then ts
   else map (insert_overloaded ctxt) ts;
 
 fun reject_unresolved ctxt =
   let
-    val the_candidates = the o get_candidates ctxt;
     fun check_unresolved t =
-      (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
+      (case overloaded_term_consts ctxt t of
         [] => t
-      | (cT :: _) => err_unresolved_overloading ctxt cT t (the_candidates cT));
+      | const :: _ => err_unresolved_overloading ctxt const t (the_candidates ctxt const));
   in map check_unresolved end;
 
+end;
+
 
 (* setup *)