src/Pure/Isar/overloading.ML
author wenzelm
Sun Mar 01 23:36:12 2009 +0100 (2009-03-01)
changeset 30190 479806475f3c
parent 29606 fedb8be05f24
child 30519 c05c0199826f
permissions -rw-r--r--
use long names for old-style fold combinators;
haftmann@25519
     1
(*  Title:      Pure/Isar/overloading.ML
haftmann@25519
     2
    Author:     Florian Haftmann, TU Muenchen
haftmann@25519
     3
haftmann@25519
     4
Overloaded definitions without any discipline.
haftmann@25519
     5
*)
haftmann@25519
     6
haftmann@25519
     7
signature OVERLOADING =
haftmann@25519
     8
sig
haftmann@25861
     9
  val init: (string * (string * typ) * bool) list -> theory -> local_theory
haftmann@25519
    10
  val conclude: local_theory -> local_theory
haftmann@25519
    11
  val declare: string * typ -> theory -> term * theory
haftmann@25519
    12
  val confirm: string -> local_theory -> local_theory
haftmann@25519
    13
  val define: bool -> string -> string * term -> theory -> thm * theory
haftmann@25519
    14
  val operation: Proof.context -> string -> (string * bool) option
haftmann@25606
    15
  val pretty: Proof.context -> Pretty.T
haftmann@26238
    16
  
haftmann@26238
    17
  type improvable_syntax
haftmann@26238
    18
  val add_improvable_syntax: Proof.context -> Proof.context
haftmann@26238
    19
  val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
haftmann@26238
    20
    -> Proof.context -> Proof.context
haftmann@26520
    21
  val set_primary_constraints: Proof.context -> Proof.context
haftmann@25519
    22
end;
haftmann@25519
    23
haftmann@25519
    24
structure Overloading: OVERLOADING =
haftmann@25519
    25
struct
haftmann@25519
    26
haftmann@26259
    27
(** generic check/uncheck combinators for improvable constants **)
haftmann@26238
    28
haftmann@26249
    29
type improvable_syntax = ((((string * typ) list * (string * typ) list) *
haftmann@26730
    30
  ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
haftmann@26249
    31
    (term * term) list)) * bool);
haftmann@25519
    32
haftmann@26238
    33
structure ImprovableSyntax = ProofDataFun(
haftmann@26249
    34
  type T = {
haftmann@26520
    35
    primary_constraints: (string * typ) list,
haftmann@26520
    36
    secondary_constraints: (string * typ) list,
haftmann@26249
    37
    improve: string * typ -> (typ * typ) option,
haftmann@26249
    38
    subst: string * typ -> (typ * term) option,
haftmann@26730
    39
    consider_abbrevs: bool,
haftmann@26249
    40
    unchecks: (term * term) list,
haftmann@26249
    41
    passed: bool
haftmann@26249
    42
  };
haftmann@26238
    43
  fun init _ = {
haftmann@26520
    44
    primary_constraints = [],
haftmann@26520
    45
    secondary_constraints = [],
haftmann@26238
    46
    improve = K NONE,
haftmann@26238
    47
    subst = K NONE,
haftmann@26730
    48
    consider_abbrevs = false,
haftmann@26238
    49
    unchecks = [],
haftmann@26238
    50
    passed = true
haftmann@26238
    51
  };
haftmann@26238
    52
);
haftmann@25536
    53
haftmann@26520
    54
fun map_improvable_syntax f = ImprovableSyntax.map (fn { primary_constraints,
haftmann@26730
    55
  secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed } => let
haftmann@26730
    56
    val (((primary_constraints', secondary_constraints'),
haftmann@26730
    57
      (((improve', subst'), consider_abbrevs'), unchecks')), passed')
haftmann@26730
    58
        = f (((primary_constraints, secondary_constraints),
haftmann@26730
    59
            (((improve, subst), consider_abbrevs), unchecks)), passed)
haftmann@26520
    60
  in { primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
haftmann@26730
    61
    improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
haftmann@26730
    62
    unchecks = unchecks', passed = passed'
haftmann@26249
    63
  } end);
haftmann@26238
    64
haftmann@26249
    65
val mark_passed = (map_improvable_syntax o apsnd) (K true);
haftmann@26238
    66
haftmann@26238
    67
fun improve_term_check ts ctxt =
haftmann@25519
    68
  let
haftmann@26730
    69
    val { primary_constraints, secondary_constraints, improve, subst,
haftmann@26730
    70
      consider_abbrevs, passed, ... } = ImprovableSyntax.get ctxt;
haftmann@26238
    71
    val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
wenzelm@27285
    72
    val is_abbrev = consider_abbrevs andalso ProofContext.abbrev_mode ctxt;
haftmann@26730
    73
    val passed_or_abbrev = passed orelse is_abbrev;
haftmann@26238
    74
    fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
haftmann@26597
    75
         of SOME ty_ty' => Type.typ_match tsig ty_ty'
haftmann@26238
    76
          | _ => I)
haftmann@26238
    77
      | accumulate_improvements _ = I;
haftmann@26238
    78
    val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
haftmann@26238
    79
    val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
haftmann@26238
    80
    fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
haftmann@26259
    81
         of SOME (ty', t') =>
haftmann@26238
    82
              if Type.typ_instance tsig (ty, ty')
haftmann@26238
    83
              then SOME (ty', apply_subst t') else NONE
haftmann@26238
    84
          | NONE => NONE)
haftmann@26259
    85
        | _ => NONE) t;
