src/Pure/Tools/class_package.ML
changeset 19280 5091dc43817b
parent 19253 f3ce97b5661a
child 19282 89949d8652c3
--- a/src/Pure/Tools/class_package.ML	Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/class_package.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -65,7 +65,8 @@
   name_axclass: string,
   intro: thm option,
   var: string,
-  consts: (string * typ) list
+  consts: (string * (string * typ)) list
+    (*locale parameter ~> toplevel const*)
 };
 
 structure ClassData = TheoryDataFun (
@@ -95,7 +96,7 @@
             Pretty.str ("class variable: " ^ var),
             (Pretty.block o Pretty.fbreaks) (
               Pretty.str "constants: "
-              :: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
+              :: map (fn (_, (c, ty)) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
             )
           ]
       in
@@ -120,7 +121,9 @@
     of NONE => error ("undeclared operational class " ^ quote class)
      | SOME data => data;
 
-fun is_class thy cls =
+val is_class = is_some oo lookup_class_data;
+
+fun is_operational_class thy cls =
   lookup_class_data thy cls
   |> Option.map (not o null o #consts)
   |> the_default false;
@@ -129,7 +132,7 @@
   let
     val classes = Sign.classes_of thy;
     fun get_sort class =
-      if is_class thy class
+      if is_operational_class thy class
       then [class]
       else operational_sort_of thy (Sorts.superclasses classes class);
   in
@@ -144,14 +147,14 @@
     Sorts.superclasses (Sign.classes_of thy) class
     |> operational_sort_of thy
   else
-    error ("no syntactic class: " ^ class);
+    error ("no class: " ^ class);
 
 fun get_superclass_derivation thy (subclass, superclass) =
   if subclass = superclass
     then SOME [subclass]
     else case Graph.find_paths ((fst o fst o ClassData.get) thy) (subclass, superclass)
       of [] => NONE
-       | (p::_) => (SOME o filter (is_class thy)) p;
+       | (p::_) => (SOME o filter (is_operational_class thy)) p;
 
 fun the_ancestry thy classes =
   let
@@ -170,14 +173,19 @@
   map_type_tfree (fn u as (w, _) =>
     if w = v then ty_subst else TFree u);
 
+fun the_parm_map thy class =
+  let
+    val data = the_class_data thy class
+  in (#consts data) end;
+
 fun the_consts_sign thy class =
   let
     val data = the_class_data thy class
-  in (#var data, #consts data) end;
+  in (#var data, (map snd o #consts) data) end;
 
 fun the_inst_sign thy (class, tyco) =
   let
-    val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
+    val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class);
     val arity =
       Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class];
     val clsvar = (#var o the_class_data thy) class;
@@ -216,13 +224,16 @@
     tab
     |> Symtab.update (class, [])),
     consttab
-    |> fold (fn (c, _) => Symtab.update (c, class)) consts
+    |> fold (fn (_, (c, _)) => Symtab.update (c, class)) consts
   ));
 
 fun add_inst_data (class, inst) =
   ClassData.map (fn ((gr, tab), consttab) =>
-     ((gr, tab |>
-    (Symtab.map_entry class (AList.update (op =) inst))), consttab));
+    let
+      val undef_supclasses = class :: (filter (Symtab.defined tab) (Graph.all_succs gr [class]));
+    in
+     ((gr, tab |> fold (fn class => Symtab.map_entry class (AList.update (op =) inst)) undef_supclasses), consttab)
+    end);
 
 
 (* name handling *)
@@ -234,7 +245,7 @@
   map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort);
 
 fun intern_class thy =
-  certify_class thy o Sign.intern_class thy;
+certify_class thy o Sign.intern_class thy;
 
 fun intern_sort thy =
   certify_sort thy o Sign.intern_sort thy;
