src/Pure/Isar/overloading.ML
changeset 26259 d30f4a509361
parent 26249 59ecf1ce8222
child 26520 9e7b7c478cb1
--- a/src/Pure/Isar/overloading.ML	Wed Mar 12 08:45:51 2008 +0100
+++ b/src/Pure/Isar/overloading.ML	Wed Mar 12 08:47:35 2008 +0100
@@ -19,37 +19,13 @@
   val add_improvable_syntax: Proof.context -> Proof.context
   val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
     -> Proof.context -> Proof.context
+  val set_local_constraints: Proof.context -> Proof.context
 end;
 
 structure Overloading: OVERLOADING =
 struct
 
-(* bookkeeping *)
-
-structure OverloadingData = ProofDataFun
-(
-  type T = ((string * typ) * (string * bool)) list;
-  fun init _ = [];
-);
-
-val get_overloading = OverloadingData.get o LocalTheory.target_of;
-val map_overloading = LocalTheory.target o OverloadingData.map;
-
-fun operation lthy v = get_overloading lthy
-  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
-
-fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
-
-
-(* overloaded declarations and definitions *)
-
-fun declare c_ty = pair (Const c_ty);
-
-fun define checked name (c, t) =
-  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
-
-
-(* generic check/uncheck combinators for improvable constants *)
+(** generic check/uncheck combinators for improvable constants **)
 
 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
   (((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) *
@@ -96,11 +72,11 @@
     val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
     val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
     fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
-         of SOME (ty', t') =>   
+         of SOME (ty', t') =>
               if Type.typ_instance tsig (ty, ty')
               then SOME (ty', apply_subst t') else NONE
           | NONE => NONE)
-      | _ => NONE) t;
+        | _ => NONE) t;
     val ts'' = map apply_subst ts';
   in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
     if passed then SOME (ts'', ctxt)
@@ -116,12 +92,43 @@
     val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
 
-fun add_improvable_syntax ctxt = ctxt
-  |> Context.proof_map
+fun set_local_constraints ctxt =
+  let
+    val { local_constraints, ... } = ImprovableSyntax.get ctxt;
+  in fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints ctxt end;
+
+val add_improvable_syntax =
+  Context.proof_map
     (Syntax.add_term_check 0 "improvement" improve_term_check
     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
-  |> fold (ProofContext.add_const_constraint o apsnd SOME)
-       ((#local_constraints o ImprovableSyntax.get) ctxt);
+  #> set_local_constraints;
+
+
+(** overloading target **)
+
+(* bookkeeping *)
+
+structure OverloadingData = ProofDataFun
+(
+  type T = ((string * typ) * (string * bool)) list;
+  fun init _ = [];
+);
+
+val get_overloading = OverloadingData.get o LocalTheory.target_of;
+val map_overloading = LocalTheory.target o OverloadingData.map;
+
+fun operation lthy v = get_overloading lthy
+  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
+
+fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
+
+
+(* overloaded declarations and definitions *)
+
+fun declare c_ty = pair (Const c_ty);
+
+fun define checked name (c, t) =
+  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
 
 
 (* target *)
@@ -139,7 +146,7 @@
     thy
     |> ProofContext.init
     |> OverloadingData.put overloading
-    |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading
+    |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
     |> map_improvable_syntax (K ((([], []), ((K NONE, subst), unchecks)), false))
     |> add_improvable_syntax
   end;
@@ -148,7 +155,7 @@
   let
     val overloading = get_overloading lthy;
     val _ = if null overloading then () else
-      error ("Missing definition(s) for parameters " ^ commas (map (quote
+      error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
         o Syntax.string_of_term lthy o Const o fst) overloading));
   in
     lthy