substantial additions using locales
authorhaftmann
Wed, 04 Jan 2006 17:04:11 +0100
changeset 18575 9ccfd1d1e874
parent 18574 46ed84a64cf6
child 18576 8d98b7711e47
substantial additions using locales
src/Pure/Tools/class_package.ML
--- a/src/Pure/Tools/class_package.ML	Wed Jan 04 17:03:43 2006 +0100
+++ b/src/Pure/Tools/class_package.ML	Wed Jan 04 17:04:11 2006 +0100
@@ -11,9 +11,16 @@
     -> ProofContext.context * theory
   val add_class_i: bstring -> Locale.expr -> Element.context_i list -> theory
     -> ProofContext.context * theory
+  val add_instance_arity: (xstring * string list) * string
+    -> ((bstring * string) * Attrib.src list) list
+    -> theory -> Proof.state
+  val add_instance_arity_i: (string * sort list) * sort
+    -> ((bstring * term) * theory attribute list) list
+    -> theory -> Proof.state
   val add_classentry: class -> xstring list -> xstring list -> theory -> theory
   val the_consts: theory -> class -> string list
   val the_tycos: theory -> class -> (string * string) list
+  val print_classes: theory -> unit
 
   val syntactic_sort_of: theory -> sort -> sort
   val get_arities: theory -> sort -> string -> sort list
@@ -37,13 +44,15 @@
 (* data kind 'Pure/classes' *)
 
 type class_data = {
-  locale_name: string,
-  axclass_name: string,
-  consts: string list,
-  tycos: (string * string) list
+  superclasses: class list,
+  name_locale: string,
+  name_axclass: string,
+  var: string,
+  consts: (string * typ) list,
+  insts: (string * string) list
 };
 
