explicit type variables for instantiation
authorhaftmann
Tue, 08 Jan 2008 11:37:30 +0100
changeset 25864 11f531354852
parent 25863 5b4a8b1d0f88
child 25865 a141d6bfd398
explicit type variables for instantiation
src/HOL/Library/Eval.thy
src/HOL/Tools/datatype_codegen.ML
src/HOL/Tools/function_package/size.ML
src/HOL/Tools/typecopy_package.ML
src/Pure/Isar/class.ML
src/Pure/Isar/instance.ML
src/Pure/Isar/theory_target.ML
--- a/src/HOL/Library/Eval.thy	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/HOL/Library/Eval.thy	Tue Jan 08 11:37:30 2008 +0100
@@ -99,10 +99,11 @@
   fun interpretator tyco thy =
     let
       val sorts = replicate (Sign.arity_number thy tyco) @{sort typ_of};
-      val ty = Type (tyco, map TFree (Name.names Name.context "'a" sorts));
+      val vs = Name.names Name.context "'a" sorts;
+      val ty = Type (tyco, map TFree vs);
     in
       thy
-      |> TheoryTarget.instantiation ([tyco], sorts, @{sort typ_of})
+      |> TheoryTarget.instantiation ([tyco], vs, @{sort typ_of})
       |> define_typ_of ty
       |> snd
       |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
@@ -265,7 +266,7 @@
       val defs = map (mk_terms_of_defs vs) css;
     in if forall (fn tyco => can (Sign.arity_sorts thy tyco) @{sort term_of}) dep_tycos
         andalso not (tycos = [@{type_name typ}])
-      then SOME (sorts, defs)
+      then SOME (vs, defs)
       else NONE
     end;
   fun prep' ctxt proto_eqs =
@@ -279,9 +280,9 @@
       val (fixes, eqnss) = split_list (map (prep' ctxt) primrecs);
     in PrimrecPackage.add_primrec fixes (flat eqnss) ctxt end;
   fun interpretator tycos thy = case prep thy tycos
-   of SOME (sorts, defs) =>
+   of SOME (vs, defs) =>
       thy
-      |> TheoryTarget.instantiation (tycos, sorts, @{sort term_of})
+      |> TheoryTarget.instantiation (tycos, vs, @{sort term_of})
       |> primrec defs
       |> snd
       |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
--- a/src/HOL/Tools/datatype_codegen.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -431,11 +431,11 @@
       in
         Code.add_funcl (const, Susp.delay get_thms) thy
       end;
-    val sorts_eq =
-      map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
+    val vs' = (map o apsnd)
+      (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
   in
     thy
-    |> TheoryTarget.instantiation (dtcos, sorts_eq, [HOLogic.class_eq])
+    |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq])
     |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
     |> LocalTheory.exit
     |> ProofContext.theory_of
--- a/src/HOL/Tools/function_package/size.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/HOL/Tools/function_package/size.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -60,23 +60,23 @@
 
 fun prove_size_thms (info : datatype_info) new_type_names thy =
   let
-    val {descr, alt_names, sorts, rec_names, rec_rewrites, induction, ...} = info;
+    val {descr, alt_names, sorts = raw_sorts, rec_names, rec_rewrites, induction, ...} = info;
 
     (*normalize type variable names to accomodate policy imposed by instantiation target*)
-    val tvars = (map dest_TFree o snd o dest_Type o hd) (get_rec_types descr sorts)
-      ~~ Name.invents Name.context Name.aT (length sorts);
-    val norm_tvars = map_atyps
-      (fn TFree (v, sort) => TFree (the (AList.lookup (op =) tvars v), sort));
+    val tvars = (map dest_TFree o snd o dest_Type o hd) (get_rec_types descr raw_sorts)
+      ~~ Name.invents Name.context Name.aT (length raw_sorts);
+    val sorts = tvars
+      |> map (fn (v, _) => (v, the (AList.lookup (op =) raw_sorts v)));
 
     val l = length new_type_names;
     val alt_names' = (case alt_names of
       NONE => replicate l NONE | SOME names => map SOME names);
     val descr' = List.take (descr, l);
     val (rec_names1, rec_names2) = chop l rec_names;
-    val recTs = map norm_tvars (get_rec_types descr sorts);
+    val recTs = get_rec_types descr sorts;
     val (recTs1, recTs2) = chop l recTs;
     val (_, (_, paramdts, _)) :: _ = descr;
