instantiation less liberal with dangling constraints
authorhaftmann
Wed, 19 Mar 2008 07:20:33 +0100
changeset 26329 3e58e4c67a2a
parent 26328 b2d6f520172c
child 26330 e493bdd1cff2
instantiation less liberal with dangling constraints
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Wed Mar 19 07:20:32 2008 +0100
+++ b/src/Pure/Isar/class.ML	Wed Mar 19 07:20:33 2008 +0100
@@ -666,17 +666,6 @@
     val thy = ProofContext.theory_of ctxt;
     fun subst_class_typ sort = map_atyps
       (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
-    val operations = these_operations thy sort;
-    val global_constraints = (*map_filter (fn (c, (class, (ty, _))) =>
-      if exists (fn ((c', _), _) => c = c') params
-        then SOME (c, subst_class_typ [class] ty)
-        else NONE) operations;*)[];
-          (*| NONE => (case map_filter
-               (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params
-             of [ty'] => (case Sign.const_typargs thy (c, ty)
-                 of [TVar (vi, _)] => if TypeInfer.is_param vi then SOME (ty, ty') else NONE
-                  | _ => NONE)
-              | _ => NONE*);
     fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
          of SOME tyco => (case AList.lookup (op =) params (c, tyco)
              of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty)
@@ -688,7 +677,7 @@
     ctxt
     |> Overloading.map_improvable_syntax
          (fn (((local_constraints, _), ((improve, _), _)), _) =>
-            (((local_constraints, global_constraints), ((improve, subst), unchecks)), false))
+            (((local_constraints, []), ((improve, subst), unchecks)), false))
   end;
 
 
@@ -721,23 +710,43 @@
     fun get_param tyco (param, (c, ty)) = if can (AxClass.param_of_inst thy) (c, tyco)
       then NONE else SOME ((c, tyco),
         (param ^ "_" ^ type_name tyco, map_atyps (K (Type (tyco, map TFree vs))) ty));
-    val params = map_product get_param tycos (these_params thy sort) |> map_filter I;
-    val operations = these_operations thy sort;
-    val local_constraints = (map o apsnd) (subst_class_typ [] o fst o snd) operations;
-    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
-     of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+    val class_of = fst o the o AList.lookup (op =) (these_operations thy sort);
+    val params = these_params thy sort;
+    val inst_params = map_product get_param tycos (these_params thy sort) |> map_filter I;
+    val local_constraints = map (apsnd (subst_class_typ []) o snd) params;
+    val pseudo_constraints = map (fn (_, (c, _)) => (c, class_of c)) params;
+    val typ_of_sort = Type.typ_of_sort (Sign.classes_of thy);
+    val typarg = the_single o Sign.const_typargs thy;
+    val pp = Sign.pp thy;
+    fun constrain_class (c, ty) class =
+      let
+        val ty' = typarg (c, ty);
+        val tyenv = typ_of_sort ty' [class] Vartab.empty
+          handle Sorts.CLASS_ERROR e => Sorts.class_error pp e;
+      in
+        map_atyps (fn TVar (vi, _) => TVar (vi, the (Vartab.lookup tyenv vi))
+          | ty => ty) ty
+      end;
+    fun improve_param (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
+     of SOME tyco => (case AList.lookup (op =) inst_params (c, tyco)
          of SOME (_, ty') => SOME (ty, ty')
           | NONE => NONE)
       | NONE => NONE;
+    fun improve (c, ty) = case improve_param (c, ty)
+     of SOME ty_ty' => SOME ty_ty'
+      | NONE => (case AList.lookup (op =) pseudo_constraints c
+         of SOME class =>
+              SOME (ty, constrain_class (c, ty) class)
+          | NONE => NONE);
   in
     thy
     |> ProofContext.init
-    |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
+    |> Instantiation.put (mk_instantiation ((tycos, vs, sort), inst_params))
     |> fold (Variable.declare_term o Logic.mk_type o TFree) vs
-    |> fold (Variable.declare_names o Free o snd) params
+    |> fold (Variable.declare_names o Free o snd) inst_params
     |> (Overloading.map_improvable_syntax o apfst)
-         (fn ((_, global_constraints), ((_, subst), unchecks)) =>
-            ((local_constraints, global_constraints), ((improve, subst), unchecks)))
+         (fn ((_, _), ((_, subst), unchecks)) =>
+            ((local_constraints, []), ((improve, K NONE), [])))
     |> Overloading.add_improvable_syntax
     |> synchronize_inst_syntax
   end;