-structure ClassesData = TheoryDataFun (
+structure ClassData = TheoryDataFun (
   struct
     val name = "Pure/classes";
     type T = class_data Symtab.table * class Symtab.table;
@@ -53,36 +62,75 @@
     fun merge _ ((t1, r1), (t2, r2))=
       (Symtab.merge (op =) (t1, t2),
        Symtab.merge (op =) (r1, r2));
-    fun print _ (tab, _) = (Pretty.writeln o Pretty.chunks) (map Pretty.str (Symtab.keys tab));
+    fun print thy (tab, _) =
+      let
+        fun pretty_class (name, {superclasses, name_locale, name_axclass, var, consts, insts}) =
+          (Pretty.block o Pretty.fbreaks) [
+            Pretty.str ("class " ^ name ^ ":"),
+            (Pretty.block o Pretty.fbreaks) (
+              Pretty.str "superclasses: "
+              :: map Pretty.str superclasses
+            ),
+            Pretty.str ("locale: " ^ name_locale),
+            Pretty.str ("axclass: " ^ name_axclass),
+            Pretty.str ("class variable: " ^ var),
+            (Pretty.block o Pretty.fbreaks) (
+              Pretty.str "constants: "
+              :: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
+            ),
+            (Pretty.block o Pretty.fbreaks) (
+              Pretty.str "instances: "
+              :: map (fn (tyco, thyname) => Pretty.str (tyco ^ ", in theory " ^ thyname)) insts
+            )
+          ]
+      in
+        (Pretty.writeln o Pretty.chunks o map pretty_class o Symtab.dest) tab
+      end;
   end
 );
 
-val lookup_class_data = Symtab.lookup o fst o ClassesData.get;
-val lookup_const_class = Symtab.lookup o snd o ClassesData.get;
+val print_classes = ClassData.print;
+
+val lookup_class_data = Symtab.lookup o fst o ClassData.get;
+val lookup_const_class = Symtab.lookup o snd o ClassData.get;
 
 fun get_class_data thy class =
   case lookup_class_data thy class
     of NONE => error ("undeclared class " ^ quote class)
      | SOME data => data;
 
-fun put_class_data class data =
-  ClassesData.map (apfst (Symtab.update (class, data)));
-fun add_const class const =
-  ClassesData.map (apsnd (Symtab.update (const, class)));
-val the_consts = #consts oo get_class_data;
-val the_tycos = #tycos oo get_class_data;
+fun add_class_data (class, (superclasses, name_locale, name_axclass, classvar, consts)) =
+  ClassData.map (fn (classtab, consttab) => (
+    classtab 
+    |> Symtab.update (class, {
+         superclasses = superclasses,
+         name_locale = name_locale,
+         name_axclass = name_axclass,
+         var = classvar,
+         consts = consts,
+         insts = []
+       }),
+    consttab
+    |> fold (fn (c, _) => Symtab.update (c, class)) consts
+  ));
+
+fun add_inst_data (class, inst) =
+  (ClassData.map o apfst o Symtab.map_entry class)
+    (fn {superclasses, name_locale, name_axclass, var, consts, insts}
+      => {
+           superclasses = superclasses,
+           name_locale = name_locale,
+           name_axclass = name_axclass,
+           var = var,
+           consts = consts,
+           insts = insts @ [inst]
+          });
+
+val the_consts = map fst o #consts oo get_class_data;
+val the_tycos = #insts oo get_class_data;
 
 
-(* name mangling *)
-
-fun get_locale_for_class thy class =
-  #locale_name (get_class_data thy class);
-
-fun get_axclass_for_class thy class =
-  #axclass_name (get_class_data thy class);
-
-
-(* classes *)
+(* classes and instances *)
 
 local
 
@@ -90,25 +138,31 @@
 
 fun gen_add_class add_locale bname raw_import raw_body thy =
   let
-    fun extract_notes_consts thy elems =
-      elems
-      |> Library.flat
-      |> List.mapPartial
-           (fn (Notes notes) => SOME notes
-             | _ => NONE)
-      |> Library.flat
-      |> map (fn (_, facts) => map fst facts)
-      |> Library.flat o Library.flat
-      |> map prop_of
-      |> (fn ts => fold (curry add_term_consts) ts [])
-      |> tap (writeln o commas);
+    fun subst_clsvar v ty_subst =
+      map_type_tfree (fn u as (w, _) =>
+        if w = v then ty_subst else TFree u);
+    fun extract_assumes c_adds elems =
+      let
+        fun subst_free ts =
+          let
+            val get_ty = the o AList.lookup (op =) (fold Term.add_frees ts []);
+            val subst_map = map (fn (c, (c', _)) =>
+              (Free (c, get_ty c), Const (c', get_ty c))) c_adds;
+          in map (subst_atomic subst_map) ts end;
+      in
+        elems
+        |> (map o List.mapPartial)
+            (fn (Assumes asms) => (SOME o map (map fst o snd)) asms
+              | _ => NONE)
+        |> Library.flat o Library.flat o Library.flat
+        |> subst_free
+      end;
     fun extract_tyvar_name thy tys =
       fold (curry add_typ_tfrees) tys []
-      |> (fn [(v, [])] => v
-           | [(v, sort)] =>
+      |> (fn [(v, sort)] =>
                 if Sorts.sort_eq (Sign.classes_of thy) (Sign.defaultS thy, sort)
                 then v 
-                else error ("sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
+                else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
            | [] => error ("no class type variable")
            | vs => error ("more than one type variable: " ^ (commas o map (Sign.string_of_typ thy o TFree)) vs))
     fun extract_tyvar_consts thy elems =
@@ -118,48 +172,41 @@
            (fn (Fixes consts) => SOME consts
              | _ => NONE)
       |> Library.flat
-      |> map (fn (c, ty, syn) => ((c, the ty), the syn))
-      |> `(fn consts => extract_tyvar_name thy (map (snd o fst) consts));
-    (* fun remove_local_syntax ((c, ty), _) thy =
-      thy
-      |> Sign.add_syntax_i [(c, ty, Syntax.NoSyn)]; *)
-    fun add_global_const ((c, ty), syn) thy =
+      |> map (fn (c, ty, syn) =>
+           ((c, the ty), (Syntax.unlocalize_mixfix o Syntax.fix_mixfix c o the) syn))
+      |> `(fn consts => extract_tyvar_name thy (map (snd o fst) consts))
+      |-> (fn v => map ((apfst o apsnd) (subst_clsvar v (TFree (v, []))))
+         #> pair v);
+    fun add_global_const v ((c, ty), syn) thy =
       thy
-      |> Sign.add_consts_i [(c, ty, syn)]
-      |> `(fn thy => Sign.intern_const thy c)
-    fun add_axclass bname_axiom locale_pred cs thy =
+      |> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
+      |> `(fn thy => (c, (Sign.intern_const thy c, ty)))
+    fun add_global_constraint v class (_, (c, ty)) thy =
       thy
-      |> AxClass.add_axclass_i (bname, Sign.defaultS thy)
-           [Thm.no_attributes (bname_axiom,
-              Const (ObjectLogic.judgment_name thy, dummyT) $
-                list_comb (Const (locale_pred, dummyT), map (fn c => Const (c, dummyT)) cs)
-              |> curry (inferT_axm thy) "locale_pred" |> snd)]
-      |-> (fn _ => `(fn thy => Sign.intern_class thy bname))
+      |> Sign.add_const_constraint_i (c, subst_clsvar v (TVar ((v, 0), [class])) ty);
     fun print_ctxt ctxt elem = 
       map Pretty.writeln (Element.pretty_ctxt ctxt elem)
   in
     thy
     |> add_locale bname raw_import raw_body
-    |-> (fn ((_, elems : context_i list list), ctxt) =>
-       tap (fn _ => (map o map) (print_ctxt ctxt) elems)
-    #> tap (fn thy => extract_notes_consts thy elems)
-    #> `(fn thy => Locale.intern thy bname)
+    |-> (fn ((import_elems, body_elems), ctxt) =>
+       `(fn thy => Locale.intern thy bname)
     #-> (fn name_locale =>
-       `(fn thy => extract_tyvar_consts thy elems)
-    #-> (fn (v, consts) =>
-       fold_map add_global_const consts
-    #-> (fn cs =>
-       add_axclass (bname ^ "_intro") name_locale cs
+          `(fn thy => extract_tyvar_consts thy body_elems)
+    #-> (fn (v, c_defs) =>
+          fold_map (add_global_const v) c_defs
+    #-> (fn c_adds =>
+          AxClass.add_axclass_i (bname, Sign.defaultS thy)
+            (map (Thm.no_attributes o pair "") (extract_assumes c_adds (import_elems @ body_elems)))
+    #-> (fn _ =>
+          `(fn thy => Sign.intern_class thy bname)
     #-> (fn name_axclass =>
-       put_class_data name_locale {
-          locale_name = name_locale,
-          axclass_name = name_axclass,
-          consts = cs,
-          tycos = []
-        })
-    #> fold (add_const name_locale) cs
+          fold (add_global_constraint v name_axclass) c_adds
+    #> add_class_data (name_locale, ([], name_locale, name_axclass, v, map snd c_adds))
+    #> tap (fn _ => (map o map) (print_ctxt ctxt) import_elems)
+    #> tap (fn _ => (map o map) (print_ctxt ctxt) body_elems)
     #> pair ctxt
-    ))))
+    ))))))
   end;
 
 in
