proper implementation of check phase; non-qualified names for class operations
authorhaftmann
Fri, 09 Nov 2007 23:24:31 +0100
changeset 25368 f12613fda79d
parent 25367 98b6b7f64e49
child 25369 5200374fda5d
proper implementation of check phase; non-qualified names for class operations
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Fri Nov 09 23:24:30 2007 +0100
+++ b/src/Pure/Isar/class.ML	Fri Nov 09 23:24:31 2007 +0100
@@ -325,31 +325,27 @@
     (*partial morphism of canonical interpretation*)
   intro: thm,
   defs: thm list,
-  operations: (string * (term * (typ * int))) list,
-    (*constant name ~> (locale term,
-        (constant constraint, instantiaton index of class typ))*)
-  unchecks: (term * term) list
+  operations: (string * (class * (typ * term))) list
 };
 
 fun rep_class_data (ClassData d) = d;
 fun mk_class_data ((consts, base_sort, inst, morphism, intro),
-    (defs, (operations, unchecks))) =
+    (defs, operations)) =
   ClassData { consts = consts, base_sort = base_sort, inst = inst,
     morphism = morphism, intro = intro, defs = defs,
-    operations = operations, unchecks = unchecks };
+    operations = operations };
 fun map_class_data f (ClassData { consts, base_sort, inst, morphism, intro,
-    defs, operations, unchecks }) =
+    defs, operations }) =
   mk_class_data (f ((consts, base_sort, inst, morphism, intro),
-    (defs, (operations, unchecks))));
+    (defs, operations)));
 fun merge_class_data _ (ClassData { consts = consts,
     base_sort = base_sort, inst = inst, morphism = morphism, intro = intro,
-    defs = defs1, operations = operations1, unchecks = unchecks1 },
+    defs = defs1, operations = operations1 },
   ClassData { consts = _, base_sort = _, inst = _, morphism = _, intro = _,
-    defs = defs2, operations = operations2, unchecks = unchecks2 }) =
+    defs = defs2, operations = operations2 }) =
   mk_class_data ((consts, base_sort, inst, morphism, intro),
     (Thm.merge_thms (defs1, defs2),
-      (AList.merge (op =) (K true) (operations1, operations2),
-        Library.merge (op aconv o pairself snd) (unchecks1, unchecks2))));
+      AList.merge (op =) (K true) (operations1, operations2)));
 
 structure ClassData = TheoryDataFun
 (
@@ -395,9 +391,6 @@
 fun these_operations thy =
   maps (#operations o the_class_data thy) o ancestry thy;
 
-fun these_unchecks thy =
-  maps (#unchecks o the_class_data thy) o ancestry thy;
-
 fun print_classes thy =
   let
     val ctxt = ProofContext.init thy;
@@ -438,33 +431,30 @@
 
 fun add_class_data ((class, superclasses), (cs, base_sort, inst, phi, intro)) thy =
   let
-    val operations = map (fn (v_ty, (c, ty)) =>
-      (c, ((Free v_ty, (Logic.varifyT ty, 0))))) cs;
-    val unchecks = map (fn ((v, ty'), (c, _)) =>
-      (Free (v, Type.strip_sorts ty'), Const (c, Type.strip_sorts ty'))) cs;
+    val operations = map (fn (v_ty as (_, ty), (c, _)) =>
+      (c, (class, (ty, Free v_ty)))) cs;
     val cs = (map o pairself) fst cs;
     val add_class = Graph.new_node (class,
-        mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], (operations, unchecks))))
+        mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], operations)))
       #> fold (curry Graph.add_edge class) superclasses;
   in
     ClassData.map add_class thy
   end;
 
-fun register_operation class (c, ((t, t_rev), some_def)) thy =
+fun register_operation class (c, (t, some_def)) thy =
   let
