completed class parameter handling in axclass.ML
authorhaftmann
Wed, 22 Nov 2006 10:22:04 +0100
changeset 21463 42dd50268c8b
parent 21462 74ddf3a522f8
child 21464 abaf43b011ee
completed class parameter handling in axclass.ML
src/Pure/Tools/class_package.ML
src/Pure/Tools/codegen_consts.ML
src/Pure/Tools/codegen_funcgr.ML
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/axclass.ML
--- a/src/Pure/Tools/class_package.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/Tools/class_package.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -28,9 +28,6 @@
   val certify_sort: theory -> sort -> sort
   val read_class: theory -> xstring -> class
   val read_sort: theory -> string -> sort
-  val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
-  val the_consts_sign: theory -> class -> string * (string * typ) list
-  val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
   val assume_arities_of_sort: theory -> ((string * sort list) * sort) list -> typ * sort -> bool
   val assume_arities_thy: theory -> ((string * sort list) * sort) list -> (theory -> 'a) -> 'a
     (*'a must not keep any reference to theory*)
@@ -52,8 +49,6 @@
   var: string,
   consts: (string * (string * typ)) list
     (*locale parameter ~> toplevel theory constant*),
-  operational: bool (* == at least one class operation,
-    or at least two operational superclasses *),
   propnames: string list
 } * thm list Symtab.table;
 
@@ -123,14 +118,13 @@
 
 (* updaters *)
 
-fun add_class_data (class, (name_locale, name_axclass, var, consts, operational, propnames)) =
+fun add_class_data (class, (name_locale, name_axclass, var, consts, propnames)) =
   ClassData.map (
     Symtab.update_new (class, ClassData ({
       name_locale = name_locale,
       name_axclass = name_axclass,
       var = var,
       consts = consts,
-      operational = operational,
       propnames = propnames}, Symtab.empty))
   );
 
@@ -314,9 +308,6 @@
       |> Sign.add_const_constraint_i (c, SOME (subst_clsvar v (TFree (v, [class])) ty));
     fun mk_const thy class v (c, ty) =
       Const (c, subst_clsvar v (TFree (v, [class])) ty);
-    fun is_operational thy mapp_this =
-      length mapp_this > 0 orelse
-        length (filter (#operational o fst o the o lookup_class_data thy) supclasses) > 1;
   in
     thy
     |> add_locale bname expr elems
@@ -330,9 +321,8 @@
           add_axclass_i (bname, supsort) (map (fst o snd) mapp_this) loc_axioms
     #-> (fn (name_axclass, (_, ax_axioms)) =>
           fold (add_global_constraint v name_axclass) mapp_this
-    #> `(fn thy => is_operational thy mapp_this)
-    #-> (fn operational => add_class_data (name_locale, (name_locale, name_axclass, v, mapp_this,
-          operational, map (fst o fst) loc_axioms)))
+    #> add_class_data (name_locale, (name_locale, name_axclass, v, mapp_this,
+          map (fst o fst) loc_axioms))
     #> prove_interpretation_i (bname, [])
           (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (map snd (mapp_sup @ mapp_this)))
           ((ALLGOALS o ProofContext.fact_tac) ax_axioms)
@@ -384,13 +374,12 @@
         (Defs.specifications_of (Theory.defs_of theory) c));
     fun get_consts_class tyco ty class =
       let
