src/Pure/Isar/class.ML
changeset 25038 522abf8a5f87
parent 25024 0615bb9955dd
child 25060 17c313217998
--- a/src/Pure/Isar/class.ML	Mon Oct 15 15:29:41 2007 +0200
+++ b/src/Pure/Isar/class.ML	Mon Oct 15 15:29:43 2007 +0200
@@ -19,11 +19,12 @@
   val class_cmd: bstring -> xstring list -> Element.context Locale.element list
     -> xstring list -> theory -> string * Proof.context
   val init: class -> Proof.context -> Proof.context;
-  val add_const_in_class: string -> (string * mixfix) * term -> theory -> string * theory
-  val add_abbrev_in_class: string -> Syntax.mode -> (string * mixfix) * term -> theory ->
-    string * theory
+  val add_const_in_class: string -> (string * mixfix) * (string * term)
+    -> theory -> string * theory
+  val add_abbrev_in_class: string -> Syntax.mode -> (string * mixfix) * (string * term)
+    -> theory -> string * theory
   val remove_constraint: class -> string -> Proof.context -> Proof.context
-  val is_class: theory -> string -> bool
+  val is_class: theory -> class -> bool
   val these_params: theory -> sort -> (string * (string * typ)) list
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
@@ -83,6 +84,15 @@
       | NONE => thm;
   in strip end;
 
+fun get_remove_global_constraint c thy =
+  let
+    val ty = Sign.the_const_constraint thy c;
+  in
+    thy
+    |> Sign.add_const_constraint (c, NONE)
+    |> pair (c, Logic.unvarifyT ty)
+  end;
+
 
 (** axclass command **)
 
@@ -232,14 +242,6 @@
      of [] => ()
       | dupl_tycos => error ("Type constructors occur more than once in arities: "
           ^ commas_quote dupl_tycos);
