src/Pure/Tools/class_package.ML
author wenzelm
Sat Jan 21 23:02:14 2006 +0100 (2006-01-21)
changeset 18728 6790126ab5f6
parent 18708 4b3dadb4fe33
child 18755 eb3733779aa8
permissions -rw-r--r--
simplified type attribute;
     1 (*  Title:      Pure/Tools/class_package.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Haskell98-like operational view on type classes.
     6 *)
     7 
     8 signature CLASS_PACKAGE =
     9 sig
    10   val add_class: bstring -> Locale.expr -> Element.context list -> theory
    11     -> ProofContext.context * theory
    12   val add_class_i: bstring -> Locale.expr -> Element.context_i list -> theory
    13     -> ProofContext.context * theory
    14   val add_instance_arity: (xstring * string list) * string
    15     -> ((bstring * string) * Attrib.src list) list
    16     -> theory -> Proof.state
    17   val add_instance_arity_i: (string * sort list) * sort
    18     -> ((bstring * term) * attribute list) list
    19     -> theory -> Proof.state
    20   val add_classentry: class -> xstring list -> xstring list -> theory -> theory
    21 
    22   val syntactic_sort_of: theory -> sort -> sort
    23   val the_superclasses: theory -> class -> class list
    24   val the_consts_sign: theory -> class -> string * (string * typ) list
    25   val lookup_const_class: theory -> string -> class option
    26   val the_instances: theory -> class -> (string * string) list
    27   val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
    28   val get_classtab: theory -> (string list * (string * string) list) Symtab.table
    29   val print_classes: theory -> unit
    30 
    31   type sortcontext = (string * sort) list
    32   datatype sortlookup = Instance of (class * string) * sortlookup list list
    33                       | Lookup of class list * (string * int)
    34   val extract_sortctxt: theory -> typ -> sortcontext
    35   val extract_sortlookup: theory -> string * typ -> sortlookup list list
    36 end;
    37 
    38 structure ClassPackage: CLASS_PACKAGE =
    39 struct
    40 
    41 
    42 (* theory data *)
    43 
    44 type class_data = {
    45   superclasses: class list,
    46   name_locale: string,
    47   name_axclass: string,
    48   var: string,
    49   consts: (string * typ) list,
    50   insts: (string * string) list
    51 };
    52 
    53 structure ClassData = TheoryDataFun (
    54   struct
    55     val name = "Pure/classes";
    56     type T = class_data Symtab.table * class Symtab.table;
    57     val empty = (Symtab.empty, Symtab.empty);
    58     val copy = I;
    59     val extend = I;
    60     fun merge _ ((t1, r1), (t2, r2))=
    61       (Symtab.merge (op =) (t1, t2),
    62        Symtab.merge (op =) (r1, r2));
    63     fun print thy (tab, _) =
    64       let
    65         fun pretty_class (name, {superclasses, name_locale, name_axclass, var, consts, insts}) =
    66           (Pretty.block o Pretty.fbreaks) [
    67             Pretty.str ("class " ^ name ^ ":"),
    68             (Pretty.block o Pretty.fbreaks) (
    69               Pretty.str "superclasses: "
    70               :: map Pretty.str superclasses
    71             ),
    72             Pretty.str ("locale: " ^ name_locale),
    73             Pretty.str ("axclass: " ^ name_axclass),
    74             Pretty.str ("class variable: " ^ var),
    75             (Pretty.block o Pretty.fbreaks) (
    76               Pretty.str "constants: "
    77               :: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
    78             ),
    79             (Pretty.block o Pretty.fbreaks) (
    80               Pretty.str "instances: "
    81               :: map (fn (tyco, thyname) => Pretty.str (tyco ^ ", in theory " ^ thyname)) insts
    82             )
    83           ]
    84       in
    85         (Pretty.writeln o Pretty.chunks o map pretty_class o Symtab.dest) tab
    86       end;
    87   end
    88 );
    89 
    90 val _ = Context.add_setup ClassData.init;
    91 val print_classes = ClassData.print;
    92 
    93 val lookup_class_data = Symtab.lookup o fst o ClassData.get;
    94 val lookup_const_class = Symtab.lookup o snd o ClassData.get;
    95 
    96 fun get_class_data thy class =
    97   case lookup_class_data thy class
    98     of NONE => error ("undeclared class " ^ quote class)
    99      | SOME data => data;
   100 
   101 fun add_class_data (class, (superclasses, name_locale, name_axclass, classvar, consts)) =
   102   ClassData.map (fn (classtab, consttab) => (
   103     classtab 
   104     |> Symtab.update (class, {
   105          superclasses = superclasses,
   106          name_locale = name_locale,
   107          name_axclass = name_axclass,
   108          var = classvar,
   109          consts = consts,
   110          insts = []
   111        }),
   112     consttab
   113     |> fold (fn (c, _) => Symtab.update (c, class)) consts
   114   ));
   115 
   116 fun add_inst_data (class, inst) =
   117   (ClassData.map o apfst o Symtab.map_entry class)
   118     (fn {superclasses, name_locale, name_axclass, var, consts, insts}
   119       => {
   120            superclasses = superclasses,
   121            name_locale = name_locale,
   122            name_axclass = name_axclass,
   123            var = var,
   124            consts = consts,
   125            insts = insts @ [inst]
   126           });
   127 
   128 
   129 (* classes and instances *)
   130 
   131 fun subst_clsvar v ty_subst =
   132   map_type_tfree (fn u as (w, _) =>
   133     if w = v then ty_subst else TFree u);
   134 
   135 local
   136 
   137 open Element
   138 
   139 fun gen_add_class add_locale bname raw_import raw_body thy =
   140   let
   141     fun extract_assumes c_adds elems =
   142       let
   143         fun subst_free ts =
   144           let
   145             val get_ty = the o AList.lookup (op =) (fold Term.add_frees ts []);
   146             val subst_map = map (fn (c, (c', _)) =>
   147               (Free (c, get_ty c), Const (c', get_ty c))) c_adds;
   148           in map (subst_atomic subst_map) ts end;
   149       in
   150         elems
   151         |> (map o List.mapPartial)
   152             (fn (Assumes asms) => (SOME o map (map fst o snd)) asms
   153               | _ => NONE)
   154         |> Library.flat o Library.flat o Library.flat
   155         |> subst_free
   156       end;
   157     fun extract_tyvar_name thy tys =
   158       fold (curry add_typ_tfrees) tys []
   159       |> (fn [(v, sort)] =>
   160                 if Sorts.sort_eq (Sign.classes_of thy) (Sign.defaultS thy, sort)
   161                 then v 
   162                 else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
   163            | [] => error ("no class type variable")
   164            | vs => error ("more than one type variable: " ^ (commas o map (Sign.string_of_typ thy o TFree)) vs))
   165     fun extract_tyvar_consts thy elems =
   166       elems
   167       |> Library.flat
   168       |> List.mapPartial
   169            (fn (Fixes consts) => SOME consts
   170              | _ => NONE)
   171       |> Library.flat
   172       |> map (fn (c, ty, syn) =>
   173            ((c, the ty), (Syntax.unlocalize_mixfix o Syntax.fix_mixfix c) syn))
   174       |> `(fn consts => extract_tyvar_name thy (map (snd o fst) consts))
   175       |-> (fn v => map ((apfst o apsnd) (subst_clsvar v (TFree (v, []))))
   176          #> pair v);
   177     fun add_global_const v ((c, ty), syn) thy =
   178       thy
   179       |> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
   180       |> `(fn thy => (c, (Sign.intern_const thy c, ty)))
   181     fun add_global_constraint v class (_, (c, ty)) thy =
   182       thy
   183       |> Sign.add_const_constraint_i (c, subst_clsvar v (TVar ((v, 0), [class])) ty);
   184     fun print_ctxt ctxt elem = 
   185       map Pretty.writeln (Element.pretty_ctxt ctxt elem)
   186   in
   187     thy
   188     |> add_locale bname raw_import raw_body
   189     |-> (fn ((import_elems, body_elems), ctxt) =>
   190        `(fn thy => Locale.intern thy bname)
   191     #-> (fn name_locale =>
   192           `(fn thy => extract_tyvar_consts thy body_elems)
   193     #-> (fn (v, c_defs) =>
   194           fold_map (add_global_const v) c_defs
   195     #-> (fn c_adds =>
   196           AxClass.add_axclass_i (bname, Sign.defaultS thy)
   197             (map (Thm.no_attributes o pair "") (extract_assumes c_adds (import_elems @ body_elems)))
   198     #-> (fn _ =>
   199           `(fn thy => Sign.intern_class thy bname)
   200     #-> (fn name_axclass =>
   201           fold (add_global_constraint v name_axclass) c_adds
   202     #> add_class_data (name_locale, ([], name_locale, name_axclass, v, map snd c_adds))
   203     #> tap (fn _ => (map o map) (print_ctxt ctxt) import_elems)
   204     #> tap (fn _ => (map o map) (print_ctxt ctxt) body_elems)
   205     #> pair ctxt
   206     ))))))
   207   end;
   208 
   209 in
   210 
   211 val add_class = gen_add_class (Locale.add_locale_context true);
   212 val add_class_i = gen_add_class (Locale.add_locale_context_i true);
   213 
   214 end; (* local *)
   215 
   216 fun gen_instance_arity prep_arity add_defs tap_def raw_arity raw_defs thy = 
   217   let
   218     val dest_def = Theory.dest_def (Sign.pp thy) handle TERM (msg, _) => error msg;
   219     val arity as (tyco, asorts, sort) = prep_arity thy ((fn ((x, y), z) => (x, y, z)) raw_arity);
   220     val ty_inst = Type (tyco, map2 (curry TVar o rpair 0) (Term.invent_names [] "'a" (length asorts)) asorts)
   221     fun get_c_req class =
   222       let
   223         val data = get_class_data thy class;
   224         val subst_ty = map_type_tfree (fn (var as (v, _)) =>
   225           if #var data = v then ty_inst else TFree var)
   226       in (map (apsnd subst_ty) o #consts) data end;
   227     val c_req = (Library.flat o map get_c_req) sort;
   228     fun get_remove_contraint c thy =
   229       let
   230         val ty1 = Sign.the_const_constraint thy c;
   231         val ty2 = Sign.the_const_type thy c;
   232       in
   233         thy
   234         |> Sign.add_const_constraint_i (c, ty2)
   235         |> pair (c, ty1)
   236       end;
   237     fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
   238     fun check_defs c_given c_req thy =
   239       let
   240         fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2
   241           andalso Sign.typ_instance thy (ty1, ty2)
   242           andalso Sign.typ_instance thy (ty2, ty1)
   243         val _ = case fold (remove eq_c) c_given c_req
   244          of [] => ()
   245           | cs => error ("no definition(s) given for"
   246                     ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
   247         val _ = case fold (remove eq_c) c_req c_given
   248          of [] => ()
   249           | cs => error ("superfluous definition(s) given for"
   250                     ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
   251       in thy end;
   252   in
   253     thy
   254     |> fold_map get_remove_contraint (map fst c_req)
   255     ||> tap (fn thy => check_defs (get_c_given thy) c_req)
   256     ||> add_defs (true, raw_defs)
   257     |-> (fn cs => fold Sign.add_const_constraint_i cs)
   258     |> AxClass.instance_arity_i arity
   259   end;
   260 
   261 val add_instance_arity = fn x => gen_instance_arity (AxClass.read_arity) IsarThy.add_defs read_axm x;
   262 val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
   263 
   264 
   265 (* queries *)
   266 
   267 fun is_class thy cls =
   268   lookup_class_data thy cls
   269   |> Option.map (not o null o #consts)
   270   |> the_default false;
   271 
   272 fun syntactic_sort_of thy sort =
   273   let
   274     val classes = Sign.classes_of thy;
   275     fun get_sort cls =
   276       if is_class thy cls
   277       then [cls]
   278       else syntactic_sort_of thy (Sorts.superclasses classes cls);
   279   in
   280     map get_sort sort
   281     |> Library.flat
   282     |> Sorts.norm_sort classes
   283   end;
   284 
   285 fun the_superclasses thy class =
   286   if is_class thy class
   287   then
   288     Sorts.superclasses (Sign.classes_of thy) class
   289     |> syntactic_sort_of thy
   290   else
   291     error ("no syntactic class: " ^ class);
   292 
   293 fun the_consts_sign thy class =
   294   let
   295     val data = (the oo Symtab.lookup) ((fst o ClassData.get) thy) class
   296   in (#var data, #consts data) end;
   297 
   298 fun lookup_const_class thy =
   299   Symtab.lookup ((snd o ClassData.get) thy);
   300 
   301 fun the_instances thy class =
   302   (#insts o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
   303 
   304 fun the_inst_sign thy (class, tyco) =
   305   let
   306     val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
   307     val arity = 
   308       Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]
   309       |> map (syntactic_sort_of thy);
   310     val clsvar = (#var o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
   311     val const_sign = (snd o the_consts_sign thy) class;
   312     fun add_var sort used =
   313       let
   314         val v = hd (Term.invent_names used "'a" 1)
   315       in ((v, sort), v::used) end;
   316     val (vsorts, _) =
   317       []
   318       |> fold (fn (_, ty) => curry (gen_union (op =))
   319            ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign
   320       |> fold_map add_var arity;
   321     val ty_inst = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vsorts);
   322     val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
   323   in (vsorts, inst_signs) end;
   324 
   325 fun get_classtab thy =
   326   Symtab.fold
   327     (fn (class, { consts = consts, insts = insts, ... }) =>
   328       Symtab.update_new (class, (map fst consts, insts)))
   329        ((fst o ClassData.get) thy) Symtab.empty;
   330 
   331 
   332 (* extracting dictionary obligations from types *)
   333 
   334 type sortcontext = (string * sort) list;
   335 
   336 fun extract_sortctxt thy ty =
   337   (typ_tfrees o fst o Type.freeze_thaw_type) ty
   338   |> map (apsnd (syntactic_sort_of thy))
   339   |> filter (not o null o snd);
   340 
   341 datatype sortlookup = Instance of (class * string) * sortlookup list list
   342                     | Lookup of class list * (string * int)
   343 
   344 fun extract_sortlookup thy (c, raw_typ_use) =
   345   let
   346     val raw_typ_def = Sign.the_const_constraint thy c;
   347     val typ_def = Type.varifyT raw_typ_def;
   348     val typ_use = Type.varifyT raw_typ_use;
   349     val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
   350     fun tab_lookup vname = (the o Vartab.lookup match_tab) (vname, 0);
   351     fun get_superclass_derivation (subclasses, superclass) =
   352       (the oo get_first) (fn subclass =>
   353         Sorts.class_le_path (Sign.classes_of thy) (subclass, superclass)
   354       ) subclasses;
   355     fun mk_class_deriv thy subclasses superclass =
   356       case get_superclass_derivation (subclasses, superclass)
   357       of (subclass::deriv) =>
   358         ((rev o filter (is_class thy)) deriv, find_index_eq subclass subclasses);
   359     fun mk_lookup (sort_def, (Type (tycon, tys))) =
   360           let
   361             val arity_lookup = map2 (curry mk_lookup)
   362               (map (syntactic_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort_def)) tys
   363           in map (fn class => Instance ((class, tycon), arity_lookup)) sort_def end
   364       | mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
   365           let
   366             fun mk_look class =
   367               let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
   368               in Lookup (deriv, (vname, classindex)) end;
   369           in map mk_look sort_def end;
   370     fun reorder_sortctxt ctxt =
   371       case lookup_const_class thy c
   372        of NONE => ctxt
   373         | SOME class =>
   374             let
   375               val data = (the o Symtab.lookup ((fst o ClassData.get) thy)) class;
   376               val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
   377               val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
   378               val v : string = case Vartab.lookup match_tab (#var data, 0)
   379                 of SOME (_, TVar ((v, _), _)) => v;
   380             in
   381               (v, (the o AList.lookup (op =) ctxt) v) :: AList.delete (op =) v ctxt
   382             end;
   383   in
   384     extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
   385     |> reorder_sortctxt
   386     |> map (tab_lookup o fst)
   387     |> map (apfst (syntactic_sort_of thy))
   388     |> filter (not o null o fst)
   389     |> map mk_lookup
   390   end;
   391 
   392 
   393 (* intermediate auxiliary *)
   394 
   395 fun add_classentry raw_class raw_cs raw_insts thy =
   396   let
   397     val class = Sign.intern_class thy raw_class;
   398     val cs_proto =
   399       raw_cs
   400       |> map (Sign.intern_const thy)
   401       |> map (fn c => (c, Sign.the_const_constraint thy c));
   402     val used = 
   403       []
   404       |> fold (fn (_, ty) => curry (gen_union (op =))
   405            ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) cs_proto
   406     val v = hd (Term.invent_names used "'a" 1);
   407     val cs =
   408       cs_proto
   409       |> map (fn (c, ty) => (c, map_type_tvar (fn var as ((tvar', _), sort) =>
   410           if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
   411           then TFree (v, [])
   412           else TVar var
   413          ) ty));
   414     val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
   415   in
   416     thy
   417     |> add_class_data (class, ([], "", class, v, cs))
   418     |> fold (curry add_inst_data class) insts
   419   end;
   420 
   421 
   422 (* toplevel interface *)
   423 
   424 local
   425 
   426 structure P = OuterParse
   427 and K = OuterKeyword
   428 
   429 in
   430 
   431 val (classK, instanceK) = ("class", "class_instance")
   432 
   433 val locale_val =
   434   (P.locale_expr --
   435     Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
   436   Scan.repeat1 P.context_element >> pair Locale.empty);
   437 
   438 val classP =
   439   OuterSyntax.command classK "operational type classes" K.thy_decl
   440     (P.name -- Scan.optional (P.$$$ "=" |-- P.!!! locale_val) (Locale.empty, [])
   441       >> (Toplevel.theory_context
   442           o (fn f => swap o f) o (fn (bname, (expr, elems)) => add_class bname expr elems)));
   443 
   444 val instanceP =
   445   OuterSyntax.command instanceK "" K.thy_goal
   446     (P.xname -- (P.$$$ "::" |-- P.!!! P.arity)
   447       -- Scan.repeat1 P.spec_name
   448       >> (Toplevel.theory_to_proof
   449           o (fn ((tyco, (asorts, sort)), defs) => add_instance_arity ((tyco, asorts), sort) defs)));
   450 
   451 val _ = OuterSyntax.add_parsers [classP, instanceP];
   452 
   453 end; (* local *)
   454 
   455 end; (* struct *)