improved improvements for instantiaton
authorhaftmann
Wed, 02 Apr 2008 15:58:41 +0200
changeset 26518 3db6a46d8460
parent 26517 ef036a63f6e9
child 26519 6cd53b7ef55c
improved improvements for instantiaton
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Wed Apr 02 15:58:40 2008 +0200
+++ b/src/Pure/Isar/class.ML	Wed Apr 02 15:58:41 2008 +0200
@@ -9,9 +9,9 @@
 sig
   (*classes*)
   val class: bstring -> class list -> Element.context_i list
-    -> string list -> theory -> string * Proof.context
+    -> theory -> string * Proof.context
   val class_cmd: bstring -> xstring list -> Element.context list
-    -> xstring list -> theory -> string * Proof.context
+    -> theory -> string * Proof.context
 
   val init: class -> theory -> Proof.context
   val declare: string -> Markup.property list
@@ -26,7 +26,7 @@
 
   val class_prefix: string -> string
   val is_class: theory -> class -> bool
-  val these_params: theory -> sort -> (string * (string * typ)) list
+  val these_params: theory -> sort -> (string * (class * (string * typ))) list
   val print_classes: theory -> unit
 
   (*instances*)
@@ -184,7 +184,8 @@
         val const_typs = (#params o AxClass.get_info thy) class;
         val const_names = (#consts o the_class_data thy) class;
       in
-        (map o apsnd) (fn c => (c, (the o AList.lookup (op =) const_typs) c)) const_names
+        (map o apsnd)
+          (fn c => (class, (c, (the o AList.lookup (op =) const_typs) c))) const_names
       end;
   in maps params o ancestry thy end;
 
@@ -252,8 +253,8 @@
 fun register_operation class (c, (t, some_def)) thy =
   let
     val base_sort = (#base_sort o the_class_data thy) class;
-    val prep_typ = map_atyps
-      (fn TVar (vi as (v, _), sort) => if Name.aT = v
+    val prep_typ = map_type_tvar
+      (fn (vi as (v, _), sort) => if Name.aT = v
         then TFree (v, base_sort) else TVar (vi, sort));
     val t' = map_types prep_typ t;
     val ty' = Term.fastype_of t';
@@ -273,9 +274,8 @@
 fun calculate sups base_sort assm_axiom param_map class thy =
   let
     (*static parts of morphism*)
-    val subst_typ = map_atyps (fn TFree (v, sort) =>
-          if v = Name.aT then TVar ((v, 0), [class]) else TVar ((v, 0), sort)
-      | ty => ty);
+    val subst_typ = map_type_tfree (fn (v, sort) =>
+          if v = Name.aT then TVar ((v, 0), [class]) else TVar ((v, 0), sort));
     fun subst_aterm (t as Free (v, ty)) = (case AList.lookup (op =) param_map v
          of SOME (c, _) => Const (c, ty)
           | NONE => t)
@@ -347,10 +347,9 @@
 
 fun class_interpretation class facts defs thy =
   let
-    val params = these_params thy [class];
-    val consts = map (fst o snd) params;
-    val constraints = map (fn c => map_atyps (K (TFree (Name.aT,
-      [the (AxClass.class_of_param thy c)]))) (Sign.the_const_type thy c)) consts;
+    val consts = map (apsnd fst o snd) (these_params thy [class]);
+    val constraints = map (fn (class, c) => map_atyps (K (TFree (Name.aT,
+      [class]))) (Sign.the_const_type thy c)) consts;
     val no_constraints = map (map_atyps (K (TFree (Name.aT, [])))) constraints;
     fun add_constraint c T = Sign.add_const_constraint (c, SOME T);
     val inst = (#inst o the_class_data thy) class;
@@ -358,10 +357,10 @@
     val prfx = class_prefix class;
   in
     thy
-    |> fold2 add_constraint consts no_constraints
+    |> fold2 add_constraint (map snd consts) no_constraints
     |> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
           (inst, map (fn def => (("", []), def)) defs)
-    |> fold2 add_constraint consts constraints
+    |> fold2 add_constraint (map snd consts) constraints
   end;
 
 fun prove_subclass (sub, sup) thm thy =
@@ -412,17 +411,16 @@
 fun synchronize_class_syntax sups base_sort ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
-    fun subst_class_typ sort = map_atyps
-      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
     val operations = these_operations thy sups;
-    val local_constraints =
+    fun subst_class_typ sort = map_type_tfree (K (TVar ((Name.aT, 0), sort)));
+    val primary_constraints =
       (map o apsnd) (subst_class_typ base_sort o fst o snd) operations;
-    val global_constraints =
+    val secondary_constraints =
       (map o apsnd) (fn (class, (ty, _)) => subst_class_typ [class] ty) operations;
     fun declare_const (c, _) =
       let val b = Sign.base_name c
       in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
-    fun improve (c, ty) = (case AList.lookup (op =) local_constraints c
+    fun improve (c, ty) = (case AList.lookup (op =) primary_constraints c
      of SOME ty' => (case try (Type.raw_match (ty', ty)) Vartab.empty
          of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
              of SOME (_, ty' as TVar (tvar as (vi, _))) =>
@@ -436,10 +434,10 @@
     val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
   in
     ctxt
-    |> fold declare_const local_constraints
-    |> Overloading.map_improvable_syntax (K (((local_constraints, global_constraints),
+    |> fold declare_const primary_constraints
+    |> Overloading.map_improvable_syntax (K (((primary_constraints, secondary_constraints),
         ((improve, subst), unchecks)), false))
-    |> Overloading.set_local_constraints
+    |> Overloading.set_primary_constraints
   end;
 
 fun refresh_syntax class ctxt =
@@ -500,10 +498,10 @@
 val read_class_spec = gen_class_spec Sign.intern_class Locale.read_expr;
 val check_class_spec = gen_class_spec (K I) Locale.cert_expr;
 
-fun adjungate_axclass bname class base_sort sups supsort supparams global_syntax other_consts thy =
+fun adjungate_axclass bname class base_sort sups supsort supparams global_syntax thy =
   let
     val supconsts = map fst supparams
-      |> AList.make (the o AList.lookup (op =) (these_params thy sups))
+      |> AList.make (snd o the o AList.lookup (op =) (these_params thy sups))
       |> (map o apsnd o apsnd o map_atyps o K o TFree) (Name.aT, [class]);
     val all_params = map fst (Locale.parameters_of thy class);
     fun add_const (v, raw_ty) thy =
@@ -538,7 +536,7 @@
     thy
     |> add_consts ((snd o chop (length supparams)) all_params)
     |-> (fn (param_map, params) => AxClass.define_class (bname, supsort)
-          (map (fst o snd) params @ other_consts)
+          (map (fst o snd) params)
           [((bname ^ "_" ^ AxClass.axiomsN, []), map (globalize param_map) raw_pred)]
     #> snd
     #> `get_axiom
@@ -546,19 +544,17 @@
     #> pair (param_map, params, assm_axiom)))
   end;
 
-fun gen_class prep_spec prep_param bname
-    raw_supclasses raw_elems raw_other_consts thy =
+fun gen_class prep_spec bname raw_supclasses raw_elems thy =
   let
     val class = Sign.full_name thy bname;
     val (((sups, supparams), (supsort, base_sort, mergeexpr)), (elems, global_syntax)) =
       prep_spec thy raw_supclasses raw_elems;
-    val other_consts = map (tap (Sign.the_const_type thy) o prep_param thy) raw_other_consts;
   in
     thy
     |> Locale.add_locale_i (SOME "") bname mergeexpr elems
     |> snd
     |> ProofContext.theory_of
-    |> adjungate_axclass bname class base_sort sups supsort supparams global_syntax other_consts
+    |> adjungate_axclass bname class base_sort sups supsort supparams global_syntax
     |-> (fn (param_map, params, assm_axiom) =>
         calculate sups base_sort assm_axiom param_map class
     #-> (fn (morphism, axiom, assm_intro, of_class) =>
@@ -569,12 +565,10 @@
     |> pair class
   end;
 
-fun read_const thy = #1 o Term.dest_Const o ProofContext.read_const (ProofContext.init thy);
-
 in
 
-val class_cmd = gen_class read_class_spec read_const;
-val class = gen_class check_class_spec (K I);
+val class_cmd = gen_class read_class_spec;
+val class = gen_class check_class_spec;
 
 end; (*local*)
 
@@ -667,8 +661,6 @@
   let
     val Instantiation { arities = (_, _, sort), params = params } = Instantiation.get ctxt;
     val thy = ProofContext.theory_of ctxt;
-    fun subst_class_typ sort = map_atyps
-      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
     fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
          of SOME tyco => (case AList.lookup (op =) params (c, tyco)
              of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty)
@@ -679,8 +671,8 @@
   in
     ctxt
     |> Overloading.map_improvable_syntax
-         (fn (((local_constraints, _), ((improve, _), _)), _) =>
-            (((local_constraints, []), ((improve, subst), unchecks)), false))
+         (fn (((primary_constraints, _), ((improve, _), _)), _) =>
+            (((primary_constraints, []), ((improve, subst), unchecks)), false))
   end;
 
 
@@ -705,43 +697,43 @@
   | type_name "+" = "sum"
   | type_name s = sanatize_name (NameSpace.base s); (*FIXME*)
 
+fun resort_terms pp algebra consts constraints ts =
+  let
+    fun matchings (Const (c_ty as (c, _))) = (case constraints c
+         of NONE => I
+          | SOME sorts => fold2 (curry (Sorts.meet_sort algebra))
+              (Consts.typargs consts c_ty) sorts)
+      | matchings _ = I
+    val tvartab = (fold o fold_aterms) matchings ts Vartab.empty
+      handle Sorts.CLASS_ERROR e => Sorts.class_error pp e;
+    val inst = map_type_tvar (fn (vi, _) => TVar (vi, the (Vartab.lookup tvartab vi)));
+  in if Vartab.is_empty tvartab then NONE else SOME ((map o map_types) inst ts) end;
+
 fun init_instantiation (tycos, vs, sort) thy =
   let
     val _ = if null tycos then error "At least one arity must be given" else ();
-    fun subst_class_typ sort = map_atyps
-      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
-    fun get_param tyco (param, (c, ty)) = if can (AxClass.param_of_inst thy) (c, tyco)
+    val params = these_params thy sort;
+    fun get_param tyco (param, (_, (c, ty))) = if can (AxClass.param_of_inst thy) (c, tyco)
       then NONE else SOME ((c, tyco),
         (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, map TFree vs))) ty));
-    val class_of = fst o the o AList.lookup (op =) (these_operations thy sort);
-    val params = these_params thy sort;
-    val inst_params = map_product get_param tycos (these_params thy sort) |> map_filter I;
-    val local_constraints = map (apsnd (subst_class_typ []) o snd) params;
-    val pseudo_constraints = map (fn (_, (c, _)) => (c, class_of c)) params;
-    val typ_of_sort = Type.typ_of_sort (Sign.classes_of thy);
-    val typargs = Sign.const_typargs thy;
+    val inst_params = map_product get_param tycos params |> map_filter I;
+    val primary_constraints = map (apsnd
+      (map_atyps (K (TVar ((Name.aT, 0), [])))) o snd o snd) params;
     val pp = Sign.pp thy;
-    fun constrain_typ tys sorts ty =
-      let
-        val tyenv = fold2 typ_of_sort tys sorts Vartab.empty
-          handle Sorts.CLASS_ERROR e => Sorts.class_error pp e;
-      in
-        map_atyps (fn TVar (vi, _) => TVar (vi, the (Vartab.lookup tyenv vi))
-          | ty => ty) ty
-      end;
-    fun constrain_class (c, ty) class =
-      constrain_typ (typargs (c, ty)) [[class]] ty;
-    fun improve_param (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
+    val algebra = Sign.classes_of thy
+      |> fold (fn tyco => Sorts.add_arities pp
+            (tyco, map (fn class => (class, map snd vs)) sort)) tycos;
+    val consts = Sign.consts_of thy;
+    val improve_constraints = AList.lookup (op =)
+      (map (fn (_, (class, (c, _))) => (c, [[class]])) params);
+    fun resort_check ts ctxt = case resort_terms pp algebra consts improve_constraints ts
+     of NONE => NONE
+      | SOME ts' => SOME (ts', ctxt);
+    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
      of SOME tyco => (case AList.lookup (op =) inst_params (c, tyco)
-         of SOME (_, ty') => SOME (ty, ty')
+         of SOME (_, ty') => if Type.raw_instance (ty', ty) then SOME (ty, ty') else NONE
           | NONE => NONE)
       | NONE => NONE;
-    fun improve (c, ty) = case improve_param (c, ty)
-     of SOME ty_ty' => SOME ty_ty'
-      | NONE => (case AList.lookup (op =) pseudo_constraints c
-         of SOME class =>
-              SOME (ty, constrain_class (c, ty) class)
-          | NONE => NONE);
   in
     thy
     |> ProofContext.init
@@ -750,8 +742,9 @@
     |> fold (Variable.declare_names o Free o snd) inst_params
     |> (Overloading.map_improvable_syntax o apfst)
          (fn ((_, _), ((_, subst), unchecks)) =>
-            ((local_constraints, []), ((improve, K NONE), [])))
+            ((primary_constraints, []), ((improve, K NONE), [])))
     |> Overloading.add_improvable_syntax
+    |> Context.proof_map (Syntax.add_term_check 0 "resorting" resort_check)
     |> synchronize_inst_syntax
   end;
 
@@ -787,11 +780,6 @@
         (Type (tyco, map TFree vs), sort)
       then () else error ("Missing instance proof for type " ^ quote (Sign.extern_type thy tyco)))
         tycos;
-    (*this checkpoint should move to AxClass as soon as "attach" has disappeared*)
-    val _ = case map (fst o snd) params
-     of [] => ()
-      | cs => Output.legacy_feature
-          ("Missing specifications for overloaded parameters " ^ commas_quote cs)
   in lthy end;
 
 fun pretty_instantiation lthy =