src/Pure/Tools/class_package.ML
changeset 18702 7dc7dcd63224
parent 18670 c3f445b92aff
child 18708 4b3dadb4fe33
--- a/src/Pure/Tools/class_package.ML	Tue Jan 17 10:26:50 2006 +0100
+++ b/src/Pure/Tools/class_package.ML	Tue Jan 17 16:36:57 2006 +0100
@@ -18,23 +18,21 @@
     -> ((bstring * term) * theory attribute list) list
     -> theory -> Proof.state
   val add_classentry: class -> xstring list -> xstring list -> theory -> theory
-  val the_consts: theory -> class -> string list
-  val the_tycos: theory -> class -> (string * string) list
-  val print_classes: theory -> unit
 
   val syntactic_sort_of: theory -> sort -> sort
-  val get_arities: theory -> sort -> string -> sort list
-  val get_superclasses: theory -> class -> class list
-  val get_const_sign: theory -> string -> string -> typ
-  val get_inst_consts_sign: theory -> string * class -> (string * typ) list
+  val the_superclasses: theory -> class -> class list
+  val the_consts_sign: theory -> class -> string * (string * typ) list
   val lookup_const_class: theory -> string -> class option
+  val the_instances: theory -> class -> (string * string) list
+  val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
   val get_classtab: theory -> (string list * (string * string) list) Symtab.table
+  val print_classes: theory -> unit
 
   type sortcontext = (string * sort) list
   datatype sortlookup = Instance of (class * string) * sortlookup list list
                       | Lookup of class list * (string * int)
   val extract_sortctxt: theory -> typ -> sortcontext
-  val extract_sortlookup: theory -> typ * typ -> sortlookup list list
+  val extract_sortlookup: theory -> string * typ -> sortlookup list list
 end;
 
 structure ClassPackage: CLASS_PACKAGE =
@@ -126,21 +124,19 @@
            insts = insts @ [inst]
           });
 
-val the_consts = map fst o #consts oo get_class_data;
-val the_tycos = #insts oo get_class_data;
-
 
 (* classes and instances *)
 
+fun subst_clsvar v ty_subst =
+  map_type_tfree (fn u as (w, _) =>
+    if w = v then ty_subst else TFree u);
+
 local
 
 open Element
 
 fun gen_add_class add_locale bname raw_import raw_body thy =
   let
-    fun subst_clsvar v ty_subst =
-      map_type_tfree (fn u as (w, _) =>
-        if w = v then ty_subst else TFree u);
     fun extract_assumes c_adds elems =
       let
         fun subst_free ts =
@@ -240,7 +236,9 @@
     fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
     fun check_defs c_given c_req thy =
       let
