fixed wrong syntax treatment in class target
authorhaftmann
Thu, 20 Sep 2007 16:37:29 +0200
changeset 24657 185502d54c3d
parent 24656 67f6bf194ca6
child 24658 49adbdcc52e2
fixed wrong syntax treatment in class target
src/HOL/List.thy
src/HOL/ex/Classpackage.thy
src/Pure/Isar/class.ML
src/Pure/Isar/theory_target.ML
--- a/src/HOL/List.thy	Thu Sep 20 16:37:28 2007 +0200
+++ b/src/HOL/List.thy	Thu Sep 20 16:37:29 2007 +0200
@@ -215,18 +215,23 @@
 text{* The following simple sort functions are intended for proofs,
 not for efficient implementations. *}
 
-fun (in linorder) sorted :: "'a list \<Rightarrow> bool" where
-"sorted [] \<longleftrightarrow> True" |
-"sorted [x] \<longleftrightarrow> True" |
-"sorted (x#y#zs) \<longleftrightarrow> x \<^loc><= y \<and> sorted (y#zs)"
-
-fun (in linorder) insort :: "'a \<Rightarrow> 'a list \<Rightarrow> 'a list" where
-"insort x [] = [x]" |
-"insort x (y#ys) = (if x \<^loc><= y then (x#y#ys) else y#(insort x ys))"
-
-fun (in linorder) sort :: "'a list \<Rightarrow> 'a list" where
-"sort [] = []" |
-"sort (x#xs) = insort x (sort xs)"
+context linorder
+begin
+
+fun  sorted :: "'a list \<Rightarrow> bool" where
+  "sorted [] \<longleftrightarrow> True" |
+  "sorted [x] \<longleftrightarrow> True" |
+  "sorted (x#y#zs) \<longleftrightarrow> x \<^loc><= y \<and> sorted (y#zs)"
+
+fun insort :: "'a \<Rightarrow> 'a list \<Rightarrow> 'a list" where
+  "insort x [] = [x]" |
+  "insort x (y#ys) = (if x \<^loc><= y then (x#y#ys) else y#(insort x ys))"
+
+fun sort :: "'a list \<Rightarrow> 'a list" where
+  "sort [] = []" |
+  "sort (x#xs) = insort x (sort xs)"
+
+end
 
 
 subsubsection {* List comprehension *}
--- a/src/HOL/ex/Classpackage.thy	Thu Sep 20 16:37:28 2007 +0200
+++ b/src/HOL/ex/Classpackage.thy	Thu Sep 20 16:37:29 2007 +0200
@@ -81,8 +81,6 @@
   units :: "'a set" where
   "units = {y. \<exists>x. x \<^loc>\<otimes> y = \<^loc>\<one> \<and> y \<^loc>\<otimes> x = \<^loc>\<one>}"
 
-end context monoid begin
-
 lemma inv_obtain:
   assumes "x \<in> units"
   obtains y where "y \<^loc>\<otimes> x = \<^loc>\<one>" and "x \<^loc>\<otimes> y = \<^loc>\<one>"
@@ -120,8 +118,6 @@
   "npow 0 x = \<^loc>\<one>"
   | "npow (Suc n) x = x \<^loc>\<otimes> npow n x"
 
-end context monoid begin
-
 abbreviation
   npow_syn :: "'a \<Rightarrow> nat \<Rightarrow> 'a" (infix "\<^loc>\<up>" 75) where
   "x \<^loc>\<up> n \<equiv> npow n x"
--- a/src/Pure/Isar/class.ML	Thu Sep 20 16:37:28 2007 +0200
+++ b/src/Pure/Isar/class.ML	Thu Sep 20 16:37:29 2007 +0200
@@ -44,6 +44,14 @@
   val inst_const: theory -> string * string -> string
   val param_const: theory -> string -> (string * string) option
   val params_of_sort: theory -> sort -> (string * (string * typ)) list
+
+  (*experimental*)
+  val init_ref: (class -> Proof.context -> (theory -> theory) * Proof.context) ref
+  val init: class -> Proof.context -> (theory -> theory) * Proof.context;
+  val init_default: class -> Proof.context -> (theory -> theory) * Proof.context;
+  val remove_constraints: class -> theory -> (string * typ) list * theory
+  val class_term_check: theory -> class -> term list -> Proof.context -> term list * Proof.context
+  val local_param: theory -> class -> string -> (term * (class * int)) option
 end;
 
 structure Class : CLASS =
@@ -91,6 +99,15 @@
       | NONE => thm;
   in strip end;
 
