rudimentary instantiation target
authorhaftmann
Fri, 23 Nov 2007 21:09:35 +0100
changeset 25462 dad0291cb76a
parent 25461 001dfba51869
child 25463 8b9c4582795a
rudimentary instantiation target
src/Pure/Isar/ROOT.ML
src/Pure/Isar/class.ML
src/Pure/Isar/code.ML
src/Pure/Isar/instance.ML
src/Pure/Isar/isar_syn.ML
src/Pure/Isar/theory_target.ML
--- a/src/Pure/Isar/ROOT.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/ROOT.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -44,6 +44,9 @@
 (*derived theory and proof elements*)
 use "calculation.ML";
 use "obtain.ML";
+
+(*local theories and target primitives*)
+use "local_theory.ML";
 use "locale.ML";
 use "class.ML";
 
@@ -52,12 +55,11 @@
 use "code.ML";
 
 (*local theories and specifications*)
-use "local_theory.ML";
 use "theory_target.ML";
 use "subclass.ML";
-use "instance.ML";
 use "spec_parse.ML";
 use "specification.ML";
+use "instance.ML";
 use "constdefs.ML";
 
 (*toplevel environment*)
--- a/src/Pure/Isar/class.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/class.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -7,11 +7,7 @@
 
 signature CLASS =
 sig
-  val axclass_cmd: bstring * xstring list
-    -> ((bstring * Attrib.src list) * string list) list
-    -> theory -> class * theory
-  val classrel_cmd: xstring * xstring -> theory -> Proof.state
-
+  (*classes*)
   val class: bstring -> class list -> Element.context_i Locale.element list
     -> string list -> theory -> string * Proof.context
   val class_cmd: bstring -> xstring list -> Element.context Locale.element list
@@ -30,8 +26,29 @@
     -> theory -> theory
   val print_classes: theory -> unit
   val class_prefix: string -> string
-  val uncheck: bool ref
 
+  (*instances*)
+  val declare_overloaded: string * typ * mixfix -> theory -> term * theory
+  val define_overloaded: string -> string * term -> theory -> thm * theory
+  val unoverload: theory -> conv
+  val overload: theory -> conv
+  val unoverload_const: theory -> string * typ -> string
+  val inst_const: theory -> string * string -> string
+  val param_const: theory -> string -> (string * string) option
+  val instantiation: arity list -> theory -> local_theory
+  val proof_instantiation: (local_theory -> local_theory) -> local_theory -> Proof.state
+  val prove_instantiation: (Proof.context -> tactic) -> local_theory -> local_theory
+  val conclude_instantiation: local_theory -> local_theory
+  val end_instantiation: local_theory -> Proof.context
+  val instantiation_const: Proof.context -> string -> string option
+
+  (*old axclass layer*)
+  val axclass_cmd: bstring * xstring list
+    -> ((bstring * Attrib.src list) * string list) list
+    -> theory -> class * theory
+  val classrel_cmd: xstring * xstring -> theory -> Proof.state
+
+  (*old instance layer*)
   val instance_arity: (theory -> theory) -> arity list -> theory -> Proof.state
   val instance: arity list -> ((bstring * Attrib.src list) * term) list
     -> (thm list -> theory -> theory)
@@ -43,11 +60,6 @@
   val prove_instance: tactic -> arity list
     -> ((bstring * Attrib.src list) * term) list
     -> theory -> thm list * theory
-  val unoverload: theory -> conv
-  val overload: theory -> conv
-  val unoverload_const: theory -> string * typ -> string
-  val inst_const: theory -> string * string -> string
-  val param_const: theory -> string -> (string * string) option
 end;
 
 structure Class : CLASS =
@@ -141,7 +153,9 @@
 end; (*local*)
 
 
-(** explicit constants for overloaded definitions **)
+(** basic overloading **)
+
+(* bookkeeping *)
 
 structure InstData = TheoryDataFun
 (
@@ -156,8 +170,18 @@
       Symtab.merge (K true) (tabb1, tabb2));
 );
 
+val inst_tyco = Option.map fst o try (dest_Type o the_single) oo Sign.const_typargs;
+
+fun inst thy (c, tyco) =
+  (the o Symtab.lookup ((the o Symtab.lookup (fst (InstData.get thy))) c)) tyco;
+
+val inst_const = fst oo inst;
+
 fun inst_thms thy = (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst)