haftmann@26730
    86
    val ts'' = if is_abbrev then ts' else map apply_subst ts';
haftmann@26730
    87
  in if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
haftmann@26730
    88
    if passed_or_abbrev then SOME (ts'', ctxt)
haftmann@26238
    89
    else SOME (ts'', ctxt
haftmann@26520
    90
      |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
haftmann@26238
    91
      |> mark_passed)
haftmann@26238
    92
  end;
haftmann@25519
    93
haftmann@26238
    94
fun improve_term_uncheck ts ctxt =
haftmann@25519
    95
  let
haftmann@26238
    96
    val thy = ProofContext.theory_of ctxt;
haftmann@26238
    97
    val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
haftmann@26238
    98
    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
haftmann@26238
    99
  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
haftmann@26238
   100
haftmann@26520
   101
fun set_primary_constraints ctxt =
haftmann@26259
   102
  let
haftmann@26520
   103
    val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
haftmann@26520
   104
  in fold (ProofContext.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
haftmann@26259
   105
haftmann@26259
   106
val add_improvable_syntax =
haftmann@26259
   107
  Context.proof_map
haftmann@26238
   108
    (Syntax.add_term_check 0 "improvement" improve_term_check
haftmann@26238
   109
    #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
haftmann@26520
   110
  #> set_primary_constraints;
haftmann@26259
   111
haftmann@26259
   112
haftmann@26259
   113
(** overloading target **)
haftmann@26259
   114
haftmann@26259
   115
(* bookkeeping *)
haftmann@26259
   116
haftmann@26259
   117
structure OverloadingData = ProofDataFun
haftmann@26259
   118
(
haftmann@26259
   119
  type T = ((string * typ) * (string * bool)) list;
haftmann@26259
   120
  fun init _ = [];
haftmann@26259
   121
);
haftmann@26259
   122
haftmann@26259
   123
val get_overloading = OverloadingData.get o LocalTheory.target_of;
haftmann@26259
   124
val map_overloading = LocalTheory.target o OverloadingData.map;
haftmann@26259
   125
haftmann@26259
   126
fun operation lthy v = get_overloading lthy
haftmann@26259
   127
  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
haftmann@26259
   128
haftmann@26259
   129
fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
haftmann@26259
   130
haftmann@26259
   131
haftmann@26259
   132
(* overloaded declarations and definitions *)
haftmann@26259
   133
haftmann@26259
   134
fun declare c_ty = pair (Const c_ty);
haftmann@26259
   135
haftmann@29579
   136
fun define checked name (c, t) = Thm.add_def (not checked) true (Binding.name name,
haftmann@29579
   137
  Logic.mk_equals (Const (c, Term.fastype_of t), t));
haftmann@25519
   138
haftmann@25519
   139
haftmann@25519
   140
(* target *)
haftmann@25519
   141
haftmann@26238
   142
fun init raw_overloading thy =
haftmann@25519
   143
  let
haftmann@26238
   144
    val _ = if null raw_overloading then error "At least one parameter must be given" else ();
haftmann@26238
   145
    val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
haftmann@26238
   146
    fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
haftmann@26238
   147
     of SOME (v, _) => SOME (ty, Free (v, ty))
haftmann@26238
   148
      | NONE => NONE;
haftmann@26238
   149
    val unchecks =
haftmann@26238
   150
      map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
haftmann@25519
   151
  in
haftmann@25519
   152
    thy
haftmann@25519
   153
    |> ProofContext.init
haftmann@26238
   154
    |> OverloadingData.put overloading
haftmann@26259
   155
    |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
haftmann@26730
   156
    |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
haftmann@26238
   157
    |> add_improvable_syntax
haftmann@25519
   158
  end;
haftmann@25519
   159
haftmann@25519
   160
fun conclude lthy =
haftmann@25519
   161
  let
haftmann@25519
   162
    val overloading = get_overloading lthy;
haftmann@25519
   163
    val _ = if null overloading then () else
haftmann@26259
   164
      error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
haftmann@25519
   165
        o Syntax.string_of_term lthy o Const o fst) overloading));
haftmann@25519
   166
  in
haftmann@25519
   167
    lthy
haftmann@25519
   168
  end;
haftmann@25519
   169
haftmann@25606
   170
fun pretty lthy =
haftmann@25606
   171
  let
haftmann@25606
   172
    val thy = ProofContext.theory_of lthy;
haftmann@25606
   173
    val overloading = get_overloading lthy;
haftmann@25606
   174
    fun pr_operation ((c, ty), (v, _)) =
haftmann@25861
   175
      (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
wenzelm@26939
   176
        Pretty.str (Sign.extern_const thy c), Pretty.str "::", Syntax.pretty_typ lthy ty];
haftmann@25606
   177
  in
haftmann@25606
   178
    (Pretty.block o Pretty.fbreaks)
haftmann@25606
   179
      (Pretty.str "overloading" :: map pr_operation overloading)
haftmann@25606
   180
  end;
haftmann@25606
   181
haftmann@25519
   182
end;