refinement in instance command
authorhaftmann
Fri, 30 Jun 2006 12:03:36 +0200
changeset 19966 88bbe97ed0b0
parent 19965 75a15223e21f
child 19967 33da452f0abe
refinement in instance command
src/Pure/Tools/class_package.ML
--- a/src/Pure/Tools/class_package.ML	Fri Jun 30 12:03:21 2006 +0200
+++ b/src/Pure/Tools/class_package.ML	Fri Jun 30 12:03:36 2006 +0200
@@ -24,8 +24,6 @@
   val instance_sort_i: class * sort -> theory -> Proof.state
   val prove_instance_sort: tactic -> class * sort -> theory -> theory
 
-  val use_cp_instance: bool ref;
-
   val intern_class: theory -> xstring -> class
   val intern_sort: theory -> sort -> sort
   val extern_class: theory -> class -> xstring
@@ -48,6 +46,9 @@
   val sortcontext_of_typ: theory -> typ -> sortcontext
   val sortlookup: theory -> sort * typ -> classlookup list
   val sortlookups_const: theory -> string * typ -> classlookup list list
+
+  val use_instance2: bool ref;
+  val the_propnames: theory -> class -> string list
 end;
 
 structure ClassPackage: CLASS_PACKAGE =
@@ -76,29 +77,27 @@
   name_locale: string,
   name_axclass: string,
   var: string,
-  consts: (string * (string * typ)) list
+  consts: (string * (string * typ)) list,
     (*locale parameter ~> toplevel constant*)
-};
+  propnames: string list
+} * thm list Symtab.table;
 
 fun rep_classdata (ClassData c) = c;
 
 structure ClassData = TheoryDataFun (
   struct
     val name = "Pure/classes";
-    type T = class_data Graph.T;
-    val empty = Graph.empty;
+    type T = class_data Symtab.table;
+    val empty = Symtab.empty;
     val copy = I;
     val extend = I;
-    fun merge _ = Graph.merge (K true);
-    fun print thy gr =
+    fun merge _ = Symtab.join (fn _ => fn (ClassData (classd, instd1), ClassData (_, instd2)) =>
+      (ClassData (classd, Symtab.merge (K true) (instd1, instd2))));
+    fun print thy data =
       let
-        fun pretty_class gr (name, ClassData {name_locale, name_axclass, var, consts}) =
+        fun pretty_class (name, ClassData ({name_locale, name_axclass, var, consts, ...}, _)) =
           (Pretty.block o Pretty.fbreaks) [
             Pretty.str ("class " ^ name ^ ":"),
-            (Pretty.block o Pretty.fbreaks) (
-              Pretty.str "superclasses: "
-              :: (map Pretty.str o Graph.imm_succs gr) name
-            ),
             Pretty.str ("locale: " ^ name_locale),
             Pretty.str ("axclass: " ^ name_axclass),
             Pretty.str ("class variable: " ^ var),
@@ -108,8 +107,7 @@
             )
           ]
       in
-        (Pretty.writeln o Pretty.chunks o map (pretty_class gr)
-          o AList.make (Graph.get_node gr) o flat o Graph.strong_conn) gr
+        (Pretty.writeln o Pretty.chunks o map pretty_class o Symtab.dest) data
       end;
   end
 );
@@ -120,7 +118,7 @@
 
 (* queries *)
 
-val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node o ClassData.get;
+val lookup_class_data = Option.map rep_classdata oo Symtab.lookup o ClassData.get;
 
 fun the_class_data thy class =
   case lookup_class_data thy class
@@ -131,7 +129,7 @@
 
 fun is_operational_class thy cls =
   lookup_class_data thy cls
