explicit constants for overloaded definitions
authorhaftmann
Fri, 17 Aug 2007 13:58:58 +0200
changeset 24304 69d40a562ba4
parent 24303 32b67bdf2c3a
child 24305 b1df9e31cda1
explicit constants for overloaded definitions
src/HOL/Relation_Power.thy
src/Pure/Isar/class.ML
--- a/src/HOL/Relation_Power.thy	Fri Aug 17 13:58:57 2007 +0200
+++ b/src/HOL/Relation_Power.thy	Fri Aug 17 13:58:58 2007 +0200
@@ -15,7 +15,7 @@
       --{* only type @{typ "('a * 'a) set"} should be in class @{text power}!*}
 
 (*R^n = R O ... O R, the n-fold composition of R*)
-primrec (relpow)
+primrec (unchecked relpow)
   "R^0 = Id"
   "R^(Suc n) = R O (R^n)"
 
@@ -25,7 +25,7 @@
       --{* only type @{typ "'a => 'a"} should be in class @{text power}!*}
 
 (*f^n = f o ... o f, the n-fold composition of f*)
-primrec (funpow)
+primrec (unchecked funpow)
   "f^0 = id"
   "f^(Suc n) = f o (f^n)"
 
--- a/src/Pure/Isar/class.ML	Fri Aug 17 13:58:57 2007 +0200
+++ b/src/Pure/Isar/class.ML	Fri Aug 17 13:58:58 2007 +0200
@@ -33,6 +33,10 @@
   val add_const_in_class: string -> (string * term) * Syntax.mixfix
     -> theory -> theory
 
+  val unoverload: theory -> thm -> thm
+  val overload: theory -> thm -> thm
+  val inst_const: theory -> string * string -> string
+
   val print_classes: theory -> unit
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
@@ -122,6 +126,67 @@
   end;
 
 
+(* explicit constants for overloaded definitions *)
+
+structure InstData = TheoryDataFun
+(
+  type T = (string * thm) Symtab.table Symtab.table;
+    (*constant name ~> type constructor ~> (constant name, equation)*)
+  val empty = Symtab.empty;
+  val copy = I;
+  val extend = I;
+  fun merge _ = Symtab.join (K (Symtab.merge (K true)));
+);
+
+fun inst_thms f thy =
+  Symtab.fold (Symtab.fold (cons o f o snd o snd) o snd) (InstData.get thy) [];
+fun add_inst (c, tyco) inst = (InstData.map o Symtab.map_default (c, Symtab.empty))
+  (Symtab.update_new (tyco, inst));
+
+fun unoverload thy thm = MetaSimplifier.rewrite_rule (inst_thms I thy) thm;
+fun overload thy thm = MetaSimplifier.rewrite_rule (inst_thms symmetric thy) thm;
+
+fun inst_const thy (c, tyco) =
+  (fst o the o Symtab.lookup ((the o Symtab.lookup (InstData.get thy)) c)) tyco;
+
+fun add_inst_def (class, tyco) (c, ty) thy =
+  let
+    val tyco_base = NameSpace.base tyco;
+    val name_inst = NameSpace.base class ^ "_" ^ tyco_base ^ "_inst";
+    val c_inst_base = NameSpace.base c ^ "_" ^ tyco_base;
+  in
+    thy
+    |> Sign.sticky_prefix name_inst
+    |> Sign.add_consts_i [(c_inst_base, ty, Syntax.NoSyn)]
+    |> `(fn thy => Sign.full_name thy c_inst_base)
+    |-> (fn c_inst => PureThy.add_defs_i true
+          [((Thm.def_name c_inst_base, Logic.mk_equals (Const (c_inst, ty), Const (c, ty))), [])]
+    #-> (fn [def] => add_inst (c, tyco) (c_inst, symmetric def))
+    #> Sign.restore_naming thy)
+  end;
+
+fun add_inst_def' (class, tyco) (c, ty) thy =
+  if case Symtab.lookup (InstData.get thy) c
+   of NONE => true
+    | SOME tab => is_none (Symtab.lookup tab tyco)
+  then add_inst_def (class, tyco) (c, Logic.unvarifyT ty) thy
+  else thy;
+
+fun add_def ((class, tyco), ((name, prop), atts)) thy =
+  let
+    val ((lhs as Const (c, ty), args), rhs) = (apfst Term.strip_comb o Logic.dest_equals) prop;
+    fun add_inst' def ([], (Const (c_inst, ty))) =
+          if forall (fn TFree_ => true | _ => false) (Sign.const_typargs thy (c_inst, ty))
+          then add_inst (c, tyco) (c_inst, def)
+          else add_inst_def (class, tyco) (c, ty)
+      | add_inst' _ t = add_inst_def (class, tyco) (c, ty);
+  in
+    thy
+    |> PureThy.add_defs_i true [((name, prop), map (Attrib.attribute thy) atts)]
+    |-> (fn [def] => add_inst' def (args, rhs) #> pair def)
+  end;
+
+
 (* instances with implicit parameter handling *)
 
 local
@@ -154,7 +219,7 @@
         val cs = (these o Option.map snd o try (AxClass.params_of_class theory)) class;
         val subst_ty = map_type_tfree (K ty);
       in
-        map (fn (c, ty) => (c, ((tyco, class), subst_ty ty))) cs
+        map (fn (c, ty) => (c, ((class, tyco), subst_ty ty))) cs
       end;
     fun get_consts_sort (tyco, asorts, sort) =
       let
@@ -171,7 +236,7 @@
           let
             val (c, (inst, ((name_opt, t), atts))) = read_def thy_read raw_def;
             val ty = Consts.instance (Sign.consts_of thy_read) (c, inst);
-            val ((tyco, class), ty') = case AList.lookup (op =) cs c
+            val ((class, tyco), ty') = case AList.lookup (op =) cs c
              of NONE => error ("illegal definition for constant " ^ quote c)
               | SOME class_ty => class_ty;
             val name = case name_opt
@@ -184,7 +249,7 @@
               | SOME norm => map_types norm t
           in (((class, tyco), ((name, t'), atts)), AList.delete (op =) c cs) end;
       in fold_map read defs cs end;
-    val (defs, _) = read_defs raw_defs cs
+    val (defs, other_cs) = read_defs raw_defs cs
       (fold Sign.primitive_arity arities (Theory.copy theory));
     fun get_remove_contraint c thy =
       let
@@ -194,18 +259,14 @@
         |> Sign.add_const_constraint_i (c, NONE)
         |> pair (c, Logic.unvarifyT ty)
       end;
-    fun add_defs defs thy =
-      thy
-      |> PureThy.add_defs_i true (map ((apsnd o map) (Attrib.attribute thy) o snd) defs)
-      |-> (fn thms => pair (map fst defs ~~ thms));
-    fun after_qed cs defs thy =
-      thy
-      |> fold Sign.add_const_constraint_i (map (apsnd SOME) cs)
-      |> fold (Code.add_func false o snd) defs;
+    fun after_qed cs defs =
+      fold Sign.add_const_constraint_i (map (apsnd SOME) cs)
+      #> fold (Code.add_func false) defs;
   in
     theory
     |> fold_map get_remove_contraint (map fst cs |> distinct (op =))
-    ||>> add_defs defs
+    ||>> fold_map add_def defs
+    ||> fold (fn (c, ((class, tyco), ty)) => add_inst_def' (class, tyco) (c, ty)) other_cs
     |-> (fn (cs, defs) => do_proof (after_qed cs defs) arities)
   end;