-    val ty = Sign.the_const_constraint thy c;
-    val typargs = Sign.const_typargs thy (c, ty);
-    val typidx = find_index (fn TVar ((v, _), _) => Name.aT = v | _ => false) typargs;
+    val base_sort = (#base_sort o the_class_data thy) class;
     val prep_typ = map_atyps
-      (fn TVar (vi as (v, _), _) => if Name.aT = v then TFree (v, []) else TVar (vi, []))
-    val t_rev' = map_types prep_typ t_rev;
-    val ty' = Term.fastype_of t_rev';
+      (fn TVar (vi as (v, _), sort) => if Name.aT = v
+        then TFree (v, base_sort) else TVar (vi, sort));
+    val t' = map_types prep_typ t;
+    val ty' = Term.fastype_of t';
   in
     thy
     |> (ClassData.map o Graph.map_node class o map_class_data o apsnd)
-      (fn (defs, (operations, unchecks)) =>
+      (fn (defs, operations) =>
         (fold cons (the_list some_def) defs,
-          ((c, (t, (ty, typidx))) :: operations, (t_rev', Const (c, ty')) :: unchecks)))
+          (c, (class, (ty', t'))) :: operations))
   end;
 
 
@@ -611,57 +601,49 @@
 
 structure ClassSyntax = ProofDataFun(
   type T = {
-    constraints: (string * typ) list,
+    local_constraints: (string * typ) list,
+    global_constraints: (string * typ) list,
     base_sort: sort,
-    local_operation: string * typ -> (typ * term) option,
+    operations: (string * (typ * term)) list,
     unchecks: (term * term) list,
     passed: bool
-  } option;
-  fun init _ = NONE;
+  };
+  fun init _ = {
+    local_constraints = [],
+    global_constraints = [],
+    base_sort = [],
+    operations = [],
+    unchecks = [],
+    passed = true
+  };;
 );
 
 fun synchronize_syntax sups base_sort ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
-
-    (* constraints *)
+    fun subst_class_typ sort = map_atyps
+      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
     val operations = these_operations thy sups;