-        val data = (fst o the_class_data theory) class;
-        val subst_ty = map_type_tfree (fn (v, sort) =>
-          if #var data = v then ty else TVar ((v, 0), sort));
+        val (_, cs) = AxClass.params_of_class theory class;
+        val subst_ty = map_type_tfree (K ty);
       in
-        (map_filter (fn (_, (c, ty)) =>
+        map_filter (fn (c, ty) =>
           if already_defined (c, ty)
-          then NONE else SOME ((c, ((tyco, class), subst_ty ty)))) o #consts) data
+          then NONE else SOME ((c, ((tyco, class), subst_ty ty)))) cs
       end;
     fun get_consts_sort (tyco, asorts, sort) =
       let
@@ -438,7 +427,8 @@
         val bind = bind_always orelse not (can (PureThy.get_thms thy) (Name name));
         val thms = maps (fn (tyco, _, sort) => maps (fn class =>
           Symtab.lookup_list
-            ((snd o the_class_data thy) class) tyco) (the_ancestry thy sort)) arities;
+            ((the_default Symtab.empty o Option.map snd o try (the_class_data thy)) class) tyco)
+            (the_ancestry thy sort)) arities;
       in if bind then
         thy
         |> PureThy.note_thmss_i (*qualified*) Thm.internalK [((name, atts), [(thms, [])])]
@@ -528,42 +518,6 @@
 end; (* local *)
 
 
-
-(** code generation view **)
-
-fun is_operational_class thy class =
-  the_default false ((Option.map (#operational o fst) o lookup_class_data thy) class);
-
-fun operational_algebra thy =
-  Sorts.project_algebra (Sign.pp thy)
-    (is_operational_class thy) (Sign.classes_of thy);
-
-fun the_consts_sign thy class =
-  let
-    val _ = if is_operational_class thy class then () else error ("no operational class: " ^ quote class);
-    val data = (fst o the_class_data thy) class
-  in (#var data, (map snd o #consts) data) end;
-
-fun the_inst_sign thy (class, tyco) =
-  let
-    val _ = if is_operational_class thy class then () else error ("no operational class: " ^ quote class);
-    val asorts = Sign.arity_sorts thy tyco [class];
-    val (clsvar, const_sign) = the_consts_sign thy class;
-    fun add_var sort used =
-      let val v = hd (Name.invents used "'a" 1);
-      in ((v, sort), Name.declare v used) end;
-    val (vsorts, _) =
-      Name.context
-      |> Name.declare clsvar
-      |> fold (fn (_, ty) => fold Name.declare
-           ((map (fst o fst) o typ_tvars) ty @ map fst (Term.add_tfreesT  ty []))) const_sign
-      |> fold_map add_var asorts;
-    val ty_inst = Type (tyco, map TFree vsorts);
-    val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
-  in (vsorts, inst_signs) end;
-
-
-
 (** toplevel interface **)
 
 local
--- a/src/Pure/Tools/codegen_consts.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/Tools/codegen_consts.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -19,6 +19,8 @@
   val norm_of_typ: theory -> string * typ -> const
   val find_def: theory -> const
     -> ((string (*theory name*) * thm) * typ list) option
+  val instance_dict: theory -> class * string
+    -> (string * sort) list * (string * typ) list
   val disc_typ_of_classop: theory -> const -> typ
   val disc_typ_of_const: theory -> (const -> typ) -> const -> typ
   val consts_of: theory -> term -> const list
@@ -45,7 +47,7 @@
   );
 
 
-(* type instantiations and overloading *)
+(* type instantiations, overloading, dictionary values *)
 
 fun inst_of_typ thy (c_ty as (c, ty)) =
   (c, Consts.typargs (Sign.consts_of thy) c_ty);
@@ -95,26 +97,35 @@
 fun norm_of_typ thy (c, ty) =
   norm thy (c, Consts.typargs (Sign.consts_of thy) (c, ty));
 
+fun instance_dict thy (class, tyco) =
+  let
+    val (var, cs) = AxClass.params_of_class thy class;
+    val sort_args = Name.names (Name.declare var Name.context) "'a"
+      (Sign.arity_sorts thy tyco [class]);
+    val ty_inst = Type (tyco, map TFree sort_args);
+    val inst_signs = (map o apsnd o map_type_tfree) (K ty_inst) cs;
+  in (sort_args, inst_signs) end;
+
 fun disc_typ_of_classop thy (c, [TVar _]) = 
       let
         val class = (the o AxClass.class_of_param thy) c;
-        val (v, cs) = ClassPackage.the_consts_sign thy class
+        val (v, cs) = AxClass.params_of_class thy class;
       in
-        (Logic.varifyT o map_type_tfree (fn u as (w, _) => if w = v then TFree (v, [class]) else TFree u))
+        (Logic.varifyT o map_type_tfree (K (TFree (v, [class]))))
           ((the o AList.lookup (op =) cs) c)
       end
   | disc_typ_of_classop thy (c, [TFree _]) = 
       let
         val class = (the o AxClass.class_of_param thy) c;
-        val (v, cs) = ClassPackage.the_consts_sign thy class
+        val (v, cs) = AxClass.params_of_class thy class;
       in
-        (Logic.varifyT o map_type_tfree (fn u as (w, _) => if w = v then TFree (v, [class]) else TFree u))
+        (Logic.varifyT o map_type_tfree (K (TFree (v, [class]))))
           ((the o AList.lookup (op =) cs) c)
       end
   | disc_typ_of_classop thy (c, [Type (tyco, _)]) =
       let
         val class = (the o AxClass.class_of_param thy) c;
-        val (_, cs) = ClassPackage.the_inst_sign thy (class, tyco);
+        val (_, cs) = instance_dict thy (class, tyco);
       in
         Logic.varifyT ((the o AList.lookup (op =) cs) c)
       end;
--- a/src/Pure/Tools/codegen_funcgr.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -41,7 +41,7 @@
   fun merge _ _ = Constgraph.empty;
   fun purge _ NONE _ = Constgraph.empty
     | purge _ (SOME cs) funcgr =
-        Constgraph.del_nodes ((Constgraph.all_succs funcgr 
+        Constgraph.del_nodes ((Constgraph.all_preds funcgr 
           o filter (can (Constgraph.get_node funcgr))) cs) funcgr;
 end);
 
@@ -186,8 +186,10 @@
   end;
 
 fun all_classops thy tyco class =
-  AxClass.params_of thy class
-  |> AList.make (fn c => CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])]))
+  try (AxClass.params_of_class thy) class
+  |> Option.map snd
+  |> these
+  |> map (fn (c, _) => (c, CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])])))
   |> map (CodegenConsts.norm_of_typ thy);
 
 fun instdefs_of thy insts =
