src/Pure/Isar/overloading.ML
author haftmann
Wed Aug 11 12:30:48 2010 +0200 (2010-08-11)
changeset 38342 09d4a04d5c2e
parent 36610 bafd82950e24
child 38377 2dfd8b7b8274
permissions -rw-r--r--
moved overloading target formally to overloading.ML
haftmann@25519
     1
(*  Title:      Pure/Isar/overloading.ML
haftmann@25519
     2
    Author:     Florian Haftmann, TU Muenchen
haftmann@25519
     3
haftmann@25519
     4
Overloaded definitions without any discipline.
haftmann@25519
     5
*)
haftmann@25519
     6
haftmann@25519
     7
signature OVERLOADING =
haftmann@25519
     8
sig
haftmann@35126
     9
  val init: (string * (string * typ) * bool) list -> theory -> Proof.context
haftmann@25519
    10
  val conclude: local_theory -> local_theory
haftmann@25519
    11
  val declare: string * typ -> theory -> term * theory
haftmann@30519
    12
  val confirm: binding -> local_theory -> local_theory
haftmann@30519
    13
  val define: bool -> binding -> string * term -> theory -> thm * theory
haftmann@30519
    14
  val operation: Proof.context -> binding -> (string * bool) option
haftmann@25606
    15
  val pretty: Proof.context -> Pretty.T
haftmann@35859
    16
haftmann@26238
    17
  type improvable_syntax
haftmann@26238
    18
  val add_improvable_syntax: Proof.context -> Proof.context
haftmann@26238
    19
  val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
haftmann@26238
    20
    -> Proof.context -> Proof.context
haftmann@26520
    21
  val set_primary_constraints: Proof.context -> Proof.context
haftmann@38342
    22
haftmann@38342
    23
  val overloading: (string * (string * typ) * bool) list -> theory -> local_theory
haftmann@38342
    24
  val overloading_cmd: (string * string * bool) list -> theory -> local_theory
haftmann@25519
    25
end;
haftmann@25519
    26
haftmann@25519
    27
structure Overloading: OVERLOADING =
haftmann@25519
    28
struct
haftmann@25519
    29
haftmann@26259
    30
(** generic check/uncheck combinators for improvable constants **)
haftmann@26238
    31
haftmann@26249
    32
type improvable_syntax = ((((string * typ) list * (string * typ) list) *
haftmann@26730
    33
  ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
haftmann@26249
    34
    (term * term) list)) * bool);
haftmann@25519
    35
wenzelm@33519
    36
structure ImprovableSyntax = Proof_Data
wenzelm@33519
    37
(
haftmann@26249
    38
  type T = {
haftmann@26520
    39
    primary_constraints: (string * typ) list,
haftmann@26520
    40
    secondary_constraints: (string * typ) list,
haftmann@26249
    41
    improve: string * typ -> (typ * typ) option,
haftmann@26249
    42
    subst: string * typ -> (typ * term) option,
haftmann@26730
    43
    consider_abbrevs: bool,
haftmann@26249
    44
    unchecks: (term * term) list,
haftmann@26249
    45
    passed: bool
haftmann@26249
    46
  };
haftmann@26238
    47
  fun init _ = {
haftmann@26520
    48
    primary_constraints = [],
haftmann@26520
    49
    secondary_constraints = [],
haftmann@26238
    50
    improve = K NONE,
haftmann@26238
    51
    subst = K NONE,
haftmann@26730
    52
    consider_abbrevs = false,
haftmann@26238
    53
    unchecks = [],
haftmann@26238
    54
    passed = true
haftmann@26238
    55
  };
haftmann@26238
    56
);
haftmann@25536
    57
haftmann@26520
    58
fun map_improvable_syntax f = ImprovableSyntax.map (fn { primary_constraints,
haftmann@26730
    59
  secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed } => let
haftmann@26730
    60
    val (((primary_constraints', secondary_constraints'),
haftmann@26730
    61
      (((improve', subst'), consider_abbrevs'), unchecks')), passed')
haftmann@26730
    62
        = f (((primary_constraints, secondary_constraints),
haftmann@26730
    63
            (((improve, subst), consider_abbrevs), unchecks)), passed)
haftmann@26520
    64
  in { primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
haftmann@26730
    65
    improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
haftmann@26730
    66
    unchecks = unchecks', passed = passed'
haftmann@26249
    67
  } end);
haftmann@26238
    68
haftmann@26249
    69
val mark_passed = (map_improvable_syntax o apsnd) (K true);
haftmann@26238
    70
haftmann@26238
    71
fun improve_term_check ts ctxt =
haftmann@25519
    72
  let
wenzelm@36354
    73
    val { secondary_constraints, improve, subst, consider_abbrevs, passed, ... } =
wenzelm@36354
    74
      ImprovableSyntax.get ctxt;
haftmann@26238
    75
    val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
wenzelm@27285
    76
    val is_abbrev = consider_abbrevs andalso ProofContext.abbrev_mode ctxt;
haftmann@26730
    77
    val passed_or_abbrev = passed orelse is_abbrev;
haftmann@26238
    78
    fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
haftmann@26597
    79
         of SOME ty_ty' => Type.typ_match tsig ty_ty'
haftmann@26238
    80
          | _ => I)
haftmann@26238
    81
      | accumulate_improvements _ = I;
haftmann@26238
    82
    val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
wenzelm@32035
    83
    val ts' = (map o map_types) (Envir.subst_type improvements) ts;
haftmann@26238
    84
    fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
haftmann@26259
    85
         of SOME (ty', t') =>
haftmann@26238
    86
              if Type.typ_instance tsig (ty, ty')
haftmann@26238
    87
              then SOME (ty', apply_subst t') else NONE
haftmann@26238
    88
          | NONE => NONE)
haftmann@26259
    89
        | _ => NONE) t;
