src/Pure/axclass.ML
changeset 25597 34860182b250
parent 25486 b944ef973109
child 25605 35a5f7f4b97b
--- a/src/Pure/axclass.ML	Mon Dec 10 11:24:14 2007 +0100
+++ b/src/Pure/axclass.ML	Mon Dec 10 11:24:15 2007 +0100
@@ -27,6 +27,16 @@
   val axiomatize_arity: arity -> theory -> theory
   val axiomatize_arity_cmd: xstring * string list * string -> theory -> theory
   val instance_name: string * class -> string
+  val declare_overloaded: string * typ -> theory -> term * theory
+  val define_overloaded: string -> string * term -> theory -> thm * theory
+  val inst_tyco_of: theory -> string * typ -> string option
+  val unoverload: theory -> thm -> thm
+  val overload: theory -> thm -> thm
+  val unoverload_conv: theory -> conv
+  val overload_conv: theory -> conv
+  val unoverload_const: theory -> string * typ -> string
+  val param_of_inst: theory -> string * string -> string
+  val inst_of_param: theory -> string -> (string * string) option
   type cache
   val of_sort: theory -> typ * sort -> cache -> thm list * cache  (*exception Sorts.CLASS_ERROR*)
   val cache: cache
@@ -88,23 +98,36 @@
   Symtab.join (K (merge (eq_fst op =))) (arities1, arities2));
 
 
+(* instance parameters *)
+
+type inst_params =
+  (string * thm) Symtab.table Symtab.table
+    (*constant name ~> type constructor ~> (constant name, equation)*)
+  * (string * string) Symtab.table; (*constant name ~> (constant name, type constructor)*)
+
+fun merge_inst_params ((const_param1, param_const1), (const_param2, param_const2)) =
+  (Symtab.join  (K (Symtab.merge (K true))) (const_param1, const_param2),
+    Symtab.merge (K true) (param_const1, param_const2));
+
+
 (* setup data *)
 
 structure AxClassData = TheoryDataFun
 (
-  type T = axclasses * instances;
-  val empty = ((Symtab.empty, []), ([], Symtab.empty));
+  type T = axclasses * (instances * inst_params);
+  val empty = ((Symtab.empty, []), (([], Symtab.empty), (Symtab.empty, Symtab.empty)));
   val copy = I;
   val extend = I;
-  fun merge pp ((axclasses1, instances1), (axclasses2, instances2)) =
-    (merge_axclasses pp (axclasses1, axclasses2), (merge_instances (instances1, instances2)));
+  fun merge pp ((axclasses1, (instances1, inst_params1)), (axclasses2, (instances2, inst_params2))) =
+    (merge_axclasses pp (axclasses1, axclasses2),
+      (merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)));
 );
 
 
 (* maintain axclasses *)
 
 val get_axclasses = #1 o AxClassData.get;
-fun map_axclasses f = AxClassData.map (apfst f);
+val map_axclasses = AxClassData.map o apfst;
 
 val lookup_def = Symtab.lookup o #1 o get_axclasses;
 
@@ -135,8 +158,8 @@
 
 fun instance_name (a, c) = NameSpace.base c ^ "_" ^ NameSpace.base a;
 
-val get_instances = #2 o AxClassData.get;
-fun map_instances f = AxClassData.map (apsnd f);
+val get_instances = #1 o #2 o AxClassData.get;
+val map_instances = AxClassData.map o apsnd o apfst;
 
 
 fun the_classrel thy (c1, c2) =
@@ -159,6 +182,39 @@
   (classrel, arities |> Symtab.insert_list (eq_fst op =) (t, ((c, Ss), th))));
 
 
