better improvement in instantiation target
authorhaftmann
Wed, 12 Mar 2008 08:47:35 +0100
changeset 26259 d30f4a509361
parent 26258 20dfaa28e5e5
child 26260 23ce0d32de11
better improvement in instantiation target
src/HOL/Library/Option_ord.thy
src/HOL/Library/Parity.thy
src/Pure/Isar/class.ML
src/Pure/Isar/overloading.ML
--- a/src/HOL/Library/Option_ord.thy	Wed Mar 12 08:45:51 2008 +0100
+++ b/src/HOL/Library/Option_ord.thy	Wed Mar 12 08:47:35 2008 +0100
@@ -3,7 +3,7 @@
     Author:     Florian Haftmann, TU Muenchen
 *)
 
-header {* Canonical order on option type *}
+header {* Canonical order on @{text option} type *}
 
 theory Option_ord
 imports ATP_Linkup
--- a/src/HOL/Library/Parity.thy	Wed Mar 12 08:45:51 2008 +0100
+++ b/src/HOL/Library/Parity.thy	Wed Mar 12 08:47:35 2008 +0100
@@ -16,19 +16,12 @@
   odd :: "'a\<Colon>even_odd \<Rightarrow> bool" where
   "odd x \<equiv> \<not> even x"
 
-instantiation int  :: even_odd
+instantiation nat and int  :: even_odd
 begin
 
 definition
   even_def [presburger]: "even x \<longleftrightarrow> (x\<Colon>int) mod 2 = 0"
 
-instance ..
-
-end
-
-instantiation nat  :: even_odd
-begin
-
 definition
   even_nat_def [presburger]: "even x \<longleftrightarrow> even (int x)"
 
--- a/src/Pure/Isar/class.ML	Wed Mar 12 08:45:51 2008 +0100
+++ b/src/Pure/Isar/class.ML	Wed Mar 12 08:47:35 2008 +0100
@@ -40,6 +40,7 @@
   val instantiation_param: local_theory -> string -> string option
   val confirm_declaration: string -> local_theory -> local_theory
   val pretty_instantiation: local_theory -> Pretty.T
+  val type_name: string -> string
 
   (*old axclass layer*)
   val axclass_cmd: bstring * xstring list
@@ -433,9 +434,9 @@
   in
     ctxt
     |> fold declare_const local_constraints
-    |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
     |> Overloading.map_improvable_syntax (K (((local_constraints, global_constraints),
         ((improve, subst), unchecks)), false))
+    |> Overloading.set_local_constraints
   end;
 
 fun refresh_syntax class ctxt =
@@ -661,22 +662,15 @@
 
 fun synchronize_inst_syntax ctxt =
   let
-    val Instantiation { arities = (_, _, sorts), params = params } = Instantiation.get ctxt;
+    val Instantiation { arities = (_, _, sort), params = params } = Instantiation.get ctxt;
     val thy = ProofContext.theory_of ctxt;
