src/Tools/adhoc_overloading.ML
changeset 52622 e0ff1625e96d
parent 50768 2172f82de515
child 52687 72cda5eb5a39
--- a/src/Tools/adhoc_overloading.ML	Fri Jul 12 15:51:25 2013 +0200
+++ b/src/Tools/adhoc_overloading.ML	Fri Jul 12 16:19:05 2013 +0200
@@ -6,11 +6,14 @@
 
 signature ADHOC_OVERLOADING =
 sig
-  val add_overloaded: string -> theory -> theory
-  val add_variant: string -> string -> theory -> theory
-
+  val is_overloaded: Proof.context -> string -> bool
+  val generic_add_overloaded: string -> Context.generic -> Context.generic
+  val generic_remove_overloaded: string -> Context.generic -> Context.generic
+  val generic_add_variant: string -> term -> Context.generic -> Context.generic
+  (*If the list of variants is empty at the end of "generic_remove_variant", then
+  "generic_remove_overloaded" is called implicitly.*)
+  val generic_remove_variant: string -> term -> Context.generic -> Context.generic
   val show_variants: bool Config.T
-  val setup: theory -> theory
 end
 
 structure Adhoc_Overloading: ADHOC_OVERLOADING =
@@ -18,122 +21,208 @@
 
 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
 
-
 (* errors *)
 
-fun duplicate_variant_err int_name ext_name =
-  error ("Constant " ^ quote int_name ^ " is already a variant of " ^ quote ext_name);
+fun duplicate_variant_error oconst =
+  error ("Duplicate variant of " ^ quote oconst);
 
-fun not_overloaded_err name =
-  error ("Constant " ^ quote name ^ " is not declared as overloaded");
+fun not_a_variant_error oconst =
+  error ("Not a variant of " ^  quote oconst);
 
-fun already_overloaded_err name =
-  error ("Constant " ^ quote name ^ " is already declared as overloaded");
+fun not_overloaded_error oconst =
+  error ("Constant " ^ quote oconst ^ " is not declared as overloaded");
 
-fun unresolved_err ctxt (c, T) t reason =
-  error ("Unresolved overloading of  " ^ quote c ^ " :: " ^
+fun unresolved_overloading_error ctxt (c, T) t reason =
+  error ("Unresolved overloading of " ^ quote c ^ " :: " ^
     quote (Syntax.string_of_typ ctxt T) ^ " in " ^
     quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")");
 
+(* generic data *)
 
-(* theory data *)
+fun variants_eq ((v1, T1), (v2, T2)) =
+  Term.aconv_untyped (v1, v2) andalso T1 = T2;
 
-structure Overload_Data = Theory_Data
+structure Overload_Data = Generic_Data
 (
   type T =
-    { internalize : (string * typ) list Symtab.table,
-      externalize : string Symtab.table };
-  val empty = {internalize=Symtab.empty, externalize=Symtab.empty};
+    {variants : (term * typ) list Symtab.table,
+     oconsts : string Termtab.table};
+  val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
   val extend = I;
 
-  fun merge_ext int_name (ext_name1, ext_name2) =
-    if ext_name1 = ext_name2 then ext_name1
-    else duplicate_variant_err int_name ext_name1;
-
   fun merge
-    ({internalize = int1, externalize = ext1},
-      {internalize = int2, externalize = ext2}) : T =
-    {internalize = Symtab.merge_list (op =) (int1, int2),
-      externalize = Symtab.join merge_ext (ext1, ext2)};
+    ({variants = vtab1, oconsts = otab1},
+     {variants = vtab2, oconsts = otab2}) : T =
+    let
+      fun merge_oconsts _ (oconst1, oconst2) =
+        if oconst1 = oconst2 then oconst1
+        else duplicate_variant_error oconst1;
+    in
+      {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
+       oconsts = Termtab.join merge_oconsts (otab1, otab2)}
+    end;
 );
 
 fun map_tables f g =
-  Overload_Data.map (fn {internalize=int, externalize=ext} =>
-    {internalize=f int, externalize=g ext});
+  Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
+    {variants = f vtab, oconsts = g otab});
+
+val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
+val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
+val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
+
+fun generic_add_overloaded oconst context =
+  if is_overloaded (Context.proof_of context) oconst then context
+  else map_tables (Symtab.update (oconst, [])) I context;
 
