src/Pure/Isar/class.ML
author haftmann
Wed Dec 05 14:15:51 2007 +0100 (2007-12-05)
changeset 25536 01753a944433
parent 25518 00d5cc16e891
child 25574 016f677ad7b8
permissions -rw-r--r--
improved
     1 (*  Title:      Pure/Isar/class.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Type classes derived from primitive axclasses and locales.
     6 *)
     7 
     8 signature CLASS =
     9 sig
    10   (*classes*)
    11   val class: bstring -> class list -> Element.context_i Locale.element list
    12     -> string list -> theory -> string * Proof.context
    13   val class_cmd: bstring -> xstring list -> Element.context Locale.element list
    14     -> xstring list -> theory -> string * Proof.context
    15 
    16   val init: class -> theory -> Proof.context
    17   val logical_const: string -> Markup.property list
    18     -> (string * mixfix) * term -> theory -> theory
    19   val syntactic_const: string -> Syntax.mode -> Markup.property list
    20     -> (string * mixfix) * term -> theory -> theory
    21   val refresh_syntax: class -> Proof.context -> Proof.context
    22 
    23   val intro_classes_tac: thm list -> tactic
    24   val default_intro_classes_tac: thm list -> tactic
    25   val prove_subclass: class * class -> thm list -> Proof.context
    26     -> theory -> theory
    27 
    28   val class_prefix: string -> string
    29   val is_class: theory -> class -> bool
    30   val these_params: theory -> sort -> (string * (string * typ)) list
    31   val print_classes: theory -> unit
    32 
    33   (*instances*)
    34   val init_instantiation: string list * sort list * sort -> theory -> local_theory
    35   val prep_spec: local_theory -> term -> term
    36   val instantiation_instance: (local_theory -> local_theory) -> local_theory -> Proof.state
    37   val prove_instantiation_instance: (Proof.context -> tactic) -> local_theory -> local_theory
    38   val conclude_instantiation: local_theory -> local_theory
    39 
    40   val overloaded_const: string * typ -> theory -> term * theory
    41   val overloaded_def: string -> string * term -> theory -> thm * theory
    42   val instantiation_param: Proof.context -> string -> string option
    43   val confirm_declaration: string -> local_theory -> local_theory
    44 
    45   val unoverload: theory -> thm -> thm
    46   val overload: theory -> thm -> thm
    47   val unoverload_conv: theory -> conv
    48   val overload_conv: theory -> conv
    49   val unoverload_const: theory -> string * typ -> string
    50   val param_of_inst: theory -> string * string -> string
    51   val inst_of_param: theory -> string -> (string * string) option
    52 
    53   (*old axclass layer*)
    54   val axclass_cmd: bstring * xstring list
    55     -> ((bstring * Attrib.src list) * string list) list
    56     -> theory -> class * theory
    57   val classrel_cmd: xstring * xstring -> theory -> Proof.state
    58 
    59   (*old instance layer*)
    60   val instance_arity: (theory -> theory) -> arity -> theory -> Proof.state
    61   val instance_arity_cmd: bstring * xstring list * xstring -> theory -> Proof.state
    62 end;
    63 
    64 structure Class : CLASS =
    65 struct
    66 
    67 (** auxiliary **)
    68 
    69 val classN = "class";
    70 val introN = "intro";
    71 
    72 fun prove_interpretation tac prfx_atts expr inst =
    73   Locale.interpretation_i I prfx_atts expr inst
    74   #> Proof.global_terminal_proof
    75       (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
    76   #> ProofContext.theory_of;
    77 
    78 fun prove_interpretation_in tac after_qed (name, expr) =
    79   Locale.interpretation_in_locale
    80       (ProofContext.theory after_qed) (name, expr)
    81   #> Proof.global_terminal_proof
    82       (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
    83   #> ProofContext.theory_of;
    84 
    85 fun OF_LAST thm1 thm2 = thm1 RSN (Thm.nprems_of thm2, thm2);
    86 
    87 fun strip_all_ofclass thy sort =
    88   let
    89     val typ = TVar ((Name.aT, 0), sort);
    90     fun prem_inclass t =
    91       case Logic.strip_imp_prems t
    92        of ofcls :: _ => try Logic.dest_inclass ofcls
    93         | [] => NONE;
    94     fun strip_ofclass class thm =
    95       thm OF (fst o AxClass.of_sort thy (typ, [class])) AxClass.cache;
    96     fun strip thm = case (prem_inclass o Thm.prop_of) thm
    97      of SOME (_, class) => thm |> strip_ofclass class |> strip
    98       | NONE => thm;
    99   in strip end;
   100 
   101 fun get_remove_global_constraint c thy =
   102   let
   103     val ty = Sign.the_const_constraint thy c;
   104   in
   105     thy
   106     |> Sign.add_const_constraint (c, NONE)
   107     |> pair (c, Logic.unvarifyT ty)
   108   end;
   109 
   110 
   111 (** primitive axclass and instance commands **)
   112 
   113 fun axclass_cmd (class, raw_superclasses) raw_specs thy =
   114   let
   115     val ctxt = ProofContext.init thy;
   116     val superclasses = map (Sign.read_class thy) raw_superclasses;
   117     val name_atts = map ((apsnd o map) (Attrib.attribute thy) o fst)
   118       raw_specs;
   119     val axiomss = ProofContext.read_propp (ctxt, map (map (rpair []) o snd)
   120           raw_specs)
   121       |> snd
   122       |> (map o map) fst;
   123   in
   124     AxClass.define_class (class, superclasses) []
   125       (name_atts ~~ axiomss) thy
   126   end;
   127 
   128 local
   129 
   130 fun gen_instance mk_prop add_thm after_qed insts thy =
   131   let
   132     fun after_qed' results =
   133       ProofContext.theory ((fold o fold) add_thm results #> after_qed);
   134   in
   135     thy
   136     |> ProofContext.init
   137     |> Proof.theorem_i NONE after_qed' ((map (fn t => [(t, [])])
   138         o mk_prop thy) insts)
   139   end;
   140 
   141 in
   142 
   143 val instance_arity =
   144   gen_instance (Logic.mk_arities oo Sign.cert_arity) AxClass.add_arity;
   145 val instance_arity_cmd =
   146   gen_instance (Logic.mk_arities oo Sign.read_arity) AxClass.add_arity I;
   147 val classrel =
   148   gen_instance (single oo (Logic.mk_classrel oo AxClass.cert_classrel)) AxClass.add_classrel I;
   149 val classrel_cmd =
   150   gen_instance (single oo (Logic.mk_classrel oo AxClass.read_classrel)) AxClass.add_classrel I;
   151 
   152 end; (*local*)
   153 
   154 
   155 (** basic overloading **)
   156 
   157 (* bookkeeping *)
   158 
   159 structure InstData = TheoryDataFun
   160 (
   161   type T = (string * thm) Symtab.table Symtab.table * (string * string) Symtab.table;
   162     (*constant name ~> type constructor ~> (constant name, equation),
   163         constant name ~> (constant name, type constructor)*)
   164   val empty = (Symtab.empty, Symtab.empty);
   165   val copy = I;
   166   val extend = I;
   167   fun merge _ ((taba1, tabb1), (taba2, tabb2)) =
   168     (Symtab.join (K (Symtab.merge (K true))) (taba1, taba2),
   169       Symtab.merge (K true) (tabb1, tabb2));
   170 );
   171 
   172 val inst_tyco = Option.map fst o try (dest_Type o the_single) oo Sign.const_typargs;
   173 
   174 fun inst thy (c, tyco) =
   175   (the o Symtab.lookup ((the o Symtab.lookup (fst (InstData.get thy))) c)) tyco;
   176 
   177 val param_of_inst = fst oo inst;
   178 
   179 fun inst_thms thy = (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst)
   180   (InstData.get thy) [];
   181 
   182 val inst_of_param = Symtab.lookup o snd o InstData.get;
   183 
   184 fun add_inst (c, tyco) inst = (InstData.map o apfst
   185       o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
   186   #> (InstData.map o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
   187 
   188 fun unoverload thy = MetaSimplifier.simplify true (inst_thms thy);
   189 fun overload thy = MetaSimplifier.simplify true (map Thm.symmetric (inst_thms thy));
   190 
   191 fun unoverload_conv thy = MetaSimplifier.rewrite true (inst_thms thy);
   192 fun overload_conv thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
   193 
   194 fun unoverload_const thy (c_ty as (c, _)) =
   195   case AxClass.class_of_param thy c
   196    of SOME class => (case inst_tyco thy c_ty
   197        of SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
   198         | NONE => c)
   199     | NONE => c;
   200 
   201 
   202 (* declaration and definition of instances of overloaded constants *)
   203 
   204 fun primitive_note kind (name, thm) =
   205   PureThy.note_thmss_i kind [((name, []), [([thm], [])])]
   206   #>> (fn [(_, [thm])] => thm);
   207 
   208 fun overloaded_const (c, ty) thy =
   209   let
   210     val SOME class = AxClass.class_of_param thy c;
   211     val SOME tyco = inst_tyco thy (c, ty);
   212     val name_inst = AxClass.instance_name (tyco, class) ^ "_inst";
   213     val c' = NameSpace.base c ^ "_" ^ NameSpace.base tyco;
   214     val ty' = Type.strip_sorts ty;
   215   in
   216     thy
   217     |> Sign.sticky_prefix name_inst
   218     |> Sign.no_base_names
   219     |> Sign.declare_const [] (c', ty', NoSyn)
   220     |-> (fn const' as Const (c'', _) => Thm.add_def false true
   221           (Thm.def_name c', Logic.mk_equals (Const (c, ty'), const'))
   222     #>> Thm.varifyT
   223     #-> (fn thm => add_inst (c, tyco) (c'', thm)
   224     #> primitive_note Thm.internalK (c', thm)
   225     #> snd
   226     #> Sign.restore_naming thy
   227     #> pair (Const (c, ty))))
   228   end;
   229 
   230 fun overloaded_def name (c, t) thy =
   231   let
   232     val ty = Term.fastype_of t;
   233     val SOME tyco = inst_tyco thy (c, ty);
   234     val (c', eq) = inst thy (c, tyco);
   235     val prop = Logic.mk_equals (Const (c', ty), t);
   236     val name' = Thm.def_name_optional
   237       (NameSpace.base c ^ "_" ^ NameSpace.base tyco) name;
   238   in
   239     thy
   240     |> Thm.add_def false false (name', prop)
   241     |>> (fn thm =>  Drule.transitive_thm OF [eq, thm])
   242   end;
   243 
   244 
   245 (** class data **)
   246 
   247 datatype class_data = ClassData of {
   248   consts: (string * string) list
   249     (*locale parameter ~> constant name*),
   250   base_sort: sort,
   251   inst: term option list
   252     (*canonical interpretation*),
   253   morphism: morphism,
   254     (*partial morphism of canonical interpretation*)
   255   intro: thm,
   256   defs: thm list,
   257   operations: (string * (class * (typ * term))) list
   258 };
   259 
   260 fun rep_class_data (ClassData d) = d;
   261 fun mk_class_data ((consts, base_sort, inst, morphism, intro),
   262     (defs, operations)) =
   263   ClassData { consts = consts, base_sort = base_sort, inst = inst,
   264     morphism = morphism, intro = intro, defs = defs,
   265     operations = operations };
   266 fun map_class_data f (ClassData { consts, base_sort, inst, morphism, intro,
   267     defs, operations }) =
   268   mk_class_data (f ((consts, base_sort, inst, morphism, intro),
   269     (defs, operations)));
   270 fun merge_class_data _ (ClassData { consts = consts,
   271     base_sort = base_sort, inst = inst, morphism = morphism, intro = intro,
   272     defs = defs1, operations = operations1 },
   273   ClassData { consts = _, base_sort = _, inst = _, morphism = _, intro = _,
   274     defs = defs2, operations = operations2 }) =
   275   mk_class_data ((consts, base_sort, inst, morphism, intro),
   276     (Thm.merge_thms (defs1, defs2),
   277       AList.merge (op =) (K true) (operations1, operations2)));
   278 
   279 structure ClassData = TheoryDataFun
   280 (
   281   type T = class_data Graph.T
   282   val empty = Graph.empty;
   283   val copy = I;
   284   val extend = I;
   285   fun merge _ = Graph.join merge_class_data;
   286 );
   287 
   288 
   289 (* queries *)
   290 
   291 val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node o ClassData.get;
   292 
   293 fun the_class_data thy class = case lookup_class_data thy class
   294  of NONE => error ("Undeclared class " ^ quote class)
   295   | SOME data => data;
   296 
   297 val is_class = is_some oo lookup_class_data;
   298 
   299 val ancestry = Graph.all_succs o ClassData.get;
   300 
   301 fun these_params thy =
   302   let
   303     fun params class =
   304       let
   305         val const_typs = (#params o AxClass.get_info thy) class;
   306         val const_names = (#consts o the_class_data thy) class;
   307       in
   308         (map o apsnd) (fn c => (c, (the o AList.lookup (op =) const_typs) c)) const_names
   309       end;
   310   in maps params o ancestry thy end;
   311 
   312 fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
   313 
   314 fun morphism thy = #morphism o the_class_data thy;
   315 
   316 fun these_intros thy =
   317   Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
   318     (ClassData.get thy) [];
   319 
   320 fun these_operations thy =
   321   maps (#operations o the_class_data thy) o ancestry thy;
   322 
   323 fun print_classes thy =
   324   let
   325     val ctxt = ProofContext.init thy;
   326     val algebra = Sign.classes_of thy;
   327     val arities =
   328       Symtab.empty
   329       |> Symtab.fold (fn (tyco, arities) => fold (fn (class, _) =>
   330            Symtab.map_default (class, []) (insert (op =) tyco)) arities)
   331              ((#arities o Sorts.rep_algebra) algebra);
   332     val the_arities = these o Symtab.lookup arities;
   333     fun mk_arity class tyco =
   334       let
   335         val Ss = Sorts.mg_domain algebra tyco [class];
   336       in Syntax.pretty_arity ctxt (tyco, Ss, [class]) end;
   337     fun mk_param (c, ty) = Pretty.str (Sign.extern_const thy c ^ " :: "
   338       ^ setmp show_sorts false (Syntax.string_of_typ ctxt o Type.strip_sorts) ty);
   339     fun mk_entry class = (Pretty.block o Pretty.fbreaks o map_filter I) [
   340       (SOME o Pretty.str) ("class " ^ Sign.extern_class thy class ^ ":"),
   341       (SOME o Pretty.block) [Pretty.str "supersort: ",
   342         (Syntax.pretty_sort ctxt o Sign.minimize_sort thy o Sign.super_classes thy) class],
   343       if is_class thy class then (SOME o Pretty.str)
   344         ("locale: " ^ Locale.extern thy class) else NONE,
   345       ((fn [] => NONE | ps => (SOME o Pretty.block o Pretty.fbreaks)
   346           (Pretty.str "parameters:" :: ps)) o map mk_param
   347         o these o Option.map #params o try (AxClass.get_info thy)) class,
   348       (SOME o Pretty.block o Pretty.breaks) [
   349         Pretty.str "instances:",
   350         Pretty.list "" "" (map (mk_arity class) (the_arities class))
   351       ]
   352     ]
   353   in
   354     (Pretty.writeln o Pretty.chunks o separate (Pretty.str "")
   355       o map mk_entry o Sorts.all_classes) algebra
   356   end;
   357 
   358 
   359 (* updaters *)
   360 
   361 fun add_class_data ((class, superclasses), (cs, base_sort, inst, phi, intro)) thy =
   362   let
   363     val operations = map (fn (v_ty as (_, ty), (c, _)) =>
   364       (c, (class, (ty, Free v_ty)))) cs;
   365     val cs = (map o pairself) fst cs;
   366     val add_class = Graph.new_node (class,
   367         mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], operations)))
   368       #> fold (curry Graph.add_edge class) superclasses;
   369   in
   370     ClassData.map add_class thy
   371   end;
   372 
   373 fun register_operation class (c, (t, some_def)) thy =
   374   let
   375     val base_sort = (#base_sort o the_class_data thy) class;
   376     val prep_typ = map_atyps
   377       (fn TVar (vi as (v, _), sort) => if Name.aT = v
   378         then TFree (v, base_sort) else TVar (vi, sort));
   379     val t' = map_types prep_typ t;
   380     val ty' = Term.fastype_of t';
   381   in
   382     thy
   383     |> (ClassData.map o Graph.map_node class o map_class_data o apsnd)
   384       (fn (defs, operations) =>
   385         (fold cons (the_list some_def) defs,
   386           (c, (class, (ty', t'))) :: operations))
   387   end;
   388 
   389 
   390 (** rule calculation, tactics and methods **)
   391 
   392 val class_prefix = Logic.const_of_class o Sign.base_name;
   393 
   394 fun calculate_morphism class cs =
   395   let
   396     val subst_typ = Term.map_type_tfree (fn var as (v, sort) =>
   397       if v = Name.aT then TVar ((v, 0), [class]) else TVar ((v, 0), sort));
   398     fun subst_aterm (t as Free (v, ty)) = (case AList.lookup (op =) cs v
   399          of SOME (c, _) => Const (c, ty)
   400           | NONE => t)
   401       | subst_aterm t = t;
   402     val subst_term = map_aterms subst_aterm #> map_types subst_typ;
   403   in
   404     Morphism.term_morphism subst_term
   405     $> Morphism.typ_morphism subst_typ
   406   end;
   407 
   408 fun class_intro thy class sups =
   409   let
   410     fun class_elim class =
   411       case (#axioms o AxClass.get_info thy) class
   412        of [thm] => SOME (Drule.unconstrainTs thm)
   413         | [] => NONE;
   414     val pred_intro = case Locale.intros thy class
   415      of ([ax_intro], [intro]) => intro |> OF_LAST ax_intro |> SOME
   416       | ([intro], []) => SOME intro
   417       | ([], [intro]) => SOME intro
   418       | _ => NONE;
   419     val pred_intro' = pred_intro
   420       |> Option.map (fn intro => intro OF map_filter class_elim sups);
   421     val class_intro = (#intro o AxClass.get_info thy) class;
   422     val raw_intro = case pred_intro'
   423      of SOME pred_intro => class_intro |> OF_LAST pred_intro
   424       | NONE => class_intro;
   425     val sort = Sign.super_classes thy class;
   426     val typ = TVar ((Name.aT, 0), sort);
   427     val defs = these_defs thy sups;
   428   in
   429     raw_intro
   430     |> Drule.instantiate' [SOME (Thm.ctyp_of thy typ)] []
   431     |> strip_all_ofclass thy sort
   432     |> Thm.strip_shyps
   433     |> MetaSimplifier.rewrite_rule defs
   434     |> Drule.unconstrainTs
   435   end;
   436 
   437 fun class_interpretation class facts defs thy =
   438   let
   439     val params = these_params thy [class];
   440     val inst = (#inst o the_class_data thy) class;
   441     val tac = ALLGOALS (ProofContext.fact_tac facts);
   442     val prfx = class_prefix class;
   443   in
   444     thy
   445     |> fold_map (get_remove_global_constraint o fst o snd) params
   446     ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
   447           (inst, map (fn def => (("", []), def)) defs)
   448     |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
   449   end;
   450 
   451 fun intro_classes_tac facts st =
   452   let
   453     val thy = Thm.theory_of_thm st;
   454     val classes = Sign.all_classes thy;
   455     val class_trivs = map (Thm.class_triv thy) classes;
   456     val class_intros = these_intros thy;
   457     val axclass_intros = map_filter (try (#intro o AxClass.get_info thy)) classes;
   458   in
   459     Method.intros_tac (class_trivs @ class_intros @ axclass_intros) facts st
   460   end;
   461 
   462 fun default_intro_classes_tac [] = intro_classes_tac []
   463   | default_intro_classes_tac _ = no_tac;
   464 
   465 fun default_tac rules ctxt facts =
   466   HEADGOAL (Method.some_rule_tac rules ctxt facts) ORELSE
   467     default_intro_classes_tac facts;
   468 
   469 val _ = Context.add_setup (Method.add_methods
   470  [("intro_classes", Method.no_args (Method.METHOD intro_classes_tac),
   471     "back-chain introduction rules of classes"),
   472   ("default", Method.thms_ctxt_args (Method.METHOD oo default_tac),
   473     "apply some intro/elim rule")]);
   474 
   475 fun subclass_rule thy (sub, sup) =
   476   let
   477     val ctxt = Locale.init sub thy;
   478     val ctxt_thy = ProofContext.init thy;
   479     val props =
   480       Locale.global_asms_of thy sup
   481       |> maps snd
   482       |> map (ObjectLogic.ensure_propT thy);
   483     fun tac { prems, context } =
   484       Locale.intro_locales_tac true context prems
   485         ORELSE ALLGOALS assume_tac;
   486   in
   487     Goal.prove_multi ctxt [] [] props tac
   488     |> map (Assumption.export false ctxt ctxt_thy)
   489     |> Variable.export ctxt ctxt_thy
   490   end;
   491 
   492 fun prove_single_subclass (sub, sup) thms ctxt thy =
   493   let
   494     val ctxt_thy = ProofContext.init thy;
   495     val subclass_rule = Conjunction.intr_balanced thms
   496       |> Assumption.export false ctxt ctxt_thy
   497       |> singleton (Variable.export ctxt ctxt_thy);
   498     val sub_inst = Thm.ctyp_of thy (TVar ((Name.aT, 0), [sub]));
   499     val sub_ax = #axioms (AxClass.get_info thy sub);
   500     val classrel =
   501       #intro (AxClass.get_info thy sup)
   502       |> Drule.instantiate' [SOME sub_inst] []
   503       |> OF_LAST (subclass_rule OF sub_ax)
   504       |> strip_all_ofclass thy (Sign.super_classes thy sup)
   505       |> Thm.strip_shyps
   506   in
   507     thy
   508     |> AxClass.add_classrel classrel
   509     |> prove_interpretation_in (ALLGOALS (ProofContext.fact_tac thms))
   510          I (sub, Locale.Locale sup)
   511     |> ClassData.map (Graph.add_edge (sub, sup))
   512   end;
   513 
   514 fun prove_subclass (sub, sup) thms ctxt thy =
   515   let
   516     val classes = ClassData.get thy;
   517     val is_sup = not o null o curry (Graph.irreducible_paths classes) sub;
   518     val supclasses = Graph.all_succs classes [sup] |> filter_out is_sup;
   519     fun transform sup' = subclass_rule thy (sup, sup') |> map (fn thm => thm OF thms);
   520   in
   521     thy
   522     |> fold_rev (fn sup' => prove_single_subclass (sub, sup')
   523          (transform sup') ctxt) supclasses
   524  end;
   525 
   526 
   527 (** classes and class target **)
   528 
   529 (* class context syntax *)
   530 
   531 structure ClassSyntax = ProofDataFun(
   532   type T = {
   533     local_constraints: (string * typ) list,
   534     global_constraints: (string * typ) list,
   535     base_sort: sort,
   536     operations: (string * (typ * term)) list,
   537     unchecks: (term * term) list,
   538     passed: bool
   539   };
   540   fun init _ = {
   541     local_constraints = [],
   542     global_constraints = [],
   543     base_sort = [],
   544     operations = [],
   545     unchecks = [],
   546     passed = true
   547   };;
   548 );
   549 
   550 fun synchronize_syntax sups base_sort ctxt =
   551   let
   552     val thy = ProofContext.theory_of ctxt;
   553     fun subst_class_typ sort = map_atyps
   554       (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
   555     val operations = these_operations thy sups;
   556     val local_constraints =
   557       (map o apsnd) (subst_class_typ base_sort o fst o snd) operations;
   558     val global_constraints =
   559       (map o apsnd) (fn (class, (ty, _)) => subst_class_typ [class] ty) operations;
   560     fun declare_const (c, _) =
   561       let val b = Sign.base_name c
   562       in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
   563     val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
   564   in
   565     ctxt
   566     |> fold declare_const local_constraints
   567     |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
   568     |> ClassSyntax.put {
   569         local_constraints = local_constraints,
   570         global_constraints = global_constraints,
   571         base_sort = base_sort,
   572         operations = (map o apsnd) snd operations,
   573         unchecks = unchecks,
   574         passed = false
   575       }
   576   end;
   577 
   578 fun refresh_syntax class ctxt =
   579   let
   580     val thy = ProofContext.theory_of ctxt;
   581     val base_sort = (#base_sort o the_class_data thy) class;
   582   in synchronize_syntax [class] base_sort ctxt end;
   583 
   584 val mark_passed = ClassSyntax.map
   585   (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } =>
   586     { local_constraints = local_constraints, global_constraints = global_constraints,
   587       base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true });
   588 
   589 fun sort_term_check ts ctxt =
   590   let
   591     val { local_constraints, global_constraints, base_sort, operations, passed, ... } =
   592       ClassSyntax.get ctxt;
   593     fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c
   594          of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty
   595              of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
   596                  of SOME (_, TVar (tvar as (vi, _))) =>
   597                       if TypeInfer.is_param vi then cons tvar else I
   598                   | _ => I)
   599               | NONE => I)
   600           | NONE => I)
   601       | check_improve _ = I;
   602     val improvements = (fold o fold_aterms) check_improve ts [];
   603     val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar =>
   604         if member (op =) improvements tvar
   605           then TFree (Name.aT, base_sort) else ty | ty => ty) ts;
   606     fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c
   607          of SOME (ty0, t) =>
   608               if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0)
   609               then SOME (ty0, check t) else NONE
   610           | NONE => NONE)
   611       | _ => NONE) t0;
   612     val ts'' = map check ts';
   613   in if eq_list (op aconv) (ts, ts'') andalso passed then NONE
   614   else
   615     ctxt
   616     |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
   617     |> mark_passed
   618     |> pair ts''
   619     |> SOME
   620   end;
   621 
   622 fun sort_term_uncheck ts ctxt =
   623   let
   624     val thy = ProofContext.theory_of ctxt;
   625     val unchecks = (#unchecks o ClassSyntax.get) ctxt;
   626     val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
   627   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
   628 
   629 fun init_ctxt sups base_sort ctxt =
   630   ctxt
   631   |> Variable.declare_term
   632       (Logic.mk_type (TFree (Name.aT, base_sort)))
   633   |> synchronize_syntax sups base_sort
   634   |> Context.proof_map (
   635       Syntax.add_term_check 0 "class" sort_term_check
   636       #> Syntax.add_term_uncheck 0 "class" sort_term_uncheck)
   637 
   638 fun init class thy =
   639   thy
   640   |> Locale.init class
   641   |> init_ctxt [class] ((#base_sort o the_class_data thy) class);
   642 
   643 
   644 (* class definition *)
   645 
   646 local
   647 
   648 fun gen_class_spec prep_class prep_expr process_expr thy raw_supclasses raw_includes_elems =
   649   let
   650     val supclasses = map (prep_class thy) raw_supclasses;
   651     val sups = filter (is_class thy) supclasses;
   652     fun the_base_sort class = lookup_class_data thy class
   653       |> Option.map #base_sort
   654       |> the_default [class];
   655     val base_sort = Sign.minimize_sort thy (maps the_base_sort supclasses);
   656     val supsort = Sign.minimize_sort thy supclasses;
   657     val suplocales = map Locale.Locale sups;
   658     val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
   659       | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
   660     val supexpr = Locale.Merge suplocales;
   661     val supparams = (map fst o Locale.parameters_of_expr thy) supexpr;
   662     val supconsts = AList.make (the o AList.lookup (op =) (these_params thy sups))
   663       (map fst supparams);
   664     val mergeexpr = Locale.Merge (suplocales @ includes);
   665     val constrain = Element.Constrains ((map o apsnd o map_atyps)
   666       (fn TFree (_, sort) => TFree (Name.aT, sort)) supparams);
   667   in
   668     ProofContext.init thy
   669     |> Locale.cert_expr supexpr [constrain]
   670     |> snd
   671     |> init_ctxt sups base_sort
   672     |> process_expr Locale.empty raw_elems
   673     |> fst
   674     |> (fn elems => ((((sups, supconsts), (supsort, base_sort, mergeexpr)),
   675           (*FIXME*) if null includes then constrain :: elems else elems)))
   676   end;
   677 
   678 val read_class_spec = gen_class_spec Sign.intern_class Locale.intern_expr Locale.read_expr;
   679 val check_class_spec = gen_class_spec (K I) (K I) Locale.cert_expr;
   680 
   681 fun define_class_params (name, raw_superclasses) raw_consts raw_dep_axioms other_consts thy =
   682   let
   683     val superclasses = map (Sign.certify_class thy) raw_superclasses;
   684     val consts = (map o apfst o apsnd) (Sign.certify_typ thy) raw_consts;
   685     fun add_const ((c, ty), syn) =
   686       Sign.declare_const [] (c, Type.strip_sorts ty, syn) #>> Term.dest_Const;
   687     fun mk_axioms cs thy =
   688       raw_dep_axioms thy cs
   689       |> (map o apsnd o map) (Sign.cert_prop thy)
   690       |> rpair thy;
   691     fun constrain_typs class = (map o apsnd o Term.map_type_tfree)
   692       (fn (v, _) => TFree (v, [class]))
   693   in
   694     thy
   695     |> Sign.add_path (Logic.const_of_class name)
   696     |> fold_map add_const consts
   697     ||> Sign.restore_naming thy
   698     |-> (fn cs => mk_axioms cs
   699     #-> (fn axioms_prop => AxClass.define_class (name, superclasses)
   700            (map fst cs @ other_consts) axioms_prop
   701     #-> (fn class => `(fn _ => constrain_typs class cs)
   702     #-> (fn cs' => `(fn thy => AxClass.get_info thy class)
   703     #-> (fn {axioms, ...} => fold (Sign.add_const_constraint o apsnd SOME) cs'
   704     #> pair (class, (cs', axioms)))))))
   705   end;
   706 
   707 fun gen_class prep_spec prep_param bname
   708     raw_supclasses raw_includes_elems raw_other_consts thy =
   709   let
   710     val class = Sign.full_name thy bname;
   711     val (((sups, supconsts), (supsort, base_sort, mergeexpr)), elems_syn) =
   712       prep_spec thy raw_supclasses raw_includes_elems;
   713     val other_consts = map (tap (Sign.the_const_type thy) o prep_param thy) raw_other_consts;
   714     fun mk_inst class cs =
   715       (map o apsnd o Term.map_type_tfree) (fn (v, _) => TFree (v, [class])) cs;
   716     fun fork_syntax (Element.Fixes xs) =
   717           fold_map (fn (c, ty, syn) => cons (c, syn) #> pair (c, ty, NoSyn)) xs
   718           #>> Element.Fixes
   719       | fork_syntax x = pair x;
   720     val (elems, global_syn) = fold_map fork_syntax elems_syn [];
   721     fun globalize (c, ty) =
   722       ((c, Term.map_type_tfree (K (TFree (Name.aT, base_sort))) ty),
   723         (the_default NoSyn o AList.lookup (op =) global_syn) c);
   724     fun extract_params thy =
   725       let
   726         val params = map fst (Locale.parameters_of thy class);
   727       in
   728         (params, (map globalize o snd o chop (length supconsts)) params)
   729       end;
   730     fun extract_assumes params thy cs =
   731       let
   732         val consts = supconsts @ (map (fst o fst) params ~~ cs);
   733         fun subst (Free (c, ty)) =
   734               Const ((fst o the o AList.lookup (op =) consts) c, ty)
   735           | subst t = t;
   736         fun prep_asm ((name, atts), ts) =
   737           ((Sign.base_name name, map (Attrib.attribute_i thy) atts),
   738             (map o map_aterms) subst ts);
   739       in
   740         Locale.global_asms_of thy class
   741         |> map prep_asm
   742       end;
   743   in
   744     thy
   745     |> Locale.add_locale_i (SOME "") bname mergeexpr elems
   746     |> snd
   747     |> ProofContext.theory_of
   748     |> `extract_params
   749     |-> (fn (all_params, params) =>
   750         define_class_params (bname, supsort) params
   751           (extract_assumes params) other_consts
   752       #-> (fn (_, (consts, axioms)) =>
   753         `(fn thy => class_intro thy class sups)
   754       #-> (fn class_intro =>
   755         PureThy.note_thmss_qualified "" (NameSpace.append class classN)
   756           [((introN, []), [([class_intro], [])])]
   757       #-> (fn [(_, [class_intro])] =>
   758         add_class_data ((class, sups),
   759           (map fst params ~~ consts, base_sort,
   760             mk_inst class (map snd supconsts @ consts),
   761               calculate_morphism class (supconsts @ (map (fst o fst) params ~~ consts)), class_intro))
   762       #> class_interpretation class axioms []
   763       ))))
   764     |> init class
   765     |> pair class
   766   end;
   767 
   768 fun read_const thy = #1 o Term.dest_Const o ProofContext.read_const (ProofContext.init thy);
   769 
   770 in
   771 
   772 val class_cmd = gen_class read_class_spec read_const;
   773 val class = gen_class check_class_spec (K I);
   774 
   775 end; (*local*)
   776 
   777 
   778 (* class target *)
   779 
   780 fun logical_const class pos ((c, mx), dict) thy =
   781   let
   782     val prfx = class_prefix class;
   783     val thy' = thy |> Sign.add_path prfx;
   784     val phi = morphism thy' class;
   785 
   786     val c' = Sign.full_name thy' c;
   787     val dict' = Morphism.term phi dict;
   788     val dict_def = map_types Logic.unvarifyT dict';
   789     val ty' = Term.fastype_of dict_def;
   790     val ty'' = Type.strip_sorts ty';
   791     val def_eq = Logic.mk_equals (Const (c', ty'), dict_def);
   792   in
   793     thy'
   794     |> Sign.declare_const pos (c, ty'', mx) |> snd
   795     |> Thm.add_def false false (c, def_eq)
   796     |>> Thm.symmetric
   797     |-> (fn def => class_interpretation class [def] [Thm.prop_of def]
   798           #> register_operation class (c', (dict', SOME (Thm.varifyT def))))
   799     |> Sign.restore_naming thy
   800     |> Sign.add_const_constraint (c', SOME ty')
   801   end;
   802 
   803 fun syntactic_const class prmode pos ((c, mx), rhs) thy =
   804   let
   805     val prfx = class_prefix class;
   806     val thy' = thy |> Sign.add_path prfx;
   807     val phi = morphism thy class;
   808 
   809     val c' = Sign.full_name thy' c;
   810     val rews = map (Logic.dest_equals o Thm.prop_of) (these_defs thy' [class])
   811     val rhs' = (Pattern.rewrite_term thy rews [] o Morphism.term phi) rhs;
   812     val ty' = Logic.unvarifyT (Term.fastype_of rhs');
   813   in
   814     thy'
   815     |> Sign.add_abbrev (#1 prmode) pos (c, map_types Type.strip_sorts rhs') |> snd
   816     |> Sign.add_const_constraint (c', SOME ty')
   817     |> Sign.notation true prmode [(Const (c', ty'), mx)]
   818     |> register_operation class (c', (rhs', NONE))
   819     |> Sign.restore_naming thy
   820   end;
   821 
   822 
   823 (** instantiation target **)
   824 
   825 (* bookkeeping *)
   826 
   827 datatype instantiation = Instantiation of {
   828   arities: string list * sort list * sort,
   829   params: ((string * string) * (string * typ)) list
   830 }
   831 
   832 structure Instantiation = ProofDataFun
   833 (
   834   type T = instantiation
   835   fun init _ = Instantiation { arities = ([], [], []), params = [] };
   836 );
   837 
   838 fun mk_instantiation (arities, params) =
   839   Instantiation { arities = arities, params = params };
   840 fun get_instantiation lthy = case Instantiation.get (LocalTheory.target_of lthy)
   841  of Instantiation data => data;
   842 fun map_instantiation f = (LocalTheory.target o Instantiation.map)
   843   (fn Instantiation { arities, params } => mk_instantiation (f (arities, params)));
   844 
   845 fun the_instantiation lthy = case get_instantiation lthy
   846  of { arities = ([], [], []), ... } => error "No instantiation target"
   847   | data => data;
   848 
   849 val instantiation_params = #params o get_instantiation;
   850 
   851 fun instantiation_param lthy v = instantiation_params lthy
   852   |> find_first (fn (_, (v', _)) => v = v')
   853   |> Option.map (fst o fst);
   854 
   855 fun confirm_declaration c = (map_instantiation o apsnd)
   856   (filter_out (fn (_, (c', _)) => c' = c));
   857 
   858 
   859 (* syntax *)
   860 
   861 fun subst_param thy params = map_aterms (fn t as Const (c, ty) => (case inst_tyco thy (c, ty)
   862      of SOME tyco => (case AList.lookup (op =) params (c, tyco)
   863          of SOME v_ty => Free v_ty
   864           | NONE => t)
   865       | NONE => t)
   866   | t => t);
   867 
   868 fun prep_spec lthy =
   869   let
   870     val thy = ProofContext.theory_of lthy;
   871     val params = instantiation_params lthy;
   872   in subst_param thy params end;
   873 
   874 fun inst_term_check ts lthy =
   875   let
   876     val params = instantiation_params lthy;
   877     val tsig = ProofContext.tsig_of lthy;
   878     val thy = ProofContext.theory_of lthy;
   879 
   880     fun check_improve (Const (c, ty)) = (case inst_tyco thy (c, ty)
   881          of SOME tyco => (case AList.lookup (op =) params (c, tyco)
   882              of SOME (_, ty') => perhaps (try (Type.typ_match tsig (ty, ty')))
   883               | NONE => I)
   884           | NONE => I)
   885       | check_improve _ = I;
   886     val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
   887     val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
   888     val ts'' = map (subst_param thy params) ts';
   889   in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end;
   890 
   891 fun inst_term_uncheck ts lthy =
   892   let
   893     val params = instantiation_params lthy;
   894     val ts' = (map o map_aterms) (fn t as Free (v, ty) =>
   895        (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params
   896          of SOME c => Const (c, ty)
   897           | NONE => t)
   898       | t => t) ts;
   899   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
   900 
   901 
   902 (* target *)
   903 
   904 val sanatize_name = (*FIXME*)
   905   let
   906     fun is_valid s = Symbol.is_ascii_letter s orelse Symbol.is_ascii_digit s orelse s = "'";
   907     val is_junk = not o is_valid andf Symbol.is_regular;
   908     val junk = Scan.many is_junk;
   909     val scan_valids = Symbol.scanner "Malformed input"
   910       ((junk |--
   911         (Scan.optional (Scan.one Symbol.is_ascii_letter) "x" ^^ (Scan.many is_valid >> implode)
   912         --| junk))
   913       -- Scan.repeat ((Scan.many1 is_valid >> implode) --| junk) >> op ::);
   914   in
   915     explode #> scan_valids #> implode
   916   end;
   917 
   918 fun init_instantiation (tycos, sorts, sort) thy =
   919   let
   920     val _ = if null tycos then error "At least one arity must be given" else ();
   921     val _ = map (the_class_data thy) sort;
   922     val vs = map TFree (Name.names Name.context Name.aT sorts);
   923     fun type_name "*" = "prod"
   924       | type_name "+" = "sum"
   925       | type_name s = sanatize_name (NameSpace.base s); (*FIXME*)
   926     fun get_param tyco (param, (c, ty)) = if can (inst thy) (c, tyco)
   927       then NONE else SOME ((unoverload_const thy (c, ty), tyco),
   928         (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, vs))) ty));
   929     val params = map_product get_param tycos (these_params thy sort) |> map_filter I;
   930   in
   931     thy
   932     |> ProofContext.init
   933     |> Instantiation.put (mk_instantiation ((tycos, sorts, sort), params))
   934     |> fold (Variable.declare_term o Logic.mk_type) vs
   935     |> fold (fn tyco => ProofContext.add_arity (tyco, sorts, sort)) tycos
   936     |> Context.proof_map (
   937         Syntax.add_term_check 0 "instance" inst_term_check
   938         #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck)
   939   end;
   940 
   941 fun gen_instantiation_instance do_proof after_qed lthy =
   942   let
   943     val (tycos, sorts, sort) = (#arities o the_instantiation) lthy;
   944     val arities_proof = maps (fn tyco => Logic.mk_arities (tyco, sorts, sort)) tycos;
   945     fun after_qed' results =
   946       LocalTheory.theory (fold (AxClass.add_arity o Thm.varifyT) results)
   947       #> after_qed;
   948   in
   949     lthy
   950     |> do_proof after_qed' arities_proof
   951   end;
   952 
   953 val instantiation_instance = gen_instantiation_instance (fn after_qed => fn ts =>
   954   Proof.theorem_i NONE (after_qed o map the_single) (map (fn t => [(t, [])]) ts));
   955 
   956 fun prove_instantiation_instance tac = gen_instantiation_instance (fn after_qed =>
   957   fn ts => fn lthy => after_qed (map (fn t => Goal.prove lthy [] [] t
   958     (fn {context, ...} => tac context)) ts) lthy) I;
   959 
   960 fun conclude_instantiation lthy =
   961   let
   962     val { arities, params } = the_instantiation lthy;
   963     val (tycos, sorts, sort) = arities;
   964     val thy = ProofContext.theory_of lthy;
   965     (*val _ = map (fn (tyco, sorts, sort) =>
   966       if Sign.of_sort thy
   967         (Type (tyco, map TFree (Name.names Name.context Name.aT sorts)), sort)
   968       then () else error ("Missing instance proof for type " ^ quote (Sign.extern_type thy tyco)))
   969         arities; FIXME activate when old instance command is gone*)
   970     val params_of = maps (these o try (#params o AxClass.get_info thy))
   971       o Sign.complete_sort thy;
   972     val missing_params = tycos
   973       |> maps (fn tyco => params_of sort |> map (rpair tyco))
   974       |> filter_out (can (inst thy) o apfst fst);
   975     fun declare_missing ((c, ty0), tyco) thy =
   976     (*fun declare_missing ((c, tyco), (_, ty)) thy =*)
   977       let
   978         val SOME class = AxClass.class_of_param thy c;
   979         val name_inst = AxClass.instance_name (tyco, class) ^ "_inst";
   980         val c' = NameSpace.base c ^ "_" ^ NameSpace.base tyco;
   981         val vs = Name.names Name.context Name.aT (replicate (Sign.arity_number thy tyco) []);
   982         val ty = map_atyps (fn _ => Type (tyco, map TFree vs)) ty0;
   983       in
   984         thy
   985         |> Sign.sticky_prefix name_inst
   986         |> Sign.no_base_names
   987         |> Sign.declare_const [] (c', ty, NoSyn)
   988         |-> (fn const' as Const (c'', _) => Thm.add_def false true
   989               (Thm.def_name c', Logic.mk_equals (const', Const (c, ty)))
   990         #>> Thm.varifyT
   991         #-> (fn thm => add_inst (c, tyco) (c'', Thm.symmetric thm)
   992         #> primitive_note Thm.internalK (c', thm)
   993         #> snd
   994         #> Sign.restore_naming thy))
   995       end;
   996   in
   997     lthy
   998     |> LocalTheory.theory (fold declare_missing missing_params)
   999   end;
  1000 
  1001 end;