src/Pure/Isar/class.ML
changeset 24748 ee0a0eb6b738
parent 24731 c25aa6ae64ec
child 24766 d0de4e48b526
--- a/src/Pure/Isar/class.ML	Fri Sep 28 10:35:53 2007 +0200
+++ b/src/Pure/Isar/class.ML	Sat Sep 29 08:58:51 2007 +0200
@@ -13,9 +13,9 @@
     -> ((bstring * Attrib.src list) * string list) list
     -> theory -> class * theory
   val classrel_cmd: xstring * xstring -> theory -> Proof.state
-  val class: bstring -> class list -> Element.context_i Locale.element list
+  val class: bool -> bstring -> class list -> Element.context_i Locale.element list
     -> string list -> theory -> string * Proof.context
-  val class_cmd: bstring -> xstring list -> Element.context Locale.element list
+  val class_cmd: bool -> bstring -> xstring list -> Element.context Locale.element list
     -> xstring list -> theory -> string * Proof.context
   val add_const_in_class: string -> (string * term) * Syntax.mixfix
     -> theory -> theory
@@ -25,6 +25,7 @@
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
   val class_of_locale: theory -> string -> class option
+  val local_syntax: theory -> class -> bool
   val print_classes: theory -> unit
 
   val instance_arity: (theory -> theory) -> arity list -> theory -> Proof.state
@@ -46,12 +47,12 @@
   val params_of_sort: theory -> sort -> (string * (string * typ)) list
 
   (*experimental*)
-  val init_ref: (class -> Proof.context -> (theory -> theory) * Proof.context) ref
-  val init: class -> Proof.context -> (theory -> theory) * Proof.context;
-  val init_default: class -> Proof.context -> (theory -> theory) * Proof.context;
-  val remove_constraints: class -> theory -> (string * typ) list * theory
-  val class_term_check: theory -> class -> term list -> Proof.context -> term list * Proof.context
-  val local_param: theory -> class -> string -> (term * (class * int)) option
+  val init_ref: (sort -> Proof.context -> Proof.context) ref
+  val init: sort -> Proof.context -> Proof.context;
+  val init_default: sort -> Proof.context -> Proof.context;
+  val local_param: theory -> sort -> string -> (term * (class * int)) option
+  val remove_constraints': sort -> theory -> (string * typ) list * theory
+  val remove_constraints: sort -> Proof.context -> (string * typ) list * Proof.context
 end;
 
 structure Class : CLASS =
@@ -99,7 +100,7 @@
       | NONE => thm;
   in strip end;
 
-fun get_remove_contraint c thy =
+fun get_remove_constraint c thy =
   let
     val ty = Sign.the_const_constraint thy c;
   in
@@ -298,7 +299,7 @@
       #> after_qed defs;
   in
     theory
