src/Pure/Isar/overloading.ML
author haftmann
Fri Mar 07 13:53:07 2008 +0100 (2008-03-07)
changeset 26238 c30bb8182da2
parent 25861 494d9301cc75
child 26249 59ecf1ce8222
permissions -rw-r--r--
generic improvable syntax for targets
haftmann@25519
     1
(*  Title:      Pure/Isar/overloading.ML
haftmann@25519
     2
    ID:         $Id$
haftmann@25519
     3
    Author:     Florian Haftmann, TU Muenchen
haftmann@25519
     4
haftmann@25519
     5
Overloaded definitions without any discipline.
haftmann@25519
     6
*)
haftmann@25519
     7
haftmann@25519
     8
signature OVERLOADING =
haftmann@25519
     9
sig
haftmann@25861
    10
  val init: (string * (string * typ) * bool) list -> theory -> local_theory
haftmann@25519
    11
  val conclude: local_theory -> local_theory
haftmann@25519
    12
  val declare: string * typ -> theory -> term * theory
haftmann@25519
    13
  val confirm: string -> local_theory -> local_theory
haftmann@25519
    14
  val define: bool -> string -> string * term -> theory -> thm * theory
haftmann@25519
    15
  val operation: Proof.context -> string -> (string * bool) option
haftmann@25606
    16
  val pretty: Proof.context -> Pretty.T
haftmann@26238
    17
  
haftmann@26238
    18
  type improvable_syntax
haftmann@26238
    19
  val add_improvable_syntax: Proof.context -> Proof.context
haftmann@26238
    20
  val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
haftmann@26238
    21
    -> Proof.context -> Proof.context
haftmann@25519
    22
end;
haftmann@25519
    23
haftmann@25519
    24
structure Overloading: OVERLOADING =
haftmann@25519
    25
struct
haftmann@25519
    26
haftmann@25519
    27
(* bookkeeping *)
haftmann@25519
    28
haftmann@25519
    29
structure OverloadingData = ProofDataFun
haftmann@25519
    30
(
haftmann@25519
    31
  type T = ((string * typ) * (string * bool)) list;
haftmann@25519
    32
  fun init _ = [];
haftmann@25519
    33
);
haftmann@25519
    34
haftmann@25519
    35
val get_overloading = OverloadingData.get o LocalTheory.target_of;
haftmann@25519
    36
val map_overloading = LocalTheory.target o OverloadingData.map;
haftmann@25519
    37
haftmann@25519
    38
fun operation lthy v = get_overloading lthy
haftmann@25519
    39
  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
haftmann@25519
    40
haftmann@25519
    41
fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
haftmann@25519
    42
haftmann@25519
    43
haftmann@25519
    44
(* overloaded declarations and definitions *)
haftmann@25519
    45
haftmann@25519
    46
fun declare c_ty = pair (Const c_ty);
haftmann@25519
    47
haftmann@25519
    48
fun define checked name (c, t) =
haftmann@25519
    49
  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
haftmann@25519
    50
haftmann@25519
    51
haftmann@26238
    52
(* generic check/uncheck combinators for improvable constants *)
haftmann@26238
    53
haftmann@26238
    54
type improvable_syntax = {
haftmann@26238
    55
  local_constraints: (string * typ) list,
haftmann@26238
    56
  global_constraints: (string * typ) list,
haftmann@26238
    57
  improve: string * typ -> (typ * typ) option,
haftmann@26238
    58
  subst: string * typ -> (typ * term) option,
haftmann@26238
    59
  unchecks: (term * term) list,
haftmann@26238
    60
  passed: bool
haftmann@26238
    61
};
haftmann@25519
    62
haftmann@26238
    63
structure ImprovableSyntax = ProofDataFun(
haftmann@26238
    64
  type T = improvable_syntax;
haftmann@26238
    65
  fun init _ = {
haftmann@26238
    66
    local_constraints = [],
haftmann@26238
    67
    global_constraints = [],
haftmann@26238
    68
    improve = K NONE,
haftmann@26238
    69
    subst = K NONE,
haftmann@26238
    70
    unchecks = [],
haftmann@26238
    71
    passed = true
haftmann@26238
    72
  };
haftmann@26238
    73
);
haftmann@25536
    74
haftmann@26238
    75
val map_improvable_syntax = ImprovableSyntax.map;
haftmann@26238
    76
haftmann@26238
    77
val mark_passed = map_improvable_syntax
haftmann@26238
    78
  (fn { local_constraints, global_constraints, improve, subst, unchecks, passed } =>
haftmann@26238
    79
    { local_constraints = local_constraints, global_constraints = global_constraints,
haftmann@26238
    80
      improve = improve, subst = subst, unchecks = unchecks, passed = true });
haftmann@26238
    81
haftmann@26238
    82
fun improve_term_check ts ctxt =
haftmann@25519
    83
  let