--- a/src/Pure/Tools/codegen_package.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -49,14 +49,11 @@
         let
           val cs_exisiting =
             map_filter (CodegenNames.const_rev thy) (Graph.keys code);
-        in
-          Graph.del_nodes
-            ((Graph.all_succs code
+          val dels = (Graph.all_preds code
               o map (CodegenNames.const thy)
               o filter (member CodegenConsts.eq_const cs_exisiting)
-              ) cs)
-            code
-        end;
+            ) cs;
+        in Graph.del_nodes dels code end;
 end);
 
 type appgen = theory -> ((sort -> sort) * Sorts.algebra) * Consts.T
@@ -105,7 +102,7 @@
 
 fun ensure_def_class thy (algbr as ((proj_sort, _), _)) funcgr strct class trns =
   let
-    val (v, cs) = (ClassPackage.the_consts_sign thy) class;
+    val (v, cs) = (AxClass.params_of_class thy) class;
     val superclasses = (proj_sort o Sign.super_classes thy) class;
     val classops' = map (CodegenNames.const thy o CodegenConsts.norm_of_typ thy) cs;
     val class' = CodegenNames.class thy class;
@@ -225,10 +222,10 @@
   end
 and ensure_def_inst thy (algbr as ((proj_sort, _), _)) funcgr strct (class, tyco) trns =
   let