-    val operations = these_operations thy sorts;
     fun subst_class_typ sort = map_atyps
       (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
-    val local_constraints =
-      (map o apsnd) (subst_class_typ [] o fst o snd) operations;
-    val global_constraints = map_filter (fn (c, (class, (ty, _))) =>
+    val operations = these_operations thy sort;
+    val global_constraints = (*map_filter (fn (c, (class, (ty, _))) =>
       if exists (fn ((c', _), _) => c = c') params
         then SOME (c, subst_class_typ [class] ty)
-        else NONE) operations;
-    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
-         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
-             of SOME (_, ty') => SOME (ty, ty')
-              | NONE => NONE)
-          | NONE => NONE;
+        else NONE) operations;*)[];
           (*| NONE => (case map_filter
                (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params
              of [ty'] => (case Sign.const_typargs thy (c, ty)
@@ -692,9 +686,9 @@
       map (fn ((c, _), v_ty as (_, ty)) => (Free v_ty, Const (c, ty))) params;
   in
     ctxt
-    |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
-    |> Overloading.map_improvable_syntax (K (((local_constraints, global_constraints),
-        ((improve, subst), unchecks)), false))
+    |> Overloading.map_improvable_syntax
+         (fn (((local_constraints, _), ((improve, _), _)), _) =>
+            (((local_constraints, global_constraints), ((improve, subst), unchecks)), false))
   end;
 
 
@@ -715,25 +709,37 @@
     explode #> scan_valids #> implode
   end;
 
+fun type_name "*" = "prod"
+  | type_name "+" = "sum"
+  | type_name s = sanatize_name (NameSpace.base s); (*FIXME*)
+
 fun init_instantiation (tycos, vs, sort) thy =
   let
     val _ = if null tycos then error "At least one arity must be given" else ();
-    val _ = map (the_class_data thy) sort;
-    fun type_name "*" = "prod"
-      | type_name "+" = "sum"
-      | type_name s = sanatize_name (NameSpace.base s); (*FIXME*)
+    fun subst_class_typ sort = map_atyps
+      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
     fun get_param tyco (param, (c, ty)) = if can (AxClass.param_of_inst thy) (c, tyco)
       then NONE else SOME ((c, tyco),
         (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, map TFree vs))) ty));
     val params = map_product get_param tycos (these_params thy sort) |> map_filter I;
+    val operations = these_operations thy sort;
+    val local_constraints = (map o apsnd) (subst_class_typ [] o fst o snd) operations;
+    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
+     of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+         of SOME (_, ty') => SOME (ty, ty')
+          | NONE => NONE)
+      | NONE => NONE;
   in
     thy
     |> ProofContext.init
     |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
     |> fold (Variable.declare_term o Logic.mk_type o TFree) vs
     |> fold (Variable.declare_names o Free o snd) params
+    |> (Overloading.map_improvable_syntax o apfst)
+         (fn ((_, global_constraints), ((_, subst), unchecks)) =>
+            ((local_constraints, global_constraints), ((improve, subst), unchecks)))
+    |> Overloading.add_improvable_syntax
     |> synchronize_inst_syntax
-    |> Overloading.add_improvable_syntax
   end;
 
 fun confirm_declaration c = (map_instantiation o apsnd)
--- a/src/Pure/Isar/overloading.ML	Wed Mar 12 08:45:51 2008 +0100
+++ b/src/Pure/Isar/overloading.ML	Wed Mar 12 08:47:35 2008 +0100
@@ -19,37 +19,13 @@
   val add_improvable_syntax: Proof.context -> Proof.context
   val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
     -> Proof.context -> Proof.context
+  val set_local_constraints: Proof.context -> Proof.context
 end;
 
 structure Overloading: OVERLOADING =
 struct
 
-(* bookkeeping *)
-
-structure OverloadingData = ProofDataFun
-(
-  type T = ((string * typ) * (string * bool)) list;
-  fun init _ = [];
-);
-
-val get_overloading = OverloadingData.get o LocalTheory.target_of;
-val map_overloading = LocalTheory.target o OverloadingData.map;
-
-fun operation lthy v = get_overloading lthy
-  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
-
-fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
-
-
-(* overloaded declarations and definitions *)
-
-fun declare c_ty = pair (Const c_ty);
-
-fun define checked name (c, t) =
-  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
-
-
-(* generic check/uncheck combinators for improvable constants *)
+(** generic check/uncheck combinators for improvable constants **)
 
 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
   (((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) *
@@ -96,11 +72,11 @@
     val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
     val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
     fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
-         of SOME (ty', t') =>   
+         of SOME (ty', t') =>
               if Type.typ_instance tsig (ty, ty')
               then SOME (ty', apply_subst t') else NONE
           | NONE => NONE)
-      | _ => NONE) t;
+        | _ => NONE) t;
     val ts'' = map apply_subst ts';
   in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
     if passed then SOME (ts'', ctxt)
@@ -116,12 +92,43 @@
     val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
 
-fun add_improvable_syntax ctxt = ctxt
-  |> Context.proof_map
+fun set_local_constraints ctxt =
+  let
+    val { local_constraints, ... } = ImprovableSyntax.get ctxt;
+  in fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints ctxt end;
+
+val add_improvable_syntax =
+  Context.proof_map
     (Syntax.add_term_check 0 "improvement" improve_term_check
     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
-  |> fold (ProofContext.add_const_constraint o apsnd SOME)
-       ((#local_constraints o ImprovableSyntax.get) ctxt);
+  #> set_local_constraints;
+
+
+(** overloading target **)
+
+(* bookkeeping *)
+
+structure OverloadingData = ProofDataFun
+(
+  type T = ((string * typ) * (string * bool)) list;
+  fun init _ = [];
+);
+
+val get_overloading = OverloadingData.get o LocalTheory.target_of;
+val map_overloading = LocalTheory.target o OverloadingData.map;
+
+fun operation lthy v = get_overloading lthy
+  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
+
+fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
+
+
+(* overloaded declarations and definitions *)
+
+fun declare c_ty = pair (Const c_ty);
+
+fun define checked name (c, t) =
+  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
 
 
 (* target *)
@@ -139,7 +146,7 @@
     thy
     |> ProofContext.init
     |> OverloadingData.put overloading
-    |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading
+    |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
     |> map_improvable_syntax (K ((([], []), ((K NONE, subst), unchecks)), false))
     |> add_improvable_syntax
   end;
@@ -148,7 +155,7 @@
   let
     val overloading = get_overloading lthy;
     val _ = if null overloading then () else
-      error ("Missing definition(s) for parameters " ^ commas (map (quote
+      error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
         o Syntax.string_of_term lthy o Const o fst) overloading));
   in
     lthy