haftmann@26238
    84
    val { local_constraints, global_constraints, improve, subst, passed, ... } =
haftmann@26238
    85
      ImprovableSyntax.get ctxt;
haftmann@26238
    86
    val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
haftmann@26238
    87
    fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
haftmann@26238
    88
         of SOME ty_ty' => (perhaps o try o Type.typ_match tsig) ty_ty'
haftmann@26238
    89
          | _ => I)
haftmann@26238
    90
      | accumulate_improvements _ = I;
haftmann@26238
    91
    val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
haftmann@26238
    92
    val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
haftmann@26238
    93
    fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
haftmann@26238
    94
         of SOME (ty', t') =>   
haftmann@26238
    95
              if Type.typ_instance tsig (ty, ty')
haftmann@26238
    96
              then SOME (ty', apply_subst t') else NONE
haftmann@26238
    97
          | NONE => NONE)
haftmann@26238
    98
      | _ => NONE) t;
haftmann@26238
    99
    val ts'' = map apply_subst ts';
haftmann@26238
   100
  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
haftmann@26238
   101
    if passed then SOME (ts'', ctxt)
haftmann@26238
   102
    else SOME (ts'', ctxt
haftmann@26238
   103
      |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
haftmann@26238
   104
      |> mark_passed)
haftmann@26238
   105
  end;
haftmann@25519
   106
haftmann@26238
   107
fun improve_term_uncheck ts ctxt =
haftmann@25519
   108
  let
haftmann@26238
   109
    val thy = ProofContext.theory_of ctxt;
haftmann@26238
   110
    val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
haftmann@26238
   111
    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
haftmann@26238
   112
  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
haftmann@26238
   113
haftmann@26238
   114
fun add_improvable_syntax ctxt = ctxt
haftmann@26238
   115
  |> Context.proof_map
haftmann@26238
   116
    (Syntax.add_term_check 0 "improvement" improve_term_check
haftmann@26238
   117
    #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
haftmann@26238
   118
  |> fold (ProofContext.add_const_constraint o apsnd SOME)
haftmann@26238
   119
       ((#local_constraints o ImprovableSyntax.get) ctxt);
haftmann@25519
   120
haftmann@25519
   121
haftmann@25519
   122
(* target *)
haftmann@25519
   123
haftmann@26238
   124
fun init raw_overloading thy =
haftmann@25519
   125
  let
haftmann@26238
   126
    val _ = if null raw_overloading then error "At least one parameter must be given" else ();
haftmann@26238
   127
    val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
haftmann@26238
   128
    fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
haftmann@26238
   129
     of SOME (v, _) => SOME (ty, Free (v, ty))
haftmann@26238
   130
      | NONE => NONE;
haftmann@26238
   131
    val unchecks =
haftmann@26238
   132
      map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
haftmann@25519
   133
  in
haftmann@25519
   134
    thy
haftmann@25519
   135
    |> ProofContext.init
haftmann@26238
   136
    |> OverloadingData.put overloading
haftmann@26238
   137
    |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading
haftmann@26238
   138
    |> map_improvable_syntax (K {
haftmann@26238
   139
        local_constraints = [],
haftmann@26238
   140
        global_constraints = [],
haftmann@26238
   141
        improve = K NONE,
haftmann@26238
   142
        subst = subst,
haftmann@26238
   143
        unchecks = unchecks,
haftmann@26238
   144
        passed = false
haftmann@26238
   145
      })
haftmann@26238
   146
    |> add_improvable_syntax
haftmann@25519
   147
  end;
haftmann@25519
   148
haftmann@25519
   149
fun conclude lthy =
haftmann@25519
   150
  let
haftmann@25519
   151
    val overloading = get_overloading lthy;
haftmann@25519
   152
    val _ = if null overloading then () else
haftmann@25519
   153
      error ("Missing definition(s) for parameters " ^ commas (map (quote
haftmann@25519
   154
        o Syntax.string_of_term lthy o Const o fst) overloading));
haftmann@25519
   155
  in
haftmann@25519
   156
    lthy
haftmann@25519
   157
  end;
haftmann@25519
   158
haftmann@25606
   159
fun pretty lthy =
haftmann@25606
   160
  let
haftmann@25606
   161
    val thy = ProofContext.theory_of lthy;
haftmann@25606
   162
    val overloading = get_overloading lthy;
haftmann@25606
   163
    fun pr_operation ((c, ty), (v, _)) =
haftmann@25861
   164
      (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
haftmann@25861
   165
        Pretty.str (Sign.extern_const thy c), Pretty.str "::", Sign.pretty_typ thy ty];
haftmann@25606
   166
  in
haftmann@25606
   167
    (Pretty.block o Pretty.fbreaks)
haftmann@25606
   168
      (Pretty.str "overloading" :: map pr_operation overloading)
haftmann@25606
   169
  end;
haftmann@25606
   170
haftmann@25519
   171
end;