haftmann@26730
    90
    val ts'' = if is_abbrev then ts' else map apply_subst ts';
haftmann@26730
    91
  in if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
haftmann@26730
    92
    if passed_or_abbrev then SOME (ts'', ctxt)
haftmann@26238
    93
    else SOME (ts'', ctxt
haftmann@26520
    94
      |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
haftmann@26238
    95
      |> mark_passed)
haftmann@26238
    96
  end;
haftmann@25519
    97
haftmann@31698
    98
fun rewrite_liberal thy unchecks t =
haftmann@31698
    99
  case try (Pattern.rewrite_term thy unchecks []) t
haftmann@31698
   100
   of NONE => NONE
haftmann@31698
   101
    | SOME t' => if t aconv t' then NONE else SOME t';
haftmann@31698
   102
haftmann@26238
   103
fun improve_term_uncheck ts ctxt =
haftmann@25519
   104
  let
haftmann@26238
   105
    val thy = ProofContext.theory_of ctxt;
haftmann@26238
   106
    val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
haftmann@31698
   107
    val ts' = map (rewrite_liberal thy unchecks) ts;
haftmann@31698
   108
  in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
haftmann@26238
   109
haftmann@26520
   110
fun set_primary_constraints ctxt =
haftmann@26259
   111
  let
haftmann@26520
   112
    val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
haftmann@26520
   113
  in fold (ProofContext.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
haftmann@26259
   114
haftmann@26259
   115
val add_improvable_syntax =
haftmann@26259
   116
  Context.proof_map
haftmann@26238
   117
    (Syntax.add_term_check 0 "improvement" improve_term_check
haftmann@26238
   118
    #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
haftmann@26520
   119
  #> set_primary_constraints;
haftmann@26259
   120
haftmann@26259
   121
haftmann@26259
   122
(** overloading target **)
haftmann@26259
   123
haftmann@26259
   124
(* bookkeeping *)
haftmann@26259
   125
wenzelm@33519
   126
structure OverloadingData = Proof_Data
haftmann@26259
   127
(
haftmann@26259
   128
  type T = ((string * typ) * (string * bool)) list;
haftmann@26259
   129
  fun init _ = [];
haftmann@26259
   130
);
haftmann@26259
   131
wenzelm@33671
   132
val get_overloading = OverloadingData.get o Local_Theory.target_of;
wenzelm@33671
   133
val map_overloading = Local_Theory.target o OverloadingData.map;
haftmann@26259
   134
haftmann@30519
   135
fun operation lthy b = get_overloading lthy
haftmann@30519
   136
  |> get_first (fn ((c, _), (v, checked)) =>
haftmann@30519
   137
      if Binding.name_of b = v then SOME (c, checked) else NONE);
haftmann@26259
   138
haftmann@32343
   139
haftmann@32343
   140
(* target *)
haftmann@26259
   141
haftmann@32343
   142
fun synchronize_syntax ctxt =
haftmann@32343
   143
  let
haftmann@32343
   144
    val overloading = OverloadingData.get ctxt;
haftmann@32343
   145
    fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
haftmann@32343
   146
     of SOME (v, _) => SOME (ty, Free (v, ty))
haftmann@32343
   147
      | NONE => NONE;
haftmann@32343
   148
    val unchecks =
haftmann@32343
   149
      map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
haftmann@32343
   150
  in 
haftmann@32343
   151
    ctxt
haftmann@32343
   152
    |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
haftmann@32343
   153
  end
haftmann@26259
   154
haftmann@32343
   155
fun init raw_overloading thy =
haftmann@32343
   156
  let
haftmann@32343
   157
    val _ = if null raw_overloading then error "At least one parameter must be given" else ();
haftmann@32343
   158
    val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
haftmann@32343
   159
  in
haftmann@32343
   160
    thy
haftmann@32379
   161
    |> Theory.checkpoint
wenzelm@36610
   162
    |> ProofContext.init_global
haftmann@32343
   163
    |> OverloadingData.put overloading
haftmann@32343
   164
    |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
haftmann@32343
   165
    |> add_improvable_syntax
haftmann@32343
   166
    |> synchronize_syntax
haftmann@32343
   167
  end;
haftmann@26259
   168
haftmann@26259
   169
fun declare c_ty = pair (Const c_ty);
haftmann@26259
   170
wenzelm@36106
   171
fun define checked b (c, t) =
wenzelm@36106
   172
  Thm.add_def (not checked) true (b, Logic.mk_equals (Const (c, Term.fastype_of t), t))
wenzelm@36106
   173
  #>> snd;
haftmann@25519
   174
haftmann@32343
   175
fun confirm b = map_overloading (filter_out (fn (_, (c', _)) => c' = Binding.name_of b))
wenzelm@33671
   176
  #> Local_Theory.target synchronize_syntax
haftmann@25519
   177
haftmann@25519
   178
fun conclude lthy =
haftmann@25519
   179
  let
haftmann@25519
   180
    val overloading = get_overloading lthy;
haftmann@25519
   181
    val _ = if null overloading then () else
haftmann@26259
   182
      error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
haftmann@25519
   183
        o Syntax.string_of_term lthy o Const o fst) overloading));
haftmann@25519
   184
  in
haftmann@25519
   185
    lthy
haftmann@25519
   186
  end;
haftmann@25519
   187
haftmann@25606
   188
fun pretty lthy =
haftmann@25606
   189
  let
haftmann@25606
   190
    val thy = ProofContext.theory_of lthy;
haftmann@25606
   191
    val overloading = get_overloading lthy;
haftmann@25606
   192
    fun pr_operation ((c, ty), (v, _)) =
haftmann@25861
   193
      (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
wenzelm@26939
   194
        Pretty.str (Sign.extern_const thy c), Pretty.str "::", Syntax.pretty_typ lthy ty];
haftmann@25606
   195
  in
haftmann@25606
   196
    (Pretty.block o Pretty.fbreaks)
haftmann@25606
   197
      (Pretty.str "overloading" :: map pr_operation overloading)
haftmann@25606
   198
  end;
haftmann@25606
   199
haftmann@38342
   200
fun syntax_error c = error ("Illegal mixfix syntax for overloaded constant " ^ quote c);
haftmann@38342
   201
haftmann@38342
   202
fun overloading_foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params) lthy =
haftmann@38342
   203
  case operation lthy b
