reduced code, better instance command
authorhaftmann
Wed, 28 Jun 2006 14:36:09 +0200
changeset 19957 91ba241a1678
parent 19956 f992e507020e
child 19958 fc4ac94f03e0
reduced code, better instance command
src/Pure/Tools/class_package.ML
--- a/src/Pure/Tools/class_package.ML	Wed Jun 28 14:35:51 2006 +0200
+++ b/src/Pure/Tools/class_package.ML	Wed Jun 28 14:36:09 2006 +0200
@@ -36,10 +36,7 @@
   val operational_sort_of: theory -> sort -> sort
   val the_superclasses: theory -> class -> class list
   val the_consts_sign: theory -> class -> string * (string * typ) list
-  val lookup_const_class: theory -> string -> class option
-  val the_instances: theory -> class -> (string * ((sort list) * string)) list
   val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
-  val get_classtab: theory -> (string * string) list Symtab.table
 
   val print_classes: theory -> unit
   val intro_classes_tac: thm list -> tactic
@@ -57,6 +54,22 @@
 struct
 
 
+(* auxiliary *)
+
+fun instantiations_of thy (ty, ty') =
+  let
+    val vartab = typ_tvars ty;
+    fun prep_vartab (v, (_, ty)) =
+      case (the o AList.lookup (op =) vartab) v
+       of [] => NONE
+        | sort => SOME ((v, sort), ty);
+  in case try (Sign.typ_match thy (ty, ty')) Vartab.empty
+   of NONE => NONE
+    | SOME vartab =>
+        SOME ((map_filter prep_vartab o Vartab.dest) vartab)
+  end;
+
+
 (* theory data *)
 
 datatype class_data = ClassData of {
@@ -64,7 +77,7 @@
   name_axclass: string,
   var: string,
   consts: (string * (string * typ)) list
-    (*locale parameter ~> toplevel const*)
+    (*locale parameter ~> toplevel constant*)
 };
 
 fun rep_classdata (ClassData c) = c;
@@ -72,17 +85,12 @@
 structure ClassData = TheoryDataFun (
   struct
     val name = "Pure/classes";
-    type T = (class_data Graph.T
-      * (string * (sort list * string)) list Symtab.table)
-        (*class ~> tyco ~> (arity, thyname)*)
-      * class Symtab.table;
-    val empty = ((Graph.empty, Symtab.empty), Symtab.empty);
+    type T = class_data Graph.T;
+    val empty = Graph.empty;
     val copy = I;
     val extend = I;
-    fun merge _ (((g1, c1), f1) : T, ((g2, c2), f2)) =
-      ((Graph.merge (K true) (g1, g2), Symtab.join (fn _ => AList.merge (op =) (op =)) (c1, c2)),
-       Symtab.merge (op =) (f1, f2));
-    fun print thy ((gr, _), _) =
+    fun merge _ = Graph.merge (K true);
+    fun print thy gr =
       let
         fun pretty_class gr (name, ClassData {name_locale, name_axclass, var, consts}) =
           (Pretty.block o Pretty.fbreaks) [
@@ -112,9 +120,7 @@
 
 (* queries *)
 
-val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node o fst o fst o ClassData.get;
-val the_instances = these oo Symtab.lookup o snd o fst o ClassData.get;
-val lookup_const_class = Symtab.lookup o snd o ClassData.get;
+val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node o ClassData.get;
 
 fun the_class_data thy class =
   case lookup_class_data thy class
@@ -144,13 +150,6 @@
   else
     error ("no class: " ^ class);
 
-fun get_superclass_derivation thy (subclass, superclass) =
-  if subclass = superclass
-    then SOME [subclass]
-    else case Graph.irreducible_paths ((fst o fst o ClassData.get) thy) (subclass, superclass)
-      of [] => NONE
-       | (p::_) => (SOME o filter (is_operational_class thy)) p;
-
 fun the_ancestry thy classes =
   let
     fun ancestry class anc =
@@ -176,7 +175,7 @@
 fun the_inst_sign thy (class, tyco) =
   let
     val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class);
-    val arity = Sign.arity_sorts thy tyco [class];
+    val asorts = Sign.arity_sorts thy tyco [class];
     val clsvar = (#var o the_class_data thy) class;
     val const_sign = (snd o the_consts_sign thy) class;
     fun add_var sort used =
@@ -187,41 +186,23 @@
       [clsvar]
       |> fold (fn (_, ty) => curry (gen_union (op =))
            ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign
-      |> fold_map add_var arity;
+      |> fold_map add_var asorts;
     val ty_inst = Type (tyco, map TFree vsorts);
     val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
   in (vsorts, inst_signs) end;
 
-fun get_classtab thy =
-  (Symtab.map o map)
-    (fn (tyco, (_, thyname)) => (tyco, thyname)) ((snd o fst o ClassData.get) thy);
-
 
 (* updaters *)
 
 fun add_class_data (class, (superclasses, name_locale, name_axclass, var, consts)) =
-  ClassData.map (fn ((gr, tab), consttab) => ((
-    gr
-    |> Graph.new_node (class, ClassData {
-         name_locale = name_locale,
-         name_axclass = name_axclass,
-         var = var,
-         consts = consts
-       })
-    |> fold (curry Graph.add_edge_acyclic class) superclasses,
-    tab
-    |> Symtab.update (class, [])),
-    consttab
-    |> fold (fn (_, (c, _)) => Symtab.update (c, class)) consts
-  ));
-
-fun add_inst_data (class, inst) =
-  ClassData.map (fn ((gr, tab), consttab) =>
-    let
-      val undef_supclasses = class :: (filter (Symtab.defined tab) (Graph.all_succs gr [class]));
-    in
-     ((gr, tab |> fold (fn class => Symtab.map_entry class (AList.update (op =) inst)) undef_supclasses), consttab)
-    end);
+  ClassData.map (
+    Graph.new_node (class, ClassData {
+      name_locale = name_locale,
+      name_axclass = name_axclass,
+      var = var,
+      consts = consts })
+    #> fold (curry Graph.add_edge_acyclic class) superclasses
+  );
 
 
 (* name handling *)
@@ -233,7 +214,7 @@
   map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort);
 
 fun intern_class thy =
-certify_class thy o Sign.intern_class thy;
+  certify_class thy o Sign.intern_class thy;
 
 fun intern_sort thy =
   certify_sort thy o Sign.intern_sort thy;
@@ -295,10 +276,10 @@
 
 local
 
-fun add_axclass_i (name, supsort) axs thy =
+fun add_axclass_i (name, supsort) params axs thy =
   let
     val (c, thy') = thy
-      |> AxClass.define_class_i (name, supsort) [] axs;
+      |> AxClass.define_class_i (name, supsort) params axs;
     val {intro, axioms, ...} = AxClass.get_definition thy' c;
   in ((c, (intro, axioms)), thy') end;
 
@@ -309,13 +290,11 @@
           let
             val p = setmp show_types true (setmp show_sorts true (setmp print_mode [] (Sign.pretty_typ thy))) ty;
             val s = c ^ "::" ^ Pretty.output p;
-            val _ = writeln s;
           in SOME s end
       | ad_hoc_term (SOME t) =
           let
             val p = setmp show_types true (setmp show_sorts true (setmp print_mode [] (Sign.pretty_term thy))) t;
             val s = Pretty.output p;
-            val _ = writeln s;
           in SOME s end;
   in
     thy
@@ -368,7 +347,7 @@
           map_aterms (fn Free (c, ty) => Const ((fst o the o AList.lookup (op =) cs_mapp) c, ty)
                        | t => t)
         fun prep_asm ((name, atts), ts) =
-          ((NameSpace.base name |> print, map (Attrib.attribute thy) atts), map subst_assume ts)
+          ((NameSpace.base name, map (Attrib.attribute thy) atts), map subst_assume ts)
       in
         (map prep_asm o Locale.local_asms_of thy) name_locale
       end;
@@ -387,7 +366,7 @@
     #-> (fn mapp_this =>
           `(fn thy => extract_assumes thy name_locale (mapp_sup @ mapp_this))
     #-> (fn loc_axioms =>
-          add_axclass_i (bname, supsort) loc_axioms
+          add_axclass_i (bname, supsort) (map (fst o snd) mapp_this) loc_axioms
     #-> (fn (name_axclass, (_, ax_axioms)) =>
           fold (add_global_constraint v name_axclass) mapp_this
     #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, v, mapp_this))
@@ -407,30 +386,20 @@
 
 local
 
-fun gen_add_defs_overloaded prep_att tap_def add_defs tyco raw_defs thy =
+fun gen_read_def thy prep_att read_def tyco ((raw_name, raw_atts), raw_t) =
   let
-    fun invent_name raw_t =
-      let
-        val t = tap_def thy raw_t;
-        val c = (fst o dest_Const o fst o strip_comb o fst o Logic.dest_equals) t;
-      in
-        Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco)
-      end;
-    fun prep_def (_, (("", a), t)) =
-          let
-            val n = invent_name t
-          in ((n, t), map (prep_att thy) a) end
-      | prep_def (_, ((n, a), t)) =
-          ((n, t), map (prep_att thy) a);
-  in
-    thy
-    |> add_defs true (map prep_def raw_defs)
-  end;
+    val (_, t) = read_def thy (raw_name, raw_t);
+    val ((c, ty), _) = Sign.cert_def (Sign.pp thy) t;
+    val atts = map (prep_att thy) raw_atts;
+    val name = case raw_name
+     of "" => Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco)
+      | _ => raw_name;
+  in (c, (Logic.varifyT ty, ((name, t), atts))) end;
 
-val add_defs_overloaded = gen_add_defs_overloaded Attrib.attribute Sign.read_term PureThy.add_defs;
-val add_defs_overloaded_i = gen_add_defs_overloaded (K I) (K I) PureThy.add_defs_i;
+fun read_def thy = gen_read_def thy Attrib.attribute read_axm;
+fun read_def_i thy = gen_read_def thy (K I) (K I);
 
-fun gen_instance_arity prep_arity prep_att add_defs tap_def do_proof raw_arity (raw_name, raw_atts) raw_defs theory =
+fun gen_instance_arity prep_arity prep_att read_def do_proof raw_arity (raw_name, raw_atts) raw_defs theory =
   let
     val pp = Sign.pp theory;
     val arity as (tyco, asorts, sort) = prep_arity theory ((fn ((x, y), z) => (x, y, z)) raw_arity);
@@ -439,24 +408,37 @@
      of "" => Thm.def_name ((space_implode "_" o map NameSpace.base) sort ^ "_" ^ NameSpace.base tyco)
       | _ => raw_name;
     val atts = map (prep_att theory) raw_atts;
-    fun get_classes thy tyco sort =
-      let
-        fun get class classes =
-          if AList.defined (op =) ((the_instances thy) class) tyco
-            then classes
-            else classes
-              |> cons class
-              |> fold get (the_superclasses thy class)
-      in fold get sort [] end;
-    val classes = get_classes theory tyco sort;
-    val _ = if null classes then error ("already instantiated") else ();
     fun get_consts class =
       let
         val data = the_class_data theory class;
-        val subst_ty = map_type_tfree (fn (var as (v, _)) =>
-          if #var data = v then ty_inst else TFree var)
-      in (map (apsnd subst_ty o snd) o #consts) data end;
-    val cs = (flat o map get_consts) classes;
+        fun defined c =
+          is_some (find_first (fn (_, { lhs = [ty], ...}) =>
+              Sign.typ_instance theory (ty, ty_inst) orelse Sign.typ_instance theory (ty_inst, ty))
+            (Defs.specifications_of (Theory.defs_of theory) c))
+        val subst_ty = map_type_tfree (fn (v, sort) =>
+          if #var data = v then ty_inst else TVar ((v, 0), sort));
+      in
+        (map_filter (fn (_, (c, ty)) =>
+          if defined c then NONE else SOME (c, subst_ty ty)) o #consts) data
+      end;
+    val cs = (maps get_consts o the_ancestry theory) sort;
+    fun read_defs defs cs =
+      let
+        val thy_read = (Sign.primitive_arity (tyco, asorts, sort) o Theory.copy) theory;
+        fun read raw_def cs =
+          let
+            val (c, (ty, def)) = read_def thy_read tyco raw_def;
+            val def' = case AList.lookup (op =) cs c
+             of NONE => error ("superfluous definition for constant " ^ quote c)
+              | SOME ty' => case instantiations_of thy_read (ty, ty')
+                 of NONE => error ("superfluous definition for constant " ^
+                      quote c ^ "::" ^ Sign.string_of_typ thy_read ty)
+                  | SOME insttab =>
+                      (apfst o apsnd o map_term_types)
+                        (Logic.unvarifyT o Term.instantiateT insttab o Logic.varifyT) def
+          in (def', AList.delete (op =) c cs) end;
+      in fold_map read defs cs end;
+    val (defs, _) = read_defs raw_defs cs;
     fun get_remove_contraint c thy =
       let
         val ty = Sign.the_const_constraint thy c;
@@ -465,65 +447,31 @@
         |> Sign.add_const_constraint_i (c, NONE)
         |> pair (c, Logic.legacy_unvarifyT ty)
       end;
-    fun check_defs0 thy raw_defs c_req =
-      let
-        fun get_c raw_def =
-          (fst o Sign.cert_def pp o tap_def thy o snd) raw_def;
-        val c_given = map get_c raw_defs;
-        fun eq_c ((c1 : string, ty1), (c2, ty2)) =
-          let
-            val ty1' = Logic.legacy_varifyT ty1;
-            val ty2' = Logic.legacy_varifyT ty2;
-          in
-            c1 = c2
-            andalso Sign.typ_instance thy (ty1', ty2')
-            andalso Sign.typ_instance thy (ty2', ty1')
-          end;
-        val _ = case subtract eq_c c_req c_given
-         of [] => ()
-          | cs => error ("superfluous definition(s) given for "
-                    ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
-        (*val _ = case subtract eq_c c_given c_req
-         of [] => ()
-          | cs => error ("no definition(s) given for "
-                    ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);*)
-      in () end;
-    fun check_defs1 raw_defs c_req thy =
-      let
-        val thy' = (Sign.primitive_arity (tyco, asorts, sort) o Theory.copy) thy
-      in (check_defs0 thy' raw_defs c_req; thy) end;
-    fun mangle_alldef_name tyco sort =
-      Thm.def_name ((space_implode "_" o map NameSpace.base) sort ^ "_" ^ NameSpace.base tyco);
-    fun note_all tyco sort thms thy =
+    fun note_all thms thy =
       thy
       |> PureThy.note_thmss_i PureThy.internalK [((name, atts), [(thms, [])])]
       |> snd;
     fun after_qed cs thy =
       thy
-      |> fold (fn class =>
-        add_inst_data (class, (tyco,
-          (map (operational_sort_of thy) asorts, Context.theory_name thy)))) sort
       |> fold Sign.add_const_constraint_i (map (apsnd SOME) cs);
   in
     theory
-    |> check_defs1 raw_defs cs
     |> fold_map get_remove_contraint (map fst cs)
-    ||>> add_defs tyco (map (pair NONE) raw_defs)
-    |-> (fn (cs, defnames) => note_all tyco sort defnames #> pair cs)
+    ||>> PureThy.add_defs_i true defs
+    |-> (fn (cs, thms) => note_all thms #> pair cs)
     |-> (fn cs => do_proof (after_qed cs) arity)
   end;
 
-fun instance_arity' do_proof = gen_instance_arity Sign.read_arity Attrib.attribute add_defs_overloaded
-  (fn thy => fn t => (snd o read_axm thy) ("", t)) do_proof;
-fun instance_arity_i' do_proof = gen_instance_arity Sign.cert_arity (K I) add_defs_overloaded_i
-  (K I) do_proof;
-val setup_proof = axclass_instance_arity_i;
+fun instance_arity' do_proof = gen_instance_arity Sign.read_arity Attrib.attribute
+  read_def do_proof;
+fun instance_arity_i' do_proof = gen_instance_arity Sign.cert_arity (K I)
+  read_def_i do_proof;
 fun tactic_proof tac after_qed arity = AxClass.prove_arity arity tac #> after_qed;
 
 in
 
-val instance_arity = instance_arity' setup_proof;
-val instance_arity_i = instance_arity_i' setup_proof;
+val instance_arity = instance_arity' axclass_instance_arity_i;
+val instance_arity_i = instance_arity_i' axclass_instance_arity_i;
 val prove_instance_arity = instance_arity_i' o tactic_proof;
 
 end; (* local *)
@@ -590,6 +538,7 @@
 
 end; (* local *)
 
+
 (* extracting dictionary obligations from types *)
 
 type sortcontext = (string * sort) list;
@@ -635,22 +584,15 @@
 
 fun sortlookups_const thy (c, typ_ctxt) =
   let
-    val typ_decl = case lookup_const_class thy c
+    val typ_decl = case AxClass.class_of thy c
      of NONE => Sign.the_const_type thy c
       | SOME class => case the_consts_sign thy class of (v, cs) =>
           (Logic.legacy_varifyT o subst_clsvar v (TFree (v, [class])))
           ((the o AList.lookup (op =) cs) c)
-    val vartab = typ_tvars typ_decl;
-    fun prep_vartab (v, (_, ty)) =
-      case (the o AList.lookup (op =) vartab) v
-       of [] => NONE
-        | sort => SOME (sort, ty);
   in
-    Vartab.empty
-    |> Sign.typ_match thy (typ_decl, typ_ctxt)
-    |> Vartab.dest
-    |> map_filter prep_vartab
-    |> map (sortlookup thy)
+    instantiations_of thy (typ_decl, typ_ctxt)
+    |> the
+    |> map (fn ((_, sort), ty) => sortlookup thy (sort, ty))
     |> filter_out null
   end;