@@ -356,18 +367,19 @@
       |> map (#name_axclass o the_class_data thy)
       |> Sorts.certify_sort (Sign.classes_of thy)
       |> null ? K (Sign.defaultS thy);
-    val supcs = (Library.flat o map (snd o the_consts_sign thy) o the_ancestry thy)
-      supclasses;
     val expr = if null supclasses
       then Locale.empty
       else
        (Locale.Merge o map (Locale.Locale o #name_locale o the_class_data thy)) supclasses;
+    val mapp_sup = AList.make
+      (the o AList.lookup (op =) ((Library.flat o map (the_parm_map thy) o the_ancestry thy) supclasses))
+      ((map (fst o fst) o Locale.parameters_of_expr thy) expr);
     fun extract_tyvar_consts thy name_locale =
       let
         fun extract_tyvar_name thy tys =
           fold (curry add_typ_tfrees) tys []
           |> (fn [(v, sort)] =>
-                    if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
+              if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
                     then v
                     else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
                | [] => error ("no class type variable")
@@ -377,10 +389,9 @@
           |> map (apsnd Syntax.unlocalize_mixfix)
         val v = (extract_tyvar_name thy o map (snd o fst)) consts1;
         val consts2 = map ((apfst o apsnd) (subst_clsvar v (TFree (v, [])))) consts1;
-      in (v, chop (length supcs) consts2) end;
+      in (v, chop (length mapp_sup) consts2) end;
     fun add_consts v raw_cs_sup raw_cs_this thy =
       let
-        val mapp_sub = map2 (fn ((c, _), _) => pair c) raw_cs_sup supcs
         fun add_global_const ((c, ty), syn) thy =
           thy
           |> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
@@ -388,7 +399,6 @@
       in
         thy
         |> fold_map add_global_const raw_cs_this
-        |-> (fn mapp_this => pair (mapp_sub @ mapp_this, map snd mapp_this))
       end;
     fun extract_assumes thy name_locale cs_mapp =
       let
@@ -400,7 +410,7 @@
       in
         (map prep_asm o Locale.local_asms_of thy) name_locale
       end;
-    fun add_global_constraint v class (c, ty) thy =
+    fun add_global_constraint v class (_, (c, ty)) thy =
       thy
       |> Sign.add_const_constraint_i (c, SOME (subst_clsvar v (TFree (v, [class])) ty));
     fun mk_const thy class v (c, ty) =
@@ -412,15 +422,15 @@
           `(fn thy => extract_tyvar_consts thy name_locale)
     #-> (fn (v, (raw_cs_sup, raw_cs_this)) =>
           add_consts v raw_cs_sup raw_cs_this
-    #-> (fn (cs_map, cs_this) =>
-          `(fn thy => extract_assumes thy name_locale cs_map)
+    #-> (fn mapp_this =>
+          `(fn thy => extract_assumes thy name_locale (mapp_sup @ mapp_this))
     #-> (fn loc_axioms =>
           add_axclass_i (bname, supsort) loc_axioms
     #-> (fn (name_axclass, (_, ax_axioms)) =>
-          fold (add_global_constraint v name_axclass) cs_this
-    #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, cs_this))
+          fold (add_global_constraint v name_axclass) mapp_this
+    #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, mapp_this))
     #> prove_interpretation_i (NameSpace.base name_locale, [])
-          (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (supcs @ cs_this))
+          (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (map snd (mapp_sup @ mapp_this)))
           ((ALLGOALS o resolve_tac) ax_axioms)
     #> pair ctxt
     )))))
@@ -490,7 +500,7 @@
         val data = the_class_data theory class;
         val subst_ty = map_type_tfree (fn (var as (v, _)) =>
           if #var data = v then ty_inst else TFree var)
-      in (map (apsnd subst_ty) o #consts) data end;
+      in (map (apsnd subst_ty o snd) o #consts) data end;
     val cs = (Library.flat o map get_consts) classes;
     fun get_remove_contraint c thy =
       let
@@ -570,7 +580,7 @@
     val _ = writeln ("sub " ^ name)
     val suplocales = (fn Locale.Merge es => map (fn Locale.Locale n => n) es) expr;
     val _ = writeln ("super " ^ commas suplocales)
-    fun get_c name = 
+    fun get_c name =
       (map (NameSpace.base o fst o fst) o Locale.parameters_of thy) name;
     fun get_a name =
       (map (NameSpace.base o fst o fst) o Locale.local_asms_of thy) name;
@@ -663,7 +673,7 @@
     fun mk_lookup (sort_def, (Type (tyco, tys))) =
           map (fn class => Instance ((class, tyco),
             map2 (curry mk_lookup)
-              ((fst o the o AList.lookup (op =) (the_instances thy class)) tyco)
+              (map (operational_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]))
               tys)
           ) sort_def
       | mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
@@ -673,7 +683,7 @@
               in Lookup (deriv, (vname, classindex)) end;
           in map mk_look sort_def end;
   in
-    sortctxt
+ sortctxt
     |> map (tab_lookup o fst)
     |> map (apfst (operational_sort_of thy))
     |> filter (not o null o fst)
@@ -690,7 +700,7 @@
         | SOME class =>
             let
               val data = the_class_data thy class;
-              val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
+              val sign = (Type.varifyT o the o AList.lookup (op =) ((map snd o #consts) data)) c;
               val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
               val v : string = case Vartab.lookup match_tab (#var data, 0)
                 of SOME (_, TVar ((v, _), _)) => v;
@@ -751,13 +761,18 @@
     Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
   Scan.repeat1 P.context_element >> pair Locale.empty);
 
+val class_subP = P.name -- Scan.repeat (P.$$$ "+" |-- P.name) >> (op ::);
+val class_bodyP = P.!!! (Scan.repeat1 P.context_element);
+
 val classP =
   OuterSyntax.command classK "operational type classes" K.thy_decl (
     P.name --| P.$$$ "="
-    -- Scan.optional (Scan.repeat1 (P.name --| P.$$$ "+")) []
-    -- Scan.optional (P.!!! (Scan.repeat1 P.context_element)) []
-      >> (Toplevel.theory_context
-          o (fn ((bname, supclasses), elems) => class bname supclasses elems)));
+    -- (
+      class_subP --| P.$$$ "+" -- class_bodyP
+      || class_subP >> rpair []
+      || class_bodyP >> pair []
+    ) >> (Toplevel.theory_context
+          o (fn (bname, (supclasses, elems)) => class bname supclasses elems)));
 
 val instanceP =
   OuterSyntax.command instanceK "prove type arity or subclass relation" K.thy_goal ((