-    val (vs, classop_defs) = ((apsnd o map) Const o ClassPackage.the_inst_sign thy)
+    val (vs, classop_defs) = ((apsnd o map) Const o CodegenConsts.instance_dict thy)
       (class, tyco);
     val classops = (map (CodegenConsts.norm_of_typ thy) o snd
-      o ClassPackage.the_consts_sign thy) class;
+      o AxClass.params_of_class thy) class;
     val arity_typ = Type (tyco, (map TFree vs));
     val superclasses = (proj_sort o Sign.super_classes thy) class
     fun gen_superarity superclass trns =
@@ -575,7 +572,8 @@
       (CodegenFuncgr.all funcgr);
     val funcgr' = CodegenFuncgr.make thy cs;
     val qnaming = NameSpace.qualified_names NameSpace.default_naming;
-    val algebr = ClassPackage.operational_algebra thy;
+    val algebr = Sorts.project_algebra (Sign.pp thy)
+      (the_default false o Option.map #operational o try (AxClass.get_definition thy)) (Sign.classes_of thy);
     val consttab = Consts.empty
       |> fold (fn c => Consts.declare qnaming
            ((CodegenNames.const thy c, CodegenFuncgr.typ funcgr' c), true))
--- a/src/Pure/Tools/codegen_serializer.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -1406,6 +1406,12 @@
     thy
   end;*)
 
+fun read_class thy raw_class =
+  let
+    val class = Sign.intern_class thy raw_class;
+    val _ = AxClass.get_definition thy class;
+  in class end;
+
 fun read_type thy raw_tyco =
   let
     val tyco = Sign.intern_type thy raw_tyco;
@@ -1419,8 +1425,8 @@
     val cs'' = map (CodegenConsts.norm_of_typ thy) cs';
   in AList.make (CodegenNames.const thy) cs'' end;
 
-val add_syntax_class = gen_add_syntax_class ClassPackage.read_class CodegenConsts.read_const;
-val add_syntax_inst = gen_add_syntax_inst ClassPackage.read_class read_type;
+val add_syntax_class = gen_add_syntax_class read_class CodegenConsts.read_const;
+val add_syntax_inst = gen_add_syntax_inst read_class read_type;
 val add_syntax_tyco = gen_add_syntax_tyco read_type;
 val add_syntax_const = gen_add_syntax_const CodegenConsts.read_const;
 
@@ -1558,7 +1564,7 @@
 
 val _ = Context.add_setup (
   gen_add_syntax_tyco (K I) "SML" "fun" (SOME (2, fn fxy => fn pr_typ => fn [ty1, ty2] =>
-      (gen_brackify (case fxy of BR => false | _ => eval_fxy (INFX (1, R)) fxy) o Pretty.breaks) [
+      (gen_brackify (case fxy of NOBR => false | _ => eval_fxy (INFX (1, R)) fxy) o Pretty.breaks) [
         pr_typ (INFX (1, X)) ty1,
         str "->",
         pr_typ (INFX (1, R)) ty2
--- a/src/Pure/axclass.ML	Wed Nov 22 10:21:17 2006 +0100
+++ b/src/Pure/axclass.ML	Wed Nov 22 10:22:04 2006 +0100
@@ -8,11 +8,11 @@
 
 signature AX_CLASS =
 sig
-  val get_definition: theory -> class -> {def: thm, intro: thm, axioms: thm list}
+  val get_definition: theory -> class -> {def: thm, intro: thm, axioms: thm list,
+    params: (string * typ) list, operational: bool}
   val class_intros: theory -> thm list
-  val params_of: theory -> class -> string list
-  val all_params_of: theory -> sort -> string list
   val class_of_param: theory -> string -> class option
+  val params_of_class: theory -> class -> string * (string * typ) list
   val print_axclasses: theory -> unit
   val cert_classrel: theory -> class * class -> class * class
   val read_classrel: theory -> xstring * xstring -> class * class
@@ -62,15 +62,21 @@
 val superN = "super";
 val axiomsN = "axioms";
 
+val param_tyvarname = "'a";
+
 datatype axclass = AxClass of
  {def: thm,
   intro: thm,
-  axioms: thm list};
+  axioms: thm list,
+  params: (string * typ) list,
+  operational: bool (* == at least one class operation,
+    or at least two operational superclasses *)};
 
 type axclasses = axclass Symtab.table * param list;
 
-fun make_axclass (def, intro, axioms) =
-  AxClass {def = def, intro = intro, axioms = axioms};
+fun make_axclass ((def, intro, axioms), (params, operational)) =
+  AxClass {def = def, intro = intro, axioms = axioms, params = params,
+    operational = operational};
 
 fun merge_axclasses pp ((tab1, params1), (tab2, params2)) : axclasses =
   (Symtab.merge (K true) (tab1, tab2), merge_params pp (params1, params2));
@@ -137,6 +143,8 @@
 fun class_of_param thy =
   AList.lookup (op =) (#2 (get_axclasses thy));
 
+fun params_of_class thy class =
+  (param_tyvarname, (#params o get_definition thy) class);
 
 (* maintain instances *)
 
@@ -169,7 +177,7 @@
     val axclasses = #1 (get_axclasses thy);
     val ctxt = ProofContext.init thy;
 
-    fun pretty_axclass (class, AxClass {def, intro, axioms}) =
+    fun pretty_axclass (class, AxClass {def, intro, axioms, params, operational}) =
       Pretty.block (Pretty.fbreaks
        [Pretty.block
           [Pretty.str "class ", ProofContext.pretty_sort ctxt [class], Pretty.str ":"],
@@ -264,8 +272,16 @@
 
 local
 
-fun def_class prep_class prep_att prep_propp
-    (bclass, raw_super) params raw_specs thy =
+fun read_param thy raw_t =
+  let
+    val t = Sign.read_term thy raw_t
+  in case try dest_Const t
+   of SOME (c, _) => c
+    | NONE => error ("Not a constant: " ^ Sign.string_of_term thy t)
+  end;
+
+fun def_class prep_class prep_att prep_param prep_propp
+    (bclass, raw_super) raw_params raw_specs thy =
   let
     val ctxt = ProofContext.init thy;
     val pp = ProofContext.pp ctxt;
@@ -318,6 +334,20 @@
          ((superN, []), [(map Drule.standard raw_classrel, [])]),
          ((axiomsN, []), [(map (fn th => Drule.standard (class_triv RS th)) raw_axioms, [])])];
 
+    (* params *)
+
+    val params = map (prep_param thy) raw_params;
+    val params_typs = map (fn param =>
+      let
+        val ty = Sign.the_const_type thy param;
+        val var = case Term.typ_tvars ty
+         of [(v, _)] => v
+          | _ => error ("exactly one type variable required in parameter " ^ quote param);
+        val ty' = Term.typ_subst_TVars [(var, TFree (param_tyvarname, []))] ty;
+      in (param, ty') end) params;
+    val operational = length params_typs > 0 orelse
+      length (filter (the_default false o Option.map
+        (fn AxClass { operational, ... } => operational) o lookup_def thy) super) > 1;
 
     (* result *)
 
@@ -328,15 +358,15 @@
       |> PureThy.note_thmss_i "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
       |> Sign.restore_naming facts_thy
       |> map_axclasses (fn (axclasses, parameters) =>
-        (Symtab.update (class, make_axclass (def, intro, axioms)) axclasses,
+        (Symtab.update (class, make_axclass ((def, intro, axioms), (params_typs, operational))) axclasses,
           fold (fn x => add_param pp (x, class)) params parameters));
 
   in (class, result_thy) end;
 
 in
 
-val define_class = def_class Sign.read_class Attrib.attribute ProofContext.read_propp;
-val define_class_i = def_class Sign.certify_class (K I) ProofContext.cert_propp;
+val define_class = def_class Sign.read_class Attrib.attribute read_param ProofContext.read_propp;
+val define_class_i = def_class Sign.certify_class (K I) (K I) ProofContext.cert_propp;
 
 end;