src/HOL/Library/adhoc_overloading.ML
author wenzelm
Tue May 15 13:57:39 2018 +0200 (16 months ago)
changeset 68189 6163c90694ef
parent 68061 81d90f830f99
child 69593 3dda49e08b9d
permissions -rw-r--r--
tuned headers;
     1 (*  Author:     Alexander Krauss, TU Muenchen
     2     Author:     Christian Sternagel, University of Innsbruck
     3 
     4 Adhoc overloading of constants based on their types.
     5 *)
     6 
     7 signature ADHOC_OVERLOADING =
     8 sig
     9   val is_overloaded: Proof.context -> string -> bool
    10   val generic_add_overloaded: string -> Context.generic -> Context.generic
    11   val generic_remove_overloaded: string -> Context.generic -> Context.generic
    12   val generic_add_variant: string -> term -> Context.generic -> Context.generic
    13   (*If the list of variants is empty at the end of "generic_remove_variant", then
    14   "generic_remove_overloaded" is called implicitly.*)
    15   val generic_remove_variant: string -> term -> Context.generic -> Context.generic
    16   val show_variants: bool Config.T
    17 end
    18 
    19 structure Adhoc_Overloading: ADHOC_OVERLOADING =
    20 struct
    21 
    22 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
    23 
    24 
    25 (* errors *)
    26 
    27 fun err_duplicate_variant oconst =
    28   error ("Duplicate variant of " ^ quote oconst);
    29 
    30 fun err_not_a_variant oconst =
    31   error ("Not a variant of " ^  quote oconst);
    32 
    33 fun err_not_overloaded oconst =
    34   error ("Constant " ^ quote oconst ^ " is not declared as overloaded");
    35 
    36 fun err_unresolved_overloading ctxt0 (c, T) t instances =
    37   let
    38     val ctxt = Config.put show_variants true ctxt0
    39     val const_space = Proof_Context.const_space ctxt
    40     val prt_const =
    41       Pretty.block [Name_Space.pretty ctxt const_space c, Pretty.str " ::", Pretty.brk 1,
    42         Pretty.quote (Syntax.pretty_typ ctxt T)]
    43   in
    44     error (Pretty.string_of (Pretty.chunks [
    45       Pretty.block [
    46         Pretty.str "Unresolved adhoc overloading of constant", Pretty.brk 1,
    47         prt_const, Pretty.brk 1, Pretty.str "in term", Pretty.brk 1,
    48         Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t)]],
    49       Pretty.block (
    50         (if null instances then [Pretty.str "no instances"]
    51         else Pretty.fbreaks (
    52           Pretty.str "multiple instances:" ::
    53           map (Pretty.mark Markup.item o Syntax.pretty_term ctxt) instances)))]))
    54   end;
    55 
    56 
    57 (* generic data *)
    58 
    59 fun variants_eq ((v1, T1), (v2, T2)) =
    60   Term.aconv_untyped (v1, v2) andalso T1 = T2;
    61 
    62 structure Overload_Data = Generic_Data
    63 (
    64   type T =
    65     {variants : (term * typ) list Symtab.table,
    66      oconsts : string Termtab.table};
    67   val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
    68   val extend = I;
    69 
    70   fun merge
    71     ({variants = vtab1, oconsts = otab1},
    72      {variants = vtab2, oconsts = otab2}) : T =
    73     let
    74       fun merge_oconsts _ (oconst1, oconst2) =
    75         if oconst1 = oconst2 then oconst1
    76         else err_duplicate_variant oconst1;
    77     in
    78       {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
    79        oconsts = Termtab.join merge_oconsts (otab1, otab2)}
    80     end;
    81 );
    82 
    83 fun map_tables f g =
    84   Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
    85     {variants = f vtab, oconsts = g otab});
    86 
    87 val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
    88 val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
    89 val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
    90 
    91 fun generic_add_overloaded oconst context =
    92   if is_overloaded (Context.proof_of context) oconst then context
    93   else map_tables (Symtab.update (oconst, [])) I context;
    94 
    95 fun generic_remove_overloaded oconst context =
    96   let
    97     fun remove_oconst_and_variants context oconst =
    98       let
    99         val remove_variants =
   100           (case get_variants (Context.proof_of context) oconst of
   101             NONE => I
   102           | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
   103       in map_tables (Symtab.delete_safe oconst) remove_variants context end;
   104   in
   105     if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
   106     else err_not_overloaded oconst
   107   end;
   108 
   109 local
   110   fun generic_variant add oconst t context =
   111     let
   112       val ctxt = Context.proof_of context;
   113       val _ = if is_overloaded ctxt oconst then () else err_not_overloaded oconst;
   114       val T = t |> fastype_of;
   115       val t' = Term.map_types (K dummyT) t;
   116     in
   117       if add then
   118         let
   119           val _ =
   120             (case get_overloaded ctxt t' of
   121               NONE => ()
   122             | SOME oconst' => err_duplicate_variant oconst');
   123         in
   124           map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
   125         end
   126       else
   127         let
   128           val _ =
   129             if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
   130             else err_not_a_variant oconst;
   131         in
   132           map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
   133             (Termtab.delete_safe t') context
   134           |> (fn context =>
   135             (case get_variants (Context.proof_of context) oconst of
   136               SOME [] => generic_remove_overloaded oconst context
   137             | _ => context))
   138         end
   139     end;
   140 in
   141   val generic_add_variant = generic_variant true;
   142   val generic_remove_variant = generic_variant false;
   143 end;
   144 
   145 
   146 (* check / uncheck *)
   147 
   148 fun unifiable_with thy T1 T2 =
   149   let
   150     val maxidx1 = Term.maxidx_of_typ T1;
   151     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
   152     val maxidx2 = Term.maxidx_typ T2' maxidx1;
   153   in can (Sign.typ_unify thy (T1, T2')) (Vartab.empty, maxidx2) end;
   154 
   155 fun get_candidates ctxt (c, T) =
   156   get_variants ctxt c
   157   |> Option.map (map_filter (fn (t, T') =>
   158     if unifiable_with (Proof_Context.theory_of ctxt) T T' then SOME t
   159     else NONE));
   160 
   161 fun insert_variants ctxt t (oconst as Const (c, T)) =
   162       (case get_candidates ctxt (c, T) of
   163         SOME [] => err_unresolved_overloading ctxt (c, T) t []
   164       | SOME [variant] => variant
   165       | _ => oconst)
   166   | insert_variants _ _ oconst = oconst;
   167 
   168 fun insert_overloaded ctxt =
   169   let
   170     fun proc t =
   171       Term.map_types (K dummyT) t
   172       |> get_overloaded ctxt
   173       |> Option.map (Const o rpair (Term.type_of t));
   174   in
   175     Pattern.rewrite_term_top (Proof_Context.theory_of ctxt) [] [proc]
   176   end;
   177 
   178 fun check ctxt =
   179   map (fn t => Term.map_aterms (insert_variants ctxt t) t);
   180 
   181 fun uncheck ctxt ts =
   182   if Config.get ctxt show_variants orelse exists (is_none o try Term.type_of) ts then ts
   183   else map (insert_overloaded ctxt) ts;
   184 
   185 fun reject_unresolved ctxt =
   186   let
   187     val the_candidates = the o get_candidates ctxt;
   188     fun check_unresolved t =
   189       (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
   190         [] => t
   191       | (cT :: _) => err_unresolved_overloading ctxt cT t (the_candidates cT));
   192   in map check_unresolved end;
   193 
   194 
   195 (* setup *)
   196 
   197 val _ = Context.>>
   198   (Syntax_Phases.term_check 0 "adhoc_overloading" check
   199    #> Syntax_Phases.term_check 1 "adhoc_overloading_unresolved_check" reject_unresolved
   200    #> Syntax_Phases.term_uncheck 0 "adhoc_overloading" uncheck);
   201 
   202 
   203 (* commands *)
   204 
   205 fun generic_adhoc_overloading_cmd add =
   206   if add then
   207     fold (fn (oconst, ts) =>
   208       generic_add_overloaded oconst
   209       #> fold (generic_add_variant oconst) ts)
   210   else
   211     fold (fn (oconst, ts) =>
   212       fold (generic_remove_variant oconst) ts);
   213 
   214 fun adhoc_overloading_cmd' add args phi =
   215   let val args' = args
   216     |> map (apsnd (map_filter (fn t =>
   217          let val t' = Morphism.term phi t;
   218          in if Term.aconv_untyped (t, t') then SOME t' else NONE end)));
   219   in generic_adhoc_overloading_cmd add args' end;
   220 
   221 fun adhoc_overloading_cmd add raw_args lthy =
   222   let
   223     fun const_name ctxt =
   224       fst o dest_Const o Proof_Context.read_const {proper = false, strict = false} ctxt;  (* FIXME {proper = true, strict = true} (!?) *)
   225     fun read_term ctxt = singleton (Variable.polymorphic ctxt) o Syntax.read_term ctxt;
   226     val args =
   227       raw_args
   228       |> map (apfst (const_name lthy))
   229       |> map (apsnd (map (read_term lthy)));
   230   in
   231     Local_Theory.declaration {syntax = true, pervasive = false}
   232       (adhoc_overloading_cmd' add args) lthy
   233   end;
   234 
   235 val _ =
   236   Outer_Syntax.local_theory @{command_keyword adhoc_overloading}
   237     "add adhoc overloading for constants / fixed variables"
   238     (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true);
   239 
   240 val _ =
   241   Outer_Syntax.local_theory @{command_keyword no_adhoc_overloading}
   242     "delete adhoc overloading for constants / fixed variables"
   243     (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false);
   244 
   245 end;
   246