-  |> Option.map (not o null o #consts)
+  |> Option.map (not o null o #consts o fst)
   |> the_default false;
 
 fun operational_sort_of thy =
@@ -162,22 +160,18 @@
   map_type_tfree (fn u as (w, _) =>
     if w = v then ty_subst else TFree u);
 
-fun the_parm_map thy class =
-  let
-    val data = the_class_data thy class
-  in (#consts data) end;
+val the_parm_map = #consts o fst oo the_class_data;
 
 fun the_consts_sign thy class =
   let
-    val data = the_class_data thy class
+    val data = (fst o the_class_data thy) class
   in (#var data, (map snd o #consts) data) end;
 
 fun the_inst_sign thy (class, tyco) =
   let
     val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class);
     val asorts = Sign.arity_sorts thy tyco [class];
-    val clsvar = (#var o the_class_data thy) class;
-    val const_sign = (snd o the_consts_sign thy) class;
+    val (clsvar, const_sign) = the_consts_sign thy class;
     fun add_var sort used =
       let
         val v = hd (Term.invent_names used "'a" 1)
@@ -191,19 +185,26 @@
     val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
   in (vsorts, inst_signs) end;
 
+val the_propnames = #propnames o fst oo the_class_data;
+
 
 (* updaters *)
 
-fun add_class_data (class, (superclasses, name_locale, name_axclass, var, consts)) =
+fun add_class_data (class, (name_locale, name_axclass, var, consts, propnames)) =
   ClassData.map (
-    Graph.new_node (class, ClassData {
+    Symtab.update_new (class, ClassData ({
       name_locale = name_locale,
       name_axclass = name_axclass,
       var = var,
-      consts = consts })
-    #> fold (curry Graph.add_edge_acyclic class) superclasses
+      consts = consts,
+      propnames = propnames}, Symtab.empty))
   );
 
+fun add_inst_def ((class, tyco), thm) =
+  ClassData.map (
+    Symtab.map_entry class (fn ClassData (classd, instd) =>
+      ClassData (classd, Symtab.insert_list eq_thm (tyco, thm) instd))
+  );
 
 (* name handling *)
 
@@ -262,7 +263,7 @@
 
 in
 
-val axclass_instance_subclass =
+val axclass_instance_sort =
   gen_instance (single oo (Logic.mk_classrel oo AxClass.read_classrel)) AxClass.add_classrel I;
 val axclass_instance_arity =
   gen_instance (Logic.mk_arities oo Sign.read_arity) AxClass.add_arity;
@@ -308,10 +309,10 @@
     val supclasses = map (prep_class thy) raw_supclasses;
     val supsort =
       supclasses
-      |> map (#name_axclass o the_class_data thy)
+      |> map (#name_axclass o fst o the_class_data thy)
       |> Sign.certify_sort thy
       |> null ? K (Sign.defaultS thy);
-    val expr = (Locale.Merge o map (Locale.Locale o #name_locale o the_class_data thy)) supclasses;
+    val expr = (Locale.Merge o map (Locale.Locale o #name_locale o fst o the_class_data thy)) supclasses;
     val mapp_sup = AList.make
       (the o AList.lookup (op =) ((flat o map (the_parm_map thy) o the_ancestry thy) supclasses))
       ((map (fst o fst) o Locale.parameters_of_expr thy) expr);
@@ -369,7 +370,8 @@
           add_axclass_i (bname, supsort) (map (fst o snd) mapp_this) loc_axioms
     #-> (fn (name_axclass, (_, ax_axioms)) =>
           fold (add_global_constraint v name_axclass) mapp_this
-    #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, v, mapp_this))
+    #> add_class_data (name_locale, (name_locale, name_axclass, v, mapp_this,
+         map (fst o fst) loc_axioms))
     #> prove_interpretation_i (NameSpace.base name_locale, [])
           (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (map snd (mapp_sup @ mapp_this)))
           ((ALLGOALS o ProofContext.fact_tac) ax_axioms)
@@ -410,7 +412,7 @@
     val atts = map (prep_att theory) raw_atts;
     fun get_consts class =
       let
-        val data = the_class_data theory class;
+        val data = (fst o the_class_data theory) class;
         fun defined c =
           is_some (find_first (fn (_, { lhs = [ty], ...}) =>
               Sign.typ_instance theory (ty, ty_inst) orelse Sign.typ_instance theory (ty_inst, ty))
@@ -419,7 +421,7 @@
           if #var data = v then ty_inst else TVar ((v, 0), sort));
       in
         (map_filter (fn (_, (c, ty)) =>
-          if defined c then NONE else SOME (c, subst_ty ty)) o #consts) data
+          if defined c then NONE else SOME ((c, (class, subst_ty ty)))) o #consts) data
       end;
     val cs = (maps get_consts o the_ancestry theory) sort;
     fun read_defs defs cs =
@@ -428,15 +430,16 @@
         fun read raw_def cs =
           let
             val (c, (ty, def)) = read_def thy_read tyco raw_def;
-            val def' = case AList.lookup (op =) cs c
+            val (class, ty') = case AList.lookup (op =) cs c
              of NONE => error ("superfluous definition for constant " ^ quote c)
-              | SOME ty' => case instantiations_of thy_read (ty, ty')
-                 of NONE => error ("superfluous definition for constant " ^
-                      quote c ^ "::" ^ Sign.string_of_typ thy_read ty)
-                  | SOME insttab =>
-                      (apfst o apsnd o map_term_types)
-                        (Logic.unvarifyT o Term.instantiateT insttab o Logic.varifyT) def
-          in (def', AList.delete (op =) c cs) end;
+              | SOME class_ty => class_ty;
+            val def' = case instantiations_of thy_read (ty, ty')
+             of NONE => error ("superfluous definition for constant " ^
+                  quote c ^ "::" ^ Sign.string_of_typ thy_read ty)
+              | SOME insttab =>
+                  (apfst o apsnd o map_term_types)
+                    (Logic.unvarifyT o Term.instantiateT insttab o Logic.varifyT) def
+          in ((class, def'), AList.delete (op =) c cs) end;
       in fold_map read defs cs end;
     val (defs, _) = read_defs raw_defs cs;
     fun get_remove_contraint c thy =
@@ -447,19 +450,33 @@
         |> Sign.add_const_constraint_i (c, NONE)
         |> pair (c, Logic.legacy_unvarifyT ty)
       end;
-    fun note_all thms thy =
+    fun add_defs defs thy =
+      thy
+      |> PureThy.add_defs_i true (map snd defs)
+      |-> (fn thms => pair (map fst defs ~~ thms));
+    fun register_def (class, thm) thy =
       thy
-      |> PureThy.note_thmss_i PureThy.internalK [((name, atts), [(thms, [])])]
-      |> snd;
+      |> add_inst_def ((class, tyco), thm);
+    fun note_all thy =
+      let
+        val thms = maps (fn class => Symtab.lookup_list
+          ((snd o the_class_data thy) class) tyco) (the_ancestry thy sort);
+      in
+        thy
+        |> PureThy.note_thmss_i PureThy.internalK [((name, atts), [(thms, [])])]
+        |> snd
+      end;
     fun after_qed cs thy =
       thy
       |> fold Sign.add_const_constraint_i (map (apsnd SOME) cs);
   in
     theory
     |> fold_map get_remove_contraint (map fst cs)
-    ||>> PureThy.add_defs_i true defs
-    |-> (fn (cs, thms) => note_all thms #> pair cs)
-    |-> (fn cs => do_proof (after_qed cs) arity)
+    ||>> add_defs defs
+    |-> (fn (cs, def_thms) => 
+       fold register_def def_thms
+    #> note_all
+    #> do_proof (after_qed cs) arity)
   end;
 
 fun instance_arity' do_proof = gen_instance_arity Sign.read_arity Attrib.attribute
@@ -478,48 +495,50 @@
 
 local
 
-fun fish_thms (name, expr) after_qed thy =
-  let
-    val _ = writeln ("sub " ^ name)
-    val suplocales = (fn Locale.Merge es => map (fn Locale.Locale n => n) es) expr;
-    val _ = writeln ("super " ^ commas suplocales)
-    fun get_c name =
-      (map (NameSpace.base o fst o fst) o Locale.parameters_of thy) name;
-    fun get_a name =
-      (map (NameSpace.base o fst o fst) o Locale.local_asms_of thy) name;
-    fun get_t supname =
-      map (NameSpace.append (NameSpace.append name ((space_implode "_" o get_c) supname)) o NameSpace.base)
-        (get_a name);
-    val names = map get_t suplocales;
-    val _ = writeln ("fishing for " ^ (commas o map commas) names);
-  in
-    thy
-    |> after_qed ((map o map) (Drule.standard o get_thm thy o Name) names)
-  end;
-
-fun add_interpretation_in (after_qed : thm list list -> theory -> theory) (name, expr) thy =
+fun add_interpretation_in (after_qed : theory -> theory) (name, expr) thy =
   thy
   |> Locale.interpretation_in_locale (name, expr);
 
-fun prove_interpretation_in tac (after_qed : thm list list -> theory -> theory) (name, expr) thy =
+fun prove_interpretation_in tac (after_qed : theory -> theory) (name, expr) thy =
   thy
   |> Locale.interpretation_in_locale (name, expr)
   |> Proof.global_terminal_proof (Method.Basic (fn _ => Method.SIMPLE_METHOD tac), NONE)
-  |-> (fn _ => I);
+  |> snd
+  |> after_qed;
 
 fun gen_instance_sort prep_class prep_sort do_proof (raw_class, raw_sort) theory =
   let
     val class = prep_class theory raw_class;
     val sort = prep_sort theory raw_sort;
-    val loc_name = (#name_locale o the_class_data theory) class;
+    val loc_name = (#name_locale o fst o the_class_data theory) class;
     val loc_expr =
-      (Locale.Merge o map (Locale.Locale o #name_locale o the_class_data theory)) sort;
-    fun after_qed thmss thy =
-      (writeln "---"; (Pretty.writeln o Display.pretty_thms o flat) thmss; writeln "---"; fold (fn supclass =>
+      (Locale.Merge o map (Locale.Locale o #name_locale o fst o the_class_data theory)) sort;
+    val const_names = (map (NameSpace.base o fst o snd)
+      o maps (#consts o fst o the_class_data theory)
+      o the_ancestry theory) [class];
+    val prop_tab = AList.make (the_propnames theory)
+      (the_ancestry theory sort);
+    fun mk_thm_names (superclass, prop_names) =
+      let
+        val thm_name_base = NameSpace.append "local" (space_implode "_" const_names);
+        val export_name = class ^ "_" ^ superclass;
+      in (export_name, map (Name o NameSpace.append thm_name_base) prop_names) end;
+    val notes_tab_proto = map mk_thm_names prop_tab;
+    fun test_note thy thmref =
+      can (Locale.note_thmss PureThy.corollaryK loc_name 
+        [(("", []), [(thmref, [])])]) (Theory.copy thy);
+    val notes_tab = map_filter (fn (export_name, thm_names) => case filter (test_note theory) thm_names
+     of [] => NONE
+      | thm_names' => SOME (export_name, thm_names')) notes_tab_proto;
+    val _ = writeln ("fishing for ");
+    val _ = print notes_tab;
+    fun after_qed thy = thy;
+    fun after_qed''' thy =
+      fold (fn supclass =>
         AxClass.prove_classrel (class, supclass)
           (ALLGOALS (K (intro_classes_tac [])) THEN
-            (ALLGOALS o resolve_tac o flat) thmss)
-      ) sort thy)
+            (ALLGOALS o resolve_tac o flat) [])
+      ) sort thy;
   in
     theory
     |> do_proof after_qed (loc_name, loc_expr)
@@ -608,19 +627,15 @@
 
 val (classK, instanceK) = ("class", "instance")
 
-val use_cp_instance = ref false;
+val use_instance2 = ref false;
 
-fun wrap_add_instance_subclass (class, sort) thy =
-  case Sign.read_sort thy sort
-   of [class'] =>
-      if ! use_cp_instance
-        andalso (is_some o lookup_class_data thy o Sign.intern_class thy) class
-        andalso (is_some o lookup_class_data thy o Sign.intern_class thy) class'
-      then
-        instance_sort (class, sort) thy
-      else
-        axclass_instance_subclass (class, sort) thy
-    | _ => instance_sort (class, sort) thy;
+fun wrap_add_instance_sort (class, sort) thy =
+  if ! use_instance2
+    andalso forall (is_some o lookup_class_data thy) (Sign.read_sort thy sort)
+  then
+    instance_sort (class, sort) thy
+  else
+    axclass_instance_sort (class, sort) thy
 
 val parse_inst =
   (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 P.sort --| P.$$$ ")")) [] -- P.xname --| P.$$$ "::" -- P.sort)
@@ -648,7 +663,7 @@
 
 val instanceP =
   OuterSyntax.command instanceK "prove type arity or subclass relation" K.thy_goal ((
-      P.xname -- ((P.$$$ "\\<subseteq>" || P.$$$ "<") |-- P.!!! P.xname) >> wrap_add_instance_subclass
+      P.xname -- ((P.$$$ "\\<subseteq>" || P.$$$ "<") |-- P.!!! P.xname) >> wrap_add_instance_sort
       || P.opt_thm_name ":" -- (parse_inst -- Scan.repeat (P.opt_thm_name ":" -- P.prop))
            >> (fn (("", []), (((tyco, asorts), sort), [])) => axclass_instance_arity I (tyco, asorts, sort)
                 | (natts, (inst, defs)) => instance_arity inst natts defs)