-    fun get_remove_constraint c thy =
-      let
-        val ty = Sign.the_const_constraint thy c;
-      in
-        thy
-        |> Sign.add_const_constraint (c, NONE)
-        |> pair (c, Logic.unvarifyT ty)
-      end;
     fun get_consts_class tyco ty class =
       let
         val cs = (these o try (#params o AxClass.get_info theory)) class;
@@ -283,7 +285,7 @@
       #> after_qed defs;
   in
     theory
-    |> fold_map get_remove_constraint (map fst cs |> distinct (op =))
+    |> fold_map get_remove_global_constraint (map fst cs |> distinct (op =))
     ||>> fold_map add_def defs
     ||> fold (fn (c, ((class, tyco), ty)) => add_inst_def' (class, tyco) (c, ty)) other_cs
     |-> (fn (cs, defs) => do_proof (after_qed' cs defs) arities defs)
@@ -311,12 +313,11 @@
 (** class data **)
 
 datatype class_data = ClassData of {
-  locale: string,
   consts: (string * string) list
     (*locale parameter ~> constant name*),
   local_sort: sort,
-  inst: typ Symtab.table * term Symtab.table
-    (*canonical interpretation*),
+  inst: (typ option list * term option list) * term Symtab.table
+    (*canonical interpretation FIXME*),
   intro: thm,
   defs: thm list,
   operations: (string * (term * (typ * int))) list
@@ -326,50 +327,46 @@
 };
 
 fun rep_class_data (ClassData d) = d;
-fun mk_class_data ((locale, consts, local_sort, inst, intro),
+fun mk_class_data ((consts, local_sort, inst, intro),
     (defs, operations, operations_rev)) =
-  ClassData { locale = locale, consts = consts, local_sort = local_sort, inst = inst,
+  ClassData { consts = consts, local_sort = local_sort, inst = inst,
     intro = intro, defs = defs, operations = operations,
     operations_rev = operations_rev };
-fun map_class_data f (ClassData { locale, consts, local_sort, inst, intro,
+fun map_class_data f (ClassData { consts, local_sort, inst, intro,
     defs, operations, operations_rev }) =
-  mk_class_data (f ((locale, consts, local_sort, inst, intro),
+  mk_class_data (f ((consts, local_sort, inst, intro),
     (defs, operations, operations_rev)));
-fun merge_class_data _ (ClassData { locale = locale, consts = consts,
+fun merge_class_data _ (ClassData { consts = consts,
     local_sort = local_sort, inst = inst, intro = intro,
     defs = defs1, operations = operations1, operations_rev = operations_rev1 },
-  ClassData { locale = _, consts = _, local_sort = _, inst = _, intro = _,
+  ClassData { consts = _, local_sort = _, inst = _, intro = _,
     defs = defs2, operations = operations2, operations_rev = operations_rev2 }) =
-  mk_class_data ((locale, consts, local_sort, inst, intro),
+  mk_class_data ((consts, local_sort, inst, intro),
     (Thm.merge_thms (defs1, defs2),
       AList.merge (op =) (K true) (operations1, operations2),
       AList.merge (op =) (K true) (operations_rev1, operations_rev2)));
 
-fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
-
 structure ClassData = TheoryDataFun
 (
-  type T = class_data Graph.T * class Symtab.table
-    (*locale name ~> class name*);
-  val empty = (Graph.empty, Symtab.empty);
+  type T = class_data Graph.T
+  val empty = Graph.empty;
   val copy = I;
   val extend = I;
-  fun merge _ = merge_pair (Graph.join merge_class_data) (Symtab.merge (K true));
+  fun merge _ = Graph.join merge_class_data;
 );
 
 
 (* queries *)
 
-val is_class = Symtab.defined o snd o ClassData.get;
-
-val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node
-  o fst o ClassData.get;
+val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node o ClassData.get;
 
 fun the_class_data thy class = case lookup_class_data thy class
  of NONE => error ("Undeclared class " ^ quote class)
   | SOME data => data;
 
-val ancestry = Graph.all_succs o fst o ClassData.get;
+val is_class = is_some oo lookup_class_data;
+
+val ancestry = Graph.all_succs o ClassData.get;
 
 fun these_params thy =
   let
@@ -386,7 +383,7 @@
 
 fun these_intros thy =
   Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
-    ((fst o ClassData.get) thy) [];
+    (ClassData.get thy) [];
 
 fun these_operations thy =
   maps (#operations o the_class_data thy) o ancestry thy;
@@ -428,7 +425,7 @@
       (SOME o Pretty.str) ("class " ^ class ^ ":"),
       (SOME o Pretty.block) [Pretty.str "supersort: ",
         (Syntax.pretty_sort ctxt o Sign.minimize_sort thy o Sign.super_classes thy) class],
-      Option.map (Pretty.str o prefix "locale: " o #locale) (lookup_class_data thy class),
+      if is_class thy class then (SOME o Pretty.str) ("locale: " ^ class) else NONE,
       ((fn [] => NONE | ps => (SOME o Pretty.block o Pretty.fbreaks) (Pretty.str "parameters:" :: ps)) o map mk_param
         o these o Option.map #params o try (AxClass.get_info thy)) class,
       (SOME o Pretty.block o Pretty.breaks) [
@@ -444,21 +441,29 @@
 
 (* updaters *)
 
-fun add_class_data ((class, superclasses), (locale, cs, local_sort, inst, intro)) =
+fun add_class_data ((class, superclasses), (cs, local_sort, inst, intro)) thy =
   let
+    (*FIXME*)
+    val is_empty = null (fold (fn ((_, ty), _) => fold_atyps cons ty) cs [])
+      andalso null ((fold o fold_types o fold_atyps) cons
+        (maps snd (Locale.global_asms_of thy class)) []);
+    (*FIXME*)
+    val inst_params = map
+      (SOME o the o Symtab.lookup inst o fst o fst)
+        (Locale.parameters_of_expr thy (Locale.Locale class));
+    val instT = if is_empty then [] else [SOME (TFree (Name.aT, [class]))];
+    val inst' = ((instT, inst_params), inst);
     val operations = map (fn (v_ty, (c, ty)) => (c, (Free v_ty, (ty, 0)))) cs;
     val cs = (map o pairself) fst cs;
-    val add_class = Graph.new_node (class, mk_class_data ((locale,
-          cs, local_sort, inst, intro),
-            ([], operations, [])))
+    val add_class = Graph.new_node (class,
+        mk_class_data ((cs, local_sort, inst', intro), ([], operations, [])))
       #> fold (curry Graph.add_edge class) superclasses;
-    val add_locale = Symtab.update (locale, class);
   in
-    ClassData.map (fn (gr, tab) => (add_class gr, add_locale tab))
+    ClassData.map add_class thy
   end;
 
 fun register_operation class (entry, some_def) =
-  (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
+  (ClassData.map o Graph.map_node class o map_class_data o apsnd)
     (fn (defs, operations, operations_rev) =>
       (case some_def of NONE => defs | SOME def => def :: defs,
         entry :: operations, (*FIXME*)operations_rev));
@@ -468,13 +473,13 @@
 
 val class_prefix = Logic.const_of_class o Sign.base_name;
 
-fun class_intro thy locale class sups =
+fun class_intro thy class sups =
   let
     fun class_elim class =
       case (#axioms o AxClass.get_info thy) class
        of [thm] => SOME (Drule.unconstrainTs thm)
         | [] => NONE;
-    val pred_intro = case Locale.intros thy locale
+    val pred_intro = case Locale.intros thy class
      of ([ax_intro], [intro]) => intro |> OF_LAST ax_intro |> SOME
       | ([intro], []) => SOME intro
       | ([], [intro]) => SOME intro
@@ -499,30 +504,17 @@
 
 fun class_interpretation class facts defs thy =
   let
-    val inst = #inst (the_class_data thy class);
+    val params = these_params thy [class];
+    val { inst = ((_, inst), _), ... } = the_class_data thy class;
+    (*val _ = tracing ("interpreting with " ^ cat_lines (map (setmp show_sorts true makestring)
+      (map_filter I inst)));*)
     val tac = ALLGOALS (ProofContext.fact_tac facts);
+    val prfx = class_prefix class;
   in
-    prove_interpretation tac ((false, class_prefix class), []) (Locale.Locale class)
-      (inst, defs) thy
-  end;
-
-fun interpretation_in_rule thy (class1, class2) =
-  let
-    val ctxt = ProofContext.init thy;
-    fun mk_axioms class =
-      let
-        val { locale, inst = (_, insttab), ... } = the_class_data thy class;
-      in
-        Locale.global_asms_of thy locale
-        |> maps snd
-        |> (map o map_aterms) (fn Free (s, _) => (the o Symtab.lookup insttab) s | t => t)
-        |> (map o map_types o map_atyps) (fn TFree _ => TFree (Name.aT, [class1]) | T => T)
-        |> map (ObjectLogic.ensure_propT thy)
-      end;
-    val (prems, concls) = pairself mk_axioms (class1, class2);
-  in
-    Goal.prove_global thy [] prems (Logic.mk_conjunction_list concls)
-      (Locale.intro_locales_tac true ctxt)
+    thy
+    |> fold_map (get_remove_global_constraint o fst o snd) params
+    ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class) (inst, defs)
+    |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
   end;
 
 fun intro_classes_tac facts st =
@@ -643,7 +635,7 @@
     val supclasses = map (prep_class thy) raw_supclasses;
     val (sups, local_sort) = sups_local_sort thy supclasses;
     val supsort = Sign.minimize_sort thy supclasses;
-    val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
+    val suplocales = map Locale.Locale sups;
     val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
       | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
     val supexpr = Locale.Merge suplocales;
@@ -695,18 +687,17 @@
 fun gen_class prep_spec prep_param bname
     raw_supclasses raw_includes_elems raw_other_consts thy =
   let
+    val class = Sign.full_name thy bname;
     val (((sups, supconsts), (supsort, local_sort, mergeexpr)), elems) =
       prep_spec thy raw_supclasses raw_includes_elems;
     val other_consts = map (tap (Sign.the_const_type thy) o prep_param thy) raw_other_consts;
-    fun mk_instT class = Symtab.empty
-      |> Symtab.update (Name.aT, TFree (Name.aT, [class]));
     fun mk_inst class param_names cs =
       Symtab.empty
       |> fold2 (fn v => fn (c, ty) => Symtab.update (v, Const
            (c, Term.map_type_tfree (fn (v, _) => TFree (v, [class])) ty))) param_names cs;
-    fun extract_params thy name_locale =
+    fun extract_params thy =
       let
-        val params = Locale.parameters_of thy name_locale;
+        val params = Locale.parameters_of thy class;
         val _ = if Sign.subsort thy (supsort, local_sort) then () else error
           ("Sort " ^ Sign.string_of_sort thy local_sort
             ^ " is less general than permitted least general sort "
@@ -718,7 +709,7 @@
         |> chop (length supconsts)
         |> snd)
       end;
-    fun extract_assumes name_locale params thy cs =
+    fun extract_assumes params thy cs =
       let
         val consts = supconsts @ (map (fst o fst) params ~~ cs);
         fun subst (Free (c, ty)) =
@@ -728,32 +719,32 @@
           ((Sign.base_name name, map (Attrib.attribute_i thy) atts),
             (map o map_aterms) subst ts);
       in
-        Locale.global_asms_of thy name_locale
+        Locale.global_asms_of thy class
         |> map prep_asm
       end;
-    fun note_intro name_axclass class_intro =
-      PureThy.note_thmss_qualified "" (class_prefix name_axclass)
+    fun note_intro class_intro =
+      PureThy.note_thmss_qualified "" (class_prefix class)
         [(("intro", []), [([class_intro], [])])]
       #> snd
   in
     thy
     |> Locale.add_locale_i (SOME "") bname mergeexpr elems
-    |-> (fn name_locale => ProofContext.theory_result (
-      `(fn thy => extract_params thy name_locale)
+    |> snd
+    |> ProofContext.theory (`extract_params
       #-> (fn (globals, params) =>
         define_class_params (bname, supsort) params
-          (extract_assumes name_locale params) other_consts
-      #-> (fn (name_axclass, (consts, axioms)) =>
-        `(fn thy => class_intro thy name_locale name_axclass sups)
+          (extract_assumes params) other_consts
+      #-> (fn (_, (consts, axioms)) =>
+        `(fn thy => class_intro thy class sups)
       #-> (fn class_intro =>
-        add_class_data ((name_axclass, sups),
-          (name_locale, map fst params ~~ consts, local_sort,
-            (mk_instT name_axclass, mk_inst name_axclass (map fst globals)
-              (map snd supconsts @ consts)), class_intro))
-      #> note_intro name_axclass class_intro
-      #> class_interpretation name_axclass axioms []
-      #> pair name_axclass
-      )))))
+        add_class_data ((class, sups),
+          (map fst params ~~ consts, local_sort,
+            mk_inst class (map fst globals) (map snd supconsts @ consts),
+              class_intro))
+      #> note_intro class_intro
+      #> class_interpretation class axioms []
+      ))))
+    |> pair class
   end;
 
 in
@@ -782,8 +773,9 @@
     val typidx = find_index (fn TFree (w, _) => Name.aT = w | _ => false) typargs;
   in (c, (rhs, (ty, typidx))) end;
 
-fun add_const_in_class class ((c, mx), rhs) thy =
+fun add_const_in_class class ((c, mx), (c_loc, rhs)) thy =
   let
+    val _ = tracing c_loc;
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
 
@@ -793,6 +785,7 @@
     val ty' = Term.fastype_of rhs';
     val ty'' = subst_typ ty';
     val c' = Sign.full_name thy' c;
+    val c'' = NameSpace.full (Sign.naming_of thy' |> NameSpace.add_path prfx) c;
     val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
     val (mx', _) = fork_mixfix true true mx;
     fun interpret def thy =
@@ -808,6 +801,7 @@
       end;
   in
     thy'
+    |> Sign.hide_consts_i false [c'']
     |> Sign.declare_const [] (c, ty', mx') |> snd
     |> Sign.parent_path
     |> Sign.sticky_prefix prfx
@@ -820,10 +814,12 @@
 
 (* abbreviation in class target *)
 
-fun add_abbrev_in_class class prmode ((c, mx), rhs) thy =
+fun add_abbrev_in_class class prmode ((c, mx), (c_loc, rhs)) thy =
   let
+    val _ = tracing c_loc;
     val prfx = class_prefix class;
-    val naming = Sign.naming_of thy |> NameSpace.add_path prfx |> NameSpace.add_path prfx; (* FIXME !? *)
+    val naming = Sign.naming_of thy |> NameSpace.add_path prfx |> NameSpace.add_path prfx;
+      (*FIXME*)
     val c' = NameSpace.full naming c;
     val rhs' = export_fixes thy class rhs;
     val ty' = Term.fastype_of rhs';