-    fun local_constraint (c, (_, (ty, _))) =
-      let
-        val ty' = ty
-          |> map_atyps (fn ty as TVar ((v, 0), _) =>
-               if v = Name.aT then TVar ((v, 0), base_sort) else ty)
-          |> SOME;
-      in (c, ty') end
-    val constraints = (map o apsnd) (fst o snd) operations;
-
-    (* check phase *)
-    val consts = ProofContext.consts_of ctxt;
+    val local_constraints =
+      (map o apsnd) (subst_class_typ base_sort o fst o snd) operations;
+    val global_constraints =
+      (map o apsnd) (fn (class, (ty, _)) => subst_class_typ [class] ty) operations;
     fun declare_const (c, _) =
       let val b = Sign.base_name c
       in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
-    val typargs = Consts.typargs consts;
-    fun check_const (c, ty) (t, (_, typidx)) = ((nth (typargs (c, ty)) typidx), t);
-    fun local_operation (c_ty as (c, _)) = AList.lookup (op =) operations c
-      |> Option.map (check_const c_ty);
-
-    (* uncheck phase *)
-    val basify =
-      map_atyps (fn ty as TFree (v, _) => if Name.aT = v then TFree (v, base_sort)
-        else ty | ty => ty);
-    val unchecks = these_unchecks thy sups
-      |> (map o pairself o map_types) basify;
+    val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
   in
     ctxt
-(*    |> fold declare_const operations  FIXME *)
-    |> fold (ProofContext.add_const_constraint o local_constraint) operations
-    |> ClassSyntax.put (SOME {
-        constraints = constraints,
+    |> fold declare_const local_constraints
+    |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
+    |> ClassSyntax.put {
+        local_constraints = local_constraints,
+        global_constraints = global_constraints,
         base_sort = base_sort,
-        local_operation = local_operation,
+        operations = (map o apsnd) snd operations,
         unchecks = unchecks,
         passed = false
-      })
+      }
   end;
 
 fun refresh_syntax class ctxt =
@@ -670,37 +652,41 @@
     val base_sort = (#base_sort o the_class_data thy) class;
   in synchronize_syntax [class] base_sort ctxt end;
 
-val mark_passed = (ClassSyntax.map o Option.map)
-  (fn { constraints, base_sort, local_operation, unchecks, passed } =>
-    { constraints = constraints, base_sort = base_sort,
-      local_operation = local_operation, unchecks = unchecks, passed = true });
+val mark_passed = ClassSyntax.map
+  (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } =>
+    { local_constraints = local_constraints, global_constraints = global_constraints,
+      base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true });
 
 fun sort_term_check ts ctxt =
   let
-    val { constraints, base_sort, local_operation, passed, ... } =
-      the (ClassSyntax.get ctxt);
-    fun check_typ (c, ty) (TFree (v, _), t) = if v = Name.aT
-          then apfst (AList.update (op =) ((c, ty), t)) else I
-      | check_typ (c, ty) (TVar (vi, _), t) = if TypeInfer.is_param vi
-          then apfst (AList.update (op =) ((c, ty), t))
-            #> apsnd (insert (op =) vi) else I
-      | check_typ _ _ = I;
-    fun add_const (Const c_ty) = Option.map (check_typ c_ty) (local_operation c_ty)
-          |> the_default I
-      | add_const _ = I;
-    val (cs, typarams) = (fold o fold_aterms) add_const ts ([], []);
-    val subst_typ = map_type_tvar (fn var as (vi, _) =>
-      if member (op =) typarams vi then TFree (Name.aT, base_sort) else TVar var);
-    val subst_term = map_aterms
-        (fn t as Const (c, ty) => the_default t (AList.lookup (op =) cs (c, ty)) | t => t)
-          #> map_types subst_typ;
-    val ts' = map subst_term ts;
-  in if eq_list (op aconv) (ts, ts') andalso passed then NONE
+    val { local_constraints, global_constraints, base_sort, operations, passed, ... } =
+      ClassSyntax.get ctxt;
+    fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c
+         of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty
+             of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
+                 of SOME (_, TVar (tvar as (vi, _))) =>
+                      if TypeInfer.is_param vi then cons tvar else I
+                  | _ => I)
+              | NONE => I)
+          | NONE => I)
+      | check_improve _ = I;
+    val improvements = (fold o fold_aterms) check_improve ts [];
+    val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar =>
+        if member (op =) improvements tvar
+          then TFree (Name.aT, base_sort) else ty | ty => ty) ts;
+    fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c
+         of SOME (ty0, t) =>
+              if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0)
+              then SOME (ty0, check t) else NONE
+          | NONE => NONE)
+      | _ => NONE) t0;
+    val ts'' = map check ts';
+  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE
   else
     ctxt
-    |> fold (ProofContext.add_const_constraint o apsnd SOME) constraints
+    |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
     |> mark_passed
-    |> pair ts'
+    |> pair ts''
     |> SOME
   end;
 
@@ -709,7 +695,7 @@
 fun sort_term_uncheck ts ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
-    val unchecks = (#unchecks o the o ClassSyntax.get) ctxt;
+    val unchecks = (#unchecks o ClassSyntax.get) ctxt;
     val ts' = if ! uncheck
       then map (Pattern.rewrite_term thy unchecks []) ts else ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
@@ -883,7 +869,7 @@
     |> Thm.add_def false (c, def_eq)    (* FIXME PureThy.add_defs_i *)
     |>> Thm.symmetric
     |-> (fn def => class_interpretation class [def] [Thm.prop_of def]
-          #> register_operation class (c', ((dict, dict'), SOME (Thm.varifyT def))))
+          #> register_operation class (c', (dict', SOME (Thm.varifyT def))))
     |> Sign.restore_naming thy
     |> Sign.add_const_constraint (c', SOME ty')
   end;
@@ -906,7 +892,7 @@
     |> Sign.add_abbrev (#1 prmode) pos (c, map_types Type.strip_sorts rhs') |> snd
     |> Sign.add_const_constraint (c', SOME ty')
     |> Sign.notation true prmode [(Const (c', ty'), mx)]
-    |> register_operation class (c', ((rhs, rhs'), NONE))
+    |> register_operation class (c', (rhs', NONE))
     |> Sign.restore_naming thy
   end;