haftmann@38342
   204
   of SOME (c, checked) => if mx <> NoSyn then syntax_error c
haftmann@38342
   205
        else lthy |> Local_Theory.theory_result (declare (c, U)
haftmann@38342
   206
            ##>> define checked b_def (c, rhs))
haftmann@38342
   207
          ||> confirm b
haftmann@38342
   208
    | NONE => lthy |>
haftmann@38342
   209
        Generic_Target.theory_foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params);
haftmann@38342
   210
haftmann@38342
   211
fun gen_overloading prep_const raw_ops thy =
haftmann@38342
   212
  let
haftmann@38342
   213
    val ctxt = ProofContext.init_global thy;
haftmann@38342
   214
    val ops = raw_ops |> map (fn (name, const, checked) =>
haftmann@38342
   215
      (name, Term.dest_Const (prep_const ctxt const), checked));
haftmann@38342
   216
  in
haftmann@38342
   217
    thy
haftmann@38342
   218
    |> init ops
haftmann@38342
   219
    |> Local_Theory.init NONE ""
haftmann@38342
   220
       {define = Generic_Target.define overloading_foundation,
haftmann@38342
   221
        notes = Generic_Target.notes
haftmann@38342
   222
          (fn kind => fn global_facts => fn _ => Generic_Target.theory_notes kind global_facts),
haftmann@38342
   223
        abbrev = Generic_Target.abbrev
haftmann@38342
   224
          (fn prmode => fn (b, mx) => fn (t, _) => fn _ =>
haftmann@38342
   225
            Generic_Target.theory_abbrev prmode ((b, mx), t)),
haftmann@38342
   226
        declaration = K Generic_Target.theory_declaration,
haftmann@38342
   227
        syntax_declaration = K Generic_Target.theory_declaration,
haftmann@38342
   228
        pretty = single o pretty,
haftmann@38342
   229
        reinit = gen_overloading prep_const raw_ops o ProofContext.theory_of,
haftmann@38342
   230
        exit = Local_Theory.target_of o conclude}
haftmann@38342
   231
  end;
haftmann@38342
   232
haftmann@38342
   233
val overloading = gen_overloading (fn ctxt => Syntax.check_term ctxt o Const);
haftmann@38342
   234
val overloading_cmd = gen_overloading Syntax.read_term;
haftmann@38342
   235
haftmann@25519
   236
end;