src/HOL/Library/adhoc_overloading.ML
author haftmann
Wed Jul 18 20:51:21 2018 +0200 (12 months ago)
changeset 68658 16cc1161ad7f
parent 68061 81d90f830f99
child 69593 3dda49e08b9d
permissions -rw-r--r--
tuned equation
wenzelm@53538
     1
(*  Author:     Alexander Krauss, TU Muenchen
wenzelm@53538
     2
    Author:     Christian Sternagel, University of Innsbruck
krauss@37789
     3
Christian@52893
     4
Adhoc overloading of constants based on their types.
krauss@37789
     5
*)
krauss@37789
     6
krauss@37789
     7
signature ADHOC_OVERLOADING =
krauss@37789
     8
sig
wenzelm@52622
     9
  val is_overloaded: Proof.context -> string -> bool
wenzelm@52622
    10
  val generic_add_overloaded: string -> Context.generic -> Context.generic
wenzelm@52622
    11
  val generic_remove_overloaded: string -> Context.generic -> Context.generic
wenzelm@52622
    12
  val generic_add_variant: string -> term -> Context.generic -> Context.generic
wenzelm@52622
    13
  (*If the list of variants is empty at the end of "generic_remove_variant", then
wenzelm@52622
    14
  "generic_remove_overloaded" is called implicitly.*)
wenzelm@52622
    15
  val generic_remove_variant: string -> term -> Context.generic -> Context.generic
wenzelm@45422
    16
  val show_variants: bool Config.T
krauss@37789
    17
end
krauss@37789
    18
krauss@37789
    19
structure Adhoc_Overloading: ADHOC_OVERLOADING =
krauss@37789
    20
struct
krauss@37789
    21
wenzelm@45422
    22
val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
krauss@37789
    23
wenzelm@53538
    24
krauss@37789
    25
(* errors *)
krauss@37789
    26
Christian@53008
    27
fun err_duplicate_variant oconst =
wenzelm@52622
    28
  error ("Duplicate variant of " ^ quote oconst);
krauss@37789
    29
Christian@53008
    30
fun err_not_a_variant oconst =
wenzelm@52622
    31
  error ("Not a variant of " ^  quote oconst);
krauss@37789
    32
Christian@53008
    33
fun err_not_overloaded oconst =
wenzelm@52622
    34
  error ("Constant " ^ quote oconst ^ " is not declared as overloaded");
krauss@37789
    35
wenzelm@53540
    36
fun err_unresolved_overloading ctxt0 (c, T) t instances =
wenzelm@53540
    37
  let
wenzelm@53540
    38
    val ctxt = Config.put show_variants true ctxt0
wenzelm@53540
    39
    val const_space = Proof_Context.const_space ctxt
wenzelm@53540
    40
    val prt_const =
wenzelm@53540
    41
      Pretty.block [Name_Space.pretty ctxt const_space c, Pretty.str " ::", Pretty.brk 1,
wenzelm@53540
    42
        Pretty.quote (Syntax.pretty_typ ctxt T)]
Christian@52688
    43
  in
wenzelm@53537
    44
    error (Pretty.string_of (Pretty.chunks [
wenzelm@53537
    45
      Pretty.block [
wenzelm@53537
    46
        Pretty.str "Unresolved adhoc overloading of constant", Pretty.brk 1,
wenzelm@53540
    47
        prt_const, Pretty.brk 1, Pretty.str "in term", Pretty.brk 1,
wenzelm@53540
    48
        Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t)]],
wenzelm@53537
    49
      Pretty.block (
wenzelm@53537
    50
        (if null instances then [Pretty.str "no instances"]
wenzelm@53537
    51
        else Pretty.fbreaks (
wenzelm@53537
    52
          Pretty.str "multiple instances:" ::
wenzelm@53540
    53
          map (Pretty.mark Markup.item o Syntax.pretty_term ctxt) instances)))]))
Christian@52688
    54
  end;
krauss@37789
    55
wenzelm@53538
    56
wenzelm@52622
    57
(* generic data *)
krauss@37789
    58
wenzelm@52622
    59
fun variants_eq ((v1, T1), (v2, T2)) =
wenzelm@52622
    60
  Term.aconv_untyped (v1, v2) andalso T1 = T2;
krauss@37789
    61
wenzelm@52622
    62
structure Overload_Data = Generic_Data
krauss@37789
    63
(
krauss@37789
    64
  type T =
wenzelm@52622
    65
    {variants : (term * typ) list Symtab.table,
wenzelm@52622
    66
     oconsts : string Termtab.table};
wenzelm@52622
    67
  val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
krauss@37789
    68
  val extend = I;
krauss@37789
    69
wenzelm@41472
    70
  fun merge
wenzelm@52622
    71
    ({variants = vtab1, oconsts = otab1},
wenzelm@52622
    72
     {variants = vtab2, oconsts = otab2}) : T =
wenzelm@52622
    73
    let
wenzelm@52622
    74
      fun merge_oconsts _ (oconst1, oconst2) =
wenzelm@52622
    75
        if oconst1 = oconst2 then oconst1
Christian@53008
    76
        else err_duplicate_variant oconst1;
wenzelm@52622
    77
    in
wenzelm@52622
    78
      {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
wenzelm@52622
    79
       oconsts = Termtab.join merge_oconsts (otab1, otab2)}
wenzelm@52622
    80
    end;
krauss@37789
    81
);
krauss@37789
    82
