src/Pure/Isar/overloading.ML
author wenzelm
Mon May 03 14:25:56 2010 +0200 (2010-05-03)
changeset 36610 bafd82950e24
parent 36354 bbd742107f56
child 38342 09d4a04d5c2e
permissions -rw-r--r--
renamed ProofContext.init to ProofContext.init_global to emphasize that this is not the real thing;
     1 (*  Title:      Pure/Isar/overloading.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Overloaded definitions without any discipline.
     5 *)
     6 
     7 signature OVERLOADING =
     8 sig
     9   val init: (string * (string * typ) * bool) list -> theory -> Proof.context
    10   val conclude: local_theory -> local_theory
    11   val declare: string * typ -> theory -> term * theory
    12   val confirm: binding -> local_theory -> local_theory
    13   val define: bool -> binding -> string * term -> theory -> thm * theory
    14   val operation: Proof.context -> binding -> (string * bool) option
    15   val pretty: Proof.context -> Pretty.T
    16 
    17   type improvable_syntax
    18   val add_improvable_syntax: Proof.context -> Proof.context
    19   val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
    20     -> Proof.context -> Proof.context
    21   val set_primary_constraints: Proof.context -> Proof.context
    22 end;
    23 
    24 structure Overloading: OVERLOADING =
    25 struct
    26 
    27 (** generic check/uncheck combinators for improvable constants **)
    28 
    29 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
    30   ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
    31     (term * term) list)) * bool);
    32 
    33 structure ImprovableSyntax = Proof_Data
    34 (
    35   type T = {
    36     primary_constraints: (string * typ) list,
    37     secondary_constraints: (string * typ) list,
    38     improve: string * typ -> (typ * typ) option,
    39     subst: string * typ -> (typ * term) option,
    40     consider_abbrevs: bool,
    41     unchecks: (term * term) list,
    42     passed: bool
    43   };
    44   fun init _ = {
    45     primary_constraints = [],
    46     secondary_constraints = [],
    47     improve = K NONE,
    48     subst = K NONE,
    49     consider_abbrevs = false,
    50     unchecks = [],
    51     passed = true
    52   };
    53 );
    54 
    55 fun map_improvable_syntax f = ImprovableSyntax.map (fn { primary_constraints,
    56   secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed } => let
    57     val (((primary_constraints', secondary_constraints'),
    58       (((improve', subst'), consider_abbrevs'), unchecks')), passed')
    59         = f (((primary_constraints, secondary_constraints),
    60             (((improve, subst), consider_abbrevs), unchecks)), passed)
    61   in { primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
    62     improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
    63     unchecks = unchecks', passed = passed'
    64   } end);
    65 
    66 val mark_passed = (map_improvable_syntax o apsnd) (K true);
    67 
    68 fun improve_term_check ts ctxt =
    69   let
    70     val { secondary_constraints, improve, subst, consider_abbrevs, passed, ... } =
    71       ImprovableSyntax.get ctxt;
    72     val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
    73     val is_abbrev = consider_abbrevs andalso ProofContext.abbrev_mode ctxt;
    74     val passed_or_abbrev = passed orelse is_abbrev;
    75     fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
    76          of SOME ty_ty' => Type.typ_match tsig ty_ty'
    77           | _ => I)
    78       | accumulate_improvements _ = I;
    79     val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
    80     val ts' = (map o map_types) (Envir.subst_type improvements) ts;
    81     fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
    82          of SOME (ty', t') =>
    83               if Type.typ_instance tsig (ty, ty')
    84               then SOME (ty', apply_subst t') else NONE
    85           | NONE => NONE)
    86         | _ => NONE) t;
    87     val ts'' = if is_abbrev then ts' else map apply_subst ts';
    88   in if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
    89     if passed_or_abbrev then SOME (ts'', ctxt)
    90     else SOME (ts'', ctxt
    91       |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
    92       |> mark_passed)
    93   end;
    94 
    95 fun rewrite_liberal thy unchecks t =
    96   case try (Pattern.rewrite_term thy unchecks []) t
    97    of NONE => NONE
    98     | SOME t' => if t aconv t' then NONE else SOME t';
    99 
   100 fun improve_term_uncheck ts ctxt =
   101   let
   102     val thy = ProofContext.theory_of ctxt;
   103     val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
   104     val ts' = map (rewrite_liberal thy unchecks) ts;
   105   in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
   106 
   107 fun set_primary_constraints ctxt =
   108   let
   109     val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
   110   in fold (ProofContext.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
   111 
   112 val add_improvable_syntax =
   113   Context.proof_map
   114     (Syntax.add_term_check 0 "improvement" improve_term_check
   115     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
   116   #> set_primary_constraints;
   117 
   118 
   119 (** overloading target **)
   120 
   121 (* bookkeeping *)
   122 
   123 structure OverloadingData = Proof_Data
   124 (
   125   type T = ((string * typ) * (string * bool)) list;
   126   fun init _ = [];
   127 );
   128 
   129 val get_overloading = OverloadingData.get o Local_Theory.target_of;
   130 val map_overloading = Local_Theory.target o OverloadingData.map;
   131 
   132 fun operation lthy b = get_overloading lthy
   133   |> get_first (fn ((c, _), (v, checked)) =>
   134       if Binding.name_of b = v then SOME (c, checked) else NONE);
   135 
   136 
   137 (* target *)
   138 
   139 fun synchronize_syntax ctxt =
   140   let
   141     val overloading = OverloadingData.get ctxt;
   142     fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
   143      of SOME (v, _) => SOME (ty, Free (v, ty))
   144       | NONE => NONE;
   145     val unchecks =
   146       map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
   147   in 
   148     ctxt
   149     |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
   150   end
   151 
   152 fun init raw_overloading thy =
   153   let
   154     val _ = if null raw_overloading then error "At least one parameter must be given" else ();
   155     val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
   156   in
   157     thy
   158     |> Theory.checkpoint
   159     |> ProofContext.init_global
   160     |> OverloadingData.put overloading
   161     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
   162     |> add_improvable_syntax
   163     |> synchronize_syntax
   164   end;
   165 
   166 fun declare c_ty = pair (Const c_ty);
   167 
   168 fun define checked b (c, t) =
   169   Thm.add_def (not checked) true (b, Logic.mk_equals (Const (c, Term.fastype_of t), t))
   170   #>> snd;
   171 
   172 fun confirm b = map_overloading (filter_out (fn (_, (c', _)) => c' = Binding.name_of b))
   173   #> Local_Theory.target synchronize_syntax
   174 
   175 fun conclude lthy =
   176   let
   177     val overloading = get_overloading lthy;
   178     val _ = if null overloading then () else
   179       error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
   180         o Syntax.string_of_term lthy o Const o fst) overloading));
   181   in
   182     lthy
   183   end;
   184 
   185 fun pretty lthy =
   186   let
   187     val thy = ProofContext.theory_of lthy;
   188     val overloading = get_overloading lthy;
   189     fun pr_operation ((c, ty), (v, _)) =
   190       (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
   191         Pretty.str (Sign.extern_const thy c), Pretty.str "::", Syntax.pretty_typ lthy ty];
   192   in
   193     (Pretty.block o Pretty.fbreaks)
   194       (Pretty.str "overloading" :: map pr_operation overloading)
   195   end;
   196 
   197 end;