+(* maintain instance parameters *)
+
+val get_inst_params = #2 o #2 o AxClassData.get;
+val map_inst_params = AxClassData.map o apsnd o apsnd;
+
+fun get_inst_param thy (c, tyco) =
+  (the o Symtab.lookup ((the o Symtab.lookup (fst (get_inst_params thy))) c)) tyco;
+
+fun add_inst_param (c, tyco) inst = (map_inst_params o apfst
+      o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
+  #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
+
+val inst_of_param = Symtab.lookup o snd o get_inst_params;
+val param_of_inst = fst oo get_inst_param;
+
+fun inst_thms thy = (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst)
+  (get_inst_params thy) [];
+
+val inst_tyco_of = Option.map fst o try (dest_Type o the_single) oo Sign.const_typargs;
+
+fun unoverload thy = MetaSimplifier.simplify true (inst_thms thy);
+fun overload thy = MetaSimplifier.simplify true (map Thm.symmetric (inst_thms thy));
+
+fun unoverload_conv thy = MetaSimplifier.rewrite true (inst_thms thy);
+fun overload_conv thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
+
+fun unoverload_const thy (c_ty as (c, _)) =
+  case class_of_param thy c
+   of SOME class => (case inst_tyco_of thy c_ty
+       of SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
+        | NONE => c)
+    | NONE => c;
+
 
 (** instances **)
 
@@ -200,10 +256,35 @@
     val prop = Thm.plain_prop_of (Thm.transfer thy th);
     val (t, Ss, c) = Logic.dest_arity prop handle TERM _ => err ();
     val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
+    (*FIXME turn this into a mere check as soon as "attach" has disappeared*)
+    val missing_params = Sign.complete_sort thy [c]
+      |> maps (these o Option.map (fn AxClass { params, ... } => params) o lookup_def thy)
+      |> filter_out (fn (p, _) => can (get_inst_param thy) (p, t));
+    fun declare_missing (p, T0) thy =
+      let
+        val name_inst = instance_name (t, c) ^ "_inst";
+        val p' = NameSpace.base p ^ "_" ^ NameSpace.base t;
+        val vs = Name.names Name.context Name.aT (replicate (Sign.arity_number thy t) []);
+        val T = map_atyps (fn _ => Type (t, map TFree vs)) T0;
+      in
+        thy
+        |> Sign.sticky_prefix name_inst
+        |> Sign.no_base_names
+        |> Sign.declare_const [] (p', T, NoSyn)
+        |-> (fn const' as Const (p'', _) => Thm.add_def false true
+              (Thm.def_name p', Logic.mk_equals (const', Const (p, T)))
+        #>> Thm.varifyT
+        #-> (fn thm => add_inst_param (p, t) (p'', Thm.symmetric thm)
+        #> PureThy.note Thm.internalK (p', thm)
+        #> snd
+        #> Sign.restore_naming thy))
+      end;
+
   in
     thy
     |> Sign.primitive_arity (t, Ss, [c])
     |> put_arity ((t, Ss, c), Drule.unconstrainTs th)
+    |> fold declare_missing missing_params
   end;
 
 
@@ -240,6 +321,47 @@
   end;
 
 
+(* instance parameters and overloaded definitions *)
+
+(* declaration and definition of instances of overloaded constants *)
+
+fun declare_overloaded (c, T) thy =
+  let
+    val SOME class = class_of_param thy c;
+    val SOME tyco = inst_tyco_of thy (c, T);
+    val name_inst = instance_name (tyco, class) ^ "_inst";
+    val c' = NameSpace.base c ^ "_" ^ NameSpace.base tyco;
+    val T' = Type.strip_sorts T;
+  in
+    thy
+    |> Sign.sticky_prefix name_inst
+    |> Sign.no_base_names
+    |> Sign.declare_const [] (c', T', NoSyn)
+    |-> (fn const' as Const (c'', _) => Thm.add_def false true
+          (Thm.def_name c', Logic.mk_equals (Const (c, T'), const'))
+    #>> Thm.varifyT
+    #-> (fn thm => add_inst_param (c, tyco) (c'', thm)
+    #> PureThy.note Thm.internalK (c', thm)
+    #> snd
+    #> Sign.restore_naming thy
+    #> pair (Const (c, T))))
+  end;
+
+fun define_overloaded name (c, t) thy =
+  let
+    val T = Term.fastype_of t;
+    val SOME tyco = inst_tyco_of thy (c, T);
+    val (c', eq) = get_inst_param thy (c, tyco);
+    val prop = Logic.mk_equals (Const (c', T), t);
+    val name' = Thm.def_name_optional
+      (NameSpace.base c ^ "_" ^ NameSpace.base tyco) name;
+  in
+    thy
+    |> Thm.add_def false false (name', prop)
+    |>> (fn thm =>  Drule.transitive_thm OF [eq, thm])
+  end;
+
+
 
 (** class definitions **)