-val is_overloaded = Symtab.defined o #internalize o Overload_Data.get;
-val get_variants = Symtab.lookup o #internalize o Overload_Data.get;
-val get_external = Symtab.lookup o #externalize o Overload_Data.get;
-
-fun add_overloaded ext_name thy =
-  let val _ = not (is_overloaded thy ext_name) orelse already_overloaded_err ext_name;
-  in map_tables (Symtab.update (ext_name, [])) I thy end;
+fun generic_remove_overloaded oconst context =
+  let
+    fun remove_oconst_and_variants context oconst =
+      let
+        val remove_variants =
+          (case get_variants (Context.proof_of context) oconst of
+            NONE => I
+          | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
+      in map_tables (Symtab.delete_safe oconst) remove_variants context end;
+  in
+    if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
+    else not_overloaded_error oconst
+  end;
 
-fun add_variant ext_name name thy =
-  let
-    val _ = is_overloaded thy ext_name orelse not_overloaded_err ext_name;
-    val _ =
-      (case get_external thy name of
-        NONE => ()
-      | SOME gen' => duplicate_variant_err name gen');
-    val T = Sign.the_const_type thy name;
-  in
-    map_tables (Symtab.cons_list (ext_name, (name, T)))
-      (Symtab.update (name, ext_name)) thy
-  end
-
+local
+  fun generic_variant add oconst t context =
+    let
+      val ctxt = Context.proof_of context;
+      val _ = if is_overloaded ctxt oconst then () else not_overloaded_error oconst;
+      val T = t |> singleton (Variable.polymorphic ctxt) |> fastype_of;
+      val t' = Term.map_types (K dummyT) t;
+    in
+      if add then
+        let
+          val _ =
+            (case get_overloaded ctxt t' of
+              NONE => ()
+            | SOME oconst' => duplicate_variant_error oconst');
+        in
+          map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
+        end
+      else
+        let
+          val _ =
+            if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
+            else not_a_variant_error oconst;
+        in
+          map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
+            (Termtab.delete_safe t') context
+          |> (fn context =>
+            (case get_variants (Context.proof_of context) oconst of
+              SOME [] => generic_remove_overloaded oconst context
+            | _ => context))
+        end
+    end;
+in
+  val generic_add_variant = generic_variant true;
+  val generic_remove_variant = generic_variant false;
+end;
 
 (* check / uncheck *)
 
-fun unifiable_with ctxt T1 (c, T2) =
+fun unifiable_with thy T1 (t, T2) =
   let
-    val thy = Proof_Context.theory_of ctxt;
     val maxidx1 = Term.maxidx_of_typ T1;
     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
-    val maxidx2 = Int.max (maxidx1, Term.maxidx_of_typ T2');
+    val maxidx2 = Term.maxidx_typ T2' maxidx1;
   in
-    (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME c)
+    (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME t)
     handle Type.TUNIFY => NONE
   end;
 
-fun insert_internal_same ctxt t (Const (c, T)) =
-      (case map_filter (unifiable_with ctxt T)
-         (Same.function (get_variants (Proof_Context.theory_of ctxt)) c) of
-        [] => unresolved_err ctxt (c, T) t "no instances"
-      | [c'] => Const (c', dummyT)
+fun insert_variants_same ctxt t (Const (c, T)) =
+      (case map_filter (unifiable_with (Proof_Context.theory_of ctxt) T)
+         (Same.function (get_variants ctxt) c) of
+        [] => unresolved_overloading_error ctxt (c, T) t "no instances"
+      | [variant] => variant
       | _ => raise Same.SAME)
-  | insert_internal_same _ _ _ = raise Same.SAME;
+  | insert_variants_same _ _ _ = raise Same.SAME;
 
-fun insert_external_same ctxt _ (Const (c, T)) =
-      Const (Same.function (get_external (Proof_Context.theory_of ctxt)) c, T)
-  | insert_external_same _ _ _ = raise Same.SAME;
+fun insert_overloaded_same ctxt variant =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    val t = Pattern.rewrite_term thy [] [fn t =>
+      Term.map_types (K dummyT) t
+      |> get_overloaded ctxt
+      |> Option.map (Const o rpair (fastype_of variant))] variant;
+  in
+    if Term.aconv_untyped (variant, t) then raise Same.SAME
+    else t
+  end;
 
 fun gen_check_uncheck replace ts ctxt =
-  Same.capture (Same.map (fn t => Term_Subst.map_aterms_same (replace ctxt t) t)) ts
+  Same.capture (Same.map replace) ts
   |> Option.map (rpair ctxt);
 
-val check = gen_check_uncheck insert_internal_same;
+fun check ts ctxt = gen_check_uncheck (fn t =>
+  Term_Subst.map_aterms_same (insert_variants_same ctxt t) t) ts ctxt;
 
 fun uncheck ts ctxt =
   if Config.get ctxt show_variants then NONE
-  else gen_check_uncheck insert_external_same ts ctxt;
+  else gen_check_uncheck (insert_overloaded_same ctxt) ts ctxt;
 
 fun reject_unresolved ts ctxt =
   let
-    val thy = Proof_Context.theory_of ctxt;
     fun check_unresolved t =
-      (case filter (is_overloaded thy o fst) (Term.add_consts t []) of
+      (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
         [] => ()
-      | ((c, T) :: _) => unresolved_err ctxt (c, T) t "multiple instances");
+      | ((c, T) :: _) => unresolved_overloading_error ctxt (c, T) t "multiple instances");
     val _ = map check_unresolved ts;
   in NONE end;
 
-
 (* setup *)
 
-val setup = Context.theory_map
+val _ = Context.>>
   (Syntax_Phases.term_check' 0 "adhoc_overloading" check
    #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved
    #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck);
 
+(* commands *)
+
+fun generic_adhoc_overloading_cmd add =
+  if add then
+    fold (fn (oconst, ts) =>
+      generic_add_overloaded oconst
+      #> fold (generic_add_variant oconst) ts)
+  else
+    fold (fn (oconst, ts) =>
+      fold (generic_remove_variant oconst) ts);
+
+fun adhoc_overloading_cmd' add args phi =
+  let val args' = args
+    |> map (apsnd (map_filter (fn t =>
+         let val t' = Morphism.term phi t;
+         in if Term.aconv_untyped (t, t') then SOME t' else NONE end)));
+  in generic_adhoc_overloading_cmd add args' end;
+
+fun adhoc_overloading_cmd add raw_args lthy =
+  let
+    fun const_name ctxt = fst o dest_Const o Proof_Context.read_const ctxt false dummyT;
+    val args =
+      raw_args
+      |> map (apfst (const_name lthy))
+      |> map (apsnd (map (Syntax.read_term lthy)));
+  in
+    Local_Theory.declaration {syntax = true, pervasive = false}
+      (adhoc_overloading_cmd' add args) lthy
+  end;
+
+val _ =
+  Outer_Syntax.local_theory @{command_spec "adhoc_overloading"}
+    "add ad-hoc overloading for constants / fixed variables"
+    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true);
+
+val _ =
+  Outer_Syntax.local_theory @{command_spec "no_adhoc_overloading"}
+    "add ad-hoc overloading for constants / fixed variables"
+    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false);
+
 end;
+