-    (InstData.get thy) [];
+  (InstData.get thy) [];
+
+val param_const = Symtab.lookup o snd o InstData.get;
+
 fun add_inst (c, tyco) inst = (InstData.map o apfst
       o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
   #> (InstData.map o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
@@ -165,19 +189,65 @@
 fun unoverload thy = MetaSimplifier.rewrite true (inst_thms thy);
 fun overload thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
 
-fun inst_const thy (c, tyco) =
-  (fst o the o Symtab.lookup ((the o Symtab.lookup (fst (InstData.get thy))) c)) tyco;
 fun unoverload_const thy (c_ty as (c, _)) =
   case AxClass.class_of_param thy c
-   of SOME class => (case Sign.const_typargs thy c_ty
-       of [Type (tyco, _)] => (case Symtab.lookup
-           ((the o Symtab.lookup (fst (InstData.get thy))) c) tyco
+   of SOME class => (case inst_tyco thy c_ty
+       of SOME tyco => (case try (inst thy) (c, tyco)
              of SOME (c, _) => c
               | NONE => c)
-        | [_] => c)
+        | NONE => c)
     | NONE => c;
 
-val param_const = Symtab.lookup o snd o InstData.get;
+
+(* declaration and definition of instances of overloaded constants *)
+
+fun primitive_note kind (name, thm) =
+  PureThy.note_thmss_i kind [((name, []), [([thm], [])])]
+  #>> (fn [(_, [thm])] => thm);
+
+fun declare_overloaded (c, ty, mx) thy =
+  let
+    val SOME class = AxClass.class_of_param thy c;
+    val SOME tyco = inst_tyco thy (c, ty);
+    val name_inst = NameSpace.base class ^ "_" ^ NameSpace.base tyco ^ "_inst";
+    val c' = NameSpace.base c;
+    val ty' = Type.strip_sorts ty;
+  in
+    thy
+    |> Sign.sticky_prefix name_inst
+    |> Sign.no_base_names
+    |> Sign.notation true Syntax.mode_default [(Const (c, ty), mx)]
+    |> Sign.declare_const [] (c', ty', NoSyn)
+    |-> (fn const' as Const (c'', _) => Thm.add_def true
+          (Thm.def_name c', Logic.mk_equals (Const (c, ty'), const'))
+    #>> Thm.varifyT
+    #-> (fn thm => add_inst (c, tyco) (c'', thm)
+    #> primitive_note Thm.internalK (c', thm)
+    #> snd
+    #> Sign.restore_naming thy
+    #> pair (Const (c, ty))))
+  end;
+
+fun define_overloaded name (c, t) thy =
+  let
+    val ty = Term.fastype_of t;
+    val SOME tyco = inst_tyco thy (c, ty);
+    val (c', eq) = inst thy (c, tyco);
+    val [Type (_, tys)] = Sign.const_typargs thy (c, ty);
+    val eq' = eq
+      |> Drule.instantiate' (map (SOME o Thm.ctyp_of thy) tys) [];
+          (*FIXME proper recover_sort mechanism*)
+    val prop = Logic.mk_equals (Const (c', ty), t);
+    val name' = if name = "" then
+      Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco) else name;
+  in
+    thy
+    |> Thm.add_def false (name', prop)
+    |>> (fn thm => Thm.transitive eq' thm)
+  end;
+
+
+(* legacy *)
 
 fun add_inst_def (class, tyco) (c, ty) thy =
   let
@@ -206,15 +276,10 @@
   let
     val ((lhs as Const (c, ty), args), rhs) =
       (apfst Term.strip_comb o Logic.dest_equals) prop;
-    fun (*add_inst' def ([], (Const (c_inst, ty))) =
-          if forall (fn TFree _ => true | _ => false) (Sign.const_typargs thy (c_inst, ty))
-          then add_inst (c, tyco) (c_inst, def)
-          else add_inst_def (class, tyco) (c, ty)
-      |*) add_inst' _ t = add_inst_def (class, tyco) (c, ty);
   in
     thy
     |> PureThy.add_defs_i true [((name, prop), map (Attrib.attribute_i thy) atts)]
-    |-> (fn [def] => add_inst' def (args, rhs) #> pair def)
+    |-> (fn [def] => add_inst_def (class, tyco) (c, ty) #> pair def)
   end;
 
 
@@ -690,14 +755,11 @@
     |> SOME
   end;
 
-val uncheck = ref true;
-
 fun sort_term_uncheck ts ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     val unchecks = (#unchecks o ClassSyntax.get) ctxt;
-    val ts' = if ! uncheck
-      then map (Pattern.rewrite_term thy unchecks []) ts else ts;
+    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
 
 fun init_ctxt sups base_sort ctxt =
@@ -896,4 +958,174 @@
     |> Sign.restore_naming thy
   end;
 
+
+(** instantiation target **)
+
+(* bookkeeping *)
+
+datatype instantiation = Instantiation of {
+  arities: arity list,
+  params: ((string * string) * (string * typ)) list
+}
+
+structure Instantiation = ProofDataFun
+(
+  type T = instantiation
+  fun init _ = Instantiation { arities = [], params = [] };
+);
+
+fun mk_instantiation (arities, params) = Instantiation {
+    arities = arities, params = params
+  };
+fun map_instantiation f (Instantiation { arities, params }) =
+  mk_instantiation (f (arities, params));
+
+fun the_instantiation ctxt = case Instantiation.get ctxt
+ of Instantiation { arities = [], ... } => error "No instantiation target"
+  | Instantiation data => data;
+
+fun init_instantiation arities ctxt =
+  let
+    val thy = ProofContext.theory_of ctxt;
+    val _ = if null arities then error "At least one arity must be given" else ();
+    val _ = case (duplicates (op =) o map #1) arities
+     of [] => ()
+      | dupl_tycos => error ("Type constructors occur more than once in arities: "
+          ^ commas_quote dupl_tycos);
+    val ty_insts = map (fn (tyco, sorts, _) =>
+        (tyco, Type (tyco, map TFree (Name.names Name.context Name.aT sorts))))
+      arities;
+    val ty_inst = the o AList.lookup (op =) ty_insts;
+    fun type_name "*" = "prod"
+      | type_name "+" = "sum"
+      | type_name s = NameSpace.base s; (*FIXME*)
+    fun get_param tyco sorts (param, (c, ty)) =
+      ((unoverload_const thy (c, ty), tyco),
+        (param ^ "_" ^ type_name tyco,
+          map_atyps (K (ty_inst tyco)) ty));
+    fun get_params (tyco, sorts, sort) =
+      map (get_param tyco sorts) (these_params thy sort)
+    val params = maps get_params arities;
+  in
+    ctxt
+    |> Instantiation.put (mk_instantiation (arities, params))
+    |> fold (Variable.declare_term o Logic.mk_type o snd) ty_insts
+    |> fold (Variable.declare_term o Free o snd) params
+  end;
+
+val instantiation_params = #params o the_instantiation;
+
+fun instantiation_const ctxt v = instantiation_params ctxt
+  |> find_first (fn (_, (v', _)) => v = v')
+  |> Option.map (fst o fst);
+
+
+(* syntax *)
+
+fun inst_term_check ts ctxt =
+  let
+    val params = instantiation_params ctxt;
+    val tsig = ProofContext.tsig_of ctxt;
+    val thy = ProofContext.theory_of ctxt;
+
+    fun check_improve (Const (c, ty)) = (case inst_tyco thy (c, ty)
+         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+             of SOME (_, ty') => Type.typ_match tsig (ty, ty')
+              | NONE => I)
+          | NONE => I)
+      | check_improve _ = I;
+    val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
+    val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
+    val ts'' = (map o map_aterms) (fn t as Const (c, ty) => (case inst_tyco thy (c, ty)
+         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+             of SOME v_ty => Free v_ty
+              | NONE => t)
+          | NONE => t)
+      | t => t) ts';
+  in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', ctxt) end;
+
+fun inst_term_uncheck ts ctxt =
+  let
+    val params = instantiation_params ctxt;
+    val ts' = (map o map_aterms) (fn t as Free (v, ty) =>
+       (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params
+         of SOME c => Const (c, ty)
+          | NONE => t)
+      | t => t) ts;
+  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
+
+
+(* target *)
+
+fun instantiation arities =
+  ProofContext.init
+  #> init_instantiation arities
+  #> fold ProofContext.add_arity arities
+  #> Context.proof_map (
+      Syntax.add_term_check 0 "instance" inst_term_check
+      #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck);
+
+fun gen_proof_instantiation do_proof after_qed lthy =
+  let
+    (*FIXME should work on fresh context but continue local theory afterwards*)
+    val ctxt = LocalTheory.target_of lthy;
+    val arities = (#arities o the_instantiation) ctxt;
+    val arities_proof = maps
+      (Logic.mk_arities o Sign.cert_arity (ProofContext.theory_of ctxt)) arities;
+    fun after_qed' results =
+      LocalTheory.theory (fold (AxClass.add_arity o Thm.varifyT) results)
+      #> after_qed;
+  in
+    lthy
+    |> do_proof after_qed' arities_proof
+  end;
+
+val proof_instantiation = gen_proof_instantiation (fn after_qed => fn ts =>
+  Proof.theorem_i NONE (after_qed o map the_single) (map (fn t => [(t, [])]) ts));
+
+fun prove_instantiation tac = gen_proof_instantiation (fn after_qed =>
+  fn ts => fn lthy => after_qed (Goal.prove_multi lthy [] [] ts
+    (fn {context, ...} => tac context)) lthy) I;
+
+fun conclude_instantiation lthy =
+  let
+    val arities = (#arities o the_instantiation) lthy;
+    val thy = ProofContext.theory_of lthy;
+    (*val _ = map (fn (tyco, sorts, sort) =>
+      if Sign.of_sort thy
+        (Type (tyco, map TFree (Name.names Name.context Name.aT sorts)), sort)
+      then () else error ("Missing instance proof for type " ^ quote (Sign.extern_type thy tyco)))
+        arities; FIXME activate when old instance command is gone*)
+    val params_of = maps (these o try (#params o AxClass.get_info thy))
+      o Sign.complete_sort thy;
+    val missing_params = arities
+      |> maps (fn (tyco, _, sort) => params_of sort |> map (rpair tyco))
+      |> filter_out (can (inst thy) o apfst fst);
+    fun declare_missing ((c, ty), tyco) thy =
+      let
+        val SOME class = AxClass.class_of_param thy c;
+        val name_inst = NameSpace.base class ^ "_" ^ NameSpace.base tyco ^ "_inst";
+        val vs = Name.names Name.context Name.aT (replicate (Sign.arity_number thy tyco) []);
+        val ty' = map_atyps (fn _ => Type (tyco, map TFree vs)) ty;
+        val c' = NameSpace.base c;
+      in
+        thy
+        |> Sign.sticky_prefix name_inst
+        |> Sign.no_base_names
+        |> Sign.declare_const [] (c', ty', NoSyn)
+        |-> (fn const' as Const (c'', _) => Thm.add_def true
+              (Thm.def_name c', Logic.mk_equals (const', Const (c, ty')))
+        #>> Thm.varifyT
+        #-> (fn thm => add_inst (c, tyco) (c'', Thm.symmetric thm)
+        #> primitive_note Thm.internalK (c', thm)
+        #> snd
+        #> Sign.restore_naming thy))
+      end;
+  in
+    lthy
+    |> LocalTheory.theory (fold declare_missing missing_params)
+  end;
+
+val end_instantiation = conclude_instantiation #> LocalTheory.target_of;
+
 end;
--- a/src/Pure/Isar/code.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/code.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -24,6 +24,7 @@
   val del_post: thm -> theory -> theory
   val add_datatype: (string * typ) list -> theory -> theory
   val add_datatype_cmd: string list -> theory -> theory
+  val type_interpretation: (string * string list -> theory -> theory) -> theory -> theory
   val add_case: thm -> theory -> theory
   val add_undefined: string -> theory -> theory
 
@@ -537,18 +538,15 @@
 fun aggregate f [] = NONE
   | aggregate f (x::xs) = SOME (aggr_neutr f x xs);
 
-fun inter_sorts thy =
-  let
-    val algebra = Sign.classes_of thy;
-    val inters = curry (Sorts.inter_sort algebra);
-  in aggregate (map2 inters) end;
+fun inter_sorts algebra =
+  aggregate (map2 (curry (Sorts.inter_sort algebra)));
 
 fun specific_constraints thy (class, tyco) =
   let
     val vs = Name.invents Name.context "" (Sign.arity_number thy tyco);
     val classparams = (map fst o these o try (#params o AxClass.get_info thy)) class;
     val funcs = classparams
-      |> map (fn c => Class.inst_const thy (c, tyco))
+      |> map_filter (fn c => try (Class.inst_const thy) (c, tyco))
       |> map (Symtab.lookup ((the_funcs o the_exec) thy))
       |> (map o Option.map) (Susp.force o fst o snd)
       |> maps these
@@ -558,37 +556,53 @@
     val sorts = map (sorts_of o Sign.const_typargs thy o CodeUnit.head_func) funcs;
   in sorts end;
 
-fun weakest_constraints thy (class, tyco) =
+fun weakest_constraints thy algebra (class, tyco) =
   let
-    val all_superclasses = Sign.complete_sort thy [class];
-  in case inter_sorts thy (maps (fn class => specific_constraints thy (class, tyco)) all_superclasses)
+    val all_superclasses = Sorts.complete_sort algebra [class];
+  in case inter_sorts algebra (maps (fn class => specific_constraints thy (class, tyco)) all_superclasses)
    of SOME sorts => sorts
-    | NONE => Sign.arity_sorts thy tyco [class]
+    | NONE => Sorts.mg_domain algebra tyco [class]
   end;
 
-fun strongest_constraints thy (class, tyco) =
+fun strongest_constraints thy algebra (class, tyco) =
   let
-    val algebra = Sign.classes_of thy;
     val all_subclasses = class :: Graph.all_preds ((#classes o Sorts.rep_algebra) algebra) [class];
     val inst_subclasses = filter (can (Sorts.mg_domain algebra tyco) o single) all_subclasses;
-  in case inter_sorts thy (maps (fn class => specific_constraints thy (class, tyco)) inst_subclasses)
+  in case inter_sorts algebra (maps (fn class => specific_constraints thy (class, tyco)) inst_subclasses)
    of SOME sorts => sorts
     | NONE => replicate
-        (Sign.arity_number thy tyco) (Sign.minimize_sort thy (Sign.all_classes thy))
+        (Sign.arity_number thy tyco) (Sorts.minimize_sort algebra (Sorts.all_classes algebra))
+  end;
+
+fun get_algebra thy (class, tyco) =
+  let
+    val base_algebra = Sign.classes_of thy;
+  in if can (Sorts.mg_domain base_algebra tyco) [class]
+    then base_algebra
+    else let
+      val superclasses = Sorts.super_classes base_algebra class;
+      val sorts = inter_sorts base_algebra
+          (map_filter (fn class => try (Sorts.mg_domain base_algebra tyco) [class]) superclasses)
+        |> the_default (replicate (Sign.arity_number thy tyco) [])
+    in
+      base_algebra
+      |> Sorts.add_arities (Sign.pp thy) (tyco, [(class, sorts)])
+    end
   end;
 
 fun gen_classparam_typ constr thy class (c, tyco) = 
   let
+    val algebra = get_algebra thy (class, tyco);
     val cs = these (try (#params o AxClass.get_info thy) class);
-    val ty = (the o AList.lookup (op =) cs) c;
+    val SOME ty = AList.lookup (op =) cs c;
     val sort_args = Name.names (Name.declare Name.aT Name.context) Name.aT
-      (constr thy (class, tyco));
+      (constr thy algebra (class, tyco));
     val ty_inst = Type (tyco, map TFree sort_args);
   in Logic.varifyT (map_type_tfree (K ty_inst) ty) end;
 
 fun retrieve_algebra thy operational =
   Sorts.subalgebra (Sign.pp thy) operational
-    (weakest_constraints thy)
+    (weakest_constraints thy (Sign.classes_of thy))
     (Sign.classes_of thy);
 
 in
@@ -763,18 +777,22 @@
 val add_default_func_attr = Attrib.internal (fn _ => Thm.declaration_attribute
   (fn thm => Context.mapping (add_default_func thm) I));
 
+structure TypeInterpretation = InterpretationFun(type T = string * string list val eq = op =);
+val type_interpretation = TypeInterpretation.interpretation;
+
 fun add_datatype raw_cs thy =
   let
     val cs = map (fn c_ty as (_, ty) => (Class.unoverload_const thy c_ty, ty)) raw_cs;
     val (tyco, vs_cos) = CodeUnit.constrset_of_consts thy cs;
-    val purge_cs = map fst (snd vs_cos);
-    val purge_cs' = case Symtab.lookup ((the_dtyps o the_exec) thy) tyco
-     of SOME (vs, cos) => if null cos then NONE else SOME (purge_cs @ map fst cos)
+    val cs' = map fst (snd vs_cos);
+    val purge_cs = case Symtab.lookup ((the_dtyps o the_exec) thy) tyco
+     of SOME (vs, cos) => if null cos then NONE else SOME (cs' @ map fst cos)
       | NONE => NONE;
   in
     thy
-    |> map_exec_purge purge_cs' (map_dtyps (Symtab.update (tyco, vs_cos))
+    |> map_exec_purge purge_cs (map_dtyps (Symtab.update (tyco, vs_cos))
         #> map_funcs (fold (Symtab.delete_safe o fst) cs))
+    |> TypeInterpretation.data (tyco, cs')
   end;
 
 fun add_datatype_cmd raw_cs thy =
@@ -837,7 +855,8 @@
       add_attribute (name, Args.del |-- Scan.succeed (mk_attribute del)
         || Scan.succeed (mk_attribute add))
   in
-    add_del_attribute ("func", (add_func, del_func))
+    TypeInterpretation.init
+    #> add_del_attribute ("func", (add_func, del_func))
     #> add_del_attribute ("inline", (add_inline, del_inline))
     #> add_del_attribute ("post", (add_post, del_post))
   end);
--- a/src/Pure/Isar/instance.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/instance.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -2,79 +2,74 @@
     ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
-User-level instantiation interface for classes.
-FIXME not operative for the moment
+A primitive instance command, based on instantiation target.
 *)
 
 signature INSTANCE =
 sig
-  val begin_instantiation: arity list -> theory -> local_theory
-  val begin_instantiation_cmd: (xstring * string list * string) list
+  val instantiate: arity list -> (local_theory -> local_theory)
+    -> (Proof.context -> tactic) -> theory -> theory
+  val instance: arity list -> ((bstring * Attrib.src list) * term) list
+    -> (thm list -> theory -> theory)
+    -> theory -> Proof.state
+  val prove_instance: tactic -> arity list -> ((bstring * Attrib.src list) * term) list
+    -> theory -> thm list * theory
+
+  val instantiation_cmd: (xstring * sort * xstring) list
     -> theory -> local_theory
-  val proof_instantiation: local_theory -> Proof.state
+  val instance_cmd: (xstring * sort * xstring) list -> ((bstring * Attrib.src list) * xstring) list
+    -> (thm list -> theory -> theory)
+    -> theory -> Proof.state
 end;
 
 structure Instance : INSTANCE =
 struct
 
-structure Instantiation = ProofDataFun
-(
-  type T = ((string * (string * sort) list) * sort) list * ((string * typ) * string) list;
-  fun init _ = ([], []);
-);
+fun instantiation_cmd raw_arities thy =
+  TheoryTarget.instantiation (map (Sign.read_arity thy) raw_arities) thy;
+
+fun instantiate arities f tac =
+  TheoryTarget.instantiation arities
+  #> f
+  #> Class.prove_instantiation tac
+  #> LocalTheory.exit
+  #> ProofContext.theory_of;
 
-local
-
-fun gen_begin_instantiation prep_arity raw_arities thy =
+fun gen_instance prep_arity prep_attr prep_term do_proof raw_arities defs after_qed thy =
   let
-    fun prep_arity' raw_arity names =
+    fun export_defs ctxt = 
+      let
+        val ctxt_thy = ProofContext.init (ProofContext.theory_of ctxt);
+      in
+        map (snd o snd)
+        #> map (Assumption.export false ctxt ctxt_thy)
+        #> Variable.export ctxt ctxt_thy
+      end;
+    fun mk_def ctxt ((name, raw_attr), raw_t) =
       let
-        val arity as (tyco, sorts, sort) = prep_arity thy raw_arity;
-        val vs = Name.invents names Name.aT (length sorts);
-        val names' = fold Name.declare vs names;
-      in (((tyco, vs ~~ sorts), sort), names') end;
-    val (arities, _) = fold_map prep_arity' raw_arities Name.context;
-    fun get_param tyco ty_subst (param, (c, ty)) =
-      ((param ^ "_" ^ NameSpace.base tyco, map_atyps (K ty_subst) ty),
-        Class.unoverload_const thy (c, ty));
-    fun get_params ((tyco, vs), sort) =
-      Class.these_params thy sort
-      |> map (get_param tyco (Type (tyco, map TFree vs)));
-    val params = maps get_params arities;
-    val ctxt =
-      ProofContext.init thy
-      |> Instantiation.put (arities, params);
-    val thy_target = TheoryTarget.begin "" ctxt;
-    val operations = {
-        pretty = LocalTheory.pretty,
-        axioms = LocalTheory.axioms,
-        abbrev = LocalTheory.abbrev,
-        define = LocalTheory.define,
-        notes = LocalTheory.notes,
-        type_syntax = LocalTheory.type_syntax,
-        term_syntax = LocalTheory.term_syntax,
-        declaration = LocalTheory.pretty,
-        reinit = LocalTheory.reinit,
-        exit = LocalTheory.exit
-      };
-  in TheoryTarget.begin "" ctxt end;
+        val attr = map (prep_attr thy) raw_attr;
+        val t = prep_term ctxt raw_t;
+      in (NONE, ((name, attr), t)) end;
+    val arities = map (prep_arity thy) raw_arities;
+  in
+    thy
+    |> TheoryTarget.instantiation arities
+    |> `(fn ctxt => map (mk_def ctxt) defs)
+    |-> (fn defs => fold_map Specification.definition defs)
+    |-> (fn defs => `(fn ctxt => export_defs ctxt defs))
+    ||> LocalTheory.exit
+    ||> ProofContext.theory_of
+    ||> TheoryTarget.instantiation arities
+    |-> (fn defs => do_proof defs (LocalTheory.theory (after_qed defs)))
+  end;
 
-in
-
-val begin_instantiation = gen_begin_instantiation Sign.cert_arity;
-val begin_instantiation_cmd = gen_begin_instantiation Sign.read_arity;
+val instance = gen_instance Sign.cert_arity (K I) (K I)
+  (fn _ => fn after_qed => Class.proof_instantiation (after_qed #> Class.conclude_instantiation));
+val instance_cmd = gen_instance Sign.read_arity Attrib.intern_src
+  (fn ctxt => Syntax.parse_prop ctxt #> Syntax.check_prop ctxt)
+  (fn _ => fn after_qed => Class.proof_instantiation (after_qed #> Class.conclude_instantiation));
+fun prove_instance tac arities defs = gen_instance Sign.cert_arity (K I) (K I)
+  (fn defs => fn after_qed => Class.prove_instantiation (K tac)
+    #> after_qed #> Class.conclude_instantiation #> ProofContext.theory_of #> pair defs) arities defs (K I);
 
 end;
-
-fun gen_proof_instantiation do_proof after_qed lthy =
-  let
-    val ctxt = LocalTheory.target_of lthy;
-    val arities = case Instantiation.get ctxt
-     of ([], _) => error "no instantiation target"
-      | (arities, _) => map (fn ((tyco, vs), sort) => (tyco, map snd vs, sort)) arities;
-    val thy = ProofContext.theory_of ctxt;
-  in (do_proof after_qed arities) thy end;
-
-val proof_instantiation = gen_proof_instantiation Class.instance_arity I;
-
-end;
--- a/src/Pure/Isar/isar_syn.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/isar_syn.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -113,7 +113,7 @@
 val _ =
   OuterSyntax.command "typedecl" "type declaration" K.thy_decl
     (P.type_args -- P.name -- P.opt_infix >> (fn ((args, a), mx) =>
-      Toplevel.theory (Sign.add_typedecls [(a, args, mx)])));
+      Toplevel.theory (Typedecl.add (a, args, mx) #> snd)));
 
 val _ =
   OuterSyntax.command "types" "declare type abbreviations" K.thy_decl
@@ -448,18 +448,18 @@
   OuterSyntax.command "instance" "prove type arity or subclass relation" K.thy_goal
   ((P.xname -- ((P.$$$ "\\<subseteq>" || P.$$$ "<") |-- P.!!! P.xname) >> Class.classrel_cmd ||
     P.and_list1 P.arity -- Scan.repeat (SpecParse.opt_thm_name ":" -- P.prop)
-      >> (fn (arities, defs) => Class.instance_cmd arities defs (fold Code.add_default_func (* FIXME ? *))))
+      >> (fn (arities, defs) => Class.instance_cmd arities defs (fold Code.add_default_func)))
     >> (Toplevel.print oo Toplevel.theory_to_proof));
 
 val _ =
   OuterSyntax.command "instantiation" "prove type arity" K.thy_decl
-   (P.and_list1 P.arity -- P.opt_begin
-     >> (fn (arities, begin) => (begin ? Toplevel.print) o
-         Toplevel.begin_local_theory begin (Instance.begin_instantiation_cmd arities)));
+   (P.and_list1 P.arity --| P.begin
+     >> (fn arities => Toplevel.print o
+         Toplevel.begin_local_theory true (Instance.instantiation_cmd arities)));
 
 val _ =  (* FIXME incorporate into "instance" *)
   OuterSyntax.command "instance_proof" "prove type arity relation" K.thy_goal
-    (Scan.succeed (Toplevel.print o Toplevel.local_theory_to_proof NONE Instance.proof_instantiation));
+    (Scan.succeed (Toplevel.print o Toplevel.local_theory_to_proof NONE (Class.proof_instantiation I)));
 
 
 (* code generation *)
--- a/src/Pure/Isar/theory_target.ML	Fri Nov 23 21:09:34 2007 +0100
+++ b/src/Pure/Isar/theory_target.ML	Fri Nov 23 21:09:35 2007 +0100
@@ -2,15 +2,17 @@
     ID:         $Id$
     Author:     Makarius
 
-Common theory/locale/class targets.
+Common theory/locale/class/instantiation targets.
 *)
 
 signature THEORY_TARGET =
 sig
-  val peek: local_theory -> {target: string, is_locale: bool, is_class: bool}
+  val peek: local_theory -> {target: string, is_locale: bool,
+    is_class: bool, instantiation: arity list}
   val init: string option -> theory -> local_theory
   val begin: string -> Proof.context -> local_theory
   val context: xstring -> theory -> local_theory
+  val instantiation: arity list -> theory -> local_theory
 end;
 
 structure TheoryTarget: THEORY_TARGET =
@@ -18,12 +20,14 @@
 
 (* context data *)
 
-datatype target = Target of {target: string, is_locale: bool, is_class: bool};
+datatype target = Target of {target: string, is_locale: bool,
+  is_class: bool, instantiation: arity list};
 
-fun make_target target is_locale is_class =
-  Target {target = target, is_locale = is_locale, is_class = is_class};
+fun make_target target is_locale is_class instantiation =
+  Target {target = target, is_locale = is_locale,
+    is_class = is_class, instantiation = instantiation};
 
-val global_target = make_target "" false false;
+val global_target = make_target "" false false [];
 
 structure Data = ProofDataFun
 (
@@ -36,7 +40,7 @@
 
 (* pretty *)
 
-fun pretty (Target {target, is_locale, is_class}) ctxt =
+fun pretty (Target {target, is_locale, is_class, instantiation}) ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     val target_name = (if is_class then "class " else "locale ") ^ Locale.extern thy target;
@@ -186,13 +190,18 @@
              Morphism.form (ProofContext.target_notation true prmode [(lhs', mx)]))))
   end;
 
-fun declare_const (ta as Target {target, is_locale, is_class}) depends ((c, T), mx) lthy =
+fun declare_const (ta as Target {target, is_locale, is_class, instantiation}) depends ((c, T), mx) lthy =
   let
     val pos = ContextPosition.properties_of lthy;
     val xs = filter depends (#1 (ProofContext.inferred_fixes (LocalTheory.target_of lthy)));
     val U = map #2 xs ---> T;
     val (mx1, mx2, mx3) = fork_mixfix ta mx;
-    val (const, lthy') = lthy |> LocalTheory.theory_result (Sign.declare_const pos (c, U, mx3));
+    val declare_const = if null instantiation
+      then Sign.declare_const pos (c, U, mx3)
+      else case Class.instantiation_const lthy c
+       of SOME c' => Class.declare_overloaded (c', U, mx3)
+        | NONE => Sign.declare_const pos (c, U, mx3);
+    val (const, lthy') = lthy |> LocalTheory.theory_result declare_const;
     val t = Term.list_comb (const, map Free xs);
   in
     lthy'
@@ -204,7 +213,7 @@
 
 (* abbrev *)
 
-fun abbrev (ta as Target {target, is_locale, is_class}) prmode ((c, mx), t) lthy =
+fun abbrev (ta as Target {target, is_locale, is_class, instantiation}) prmode ((c, mx), t) lthy =
   let
     val pos = ContextPosition.properties_of lthy;
     val thy_ctxt = ProofContext.init (ProofContext.theory_of lthy);
@@ -236,7 +245,7 @@
 
 (* define *)
 
-fun define (ta as Target {target, is_locale, is_class})
+fun define (ta as Target {target, is_locale, is_class, instantiation})
     kind ((c, mx), ((name, atts), rhs)) lthy =
   let
     val thy = ProofContext.theory_of lthy;
@@ -253,12 +262,18 @@
     val (_, lhs') = Logic.dest_equals (Thm.prop_of local_def);
 
     (*def*)
+    val is_instantiation = not (null instantiation)
+      andalso is_some (Class.instantiation_const lthy c);
+    val define_const = if not is_instantiation
+      then (fn name => fn eq => Thm.add_def false (name, Logic.mk_equals eq))
+      else (fn name => fn (Const (c, _), rhs) => Class.define_overloaded name (c, rhs));
     val (global_def, lthy3) = lthy2
-      |> LocalTheory.theory_result (Thm.add_def false (name', Logic.mk_equals (lhs', rhs')));
-    val def = LocalDefs.trans_terms lthy3
+      |> LocalTheory.theory_result (define_const name' (lhs', rhs'));
+    val def = if not is_instantiation then LocalDefs.trans_terms lthy3
       [(*c == global.c xs*)     local_def,
        (*global.c xs == rhs'*)  global_def,
-       (*rhs' == rhs*)          Thm.symmetric rhs_conv];
+       (*rhs' == rhs*)          Thm.symmetric rhs_conv]
+      else Thm.transitive local_def global_def;
 
     (*note*)
     val ([(res_name, [res])], lthy4) = lthy3
@@ -298,14 +313,18 @@
 local
 
 fun init_target _ NONE = global_target
-  | init_target thy (SOME target) = make_target target true (Class.is_class thy target);
+  | init_target thy (SOME target) = make_target target true (Class.is_class thy target) [];
+
+fun init_instantiaton arities = make_target "" false false arities
 
-fun init_ctxt (Target {target, is_locale, is_class}) =
-  if not is_locale then ProofContext.init
-  else if not is_class then Locale.init target
-  else Class.init target;
+fun init_ctxt (Target {target, is_locale, is_class, instantiation}) =
+  if null instantiation then
+    if not is_locale then ProofContext.init
+    else if not is_class then Locale.init target
+    else Class.init target
+  else Class.instantiation instantiation;
 
-fun init_lthy (ta as Target {target, ...}) =
+fun init_lthy (ta as Target {target, instantiation, ...}) =
   Data.put ta #>
   LocalTheory.init (NameSpace.base target)
    {pretty = pretty ta,
@@ -317,7 +336,7 @@
     term_syntax = term_syntax ta,
     declaration = declaration ta,
     reinit = fn lthy => init_lthy_ctxt ta (ProofContext.theory_of lthy),
-    exit = LocalTheory.target_of}
+    exit = if null instantiation then LocalTheory.target_of else Class.end_instantiation}
 and init_lthy_ctxt ta = init_lthy ta o init_ctxt ta;
 
 in
@@ -328,6 +347,9 @@
 fun context "-" thy = init NONE thy
   | context target thy = init (SOME (Locale.intern thy target)) thy;
 
+fun instantiation raw_arities thy =
+  init_lthy_ctxt (init_instantiaton (map (Sign.cert_arity thy) raw_arities)) thy;
+
 end;
 
 end;