src/HOL/Library/adhoc_overloading.ML
changeset 37789 93f6dcf9ec02
equal deleted inserted replaced
37788:261c61fabc98 37789:93f6dcf9ec02
       
     1 (* Author: Alexander Krauss, TU Muenchen
       
     2    Author: Christian Sternagel, University of Innsbruck
       
     3 
       
     4 Ad-hoc overloading of constants based on their types.
       
     5 *)
       
     6 
       
     7 signature ADHOC_OVERLOADING =
       
     8 sig
       
     9 
       
    10   val add_overloaded: string -> theory -> theory
       
    11   val add_variant: string -> string -> theory -> theory
       
    12 
       
    13   val show_variants: bool Unsynchronized.ref
       
    14   val setup: theory -> theory
       
    15 
       
    16 end
       
    17 
       
    18 
       
    19 structure Adhoc_Overloading: ADHOC_OVERLOADING =
       
    20 struct
       
    21 
       
    22 val show_variants = Unsynchronized.ref false;
       
    23 
       
    24 
       
    25 (* errors *)
       
    26 
       
    27 fun duplicate_variant_err int_name ext_name =
       
    28   error ("Constant " ^ quote int_name ^ " is already a variant of " ^ quote ext_name);
       
    29 
       
    30 fun not_overloaded_err name =
       
    31   error ("Constant " ^ quote name ^ " is not declared as overloaded");
       
    32 
       
    33 fun already_overloaded_err name =
       
    34   error ("Constant " ^ quote name ^ " is already declared as overloaded");
       
    35 
       
    36 fun unresolved_err ctxt (c, T) t reason =
       
    37   error ("Unresolved overloading of  " ^ quote c ^ " :: " ^
       
    38     quote (Syntax.string_of_typ ctxt T) ^ " in " ^
       
    39     quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")");
       
    40 
       
    41 
       
    42 (* theory data *)
       
    43 
       
    44 structure Overload_Data = Theory_Data
       
    45 (
       
    46   type T =
       
    47     { internalize : (string * typ) list Symtab.table,
       
    48       externalize : string Symtab.table };
       
    49   val empty = {internalize=Symtab.empty, externalize=Symtab.empty};
       
    50   val extend = I;
       
    51 
       
    52   fun merge_ext int_name (ext_name1, ext_name2) =
       
    53     if ext_name1 = ext_name2 then ext_name1
       
    54     else duplicate_variant_err int_name ext_name1;
       
    55 
       
    56   fun merge ({internalize=int1, externalize=ext1},
       
    57       {internalize=int2, externalize=ext2}) =
       
    58     {internalize=Symtab.join (K (Library.merge (op =))) (int1, int2),
       
    59      externalize=Symtab.join merge_ext (ext1, ext2)};
       
    60 );
       
    61 
       
    62 fun map_tables f g =
       
    63   Overload_Data.map (fn {internalize=int, externalize=ext} =>
       
    64     {internalize=f int, externalize=g ext});
       
    65 
       
    66 val is_overloaded = Symtab.defined o #internalize o Overload_Data.get;
       
    67 val get_variants = Symtab.lookup o #internalize o Overload_Data.get;
       
    68 val get_external = Symtab.lookup o #externalize o Overload_Data.get;
       
    69 
       
    70 fun add_overloaded ext_name thy =
       
    71   let val _ = not (is_overloaded thy ext_name) orelse already_overloaded_err ext_name;
       
    72   in map_tables (Symtab.update (ext_name, [])) I thy end;
       
    73 
       
    74 fun add_variant ext_name name thy =
       
    75   let
       
    76     val _ = is_overloaded thy ext_name orelse not_overloaded_err ext_name;
       
    77     val _ = case get_external thy name of
       
    78               NONE => ()
       
    79             | SOME gen' => duplicate_variant_err name gen';
       
    80     val T = Sign.the_const_type thy name;
       
    81   in
       
    82     map_tables (Symtab.cons_list (ext_name, (name, T)))
       
    83       (Symtab.update (name, ext_name)) thy    
       
    84   end
       
    85 
       
    86 
       
    87 (* check / uncheck *)
       
    88 
       
    89 fun unifiable_with ctxt T1 (c, T2) =
       
    90   let
       
    91     val thy = ProofContext.theory_of ctxt;
       
    92     val maxidx1 = Term.maxidx_of_typ T1;
       
    93     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
       
    94     val maxidx2 = Int.max (maxidx1, Term.maxidx_of_typ T2');
       
    95   in
       
    96     (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME c)
       
    97     handle Type.TUNIFY => NONE
       
    98   end;
       
    99 
       
   100 fun insert_internal_same ctxt t (Const (c, T)) =
       
   101   (case map_filter (unifiable_with ctxt T) 
       
   102      (Same.function (get_variants (ProofContext.theory_of ctxt)) c) of
       
   103       [] => unresolved_err ctxt (c, T) t "no instances"
       
   104     | [c'] => Const (c', dummyT)
       
   105     | _ => raise Same.SAME)
       
   106   | insert_internal_same _ _ _ = raise Same.SAME;
       
   107 
       
   108 fun insert_external_same ctxt _ (Const (c, T)) =
       
   109     Const (Same.function (get_external (ProofContext.theory_of ctxt)) c, T)
       
   110   | insert_external_same _ _ _ = raise Same.SAME;
       
   111 
       
   112 fun gen_check_uncheck replace ts ctxt =
       
   113   Same.capture (Same.map (fn t => Term_Subst.map_aterms_same (replace ctxt t) t)) ts
       
   114   |> Option.map (rpair ctxt);
       
   115 
       
   116 val check = gen_check_uncheck insert_internal_same;
       
   117 fun uncheck ts ctxt = 
       
   118   if !show_variants then NONE
       
   119   else gen_check_uncheck insert_external_same ts ctxt;
       
   120 
       
   121 fun reject_unresolved ts ctxt =
       
   122   let
       
   123     val thy = ProofContext.theory_of ctxt;
       
   124     fun check_unresolved t =
       
   125       case filter (is_overloaded thy o fst) (Term.add_consts t []) of
       
   126           [] => ()
       
   127         | ((c, T) :: _) => unresolved_err ctxt (c, T) t "multiple instances";
       
   128 
       
   129     val _ = map check_unresolved ts;
       
   130   in NONE end;
       
   131 
       
   132 
       
   133 (* setup *)
       
   134 
       
   135 val setup = Context.theory_map 
       
   136   (Syntax.add_term_check 0 "adhoc_overloading" check
       
   137    #> Syntax.add_term_check 1 "adhoc_overloading_unresolved_check" reject_unresolved
       
   138    #> Syntax.add_term_uncheck 0 "adhoc_overloading" uncheck);
       
   139 
       
   140 end