src/Pure/Isar/overloading.ML
author wenzelm
Mon Nov 14 16:52:19 2011 +0100 (2011-11-14 ago)
changeset 45488 6d71d9e52369
parent 45444 ac069060e08a
child 46916 e7ea35b41e2d
permissions -rw-r--r--
pass positions for named targets, for formal links in the document model;
haftmann@25519
     1
(*  Title:      Pure/Isar/overloading.ML
haftmann@25519
     2
    Author:     Florian Haftmann, TU Muenchen
haftmann@25519
     3
haftmann@25519
     4
Overloaded definitions without any discipline.
haftmann@25519
     5
*)
haftmann@25519
     6
haftmann@25519
     7
signature OVERLOADING =
haftmann@25519
     8
sig
haftmann@26238
     9
  type improvable_syntax
haftmann@39378
    10
  val activate_improvable_syntax: Proof.context -> Proof.context
haftmann@26238
    11
  val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
haftmann@26238
    12
    -> Proof.context -> Proof.context
haftmann@26520
    13
  val set_primary_constraints: Proof.context -> Proof.context
haftmann@38342
    14
haftmann@38342
    15
  val overloading: (string * (string * typ) * bool) list -> theory -> local_theory
haftmann@38342
    16
  val overloading_cmd: (string * string * bool) list -> theory -> local_theory
haftmann@25519
    17
end;
haftmann@25519
    18
haftmann@25519
    19
structure Overloading: OVERLOADING =
haftmann@25519
    20
struct
haftmann@25519
    21
wenzelm@42404
    22
(* generic check/uncheck combinators for improvable constants *)
haftmann@26238
    23
haftmann@26249
    24
type improvable_syntax = ((((string * typ) list * (string * typ) list) *
haftmann@26730
    25
  ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
haftmann@26249
    26
    (term * term) list)) * bool);
haftmann@25519
    27
wenzelm@42404
    28
structure Improvable_Syntax = Proof_Data
wenzelm@33519
    29
(
haftmann@26249
    30
  type T = {
haftmann@26520
    31
    primary_constraints: (string * typ) list,
haftmann@26520
    32
    secondary_constraints: (string * typ) list,
haftmann@26249
    33
    improve: string * typ -> (typ * typ) option,
haftmann@26249
    34
    subst: string * typ -> (typ * term) option,
haftmann@26730
    35
    consider_abbrevs: bool,
haftmann@26249
    36
    unchecks: (term * term) list,
haftmann@26249
    37
    passed: bool
haftmann@26249
    38
  };
haftmann@26238
    39
  fun init _ = {
haftmann@26520
    40
    primary_constraints = [],
haftmann@26520
    41
    secondary_constraints = [],
haftmann@26238
    42
    improve = K NONE,
haftmann@26238
    43
    subst = K NONE,
haftmann@26730
    44
    consider_abbrevs = false,
haftmann@26238
    45
    unchecks = [],
haftmann@26238
    46
    passed = true
haftmann@26238
    47
  };
haftmann@26238
    48
);
haftmann@25536
    49
wenzelm@42404
    50