krauss@37789
    83
fun map_tables f g =
wenzelm@52622
    84
  Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
wenzelm@52622
    85
    {variants = f vtab, oconsts = g otab});
wenzelm@52622
    86
wenzelm@52622
    87
val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
wenzelm@52622
    88
val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
wenzelm@52622
    89
val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
wenzelm@52622
    90
wenzelm@52622
    91
fun generic_add_overloaded oconst context =
wenzelm@52622
    92
  if is_overloaded (Context.proof_of context) oconst then context
wenzelm@52622
    93
  else map_tables (Symtab.update (oconst, [])) I context;
krauss@37789
    94
wenzelm@52622
    95
fun generic_remove_overloaded oconst context =
wenzelm@52622
    96
  let
wenzelm@52622
    97
    fun remove_oconst_and_variants context oconst =
wenzelm@52622
    98
      let
wenzelm@52622
    99
        val remove_variants =
wenzelm@52622
   100
          (case get_variants (Context.proof_of context) oconst of
wenzelm@52622
   101
            NONE => I
wenzelm@52622
   102
          | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
wenzelm@52622
   103
      in map_tables (Symtab.delete_safe oconst) remove_variants context end;
wenzelm@52622
   104
  in
wenzelm@52622
   105
    if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
Christian@53008
   106
    else err_not_overloaded oconst
wenzelm@52622
   107
  end;
krauss@37789
   108
wenzelm@52622
   109
local
wenzelm@52622
   110
  fun generic_variant add oconst t context =
wenzelm@52622
   111
    let
wenzelm@52622
   112
      val ctxt = Context.proof_of context;
Christian@53008
   113
      val _ = if is_overloaded ctxt oconst then () else err_not_overloaded oconst;
Christian@53005
   114
      val T = t |> fastype_of;
wenzelm@52622
   115
      val t' = Term.map_types (K dummyT) t;
wenzelm@52622
   116
    in
wenzelm@52622
   117
      if add then
wenzelm@52622
   118
        let
wenzelm@52622
   119
          val _ =
wenzelm@52622
   120
            (case get_overloaded ctxt t' of
wenzelm@52622
   121
              NONE => ()
Christian@53008
   122
            | SOME oconst' => err_duplicate_variant oconst');
wenzelm@52622
   123
        in
wenzelm@52622
   124
          map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
wenzelm@52622
   125
        end
wenzelm@52622
   126
      else
wenzelm@52622
   127
        let
wenzelm@52622
   128
          val _ =
wenzelm@52622
   129
            if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
Christian@53008
   130
            else err_not_a_variant oconst;
wenzelm@52622
   131
        in
wenzelm@52622
   132
          map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
wenzelm@52622
   133
            (Termtab.delete_safe t') context
wenzelm@52622
   134
          |> (fn context =>
wenzelm@52622
   135
            (case get_variants (Context.proof_of context) oconst of
wenzelm@52622
   136
              SOME [] => generic_remove_overloaded oconst context
wenzelm@52622
   137
            | _ => context))
wenzelm@52622
   138
        end
wenzelm@52622
   139
    end;
wenzelm@52622
   140
in
wenzelm@52622
   141
  val generic_add_variant = generic_variant true;
wenzelm@52622
   142
  val generic_remove_variant = generic_variant false;
wenzelm@52622
   143
end;
krauss@37789
   144
wenzelm@53538
   145
krauss@37789
   146
(* check / uncheck *)
krauss@37789
   147
Christian@53004
   148
fun unifiable_with thy T1 T2 =
krauss@37789
   149
  let
krauss@37789
   150
    val maxidx1 = Term.maxidx_of_typ T1;
krauss@37789
   151
    val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
wenzelm@52622
   152
    val maxidx2 = Term.maxidx_typ T2' maxidx1;
Christian@53004
   153
  in can (Sign.typ_unify thy (T1, T2')) (Vartab.empty, maxidx2) end;
krauss@37789
   154
Christian@53006
   155
fun get_candidates ctxt (c, T) =
Christian@53007
   156
  get_variants ctxt c
Christian@53007
   157
  |> Option.map (map_filter (fn (t, T') =>
Christian@53007
   158
    if unifiable_with (Proof_Context.theory_of ctxt) T T' then SOME t
Christian@53007
   159
    else NONE));
Christian@52687
   160
Christian@53007
   161
fun insert_variants ctxt t (oconst as Const (c, T)) =
Christian@53006
   162
      (case get_candidates ctxt (c, T) of
Christian@53008
   163
        SOME [] => err_unresolved_overloading ctxt (c, T) t []
Christian@53007
   164
      | SOME [variant] => variant
Christian@53007
   165
      | _ => oconst)
Christian@53007
   166
  | insert_variants _ _ oconst = oconst;
krauss@37789
   167
Christian@54004
   168
fun insert_overloaded ctxt =
Christian@54004
   169
  let
Christian@54004
   170
    fun proc t =
Christian@54004
   171
      Term.map_types (K dummyT) t
Christian@54004
   172
      |> get_overloaded ctxt
Christian@54004
   173
      |> Option.map (Const o rpair (Term.type_of t));
Christian@54004
   174
  in
wenzelm@55237
   175
    Pattern.rewrite_term_top (Proof_Context.theory_of ctxt) [] [proc]
Christian@54004
   176
  end;
Christian@53007
   177
Christian@53007
   178
fun check ctxt =
Christian@53007
   179
  map (fn t => Term.map_aterms (insert_variants ctxt t) t);
Christian@53007
   180
traytel@54468
   181
fun uncheck ctxt ts =
traytel@54468
   182
  if Config.get ctxt show_variants orelse exists (is_none o try Term.type_of) ts then ts
traytel@54468
   183
  else map (insert_overloaded ctxt) ts;
Christian@53007
   184
Christian@53007
   185
fun reject_unresolved ctxt =
wenzelm@52622
   186
  let
Christian@53007
   187
    val the_candidates = the o get_candidates ctxt;
krauss@37789
   188
    fun check_unresolved t =
wenzelm@52622
   189
      (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
Christian@53007
   190
        [] => t
Christian@53008
   191
      | (cT :: _) => err_unresolved_overloading ctxt cT t (the_candidates cT));
Christian@53007
   192
  in map check_unresolved end;
krauss@37789
   193
wenzelm@53538
   194
krauss@37789
   195
(* setup *)
krauss@37789
   196
wenzelm@52622
   197
val _ = Context.>>
Christian@53007
   198
  (Syntax_Phases.term_check 0 "adhoc_overloading" check
Christian@53007
   199
   #> Syntax_Phases.term_check 1 "adhoc_overloading_unresolved_check" reject_unresolved
Christian@53007
   200
   #> Syntax_Phases.term_uncheck 0 "adhoc_overloading" uncheck);
krauss@37789
   201
wenzelm@53538
   202
wenzelm@52622
   203
(* commands *)
wenzelm@52622
   204
wenzelm@52622
   205
fun generic_adhoc_overloading_cmd add =
wenzelm@52622
   206
  if add then
wenzelm@52622
   207
    fold (fn (oconst, ts) =>
wenzelm@52622
   208
      generic_add_overloaded oconst
wenzelm@52622
   209
      #> fold (generic_add_variant oconst) ts)
wenzelm@52622
   210
  else
wenzelm@52622
   211
    fold (fn (oconst, ts) =>
wenzelm@52622
   212
      fold (generic_remove_variant oconst) ts);
wenzelm@52622
   213
wenzelm@52622
   214
fun adhoc_overloading_cmd' add args phi =
wenzelm@52622
   215
  let val args' = args
wenzelm@52622
   216
    |> map (apsnd (map_filter (fn t =>
wenzelm@52622
   217
         let val t' = Morphism.term phi t;
wenzelm@52622
   218
         in if Term.aconv_untyped (t, t') then SOME t' else NONE end)));
wenzelm@52622
   219
  in generic_adhoc_overloading_cmd add args' end;
wenzelm@52622
   220
wenzelm@52622
   221
fun adhoc_overloading_cmd add raw_args lthy =
wenzelm@52622
   222
  let
wenzelm@55954
   223
    fun const_name ctxt =
wenzelm@56002
   224
      fst o dest_Const o Proof_Context.read_const {proper = false, strict = false} ctxt;  (* FIXME {proper = true, strict = true} (!?) *)
Christian@53005
   225
    fun read_term ctxt = singleton (Variable.polymorphic ctxt) o Syntax.read_term ctxt;
wenzelm@52622
   226
    val args =
wenzelm@52622
   227
      raw_args
wenzelm@52622
   228
      |> map (apfst (const_name lthy))
Christian@53005
   229
      |> map (apsnd (map (read_term lthy)));
wenzelm@52622
   230
  in
wenzelm@52622
   231
    Local_Theory.declaration {syntax = true, pervasive = false}
wenzelm@52622
   232
      (adhoc_overloading_cmd' add args) lthy
wenzelm@52622
   233
  end;
wenzelm@52622
   234
wenzelm@52622
   235
val _ =
wenzelm@59936
   236
  Outer_Syntax.local_theory @{command_keyword adhoc_overloading}
Christian@52893
   237
    "add adhoc overloading for constants / fixed variables"
wenzelm@52622
   238
    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true);
wenzelm@52622
   239
wenzelm@52622
   240
val _ =
wenzelm@59936
   241
  Outer_Syntax.local_theory @{command_keyword no_adhoc_overloading}
blanchet@59414
   242
    "delete adhoc overloading for constants / fixed variables"
wenzelm@52622
   243
    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false);
wenzelm@52622
   244
wenzelm@45422
   245
end;
wenzelm@52622
   246