-    val paramTs = map (norm_tvars o typ_of_dtyp descr sorts) paramdts;
+    val paramTs = map (typ_of_dtyp descr sorts) paramdts;
     val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
       map (fn T as TFree (s, _) =>
         let
@@ -108,7 +108,7 @@
     (* instantiation for primrec combinator *)
     fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
       let
-        val Ts = map (norm_tvars o typ_of_dtyp descr sorts) cargs;
+        val Ts = map (typ_of_dtyp descr sorts) cargs;
         val k = length (filter is_rec_type cargs);
         val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
           if is_rec_type dt then (Bound i :: us, i + 1, j + 1)
@@ -139,18 +139,12 @@
     fun define_overloaded (def_name, eq) lthy =
       let
         val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
-        val (thm, lthy') = lthy
+        val ((_, (_, thm)), lthy') = lthy
           |> LocalTheory.define Thm.definitionK ((c, NoSyn), ((def_name, []), rhs));
         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy');
-        val thm' = thm
-          |> Assumption.export false lthy' ctxt_thy o snd o snd
-          |> singleton (Variable.export lthy' ctxt_thy)
+        val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
       in (thm', lthy') end;
 
-    val size_sorts = tvars
-      |> map (fn (v, _) => Sorts.inter_sort (Sign.classes_of thy) (HOLogic.typeS,
-           the (AList.lookup (op =) sorts v)));
-
     val ((size_def_thms, size_def_thms'), thy') =
       thy
       |> Sign.add_consts_i (map (fn (s, T) =>
@@ -160,7 +154,7 @@
         (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
            (def_names ~~ (size_fns ~~ rec_combs1)))
       ||> TheoryTarget.instantiation
-           (map (#1 o snd) descr', size_sorts, [HOLogic.class_size])
+           (map (#1 o snd) descr', sorts, [HOLogic.class_size])
       ||>> fold_map define_overloaded
         (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
       ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
@@ -192,7 +186,7 @@
     (* characteristic equations for size functions *)
     fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
       let
-        val Ts = map (norm_tvars o typ_of_dtyp descr sorts) cargs;
+        val Ts = map (typ_of_dtyp descr sorts) cargs;
         val tnames = Name.variant_list f_names (DatatypeProp.make_tnames Ts);
         val ts = List.mapPartial (fn (sT as (s, T), dt) =>
           Option.map (fn sz => sz $ Free sT)
--- a/src/HOL/Tools/typecopy_package.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/HOL/Tools/typecopy_package.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -124,14 +124,13 @@
 fun add_typecopy_spec tyco thy =
   let
     val SOME { constr, proj_def, inject, vs, ... } = get_info thy tyco;
-    val sorts_eq =
-      map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
+    val vs' = (map o apsnd) (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
     val ty = Logic.unvarifyT (Sign.the_const_type thy constr);
   in
     thy
     |> Code.add_datatype [(constr, ty)]
     |> Code.add_func proj_def
-    |> TheoryTarget.instantiation ([tyco], sorts_eq, [HOLogic.class_eq])
+    |> TheoryTarget.instantiation ([tyco], vs', [HOLogic.class_eq])
     |> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
     |> LocalTheory.exit
     |> ProofContext.theory_of
--- a/src/Pure/Isar/class.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/Pure/Isar/class.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -30,7 +30,7 @@
   val print_classes: theory -> unit
 
   (*instances*)
-  val init_instantiation: string list * sort list * sort -> theory -> local_theory
+  val init_instantiation: string list * (string * sort) list * sort -> theory -> local_theory
   val instantiation_instance: (local_theory -> local_theory) -> local_theory -> Proof.state
   val prove_instantiation_instance: (Proof.context -> tactic) -> local_theory -> local_theory
   val conclude_instantiation: local_theory -> local_theory
@@ -681,7 +681,7 @@
 (* bookkeeping *)
 
 datatype instantiation = Instantiation of {
-  arities: string list * sort list * sort,
+  arities: string list * (string * sort) list * sort,
   params: ((string * string) * (string * typ)) list
     (*(instantiation parameter, type constructor), (local instantiation parameter, typ)*)
 }
@@ -768,25 +768,24 @@
     explode #> scan_valids #> implode
   end;
 
-fun init_instantiation (tycos, sorts, sort) thy =
+fun init_instantiation (tycos, vs, sort) thy =
   let
     val _ = if null tycos then error "At least one arity must be given" else ();
     val _ = map (the_class_data thy) sort;
-    val vs = map TFree (Name.names Name.context Name.aT sorts);
     fun type_name "*" = "prod"
       | type_name "+" = "sum"
       | type_name s = sanatize_name (NameSpace.base s); (*FIXME*)
     fun get_param tyco (param, (c, ty)) = if can (AxClass.param_of_inst thy) (c, tyco)
       then NONE else SOME ((c, tyco),
-        (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, vs))) ty));
+        (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, map TFree vs))) ty));
     val params = map_product get_param tycos (these_params thy sort) |> map_filter I;
   in
     thy
     |> ProofContext.init
-    |> Instantiation.put (mk_instantiation ((tycos, sorts, sort), params))
-    |> fold (Variable.declare_term o Logic.mk_type) vs
+    |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
+    |> fold (Variable.declare_term o Logic.mk_type o TFree) vs
     |> fold (Variable.declare_names o Free o snd) params
-    |> fold (fn tyco => ProofContext.add_arity (tyco, sorts, sort)) tycos
+    |> fold (fn tyco => ProofContext.add_arity (tyco, map snd vs, sort)) tycos
     |> Context.proof_map (
         Syntax.add_term_check 0 "instance" inst_term_check
         #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck)
@@ -794,8 +793,8 @@
 
 fun gen_instantiation_instance do_proof after_qed lthy =
   let
-    val (tycos, sorts, sort) = (#arities o the_instantiation) lthy;
-    val arities_proof = maps (fn tyco => Logic.mk_arities (tyco, sorts, sort)) tycos;
+    val (tycos, vs, sort) = (#arities o the_instantiation) lthy;
+    val arities_proof = maps (fn tyco => Logic.mk_arities (tyco, map snd vs, sort)) tycos;
     fun after_qed' results =
       LocalTheory.theory (fold (AxClass.add_arity o Thm.varifyT) results)
       #> after_qed;
@@ -814,10 +813,10 @@
 fun conclude_instantiation lthy =
   let
     val { arities, params } = the_instantiation lthy;
-    val (tycos, sorts, sort) = arities;
+    val (tycos, vs, sort) = arities;
     val thy = ProofContext.theory_of lthy;
     val _ = map (fn tyco => if Sign.of_sort thy
-        (Type (tyco, map TFree (Name.names Name.context Name.aT sorts)), sort)
+        (Type (tyco, map TFree vs), sort)
       then () else error ("Missing instance proof for type " ^ quote (Sign.extern_type thy tyco)))
         tycos;
     (*this checkpoint should move to AxClass as soon as "attach" has disappeared*)
@@ -830,12 +829,12 @@
 fun pretty_instantiation lthy =
   let
     val { arities, params } = the_instantiation lthy;
-    val (tycos, sorts, sort) = arities;
+    val (tycos, vs, sort) = arities;
     val thy = ProofContext.theory_of lthy;
-    fun pr_arity tyco = Syntax.pretty_arity lthy (tyco, sorts, sort);
+    fun pr_arity tyco = Syntax.pretty_arity lthy (tyco, map snd vs, sort);
     fun pr_param ((c, _), (v, ty)) =
-      (Pretty.block o Pretty.breaks) [(Pretty.str o Sign.extern_const thy) c, Pretty.str "::",
-        Sign.pretty_typ thy ty, Pretty.str "as", Pretty.str v];
+      (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
+        (Pretty.str o Sign.extern_const thy) c, Pretty.str "::", Sign.pretty_typ thy ty];
   in
     (Pretty.block o Pretty.fbreaks)
       (Pretty.str "instantiation" :: map pr_arity tycos @ map pr_param params)
--- a/src/Pure/Isar/instance.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/Pure/Isar/instance.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -15,32 +15,31 @@
 structure Instance : INSTANCE =
 struct
 
+fun read_single_arity thy (raw_tyco, raw_sorts, raw_sort) =
+  let
+    val (tyco, sorts, sort) = Sign.read_arity thy (raw_tyco, raw_sorts, raw_sort);
+    val vs = Name.names Name.context Name.aT sorts;
+  in (tyco, vs, sort) end;
+
 fun read_multi_arity thy (raw_tycos, raw_sorts, raw_sort) =
   let
     val all_arities = map (fn raw_tyco => Sign.read_arity thy
       (raw_tyco, raw_sorts, raw_sort)) raw_tycos;
     val tycos = map #1 all_arities;
     val (_, sorts, sort) = hd all_arities;
-  in (tycos, sorts, sort) end;
+    val vs = Name.names Name.context Name.aT sorts;
+  in (tycos, vs, sort) end;
 
 fun instantiation_cmd raw_arities thy =
   TheoryTarget.instantiation (read_multi_arity thy raw_arities) thy;
 
-fun gen_instance prep_arity prep_attr parse_term do_proof do_proof' raw_arities defs thy =
+fun instance_cmd raw_arities defs thy =
   let
-    val (tyco, sorts, sort) = prep_arity thy raw_arities;
-    fun export_defs ctxt = 
-      let
-        val ctxt_thy = ProofContext.init (ProofContext.theory_of ctxt);
-      in
-        map (snd o snd)
-        #> map (Assumption.export false ctxt ctxt_thy)
-        #> Variable.export ctxt ctxt_thy
-      end;
+    val (tyco, vs, sort) = read_single_arity thy raw_arities;
     fun mk_def ctxt ((name, raw_attr), raw_t) =
       let
-        val attr = map (prep_attr thy) raw_attr;
-        val t = parse_term ctxt raw_t;
+        val attr = map (Attrib.intern_src thy) raw_attr;
+        val t = Syntax.parse_prop ctxt raw_t;
       in (NONE, ((name, attr), t)) end;
     fun define def ctxt =
       let
@@ -51,18 +50,15 @@
   in if not (null defs) orelse forall (Class.is_class thy) sort
   then
     thy
-    |> TheoryTarget.instantiation ([tyco], sorts, sort)
+    |> TheoryTarget.instantiation ([tyco], vs, sort)
     |> `(fn ctxt => map (mk_def ctxt) defs)
     |-> (fn defs => fold_map Specification.definition defs)
-    |-> (fn defs => `(fn ctxt => export_defs ctxt defs))
-    ||> LocalTheory.reinit
-    |-> (fn defs => do_proof defs)
+    |> snd
+    |> LocalTheory.reinit
+    |> Class.instantiation_instance Class.conclude_instantiation
   else
     thy
-    |> do_proof' (tyco, sorts, sort)
+    |> Class.instance_arity I (tyco, map snd vs, sort)
   end;
 
-val instance_cmd = gen_instance Sign.read_arity Attrib.intern_src Syntax.parse_prop
-  (fn _ => Class.instantiation_instance Class.conclude_instantiation) (Class.instance_arity I);
-
 end;
--- a/src/Pure/Isar/theory_target.ML	Tue Jan 08 11:37:29 2008 +0100
+++ b/src/Pure/Isar/theory_target.ML	Tue Jan 08 11:37:30 2008 +0100
@@ -8,14 +8,14 @@
 signature THEORY_TARGET =
 sig
   val peek: local_theory -> {target: string, is_locale: bool,
-    is_class: bool, instantiation: string list * sort list * sort,
-    overloading: ((string * typ) * (string * bool)) list}
+    is_class: bool, instantiation: string list * (string * sort) list * sort,
+    overloading: (string * (string * typ) * bool) list}
   val init: string option -> theory -> local_theory
   val begin: string -> Proof.context -> local_theory
   val context: xstring -> theory -> local_theory
-  val instantiation: string list * sort list * sort -> theory -> local_theory
-  val overloading: ((string * typ) * (string * bool)) list -> theory -> local_theory
-  val overloading_cmd: (((xstring * xstring) * string) * bool) list -> theory -> local_theory
+  val instantiation: string list * (string * sort) list * sort -> theory -> local_theory
+  val overloading: (string * (string * typ) * bool) list -> theory -> local_theory
+  val overloading_cmd: ((xstring * xstring) * bool) list -> theory -> local_theory
 end;
 
 structure TheoryTarget: THEORY_TARGET =
@@ -24,8 +24,8 @@
 (* context data *)
 
 datatype target = Target of {target: string, is_locale: bool,
-  is_class: bool, instantiation: string list * sort list * sort,
-  overloading: ((string * typ) * (string * bool)) list};
+  is_class: bool, instantiation: string list * (string * sort) list * sort,
+  overloading: (string * (string * typ) * bool) list};
 
 fun make_target target is_locale is_class instantiation overloading =
   Target {target = target, is_locale = is_locale,
@@ -366,7 +366,8 @@
 in
 
 fun init target thy = init_lthy_ctxt (init_target thy target) thy;
-fun begin target ctxt = init_lthy (init_target (ProofContext.theory_of ctxt) (SOME target)) ctxt;
+fun begin target ctxt = init_lthy (init_target (ProofContext.theory_of ctxt)
+  (SOME target)) ctxt;
 
 fun context "-" thy = init NONE thy
   | context target thy = init (SOME (Locale.intern thy target)) thy;
@@ -375,18 +376,18 @@
 
 fun gen_overloading prep_operation raw_operations thy =
   let
-    val check_const = dest_Const o Syntax.check_term (ProofContext.init thy) o Const;
+    val check_const = dest_Const o Syntax.check_term (ProofContext.init thy);
     val operations = raw_operations
       |> map (prep_operation thy)
-      |> (map o apfst) check_const;
+      |> map (fn (v, cTt, checked) => (v, check_const cTt, checked));
   in
     thy
     |> init_lthy_ctxt (init_overloading operations)
   end;
 
-val overloading = gen_overloading (K I);
-val overloading_cmd = gen_overloading (fn thy => fn (((raw_c, rawT), v), checked) =>
-  ((Sign.intern_const thy raw_c, Sign.read_typ thy rawT), (v, checked)));
+val overloading = gen_overloading (fn _ => fn (v, cT, checked) => (v, Const cT, checked));
+val overloading_cmd = gen_overloading (fn thy => fn ((v, raw_cT), checked) =>
+  (v, Sign.read_term thy raw_cT, checked));
 
 end;