fun map_improvable_syntax f = Improvable_Syntax.map (fn {primary_constraints,
wenzelm@42404
    51
    secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed} =>
wenzelm@40782
    52
  let
haftmann@26730
    53
    val (((primary_constraints', secondary_constraints'),
haftmann@26730
    54
      (((improve', subst'), consider_abbrevs'), unchecks')), passed')
haftmann@26730
    55
        = f (((primary_constraints, secondary_constraints),
haftmann@26730
    56
            (((improve, subst), consider_abbrevs), unchecks)), passed)
wenzelm@42404
    57
  in
wenzelm@42404
    58
   {primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
haftmann@26730
    59
    improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
wenzelm@42404
    60
    unchecks = unchecks', passed = passed'}
wenzelm@40782
    61
  end);
haftmann@26238
    62
haftmann@26249
    63
val mark_passed = (map_improvable_syntax o apsnd) (K true);
haftmann@26238
    64
haftmann@26238
    65
fun improve_term_check ts ctxt =
haftmann@25519
    66
  let
wenzelm@42388
    67
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42388
    68
wenzelm@42404
    69
    val {secondary_constraints, improve, subst, consider_abbrevs, passed, ...} =
wenzelm@42404
    70
      Improvable_Syntax.get ctxt;
wenzelm@42360
    71
    val is_abbrev = consider_abbrevs andalso Proof_Context.abbrev_mode ctxt;
haftmann@26730
    72
    val passed_or_abbrev = passed orelse is_abbrev;
wenzelm@42388
    73
    fun accumulate_improvements (Const (c, ty)) =
wenzelm@42388
    74
          (case improve (c, ty) of
wenzelm@42388
    75
            SOME ty_ty' => Sign.typ_match thy ty_ty'
haftmann@26238
    76
          | _ => I)
haftmann@26238
    77
      | accumulate_improvements _ = I;
haftmann@26238
    78
    val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
wenzelm@32035
    79
    val ts' = (map o map_types) (Envir.subst_type improvements) ts;
wenzelm@42388
    80
    fun apply_subst t =
wenzelm@42388
    81
      Envir.expand_term
wenzelm@42388
    82
        (fn Const (c, ty) =>
wenzelm@42388
    83
          (case subst (c, ty) of
wenzelm@42388
    84
            SOME (ty', t') =>
wenzelm@42388
    85
              if Sign.typ_instance thy (ty, ty')
haftmann@26238
    86
              then SOME (ty', apply_subst t') else NONE
haftmann@26238
    87
          | NONE => NONE)
haftmann@26259
    88
        | _ => NONE) t;
haftmann@26730
    89
    val ts'' = if is_abbrev then ts' else map apply_subst ts';
wenzelm@40782
    90
  in
wenzelm@42404
    91
    if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE
wenzelm@42404
    92
    else if passed_or_abbrev then SOME (ts'', ctxt)
wenzelm@42404
    93
    else
wenzelm@42404
    94
      SOME (ts'', ctxt
wenzelm@42404
    95
        |> fold (Proof_Context.add_const_constraint o apsnd SOME) secondary_constraints
wenzelm@42404
    96
        |> mark_passed)
haftmann@26238
    97
  end;
haftmann@25519
    98
haftmann@31698
    99
fun rewrite_liberal thy unchecks t =
wenzelm@42404
   100
  (case try (Pattern.rewrite_term thy unchecks []) t of
wenzelm@42404
   101
    NONE => NONE
wenzelm@42404
   102
  | SOME t' => if t aconv t' then NONE else SOME t');
haftmann@31698
   103
haftmann@26238
   104
fun improve_term_uncheck ts ctxt =
haftmann@25519
   105
  let
wenzelm@42360
   106
    val thy = Proof_Context.theory_of ctxt;
wenzelm@42404
   107
    val {unchecks, ...} = Improvable_Syntax.get ctxt;
haftmann@31698
   108
    val ts' = map (rewrite_liberal thy unchecks) ts;
haftmann@31698
   109
  in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
haftmann@26238
   110
haftmann@26520
   111
fun set_primary_constraints ctxt =
wenzelm@42404
   112
  let val {primary_constraints, ...} = Improvable_Syntax.get ctxt;
wenzelm@42360
   113
  in fold (Proof_Context.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
haftmann@26259
   114
haftmann@39378
   115
val activate_improvable_syntax =
haftmann@26259
   116
  Context.proof_map
wenzelm@45444
   117
    (Syntax_Phases.term_check' 0 "improvement" improve_term_check
wenzelm@45444
   118
    #> Syntax_Phases.term_uncheck' 0 "improvement" improve_term_uncheck)
haftmann@26520
   119
  #> set_primary_constraints;
haftmann@26259
   120
haftmann@26259
   121
wenzelm@42404
   122
(* overloading target *)
haftmann@26259
   123
haftmann@38382
   124
structure Data = Proof_Data
haftmann@26259
   125
(
haftmann@26259
   126
  type T = ((string * typ) * (string * bool)) list;
haftmann@26259
   127
  fun init _ = [];
haftmann@26259
   128
);
haftmann@26259
   129
haftmann@38382
   130
val get_overloading = Data.get o Local_Theory.target_of;
haftmann@38382
   131
val map_overloading = Local_Theory.target o Data.map;
haftmann@26259
   132
wenzelm@42404
   133
fun operation lthy b =
wenzelm@42404
   134
  get_overloading lthy
haftmann@30519
   135
  |> get_first (fn ((c, _), (v, checked)) =>
haftmann@38382
   136
      if Binding.name_of b = v then SOME (c, (v, checked)) else NONE);
haftmann@26259
   137
haftmann@32343
   138
fun synchronize_syntax ctxt =
haftmann@32343
   139
  let
haftmann@38382
   140
    val overloading = Data.get ctxt;
wenzelm@42404
   141
    fun subst (c, ty) =
wenzelm@42404
   142
      (case AList.lookup (op =) overloading (c, ty) of
wenzelm@42404
   143
        SOME (v, _) => SOME (ty, Free (v, ty))
wenzelm@42404
   144
      | NONE => NONE);
haftmann@32343
   145
    val unchecks =
haftmann@32343
   146
      map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
haftmann@32343
   147
  in 
haftmann@32343
   148
    ctxt
haftmann@32343
   149
    |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
haftmann@32343
   150
  end
haftmann@26259
   151
haftmann@38382
   152
fun define_overloaded (c, U) (v, checked) (b_def, rhs) =
wenzelm@38757
   153
  Local_Theory.background_theory_result
wenzelm@42375
   154
    (Thm.add_def_global (not checked) true
wenzelm@42375
   155
      (b_def, Logic.mk_equals (Const (c, Term.fastype_of rhs), rhs)))
haftmann@38382
   156
  ##> map_overloading (filter_out (fn (_, (v', _)) => v' = v))
haftmann@38382
   157
  ##> Local_Theory.target synchronize_syntax
haftmann@38382
   158
  #-> (fn (_, def) => pair (Const (c, U), def))
haftmann@26259
   159
haftmann@38382
   160
fun foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params) lthy =
wenzelm@42404
   161
  (case operation lthy b of
wenzelm@42404
   162
    SOME (c, (v, checked)) =>
wenzelm@42404
   163
      if mx <> NoSyn
wenzelm@42404
   164
      then error ("Illegal mixfix syntax for overloaded constant " ^ quote c)
wenzelm@42404
   165
      else lthy |> define_overloaded (c, U) (v, checked) (b_def, rhs)
wenzelm@42404
   166
  | NONE => lthy
wenzelm@42404
   167
      |> Generic_Target.theory_foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params));
haftmann@25519
   168
haftmann@25606
   169
fun pretty lthy =
haftmann@25606
   170
  let
haftmann@25606
   171
    val overloading = get_overloading lthy;
haftmann@25606
   172
    fun pr_operation ((c, ty), (v, _)) =
wenzelm@42359
   173
      Pretty.block (Pretty.breaks
wenzelm@42360
   174
        [Pretty.str v, Pretty.str "==", Pretty.str (Proof_Context.extern_const lthy c),
wenzelm@42359
   175
          Pretty.str "::", Syntax.pretty_typ lthy ty]);
haftmann@38382
   176
  in Pretty.str "overloading" :: map pr_operation overloading end;
haftmann@38342
   177
haftmann@38382
   178
fun conclude lthy =
haftmann@38382
   179
  let
haftmann@38382
   180
    val overloading = get_overloading lthy;
wenzelm@40782
   181
    val _ =
wenzelm@40782
   182
      if null overloading then ()
wenzelm@40782
   183
      else
wenzelm@42404
   184
        error ("Missing definition(s) for parameter(s) " ^
wenzelm@42404
   185
          commas_quote (map (Syntax.string_of_term lthy o Const o fst) overloading));
haftmann@38382
   186
  in lthy end;
haftmann@38342
   187
haftmann@38382
   188
fun gen_overloading prep_const raw_overloading thy =
haftmann@38342
   189
  let
wenzelm@42360
   190
    val ctxt = Proof_Context.init_global thy;
haftmann@38382
   191
    val _ = if null raw_overloading then error "At least one parameter must be given" else ();
haftmann@38382
   192
    val overloading = raw_overloading |> map (fn (v, const, checked) =>
haftmann@38382
   193
      (Term.dest_Const (prep_const ctxt const), (v, checked)));
haftmann@38342
   194
  in
haftmann@38342
   195
    thy
haftmann@38382
   196
    |> Theory.checkpoint
wenzelm@42360
   197
    |> Proof_Context.init_global
haftmann@38382
   198
    |> Data.put overloading
haftmann@38382
   199
    |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
haftmann@39378
   200
    |> activate_improvable_syntax
haftmann@38382
   201
    |> synchronize_syntax
haftmann@38342
   202
    |> Local_Theory.init NONE ""
haftmann@38382
   203
       {define = Generic_Target.define foundation,
haftmann@38342
   204
        notes = Generic_Target.notes
haftmann@38342
   205
          (fn kind => fn global_facts => fn _ => Generic_Target.theory_notes kind global_facts),
haftmann@38342
   206
        abbrev = Generic_Target.abbrev
haftmann@38342
   207
          (fn prmode => fn (b, mx) => fn (t, _) => fn _ =>
haftmann@38342
   208
            Generic_Target.theory_abbrev prmode ((b, mx), t)),
wenzelm@45310
   209
        declaration = K Generic_Target.theory_declaration,
haftmann@38382
   210
        pretty = pretty,
haftmann@38342
   211
        exit = Local_Theory.target_of o conclude}
haftmann@38342
   212
  end;
haftmann@38342
   213
haftmann@38342
   214
val overloading = gen_overloading (fn ctxt => Syntax.check_term ctxt o Const);
haftmann@38342
   215
val overloading_cmd = gen_overloading Syntax.read_term;
haftmann@38342
   216
haftmann@25519
   217
end;