-    |> fold_map get_remove_contraint (map fst cs |> distinct (op =))
+    |> fold_map get_remove_constraint (map fst cs |> distinct (op =))
     ||>> fold_map add_def defs
     ||> fold (fn (c, ((class, tyco), ty)) => add_inst_def' (class, tyco) (c, ty)) other_cs
     |-> (fn (cs, defs) => do_proof (after_qed' cs defs) arities defs)
@@ -329,26 +330,27 @@
   locale: string,
   consts: (string * string) list
     (*locale parameter ~> theory constant name*),
-  v: string,
+  local_sort: sort,
   inst: typ Symtab.table * term Symtab.table
     (*canonical interpretation*),
   intro: thm,
+  local_syntax: bool,
   defs: thm list,
   localized: (string * (term * (class * int))) list
     (*theory constant name ~> (locale parameter, (class, instantiaton index of class typ))*)
 };
 
 fun rep_class_data (ClassData d) = d;
-fun mk_class_data ((locale, consts, v, inst, intro), (defs, localized)) =
-  ClassData { locale = locale, consts = consts, v = v, inst = inst, intro = intro,
-    defs = defs, localized = localized };
-fun map_class_data f (ClassData { locale, consts, v, inst, intro, defs, localized }) =
-  mk_class_data (f ((locale, consts, v, inst, intro), (defs, localized)))
-fun merge_class_data _ (ClassData { locale = locale, consts = consts, v = v, inst = inst,
-    intro = intro, defs = defs1, localized = localized1 },
-  ClassData { locale = _, consts = _, v = _, inst = _, intro = _,
+fun mk_class_data ((locale, consts, local_sort, inst, intro, local_syntax), (defs, localized)) =
+  ClassData { locale = locale, consts = consts, local_sort = local_sort, inst = inst, intro = intro,
+    local_syntax = local_syntax, defs = defs, localized = localized };
+fun map_class_data f (ClassData { locale, consts, local_sort, inst, intro, local_syntax, defs, localized }) =
+  mk_class_data (f ((locale, consts, local_sort, inst, intro, local_syntax), (defs, localized)))
+fun merge_class_data _ (ClassData { locale = locale, consts = consts, local_sort = local_sort, inst = inst,
+    intro = intro, local_syntax = local_syntax, defs = defs1, localized = localized1 },
+  ClassData { locale = _, consts = _, local_sort = _, inst = _, intro = _, local_syntax = _,
     defs = defs2, localized = localized2 }) =
-  mk_class_data ((locale, consts, v, inst, intro),
+  mk_class_data ((locale, consts, local_sort, inst, intro, local_syntax),
     (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (localized1, localized2)));
 
 fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
@@ -393,11 +395,13 @@
   Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
     ((fst o ClassData.get) thy) [];
 
-fun these_localized thy class =
-  maps (#localized o the_class_data thy) (ancestry thy [class]);
+fun these_localized thy =
+  maps (#localized o the_class_data thy) o ancestry thy;
 
 fun local_param thy = AList.lookup (op =) o these_localized thy;
 
+fun local_syntax thy = #local_syntax o the_class_data thy
+
 fun print_classes thy =
   let
     val algebra = Sign.classes_of thy;
@@ -433,11 +437,11 @@
 
 (* updaters *)
 
-fun add_class_data ((class, superclasses), (locale, consts, v, inst, intro)) =
+fun add_class_data ((class, superclasses), (locale, consts, local_sort, inst, intro, local_syntax)) =
   ClassData.map (fn (gr, tab) => (
     gr
-    |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, v, inst, intro),
-         ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
+    |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, local_sort, inst,
+         intro, local_syntax), ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
     |> fold (curry Graph.add_edge class) superclasses,
     tab
     |> Symtab.update (locale, class)
@@ -539,6 +543,90 @@
 
 (** classes and class target **)
 
+(* class context initialization *)
+
+(*experimental*)
+fun get_remove_constraint_ctxt c ctxt =
+  let
+    val ty = ProofContext.the_const_constraint ctxt c;
+  in
+    ctxt
+    |> ProofContext.add_const_constraint (c, NONE)
+    |> pair (c, ty)
+  end;
+
+fun remove_constraints' class thy =
+  thy |> fold_map (get_remove_constraint o fst) (these_localized thy class);
+
+fun remove_constraints class ctxt =
+  ctxt |> fold_map (get_remove_constraint_ctxt o fst) (these_localized (ProofContext.theory_of ctxt) class);
+
+fun default_typ ctxt constraints c =
+  case AList.lookup (op =) constraints c
+   of SOME ty => SOME ty
+    | NONE => try (Consts.the_constraint (ProofContext.consts_of ctxt)) c;
+
+fun infer_constraints ctxt constraints ts =
+    TypeInfer.infer_types (ProofContext.pp ctxt) (Sign.tsig_of (ProofContext.theory_of ctxt))
+     (Syntax.check_typs ctxt)
+      (default_typ ctxt constraints) (K NONE)
+      (Variable.names_of ctxt) true (map (rpair dummyT) ts)
+    |> #1 |> map #1
+  handle TYPE (msg, _, _) => error msg
+
+fun subst_typ local_sort =
+  map_atyps (fn (t as TFree (v, _)) => if v = AxClass.param_tyvarname
+        then TFree (v, local_sort)
+        else t
+    | t => t);
+
+fun sort_typ_check thy sort =
+  let
+    val local_sort = (#local_sort o the_class_data thy) (hd sort);
+  in
+    pair o map (subst_typ local_sort)
+  end;
+
+fun sort_term_check thy sort constraints =
+  let
+    val algebra = Sign.classes_of thy;
+    val local_sort = (#local_sort o the_class_data thy) (hd sort);
+    val v = AxClass.param_tyvarname;
+    val local_param = local_param thy sort;
+      (*FIXME efficiency*)
+    fun class_arg c idx ty =
+      let
+        val typargs = Sign.const_typargs thy (c, ty);
+        fun classtyp (TFree (w, _)) = w = v
+          | classtyp t = false;
+      in classtyp (nth typargs idx) end;
+    fun subst (t as Const (c, ty)) = (case local_param c
+         of NONE => t
+          | SOME (t', (_, idx)) => if class_arg c idx ty
+             then t' else t)
+      | subst t = t;
+  in fn ts => fn ctxt =>
+    ((map (map_aterms subst) #> infer_constraints ctxt constraints) ts, ctxt)
+  end;
+
+fun init_default sort ctxt =
+  let
+    val thy = ProofContext.theory_of ctxt;
+    val typ_check = sort_typ_check thy sort;
+    val term_check = sort_term_check thy sort;
+  in
+    ctxt
+    |> remove_constraints sort
+    ||> Variable.declare_term (Logic.mk_type (TFree (AxClass.param_tyvarname, sort)))
+    ||> Context.proof_map (Syntax.add_typ_check typ_check)
+    |-> (fn constraints =>
+        Context.proof_map (Syntax.add_term_check (term_check constraints)))
+  end;
+
+val init_ref = ref (K I : sort -> Proof.context -> Proof.context);
+fun init class = ! init_ref class;
+
+
 (* class definition *)
 
 local
@@ -551,52 +639,62 @@
     | NONE => error ("Not a constant: " ^ Sign.string_of_term thy t)
   end;
 
-fun gen_class add_locale prep_class prep_param bname
-    raw_supclasses raw_elems raw_other_consts thy =
+fun gen_class_spec prep_class prep_expr process_expr thy raw_supclasses raw_includes_elems =
   let
+    val supclasses = map (prep_class thy) raw_supclasses;
+    val sups = filter (is_some o lookup_class_data thy) supclasses
+      |> Sign.minimize_sort thy;
+    val supsort = Sign.minimize_sort thy supclasses;
+    val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
+    val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
+      | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
+    val supexpr = Locale.Merge suplocales;
+    val supparams = (map fst o Locale.parameters_of_expr thy) supexpr;
+    val supconsts = AList.make (the o AList.lookup (op =) (params_of_sort thy sups))
+      (map fst supparams);
+    val mergeexpr = Locale.Merge (suplocales @ includes);
+    val constrain = Element.Constrains ((map o apsnd o map_atyps)
+      (fn TFree (_, sort) => TFree (AxClass.param_tyvarname, sort)) supparams);
+  in
+    ProofContext.init thy
+    |> Locale.cert_expr supexpr [constrain]
+    |> snd
+    |> init supsort
+    |> process_expr Locale.empty raw_elems
+    |> fst
+    |> (fn elems => ((((sups, supconsts), (supsort, mergeexpr)),
+          (*FIXME*) if null includes then constrain :: elems else elems)))
+  end;
+
+val read_class_spec = gen_class_spec Sign.intern_class Locale.intern_expr Locale.read_expr;
+val check_class_spec = gen_class_spec (K I) (K I) Locale.cert_expr;
+
+fun gen_class prep_spec prep_param local_syntax bname
+    raw_supclasses raw_includes_elems raw_other_consts thy =
+  let
+    val (((sups, supconsts), (supsort, mergeexpr)), elems) =
+      prep_spec thy raw_supclasses raw_includes_elems;
+    val other_consts = map (prep_param thy) raw_other_consts;
     fun mk_instT class = Symtab.empty
       |> Symtab.update (AxClass.param_tyvarname, TFree (AxClass.param_tyvarname, [class]));
     fun mk_inst class param_names cs =
       Symtab.empty
       |> fold2 (fn v => fn (c, ty) => Symtab.update (v, Const
            (c, Term.map_type_tfree (fn (v, _) => TFree (v, [class])) ty))) param_names cs;
-    (*FIXME need proper concept for reading locale statements*)
-    fun subst_classtyvar (_, _) =
-          TFree (AxClass.param_tyvarname, [])
-      | subst_classtyvar (v, sort) =
-          error ("Sort constraint illegal in type class, for type variable "
-            ^ v ^ "::" ^ Sign.string_of_sort thy sort);
-    (*val subst_classtyvars = Element.map_ctxt {name = I, var = I, term = I,
-      typ = Term.map_type_tfree subst_classtyvar, fact = I, attrib = I};*)
-    val other_consts = map (prep_param thy) raw_other_consts;
-    val (elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
-      | Locale.Expr i => apsnd (cons i)) raw_elems ([], []);
-    val supclasses = map (prep_class thy) raw_supclasses;
-    val sups = filter (is_some o lookup_class_data thy) supclasses
-      |> Sign.minimize_sort thy;
-    val supsort = Sign.minimize_sort thy supclasses;
-    val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
-    val supexpr = Locale.Merge (suplocales @ includes);
-    val supparams = (map fst o Locale.parameters_of_expr thy)
-      (Locale.Merge suplocales);
-    val supconsts = AList.make (the o AList.lookup (op =) (params_of_sort thy sups))
-      (map fst supparams);
-    (*val elems_constrains = map
-      (Element.Constrains o apsnd (Term.map_type_tfree subst_classtyvar)) supparams;*)
-    fun mk_tyvar (_, sort) = TFree (AxClass.param_tyvarname,
-      if Sign.subsort thy (supsort, sort) then sort else error
-        ("Sort " ^ Sign.string_of_sort thy sort
-          ^ " is less general than permitted least general sort "
-          ^ Sign.string_of_sort thy supsort));
     fun extract_params thy name_locale =
       let
         val params = Locale.parameters_of thy name_locale;
-        val v = case (maps typ_tfrees o map (snd o fst)) params
-         of (v, _) :: _ => v
-          | [] => AxClass.param_tyvarname;
+        val local_sort = case AList.group (op =) ((maps typ_tfrees o map (snd o fst)) params)
+         of [(_, local_sort :: _)] => local_sort
+          | _ => Sign.defaultS thy
+          | vs => error ("exactly one type variable required: " ^ commas (map fst vs));
+        val _ = if Sign.subsort thy (supsort, local_sort) then () else error
+          ("Sort " ^ Sign.string_of_sort thy local_sort
+            ^ " is less general than permitted least general sort "
+            ^ Sign.string_of_sort thy supsort);
       in
-        (v, (map fst params, params
-        |> (map o apfst o apsnd o Term.map_type_tfree) mk_tyvar
+        (local_sort, (map fst params, params
+        |> (map o apfst o apsnd o Term.map_type_tfree) (K (TFree (AxClass.param_tyvarname, local_sort)))
         |> (map o apsnd) (fork_mixfix true NONE #> fst)
         |> chop (length supconsts)
         |> snd))
@@ -620,19 +718,19 @@
       #> snd
   in
     thy
-    |> add_locale (SOME "") bname supexpr ((*elems_constrains @*) elems)
+    |> Locale.add_locale_i (SOME "") bname mergeexpr elems
     |-> (fn name_locale => ProofContext.theory_result (
       `(fn thy => extract_params thy name_locale)
-      #-> (fn (v, (globals, params)) =>
+      #-> (fn (local_sort, (globals, params)) =>
         AxClass.define_class_params (bname, supsort) params
           (extract_assumes name_locale params) other_consts
       #-> (fn (name_axclass, (consts, axioms)) =>
         `(fn thy => class_intro thy name_locale name_axclass sups)
       #-> (fn class_intro =>
         add_class_data ((name_axclass, sups),
-          (name_locale, map fst params ~~ map fst consts, v,
+          (name_locale, map fst params ~~ map fst consts, local_sort,
             (mk_instT name_axclass, mk_inst name_axclass (map fst globals)
-              (map snd supconsts @ consts)), class_intro))
+              (map snd supconsts @ consts)), class_intro, local_syntax))
       #> note_intro name_axclass class_intro
       #> class_interpretation name_axclass axioms []
       #> pair name_axclass
@@ -641,18 +739,12 @@
 
 in
 
-val class_cmd = gen_class Locale.add_locale Sign.intern_class read_param;
-val class = gen_class Locale.add_locale_i Sign.certify_class (K I);
+val class_cmd = gen_class read_class_spec read_param;
+val class = gen_class check_class_spec (K I);
 
 end; (*local*)
 
 
-(* class target context *)
-
-fun remove_constraints class thy =
-  thy |> fold_map (get_remove_contraint o fst) (these_localized thy class);
-
-
 (* definition in class target *)
 
 fun export_fixes thy class =
@@ -673,11 +765,10 @@
         val n2 = NameSpace.qualifier n1;
         val n3 = NameSpace.base n1;
       in NameSpace.implode [n2, prfx, n3] end;
-    val v = (#v o the_class_data thy) class;
     val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
+    val rhs' = export_fixes thy class rhs;
     val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
-      if w = v then TFree (w, constrain_sort sort) else TFree var);
-    val rhs' = export_fixes thy class rhs;
+      if w = AxClass.param_tyvarname then TFree (w, constrain_sort sort) else TFree var);
     val ty' = Term.fastype_of rhs';
     val ty'' = subst_typ ty';
     val c' = mk_name c;
@@ -688,7 +779,7 @@
         val def' = symmetric def;
         val def_eq = Thm.prop_of def';
         val typargs = Sign.const_typargs thy (c', fastype_of rhs);
-        val typidx = find_index (fn TFree (w, _) => v = w | _ => false) typargs;
+        val typidx = find_index (fn TFree (w, _) => AxClass.param_tyvarname = w | _ => false) typargs;
       in
         thy
         |> class_interpretation class [def'] [def_eq]
@@ -754,52 +845,4 @@
 
 end; (*local*)
 
-(*experimental*)
-fun class_term_check thy class =
-  let
-    val algebra = Sign.classes_of thy;
-    val { v, ... } = the_class_data thy class;
-    fun add_constrain_classtyp sort' (ty as TFree (v, _)) =
-          AList.map_default (op =) (v, []) (curry (Sorts.inter_sort algebra) sort')
-      | add_constrain_classtyp sort' (Type (tyco, tys)) = case Sorts.mg_domain algebra tyco sort'
-         of sorts => fold2 add_constrain_classtyp sorts tys;
-    fun class_arg c idx ty =
-      let
-        val typargs = Sign.const_typargs thy (c, ty);
-        fun classtyp (t as TFree (w, _)) = if w = v then NONE else SOME t
-          | classtyp t = SOME t;
-      in classtyp (nth typargs idx) end;
-    fun add_inst (c, ty) (terminsts, typinsts) = case local_param thy class c
-     of NONE => (terminsts, typinsts)
-      | SOME (t, (class', idx)) => (case class_arg c idx ty
-         of NONE => (((c, ty), t) :: terminsts, typinsts)
-          | SOME ty => (terminsts, add_constrain_classtyp [class'] ty typinsts));
-  in pair o (fn ts => let
-    val cs = (fold o fold_aterms) (fn Const c_ty => insert (op =) c_ty | _ => I) ts [];
-    val (terminsts, typinsts) = fold add_inst cs ([], []);
-  in
-    ts
-    |> (map o map_aterms) (fn t as Const c_ty => the_default t (AList.lookup (op =) terminsts c_ty)
-         | t => t)
-    |> (map o map_types o map_atyps) (fn t as TFree (v, sort) =>
-         case AList.lookup (op =) typinsts v
-          of SOME sort' => TFree (v, Sorts.inter_sort algebra (sort, sort'))
-           | NONE => t)
-  end) end;
-
-val init_ref = ref (K (pair I) : class -> Proof.context -> (theory -> theory) * Proof.context);
-fun init class = ! init_ref class;
-
-fun init_default class ctxt =
-  let
-    val thy = ProofContext.theory_of ctxt;
-    val term_check = class_term_check thy class;
-  in
-    ctxt
-    (*|> ProofContext.theory_result (remove_constraints class)*)
-    |> Context.proof_map (Syntax.add_term_check term_check)
-    (*|>> fold (fn (c, ty) => Sign.add_const_constraint_i (c, SOME ty))*)
-    |> pair I
-  end;
-
 end;