src/Pure/Isar/overloading.ML
author haftmann
Mon Dec 03 16:04:17 2007 +0100 (2007-12-03)
changeset 25519 8570745cb40b
child 25536 01753a944433
permissions -rw-r--r--
overloading target
haftmann@25519
     1
(*  Title:      Pure/Isar/overloading.ML
haftmann@25519
     2
    ID:         $Id$
haftmann@25519
     3
    Author:     Florian Haftmann, TU Muenchen
haftmann@25519
     4
haftmann@25519
     5
Overloaded definitions without any discipline.
haftmann@25519
     6
*)
haftmann@25519
     7
haftmann@25519
     8
signature OVERLOADING =
haftmann@25519
     9
sig
haftmann@25519
    10
  val init: ((string * typ) * (string * bool)) list -> theory -> local_theory
haftmann@25519
    11
  val conclude: local_theory -> local_theory
haftmann@25519
    12
  val declare: string * typ -> theory -> term * theory
haftmann@25519
    13
  val confirm: string -> local_theory -> local_theory
haftmann@25519
    14
  val define: bool -> string -> string * term -> theory -> thm * theory
haftmann@25519
    15
  val operation: Proof.context -> string -> (string * bool) option
haftmann@25519
    16
end;
haftmann@25519
    17
haftmann@25519
    18
structure Overloading: OVERLOADING =
haftmann@25519
    19
struct
haftmann@25519
    20
haftmann@25519
    21
(* bookkeeping *)
haftmann@25519
    22
haftmann@25519
    23
structure OverloadingData = ProofDataFun
haftmann@25519
    24
(
haftmann@25519
    25
  type T = ((string * typ) * (string * bool)) list;
haftmann@25519
    26
  fun init _ = [];
haftmann@25519
    27
);
haftmann@25519
    28
haftmann@25519
    29
val get_overloading = OverloadingData.get o LocalTheory.target_of;
haftmann@25519
    30
val map_overloading = LocalTheory.target o OverloadingData.map;
haftmann@25519
    31
haftmann@25519
    32
fun operation lthy v = get_overloading lthy
haftmann@25519
    33
  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
haftmann@25519
    34
haftmann@25519
    35
fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
haftmann@25519
    36
haftmann@25519
    37
haftmann@25519
    38
(* overloaded declarations and definitions *)
haftmann@25519
    39
haftmann@25519
    40
fun declare c_ty = pair (Const c_ty);
haftmann@25519
    41
haftmann@25519
    42
fun define checked name (c, t) =
haftmann@25519
    43
  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
haftmann@25519
    44
haftmann@25519
    45
haftmann@25519
    46
(* syntax *)
haftmann@25519
    47
haftmann@25519
    48
fun term_check ts lthy =
haftmann@25519
    49
  let
haftmann@25519
    50
    val overloading = get_overloading lthy;
haftmann@25519
    51
    fun subst (t as Const (c, ty)) = (case AList.lookup (op =) overloading (c, ty)
haftmann@25519
    52
         of SOME (v, _) => Free (v, ty)
haftmann@25519
    53
          | NONE => t)
haftmann@25519
    54
      | subst t = t;
haftmann@25519
    55
    val ts' = (map o map_aterms) subst ts;
haftmann@25519
    56
  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
haftmann@25519
    57
haftmann@25519
    58
fun term_uncheck ts lthy =
haftmann@25519
    59
  let
haftmann@25519
    60
    val overloading = get_overloading lthy;
haftmann@25519
    61
    fun subst (t as Free (v, ty)) = (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) overloading
haftmann@25519
    62
         of SOME c => Const (c, ty)
haftmann@25519
    63
          | NONE => t)
haftmann@25519
    64
      | subst t = t;
haftmann@25519
    65
    val ts' = (map o map_aterms) subst ts;
haftmann@25519
    66
  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
haftmann@25519
    67
haftmann@25519
    68
haftmann@25519
    69
(* target *)
haftmann@25519
    70
haftmann@25519
    71
fun init overloading thy =
haftmann@25519
    72
  let
haftmann@25519
    73
    val _ = if null overloading then error "At least one parameter must be given" else ();
haftmann@25519
    74
  in
haftmann@25519
    75
    thy
haftmann@25519
    76
    |> ProofContext.init
haftmann@25519
    77
    |> OverloadingData.put overloading
haftmann@25519
    78
    |> fold (Variable.declare_term o Logic.mk_type o snd o fst) overloading
haftmann@25519
    79
    |> Context.proof_map (
haftmann@25519
    80
        Syntax.add_term_check 0 "overloading" term_check
haftmann@25519
    81
        #> Syntax.add_term_uncheck 0 "overloading" term_uncheck)
haftmann@25519
    82
  end;
haftmann@25519
    83
haftmann@25519
    84
fun conclude lthy =
haftmann@25519
    85
  let
haftmann@25519
    86
    val overloading = get_overloading lthy;
haftmann@25519
    87
    val _ = if null overloading then () else
haftmann@25519
    88
      error ("Missing definition(s) for parameters " ^ commas (map (quote
haftmann@25519
    89
        o Syntax.string_of_term lthy o Const o fst) overloading));
haftmann@25519
    90
  in
haftmann@25519
    91
    lthy
haftmann@25519
    92
  end;
haftmann@25519
    93
haftmann@25519
    94
end;