-        fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2)
+        fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2
+          andalso Sign.typ_instance thy (ty1, ty2)
+          andalso Sign.typ_instance thy (ty2, ty1)
         val _ = case fold (remove eq_c) c_given c_req
          of [] => ()
           | cs => error ("no definition(s) given for"
@@ -263,9 +261,12 @@
 val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
 
 
-(* class queries *)
+(* queries *)
 
-fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false;
+fun is_class thy cls =
+  lookup_class_data thy cls
+  |> Option.map (not o null o #consts)
+  |> the_default false;
 
 fun syntactic_sort_of thy sort =
   let
@@ -280,11 +281,7 @@
     |> Sorts.norm_sort classes
   end;
 
-fun get_arities thy sort tycon =
-  Sorts.mg_domain (Sign.classes_arities_of thy) tycon (syntactic_sort_of thy sort)
-  |> map (syntactic_sort_of thy);
-
-fun get_superclasses thy class =
+fun the_superclasses thy class =
   if is_class thy class
   then
     Sorts.superclasses (Sign.classes_of thy) class
@@ -292,49 +289,43 @@
   else
     error ("no syntactic class: " ^ class);
 
-
-(* instance queries *)
-
-fun mk_const_sign thy class tvar ty =
+fun the_consts_sign thy class =
   let
-    val (ty', thaw) = Type.freeze_thaw_type ty;
-    val tvars_used = Term.add_tfreesT ty' [];
-    val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1);
-  in
-    ty'
-    |> map_type_tfree (fn (tvar', sort) =>
-          if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
-          then TFree (tvar, [])
-          else if tvar' = tvar
-          then TVar ((tvar_rename, 0), sort)
-          else TFree (tvar', sort))
-    |> thaw
-  end;
+    val data = (the oo Symtab.lookup) ((fst o ClassData.get) thy) class
+  in (#var data, #consts data) end;
+
+fun lookup_const_class thy =
+  Symtab.lookup ((snd o ClassData.get) thy);
+
+fun the_instances thy class =
+  (#insts o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
 
-fun get_const_sign thy tvar const =
-  let
-    val class = (the o lookup_const_class thy) const;
-    val ty = Sign.the_const_constraint thy const;
-  in mk_const_sign thy class tvar ty end;
-
-fun get_inst_consts_sign thy (tyco, class) =
+fun the_inst_sign thy (class, tyco) =
   let
-    val consts = the_consts thy class;
-    val arities = get_arities thy [class] tyco;
-    val const_signs = map (get_const_sign thy "'a") consts;
-    val vars_used = fold (fn ty => curry (gen_union (op =))
-      (map fst (typ_tfrees ty) |> remove (op =) "'a")) const_signs [];
-    val vars_new = Term.invent_names vars_used "'a" (length arities);
-    val typ_arity = Type (tyco, map2 (curry TFree) vars_new arities);
-    val instmem_signs =
-      map (typ_subst_TVars [(("'a", 0), typ_arity)]) const_signs;
-  in consts ~~ instmem_signs end;
+    val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
+    val arity = 
+      Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]
+      |> map (syntactic_sort_of thy);
+    val clsvar = (#var o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
+    val const_sign = (snd o the_consts_sign thy) class;
+    fun add_var sort used =
+      let
+        val v = hd (Term.invent_names used "'a" 1)
+      in ((v, sort), v::used) end;
+    val (vsorts, _) =
+      []
+      |> fold (fn (_, ty) => curry (gen_union (op =))
+           ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign
+      |> fold_map add_var arity;
+    val ty_inst = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vsorts);
+    val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
+  in (vsorts, inst_signs) end;
 
 fun get_classtab thy =
   Symtab.fold
     (fn (class, { consts = consts, insts = insts, ... }) =>
       Symtab.update_new (class, (map fst consts, insts)))
-       (fst (ClassData.get thy)) Symtab.empty;
+       ((fst o ClassData.get) thy) Symtab.empty;
 
 
 (* extracting dictionary obligations from types *)
@@ -342,15 +333,16 @@
 type sortcontext = (string * sort) list;
 
 fun extract_sortctxt thy ty =
-  (typ_tfrees o Type.no_tvars) ty
+  (typ_tfrees o fst o Type.freeze_thaw_type) ty
   |> map (apsnd (syntactic_sort_of thy))
   |> filter (not o null o snd);
 
 datatype sortlookup = Instance of (class * string) * sortlookup list list
                     | Lookup of class list * (string * int)
 
-fun extract_sortlookup thy (raw_typ_def, raw_typ_use) =
+fun extract_sortlookup thy (c, raw_typ_use) =
   let
+    val raw_typ_def = Sign.the_const_constraint thy c;
     val typ_def = Type.varifyT raw_typ_def;
     val typ_use = Type.varifyT raw_typ_use;
     val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
@@ -374,8 +366,22 @@
               let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
               in Lookup (deriv, (vname, classindex)) end;
           in map mk_look sort_def end;
+    fun reorder_sortctxt ctxt =
+      case lookup_const_class thy c
+       of NONE => ctxt
+        | SOME class =>
+            let
+              val data = (the o Symtab.lookup ((fst o ClassData.get) thy)) class;
+              val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
+              val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
+              val v : string = case Vartab.lookup match_tab (#var data, 0)
+                of SOME (_, TVar ((v, _), _)) => v;
+            in
+              (v, (the o AList.lookup (op =) ctxt) v) :: AList.delete (op =) v ctxt
+            end;
   in
     extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
+    |> reorder_sortctxt
     |> map (tab_lookup o fst)
     |> map (apfst (syntactic_sort_of thy))
     |> filter (not o null o fst)
@@ -388,11 +394,26 @@
 fun add_classentry raw_class raw_cs raw_insts thy =
   let
     val class = Sign.intern_class thy raw_class;
-    val cs = raw_cs |> map (Sign.intern_const thy);
+    val cs_proto =
+      raw_cs
+      |> map (Sign.intern_const thy)
+      |> map (fn c => (c, Sign.the_const_constraint thy c));
+    val used = 
+      []
+      |> fold (fn (_, ty) => curry (gen_union (op =))
+           ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) cs_proto
+    val v = hd (Term.invent_names used "'a" 1);
+    val cs =
+      cs_proto
+      |> map (fn (c, ty) => (c, map_type_tvar (fn var as ((tvar', _), sort) =>
+          if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
+          then TFree (v, [])
+          else TVar var
+         ) ty));
     val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
   in
     thy
-    |> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs))
+    |> add_class_data (class, ([], "", class, v, cs))
     |> fold (curry add_inst_data class) insts
   end;