+fun get_remove_contraint c thy =
+  let
+    val ty = Sign.the_const_constraint thy c;
+  in
+    thy
+    |> Sign.add_const_constraint_i (c, NONE)
+    |> pair (c, Logic.unvarifyT ty)
+  end;
+
 
 (** axclass command **)
 
@@ -277,14 +294,6 @@
       in fold_map read defs cs end;
     val (defs, other_cs) = read_defs raw_defs cs
       (fold Sign.primitive_arity arities (Theory.copy theory));
-    fun get_remove_contraint c thy =
-      let
-        val ty = Sign.the_const_constraint thy c;
-      in
-        thy
-        |> Sign.add_const_constraint_i (c, NONE)
-        |> pair (c, Logic.unvarifyT ty)
-      end;
     fun after_qed' cs defs =
       fold Sign.add_const_constraint_i (map (apsnd SOME) cs)
       #> after_qed defs;
@@ -320,30 +329,45 @@
 datatype class_data = ClassData of {
   locale: string,
   consts: (string * string) list
-    (*locale parameter ~> toplevel theory constant*),
-  v: string option,
+    (*locale parameter ~> theory constant name*),
+  v: string,
   inst: typ Symtab.table * term Symtab.table
     (*canonical interpretation*),
-  intro: thm
-} * thm list (*derived defs*);
+  intro: thm,
+  defs: thm list,
+  localized: (string * (term * (class * int))) list
+    (*theory constant name ~> (locale parameter, (class, instantiaton index of class typ))*)
+};
 
-fun rep_classdata (ClassData c) = c;
+fun rep_class_data (ClassData d) = d;
+fun mk_class_data ((locale, consts, v, inst, intro), (defs, localized)) =
+  ClassData { locale = locale, consts = consts, v = v, inst = inst, intro = intro,
+    defs = defs, localized = localized };
+fun map_class_data f (ClassData { locale, consts, v, inst, intro, defs, localized }) =
+  mk_class_data (f ((locale, consts, v, inst, intro), (defs, localized)))
+fun merge_class_data _ (ClassData { locale = locale, consts = consts, v = v, inst = inst,
+    intro = intro, defs = defs1, localized = localized1 },
+  ClassData { locale = _, consts = _, v = _, inst = _, intro = _,
+    defs = defs2, localized = localized2 }) =
+  mk_class_data ((locale, consts, v, inst, intro),
+    (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (localized1, localized2)));
 
 fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
 
 structure ClassData = TheoryDataFun
 (
-  type T = class_data Graph.T * class Symtab.table (*locale name ~> class name*);
+  type T = class_data Graph.T * class Symtab.table
+    (*locale name ~> class name*);
   val empty = (Graph.empty, Symtab.empty);
   val copy = I;
   val extend = I;
-  fun merge _ = merge_pair (Graph.merge (K true)) (Symtab.merge (K true));
+  fun merge _ = merge_pair (Graph.join merge_class_data) (Symtab.merge (K true));
 );
 
 
 (* queries *)
 
-val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node
+val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node
   o fst o ClassData.get;
 fun class_of_locale thy = Symtab.lookup ((snd o ClassData.get) thy);
 
@@ -358,18 +382,23 @@
     fun params class =
       let
         val const_typs = (#params o AxClass.get_definition thy) class;
-        val const_names = (#consts o fst o the_class_data thy) class;
+        val const_names = (#consts o the_class_data thy) class;
       in
         (map o apsnd) (fn c => (c, (the o AList.lookup (op =) const_typs) c)) const_names
       end;
   in maps params o ancestry thy end;
 
-fun these_defs thy = maps (these o Option.map snd o lookup_class_data thy) o ancestry thy;
+fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
 
 fun these_intros thy =
-  Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o fst o rep_classdata) data))
+  Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
     ((fst o ClassData.get) thy) [];
 