@@ -169,6 +216,52 @@
 
 end; (* local *)
 
+fun gen_instance_arity prep_arity add_defs tap_def raw_arity raw_defs thy = 
+  let
+    val dest_def = Theory.dest_def (Sign.pp thy) handle TERM (msg, _) => error msg;
+    val arity as (tyco, asorts, sort) = prep_arity thy ((fn ((x, y), z) => (x, y, z)) raw_arity);
+    val ty_inst = Type (tyco, map2 (curry TVar o rpair 0) (Term.invent_names [] "'a" (length asorts)) asorts)
+    fun get_c_req class =
+      let
+        val data = get_class_data thy class;
+        val subst_ty = map_type_tfree (fn (var as (v, _)) =>
+          if #var data = v then ty_inst else TFree var)
+      in (map (apsnd subst_ty) o #consts) data end;
+    val c_req = (Library.flat o map get_c_req) sort;
+    fun get_remove_contraint c thy =
+      let
+        val ty1 = Sign.the_const_constraint thy c;
+        val ty2 = Sign.the_const_type thy c;
+      in
+        thy
+        |> Sign.add_const_constraint_i (c, ty2)
+        |> pair (c, ty1)
+      end;
+    fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
+    fun check_defs c_given c_req thy =
+      let
+        fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2)
+        val _ = case fold (remove eq_c) c_given c_req
+         of [] => ()
+          | cs => error ("no definition(s) given for"
+                    ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
+        val _ = case fold (remove eq_c) c_req c_given
+         of [] => ()
+          | cs => error ("superfluous definition(s) given for"
+                    ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
+      in thy end;
+  in
+    thy
+    |> fold_map get_remove_contraint (map fst c_req)
+    ||> tap (fn thy => check_defs (get_c_given thy) c_req)
+    ||> add_defs (true, raw_defs)
+    |-> (fn cs => fold Sign.add_const_constraint_i cs)
+    |> AxClass.instance_arity_i arity
+  end;
+
+val add_instance_arity = fn x => gen_instance_arity (AxClass.read_arity) IsarThy.add_defs read_axm x;
+val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
+
 
 (* class queries *)
 
@@ -202,14 +295,13 @@
 
 (* instance queries *)
 
-fun get_const_sign thy tvar const =
+fun mk_const_sign thy class tvar ty =
   let
-    val class = (the o lookup_const_class thy) const;
-    val (ty, thaw) = (Type.freeze_thaw_type o Sign.the_const_constraint thy) const;
-    val tvars_used = Term.add_tfreesT ty [];
+    val (ty', thaw) = Type.freeze_thaw_type ty;
+    val tvars_used = Term.add_tfreesT ty' [];
     val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1);
   in
-    ty
+    ty'
     |> map_type_tfree (fn (tvar', sort) =>
           if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
           then TFree (tvar, [])
@@ -219,6 +311,12 @@
     |> thaw
   end;
 
+fun get_const_sign thy tvar const =
+  let
+    val class = (the o lookup_const_class thy) const;
+    val ty = Sign.the_const_constraint thy const;
+  in mk_const_sign thy class tvar ty end;
+
 fun get_inst_consts_sign thy (tyco, class) =
   let
     val consts = the_consts thy class;
@@ -234,9 +332,9 @@
 
 fun get_classtab thy =
   Symtab.fold
-    (fn (class, { consts = consts, tycos = tycos, ... }) =>
-      Symtab.update_new (class, (consts, tycos)))
-       (fst (ClassesData.get thy)) Symtab.empty;
+    (fn (class, { consts = consts, insts = insts, ... }) =>
+      Symtab.update_new (class, (map fst consts, insts)))
+       (fst (ClassData.get thy)) Symtab.empty;
 
 
 (* extracting dictionary obligations from types *)
@@ -287,20 +385,15 @@
 
 (* intermediate auxiliary *)
 
-fun add_classentry raw_class raw_cs raw_tycos thy =
+fun add_classentry raw_class raw_cs raw_insts thy =
   let
     val class = Sign.intern_class thy raw_class;
-    val cs = map (Sign.intern_const thy) raw_cs;
-    val tycos = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_tycos;
+    val cs = raw_cs |> map (Sign.intern_const thy);
+    val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
   in
     thy
-    |> put_class_data class {
-         locale_name = "",
-         axclass_name = class,
-         consts = cs,
-         tycos = tycos
-       }
-    |> fold (add_const class) cs
+    |> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs))
+    |> fold (curry add_inst_data class) insts
   end;
 
 
@@ -313,7 +406,7 @@
 
 in
 
-val classK = "class"
+val (classK, instanceK) = ("class", "class_instance")
 
 val locale_val =
   (P.locale_expr --
@@ -324,15 +417,22 @@
   OuterSyntax.command classK "operational type classes" K.thy_decl
     (P.name -- Scan.optional (P.$$$ "=" |-- P.!!! locale_val) (Locale.empty, [])
       >> (Toplevel.theory_context
-          o (fn f => swap o f) o (fn (name, (expr, elems)) => add_class name expr elems)));
+          o (fn f => swap o f) o (fn (bname, (expr, elems)) => add_class bname expr elems)));
 
-val _ = OuterSyntax.add_parsers [classP];
+val instanceP =
+  OuterSyntax.command instanceK "" K.thy_goal
+    (P.xname -- (P.$$$ "::" |-- P.!!! P.arity)
+      -- Scan.repeat1 P.spec_name
+      >> (Toplevel.theory_theory_to_proof
+          o (fn ((tyco, (asorts, sort)), defs) => add_instance_arity ((tyco, asorts), sort) defs)));
+
+val _ = OuterSyntax.add_parsers [classP, instanceP];
 
 end; (* local *)
 
 
 (* setup *)
 
-val _ = Context.add_setup [ClassesData.init];
+val _ = Context.add_setup [ClassData.init];
 
 end; (* struct *)