+fun these_localized thy class =
+  maps (#localized o the_class_data thy) (ancestry thy [class]);
+
+fun local_param thy = AList.lookup (op =) o these_localized thy;
+
 fun print_classes thy =
   let
     val algebra = Sign.classes_of thy;
@@ -389,7 +418,7 @@
       (SOME o Pretty.str) ("class " ^ class ^ ":"),
       (SOME o Pretty.block) [Pretty.str "supersort: ",
         (Sign.pretty_sort thy o Sign.certify_sort thy o Sign.super_classes thy) class],
-      Option.map (Pretty.str o prefix "locale: " o #locale o fst) (lookup_class_data thy class),
+      Option.map (Pretty.str o prefix "locale: " o #locale) (lookup_class_data thy class),
       ((fn [] => NONE | ps => (SOME o Pretty.block o Pretty.fbreaks) (Pretty.str "parameters:" :: ps)) o map mk_param
         o these o Option.map #params o try (AxClass.get_definition thy)) class,
       (SOME o Pretty.block o Pretty.breaks) [
@@ -408,15 +437,16 @@
 fun add_class_data ((class, superclasses), (locale, consts, v, inst, intro)) =
   ClassData.map (fn (gr, tab) => (
     gr
-    |> Graph.new_node (class, ClassData ({ locale = locale, consts = consts,
-         v = v, inst = inst, intro = intro }, []))
+    |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, v, inst, intro),
+         ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
     |> fold (curry Graph.add_edge class) superclasses,
     tab
     |> Symtab.update (locale, class)
   ));
 
-fun add_class_const_thm (class, thm) = (ClassData.map o apfst o Graph.map_node class)
-  (fn ClassData (data, thms) => ClassData (data, thm :: thms));
+fun add_class_const_def (class, (entry, def)) =
+  (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
+    (fn (defs, localized) => (def :: defs, (apsnd o apsnd) (pair class) entry :: localized));
 
 
 (** rule calculation, tactics and methods **)
@@ -452,7 +482,7 @@
 
 fun class_interpretation class facts defs thy =
   let
-    val ({ locale, inst, ... }, _) = the_class_data thy class;
+    val { locale, inst, ... } = the_class_data thy class;
     val tac = (ALLGOALS o ProofContext.fact_tac) facts;
     val prfx = Logic.const_of_class (NameSpace.base class);
   in
@@ -464,7 +494,7 @@
   let
     fun mk_axioms class =
       let
-        val ({ locale, inst = (_, insttab), ... }, _) = the_class_data thy class;
+        val { locale, inst = (_, insttab), ... } = the_class_data thy class;
       in
         Locale.global_asms_of thy locale
         |> maps snd
@@ -546,7 +576,7 @@
     val sups = filter (is_some o lookup_class_data thy) supclasses
       |> Sign.certify_sort thy;
     val supsort = Sign.certify_sort thy supclasses;
-    val suplocales = map (Locale.Locale o #locale o fst o the_class_data thy) sups;
+    val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
     val supexpr = Locale.Merge (suplocales @ includes);
     val supparams = (map fst o Locale.parameters_of_expr thy)
       (Locale.Merge suplocales);
@@ -563,10 +593,10 @@
       let
         val params = Locale.parameters_of thy name_locale;
         val v = case (maps typ_tfrees o map (snd o fst)) params
-         of (v, _) :: _ => SOME v
-          | _ => NONE;
+         of (v, _) :: _ => v
+          | [] => AxClass.param_tyvarname;
       in
-        (v, (map (fst o fst) params, params
+        (v, (map fst params, params
         |> (map o apfst o apsnd o Term.map_type_tfree) mk_tyvar
         |> (map o apsnd) (fork_mixfix true NONE #> fst)
         |> chop (length supconsts)
@@ -578,7 +608,6 @@
         fun subst (Free (c, ty)) =
               Const ((fst o the o AList.lookup (op =) consts) c, ty)
           | subst t = t;
-        val super_defs = these_defs thy sups;
         fun prep_asm ((name, atts), ts) =
           ((NameSpace.base name, map (Attrib.attribute thy) atts),
             (map o map_aterms) subst ts);
@@ -595,15 +624,15 @@
     |> add_locale (SOME "") bname supexpr ((*elems_constrains @*) elems)
     |-> (fn name_locale => ProofContext.theory_result (
       `(fn thy => extract_params thy name_locale)
-      #-> (fn (v, (param_names, params)) =>
+      #-> (fn (v, (globals, params)) =>
         AxClass.define_class_params (bname, supsort) params
           (extract_assumes name_locale params) other_consts
       #-> (fn (name_axclass, (consts, axioms)) =>
         `(fn thy => class_intro thy name_locale name_axclass sups)
       #-> (fn class_intro =>
         add_class_data ((name_axclass, sups),
-          (name_locale, map (fst o fst) params ~~ map fst consts, v,
-            (mk_instT name_axclass, mk_inst name_axclass param_names
+          (name_locale, map fst params ~~ map fst consts, v,
+            (mk_instT name_axclass, mk_inst name_axclass (map fst globals)
               (map snd supconsts @ consts)), class_intro))
       #> note_intro name_axclass class_intro
       #> class_interpretation name_axclass axioms []
@@ -619,52 +648,62 @@
 end; (*local*)
 
 
+(* class target context *)
+
+fun remove_constraints class thy =
+  thy |> fold_map (get_remove_contraint o fst) (these_localized thy class);
+
+
 (* definition in class target *)
 
 fun export_fixes thy class =
   let
-    val v = (#v o fst o the_class_data thy) class;
-    val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
-    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
-      if SOME w = v then TFree (w, constrain_sort sort) else TFree var);
     val consts = params_of_sort thy [class];
     fun subst_aterm (t as Free (v, ty)) = (case AList.lookup (op =) consts v
          of SOME (c, _) => Const (c, ty)
           | NONE => t)
       | subst_aterm t = t;
-  in map_types subst_typ #> Term.map_aterms subst_aterm end;
+  in Term.map_aterms subst_aterm end;
 
 fun add_const_in_class class ((c, rhs), syn) thy =
   let
     val prfx = (Logic.const_of_class o NameSpace.base) class;
-    fun mk_name inject c =
+    fun mk_name c =
       let
         val n1 = Sign.full_name thy c;
         val n2 = NameSpace.qualifier n1;
         val n3 = NameSpace.base n1;
-      in NameSpace.implode (n2 :: inject @ [n3]) end;
-    val abbr' = mk_name [prfx, prfx] c;
+      in NameSpace.implode [n2, prfx, n3] end;
+    val v = (#v o the_class_data thy) class;
+    val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
+    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
+      if w = v then TFree (w, constrain_sort sort) else TFree var);
     val rhs' = export_fixes thy class rhs;
     val ty' = Term.fastype_of rhs';
-    val def = (c, Logic.mk_equals (Const (mk_name [prfx] c, ty'), rhs'));
+    val ty'' = subst_typ ty';
+    val c' = mk_name c;
+    val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
     val (syn', _) = fork_mixfix true NONE syn;
-    fun interpret def =
+    fun interpret def thy =
       let
         val def' = symmetric def;
         val def_eq = Thm.prop_of def';
+        val typargs = Sign.const_typargs thy (c', fastype_of rhs);
+        val typidx = find_index (fn TFree (w, _) => v = w | _ => false) typargs;
       in
-        class_interpretation class [def'] [def_eq]
-        #> add_class_const_thm (class, def')
+        thy
+        |> class_interpretation class [def'] [def_eq]
+        |> add_class_const_def (class, ((c', (rhs, typidx)), def'))
       end;
   in
     thy
-    |> Sign.hide_consts_i true [abbr']
     |> Sign.add_path prfx
     |> Sign.add_consts_authentic [(c, ty', syn')]
     |> Sign.parent_path
     |> Sign.sticky_prefix prfx
     |> PureThy.add_defs_i false [(def, [])]
     |-> (fn [def] => interpret def)
+    |> Sign.add_const_constraint_i (c', SOME ty'')
     |> Sign.restore_naming thy
   end;
 
@@ -677,8 +716,8 @@
   let
     val class = prep_class theory raw_class;
     val superclass = prep_class theory raw_superclass;
-    val loc_name = (#locale o fst o the_class_data theory) class;
-    val loc_expr = (Locale.Locale o #locale o fst o the_class_data theory) superclass;
+    val loc_name = (#locale o the_class_data theory) class;
+    val loc_expr = (Locale.Locale o #locale o the_class_data theory) superclass;
     fun prove_classrel (class, superclass) thy =
       let
         val classes = (Graph.all_succs o #classes o Sorts.rep_algebra
@@ -717,4 +756,52 @@
 
 end; (*local*)
 
+(*experimental*)
+fun class_term_check thy class =
+  let
+    val algebra = Sign.classes_of thy;
+    val { v, ... } = the_class_data thy class;
+    fun add_constrain_classtyp sort' (ty as TFree (v, _)) =
+          AList.map_default (op =) (v, []) (curry (Sorts.inter_sort algebra) sort')
+      | add_constrain_classtyp sort' (Type (tyco, tys)) = case Sorts.mg_domain algebra tyco sort'
+         of sorts => fold2 add_constrain_classtyp sorts tys;
+    fun class_arg c idx ty =
+      let
+        val typargs = Sign.const_typargs thy (c, ty);
+        fun classtyp (t as TFree (w, _)) = if w = v then NONE else SOME t
+          | classtyp t = SOME t;
+      in classtyp (nth typargs idx) end;
+    fun add_inst (c, ty) (terminsts, typinsts) = case local_param thy class c
+     of NONE => (terminsts, typinsts)
+      | SOME (t, (class', idx)) => (case class_arg c idx ty
+         of NONE => (((c, ty), t) :: terminsts, typinsts)
+          | SOME ty => (terminsts, add_constrain_classtyp [class'] ty typinsts));
+  in pair o (fn ts => let
+    val cs = (fold o fold_aterms) (fn Const c_ty => insert (op =) c_ty | _ => I) ts [];
+    val (terminsts, typinsts) = fold add_inst cs ([], []);
+  in
+    ts
+    |> (map o map_aterms) (fn t as Const c_ty => the_default t (AList.lookup (op =) terminsts c_ty)
+         | t => t)
+    |> (map o map_types o map_atyps) (fn t as TFree (v, sort) =>
+         case AList.lookup (op =) typinsts v
+          of SOME sort' => TFree (v, Sorts.inter_sort algebra (sort, sort'))
+           | NONE => t)
+  end) end;
+
+val init_ref = ref (K (pair I) : class -> Proof.context -> (theory -> theory) * Proof.context);
+fun init class = ! init_ref class;
+
+fun init_default class ctxt =
+  let
+    val thy = ProofContext.theory_of ctxt;
+    val term_check = class_term_check thy class;
+  in
+    ctxt
+    (*|> ProofContext.theory_result (remove_constraints class)*)
+    |> Context.proof_map (Syntax.add_term_check term_check)
+    (*|>> fold (fn (c, ty) => Sign.add_const_constraint_i (c, SOME ty))*)
+    |> pair I
+  end;
+
 end;
--- a/src/Pure/Isar/theory_target.ML	Thu Sep 20 16:37:28 2007 +0200
+++ b/src/Pure/Isar/theory_target.ML	Thu Sep 20 16:37:29 2007 +0200
@@ -90,14 +90,31 @@
     fun const_class (SOME class) ((c, _), mx) (_, t) =
           Class.add_const_in_class class ((c, t), mx)
       | const_class NONE _ _ = I;
-
+    fun hide_abbrev (SOME class) abbrs thy =
+          let
+            val raw_cs = map (fst o fst) abbrs;
+            val prfx = (Logic.const_of_class o NameSpace.base) class;
+            fun mk_name c =
+              let
+                val n1 = Sign.full_name thy c;
+                val n2 = NameSpace.qualifier n1;
+                val n3 = NameSpace.base n1;
+              in NameSpace.implode [n2, prfx, prfx, n3] end;
+            val cs = map mk_name raw_cs;
+          in
+            Sign.hide_consts_i true cs thy
+          end
+      | hide_abbrev NONE _ thy = thy;
     val (abbrs, lthy') = lthy
       |> LocalTheory.theory_result (fold_map const decls)
     val defs = map (apsnd (pair ("", []))) abbrs;
+    
   in
     lthy'
+    |> LocalTheory.raw_theory (fold2 (const_class some_class) decls abbrs)
     |> is_loc ? fold (internal_abbrev Syntax.default_mode) abbrs
-    |> LocalTheory.raw_theory (fold2 (const_class some_class) decls abbrs)
+    |> LocalTheory.raw_theory (hide_abbrev some_class abbrs)
+        (*FIXME abbreviations should never occur*)
     |> LocalDefs.add_defs defs
     |>> map (apsnd snd)
   end;
@@ -342,10 +359,15 @@
     val thy = ProofContext.theory_of ctxt;
     val is_loc = loc <> "";
     val some_class = Class.class_of_locale thy loc;
+    fun class_init_exit (SOME class) =
+          Class.init class
+      | class_init_exit NONE =
+          pair I;
   in
     ctxt
     |> Data.put (if is_loc then SOME loc else NONE)
-    |> LocalTheory.init (NameSpace.base loc)
+    |> class_init_exit some_class
+    |-> (fn exit => LocalTheory.init (NameSpace.base loc)
      {pretty = pretty loc,
       consts = consts is_loc some_class,
       axioms = axioms,
@@ -358,8 +380,9 @@
       target_morphism = target_morphism loc,
       target_naming = target_naming loc,
       reinit = fn _ =>
-        begin loc o (if is_loc then Locale.init loc else ProofContext.init),
-      exit = LocalTheory.target_of}
+        (if is_loc then Locale.init loc else ProofContext.init)
+        #> begin loc,
+      exit = LocalTheory.target_of #> ProofContext.theory exit })
   end;
 
 fun init_i NONE thy = begin "" (ProofContext.init thy)