src/Pure/Isar/class.ML
changeset 24657 185502d54c3d
parent 24589 d3fca349736c
child 24701 f8bfd592a6dc
     1.1 --- a/src/Pure/Isar/class.ML	Thu Sep 20 16:37:28 2007 +0200
     1.2 +++ b/src/Pure/Isar/class.ML	Thu Sep 20 16:37:29 2007 +0200
     1.3 @@ -44,6 +44,14 @@
     1.4    val inst_const: theory -> string * string -> string
     1.5    val param_const: theory -> string -> (string * string) option
     1.6    val params_of_sort: theory -> sort -> (string * (string * typ)) list
     1.7 +
     1.8 +  (*experimental*)
     1.9 +  val init_ref: (class -> Proof.context -> (theory -> theory) * Proof.context) ref
    1.10 +  val init: class -> Proof.context -> (theory -> theory) * Proof.context;
    1.11 +  val init_default: class -> Proof.context -> (theory -> theory) * Proof.context;
    1.12 +  val remove_constraints: class -> theory -> (string * typ) list * theory
    1.13 +  val class_term_check: theory -> class -> term list -> Proof.context -> term list * Proof.context
    1.14 +  val local_param: theory -> class -> string -> (term * (class * int)) option
    1.15  end;
    1.16  
    1.17  structure Class : CLASS =
    1.18 @@ -91,6 +99,15 @@
    1.19        | NONE => thm;
    1.20    in strip end;
    1.21  
    1.22 +fun get_remove_contraint c thy =
    1.23 +  let
    1.24 +    val ty = Sign.the_const_constraint thy c;
    1.25 +  in
    1.26 +    thy
    1.27 +    |> Sign.add_const_constraint_i (c, NONE)
    1.28 +    |> pair (c, Logic.unvarifyT ty)
    1.29 +  end;
    1.30 +
    1.31  
    1.32  (** axclass command **)
    1.33  
    1.34 @@ -277,14 +294,6 @@
    1.35        in fold_map read defs cs end;
    1.36      val (defs, other_cs) = read_defs raw_defs cs
    1.37        (fold Sign.primitive_arity arities (Theory.copy theory));
    1.38 -    fun get_remove_contraint c thy =
    1.39 -      let
    1.40 -        val ty = Sign.the_const_constraint thy c;
    1.41 -      in
    1.42 -        thy
    1.43 -        |> Sign.add_const_constraint_i (c, NONE)
    1.44 -        |> pair (c, Logic.unvarifyT ty)
    1.45 -      end;
    1.46      fun after_qed' cs defs =
    1.47        fold Sign.add_const_constraint_i (map (apsnd SOME) cs)
    1.48        #> after_qed defs;
    1.49 @@ -320,30 +329,45 @@
    1.50  datatype class_data = ClassData of {
    1.51    locale: string,
    1.52    consts: (string * string) list
    1.53 -    (*locale parameter ~> toplevel theory constant*),
    1.54 -  v: string option,
    1.55 +    (*locale parameter ~> theory constant name*),
    1.56 +  v: string,
    1.57    inst: typ Symtab.table * term Symtab.table
    1.58      (*canonical interpretation*),
    1.59 -  intro: thm
    1.60 -} * thm list (*derived defs*);
    1.61 +  intro: thm,
    1.62 +  defs: thm list,
    1.63 +  localized: (string * (term * (class * int))) list
    1.64 +    (*theory constant name ~> (locale parameter, (class, instantiaton index of class typ))*)
    1.65 +};
    1.66  
    1.67 -fun rep_classdata (ClassData c) = c;
    1.68 +fun rep_class_data (ClassData d) = d;
    1.69 +fun mk_class_data ((locale, consts, v, inst, intro), (defs, localized)) =
    1.70 +  ClassData { locale = locale, consts = consts, v = v, inst = inst, intro = intro,
    1.71 +    defs = defs, localized = localized };
    1.72 +fun map_class_data f (ClassData { locale, consts, v, inst, intro, defs, localized }) =
    1.73 +  mk_class_data (f ((locale, consts, v, inst, intro), (defs, localized)))
    1.74 +fun merge_class_data _ (ClassData { locale = locale, consts = consts, v = v, inst = inst,
    1.75 +    intro = intro, defs = defs1, localized = localized1 },
    1.76 +  ClassData { locale = _, consts = _, v = _, inst = _, intro = _,
    1.77 +    defs = defs2, localized = localized2 }) =
    1.78 +  mk_class_data ((locale, consts, v, inst, intro),
    1.79 +    (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (localized1, localized2)));
    1.80  
    1.81  fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
    1.82  
    1.83  structure ClassData = TheoryDataFun
    1.84  (
    1.85 -  type T = class_data Graph.T * class Symtab.table (*locale name ~> class name*);
    1.86 +  type T = class_data Graph.T * class Symtab.table
    1.87 +    (*locale name ~> class name*);
    1.88    val empty = (Graph.empty, Symtab.empty);
    1.89    val copy = I;
    1.90    val extend = I;
    1.91 -  fun merge _ = merge_pair (Graph.merge (K true)) (Symtab.merge (K true));
    1.92 +  fun merge _ = merge_pair (Graph.join merge_class_data) (Symtab.merge (K true));
    1.93  );
    1.94  
    1.95  
    1.96  (* queries *)
    1.97  
    1.98 -val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node
    1.99 +val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node
   1.100    o fst o ClassData.get;
   1.101  fun class_of_locale thy = Symtab.lookup ((snd o ClassData.get) thy);
   1.102  
   1.103 @@ -358,18 +382,23 @@
   1.104      fun params class =
   1.105        let
   1.106          val const_typs = (#params o AxClass.get_definition thy) class;
   1.107 -        val const_names = (#consts o fst o the_class_data thy) class;
   1.108 +        val const_names = (#consts o the_class_data thy) class;
   1.109        in
   1.110          (map o apsnd) (fn c => (c, (the o AList.lookup (op =) const_typs) c)) const_names
   1.111        end;
   1.112    in maps params o ancestry thy end;
   1.113  
   1.114 -fun these_defs thy = maps (these o Option.map snd o lookup_class_data thy) o ancestry thy;
   1.115 +fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
   1.116  
   1.117  fun these_intros thy =
   1.118 -  Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o fst o rep_classdata) data))
   1.119 +  Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
   1.120      ((fst o ClassData.get) thy) [];
   1.121  
   1.122 +fun these_localized thy class =
   1.123 +  maps (#localized o the_class_data thy) (ancestry thy [class]);
   1.124 +
   1.125 +fun local_param thy = AList.lookup (op =) o these_localized thy;
   1.126 +
   1.127  fun print_classes thy =
   1.128    let
   1.129      val algebra = Sign.classes_of thy;
   1.130 @@ -389,7 +418,7 @@
   1.131        (SOME o Pretty.str) ("class " ^ class ^ ":"),
   1.132        (SOME o Pretty.block) [Pretty.str "supersort: ",
   1.133          (Sign.pretty_sort thy o Sign.certify_sort thy o Sign.super_classes thy) class],
   1.134 -      Option.map (Pretty.str o prefix "locale: " o #locale o fst) (lookup_class_data thy class),
   1.135 +      Option.map (Pretty.str o prefix "locale: " o #locale) (lookup_class_data thy class),
   1.136        ((fn [] => NONE | ps => (SOME o Pretty.block o Pretty.fbreaks) (Pretty.str "parameters:" :: ps)) o map mk_param
   1.137          o these o Option.map #params o try (AxClass.get_definition thy)) class,
   1.138        (SOME o Pretty.block o Pretty.breaks) [
   1.139 @@ -408,15 +437,16 @@
   1.140  fun add_class_data ((class, superclasses), (locale, consts, v, inst, intro)) =
   1.141    ClassData.map (fn (gr, tab) => (
   1.142      gr
   1.143 -    |> Graph.new_node (class, ClassData ({ locale = locale, consts = consts,
   1.144 -         v = v, inst = inst, intro = intro }, []))
   1.145 +    |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, v, inst, intro),
   1.146 +         ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
   1.147      |> fold (curry Graph.add_edge class) superclasses,
   1.148      tab
   1.149      |> Symtab.update (locale, class)
   1.150    ));
   1.151  
   1.152 -fun add_class_const_thm (class, thm) = (ClassData.map o apfst o Graph.map_node class)
   1.153 -  (fn ClassData (data, thms) => ClassData (data, thm :: thms));
   1.154 +fun add_class_const_def (class, (entry, def)) =
   1.155 +  (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
   1.156 +    (fn (defs, localized) => (def :: defs, (apsnd o apsnd) (pair class) entry :: localized));
   1.157  
   1.158  
   1.159  (** rule calculation, tactics and methods **)
   1.160 @@ -452,7 +482,7 @@
   1.161  
   1.162  fun class_interpretation class facts defs thy =
   1.163    let
   1.164 -    val ({ locale, inst, ... }, _) = the_class_data thy class;
   1.165 +    val { locale, inst, ... } = the_class_data thy class;
   1.166      val tac = (ALLGOALS o ProofContext.fact_tac) facts;
   1.167      val prfx = Logic.const_of_class (NameSpace.base class);
   1.168    in
   1.169 @@ -464,7 +494,7 @@
   1.170    let
   1.171      fun mk_axioms class =
   1.172        let
   1.173 -        val ({ locale, inst = (_, insttab), ... }, _) = the_class_data thy class;
   1.174 +        val { locale, inst = (_, insttab), ... } = the_class_data thy class;
   1.175        in
   1.176          Locale.global_asms_of thy locale
   1.177          |> maps snd
   1.178 @@ -546,7 +576,7 @@
   1.179      val sups = filter (is_some o lookup_class_data thy) supclasses
   1.180        |> Sign.certify_sort thy;
   1.181      val supsort = Sign.certify_sort thy supclasses;
   1.182 -    val suplocales = map (Locale.Locale o #locale o fst o the_class_data thy) sups;
   1.183 +    val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
   1.184      val supexpr = Locale.Merge (suplocales @ includes);
   1.185      val supparams = (map fst o Locale.parameters_of_expr thy)
   1.186        (Locale.Merge suplocales);
   1.187 @@ -563,10 +593,10 @@
   1.188        let
   1.189          val params = Locale.parameters_of thy name_locale;
   1.190          val v = case (maps typ_tfrees o map (snd o fst)) params
   1.191 -         of (v, _) :: _ => SOME v
   1.192 -          | _ => NONE;
   1.193 +         of (v, _) :: _ => v
   1.194 +          | [] => AxClass.param_tyvarname;
   1.195        in
   1.196 -        (v, (map (fst o fst) params, params
   1.197 +        (v, (map fst params, params
   1.198          |> (map o apfst o apsnd o Term.map_type_tfree) mk_tyvar
   1.199          |> (map o apsnd) (fork_mixfix true NONE #> fst)
   1.200          |> chop (length supconsts)
   1.201 @@ -578,7 +608,6 @@
   1.202          fun subst (Free (c, ty)) =
   1.203                Const ((fst o the o AList.lookup (op =) consts) c, ty)
   1.204            | subst t = t;
   1.205 -        val super_defs = these_defs thy sups;
   1.206          fun prep_asm ((name, atts), ts) =
   1.207            ((NameSpace.base name, map (Attrib.attribute thy) atts),
   1.208              (map o map_aterms) subst ts);
   1.209 @@ -595,15 +624,15 @@
   1.210      |> add_locale (SOME "") bname supexpr ((*elems_constrains @*) elems)
   1.211      |-> (fn name_locale => ProofContext.theory_result (
   1.212        `(fn thy => extract_params thy name_locale)
   1.213 -      #-> (fn (v, (param_names, params)) =>
   1.214 +      #-> (fn (v, (globals, params)) =>
   1.215          AxClass.define_class_params (bname, supsort) params
   1.216            (extract_assumes name_locale params) other_consts
   1.217        #-> (fn (name_axclass, (consts, axioms)) =>
   1.218          `(fn thy => class_intro thy name_locale name_axclass sups)
   1.219        #-> (fn class_intro =>
   1.220          add_class_data ((name_axclass, sups),
   1.221 -          (name_locale, map (fst o fst) params ~~ map fst consts, v,
   1.222 -            (mk_instT name_axclass, mk_inst name_axclass param_names
   1.223 +          (name_locale, map fst params ~~ map fst consts, v,
   1.224 +            (mk_instT name_axclass, mk_inst name_axclass (map fst globals)
   1.225                (map snd supconsts @ consts)), class_intro))
   1.226        #> note_intro name_axclass class_intro
   1.227        #> class_interpretation name_axclass axioms []
   1.228 @@ -619,52 +648,62 @@
   1.229  end; (*local*)
   1.230  
   1.231  
   1.232 +(* class target context *)
   1.233 +
   1.234 +fun remove_constraints class thy =
   1.235 +  thy |> fold_map (get_remove_contraint o fst) (these_localized thy class);
   1.236 +
   1.237 +
   1.238  (* definition in class target *)
   1.239  
   1.240  fun export_fixes thy class =
   1.241    let
   1.242 -    val v = (#v o fst o the_class_data thy) class;
   1.243 -    val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
   1.244 -    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
   1.245 -      if SOME w = v then TFree (w, constrain_sort sort) else TFree var);
   1.246      val consts = params_of_sort thy [class];
   1.247      fun subst_aterm (t as Free (v, ty)) = (case AList.lookup (op =) consts v
   1.248           of SOME (c, _) => Const (c, ty)
   1.249            | NONE => t)
   1.250        | subst_aterm t = t;
   1.251 -  in map_types subst_typ #> Term.map_aterms subst_aterm end;
   1.252 +  in Term.map_aterms subst_aterm end;
   1.253  
   1.254  fun add_const_in_class class ((c, rhs), syn) thy =
   1.255    let
   1.256      val prfx = (Logic.const_of_class o NameSpace.base) class;
   1.257 -    fun mk_name inject c =
   1.258 +    fun mk_name c =
   1.259        let
   1.260          val n1 = Sign.full_name thy c;
   1.261          val n2 = NameSpace.qualifier n1;
   1.262          val n3 = NameSpace.base n1;
   1.263 -      in NameSpace.implode (n2 :: inject @ [n3]) end;
   1.264 -    val abbr' = mk_name [prfx, prfx] c;
   1.265 +      in NameSpace.implode [n2, prfx, n3] end;
   1.266 +    val v = (#v o the_class_data thy) class;
   1.267 +    val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
   1.268 +    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
   1.269 +      if w = v then TFree (w, constrain_sort sort) else TFree var);
   1.270      val rhs' = export_fixes thy class rhs;
   1.271      val ty' = Term.fastype_of rhs';
   1.272 -    val def = (c, Logic.mk_equals (Const (mk_name [prfx] c, ty'), rhs'));
   1.273 +    val ty'' = subst_typ ty';
   1.274 +    val c' = mk_name c;
   1.275 +    val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
   1.276      val (syn', _) = fork_mixfix true NONE syn;
   1.277 -    fun interpret def =
   1.278 +    fun interpret def thy =
   1.279        let
   1.280          val def' = symmetric def;
   1.281          val def_eq = Thm.prop_of def';
   1.282 +        val typargs = Sign.const_typargs thy (c', fastype_of rhs);
   1.283 +        val typidx = find_index (fn TFree (w, _) => v = w | _ => false) typargs;
   1.284        in
   1.285 -        class_interpretation class [def'] [def_eq]
   1.286 -        #> add_class_const_thm (class, def')
   1.287 +        thy
   1.288 +        |> class_interpretation class [def'] [def_eq]
   1.289 +        |> add_class_const_def (class, ((c', (rhs, typidx)), def'))
   1.290        end;
   1.291    in
   1.292      thy
   1.293 -    |> Sign.hide_consts_i true [abbr']
   1.294      |> Sign.add_path prfx
   1.295      |> Sign.add_consts_authentic [(c, ty', syn')]
   1.296      |> Sign.parent_path
   1.297      |> Sign.sticky_prefix prfx
   1.298      |> PureThy.add_defs_i false [(def, [])]
   1.299      |-> (fn [def] => interpret def)
   1.300 +    |> Sign.add_const_constraint_i (c', SOME ty'')
   1.301      |> Sign.restore_naming thy
   1.302    end;
   1.303  
   1.304 @@ -677,8 +716,8 @@
   1.305    let
   1.306      val class = prep_class theory raw_class;
   1.307      val superclass = prep_class theory raw_superclass;
   1.308 -    val loc_name = (#locale o fst o the_class_data theory) class;
   1.309 -    val loc_expr = (Locale.Locale o #locale o fst o the_class_data theory) superclass;
   1.310 +    val loc_name = (#locale o the_class_data theory) class;
   1.311 +    val loc_expr = (Locale.Locale o #locale o the_class_data theory) superclass;
   1.312      fun prove_classrel (class, superclass) thy =
   1.313        let
   1.314          val classes = (Graph.all_succs o #classes o Sorts.rep_algebra
   1.315 @@ -717,4 +756,52 @@
   1.316  
   1.317  end; (*local*)
   1.318  
   1.319 +(*experimental*)
   1.320 +fun class_term_check thy class =
   1.321 +  let
   1.322 +    val algebra = Sign.classes_of thy;
   1.323 +    val { v, ... } = the_class_data thy class;
   1.324 +    fun add_constrain_classtyp sort' (ty as TFree (v, _)) =
   1.325 +          AList.map_default (op =) (v, []) (curry (Sorts.inter_sort algebra) sort')
   1.326 +      | add_constrain_classtyp sort' (Type (tyco, tys)) = case Sorts.mg_domain algebra tyco sort'
   1.327 +         of sorts => fold2 add_constrain_classtyp sorts tys;
   1.328 +    fun class_arg c idx ty =
   1.329 +      let
   1.330 +        val typargs = Sign.const_typargs thy (c, ty);
   1.331 +        fun classtyp (t as TFree (w, _)) = if w = v then NONE else SOME t
   1.332 +          | classtyp t = SOME t;
   1.333 +      in classtyp (nth typargs idx) end;
   1.334 +    fun add_inst (c, ty) (terminsts, typinsts) = case local_param thy class c
   1.335 +     of NONE => (terminsts, typinsts)
   1.336 +      | SOME (t, (class', idx)) => (case class_arg c idx ty
   1.337 +         of NONE => (((c, ty), t) :: terminsts, typinsts)
   1.338 +          | SOME ty => (terminsts, add_constrain_classtyp [class'] ty typinsts));
   1.339 +  in pair o (fn ts => let
   1.340 +    val cs = (fold o fold_aterms) (fn Const c_ty => insert (op =) c_ty | _ => I) ts [];
   1.341 +    val (terminsts, typinsts) = fold add_inst cs ([], []);
   1.342 +  in
   1.343 +    ts
   1.344 +    |> (map o map_aterms) (fn t as Const c_ty => the_default t (AList.lookup (op =) terminsts c_ty)
   1.345 +         | t => t)
   1.346 +    |> (map o map_types o map_atyps) (fn t as TFree (v, sort) =>
   1.347 +         case AList.lookup (op =) typinsts v
   1.348 +          of SOME sort' => TFree (v, Sorts.inter_sort algebra (sort, sort'))
   1.349 +           | NONE => t)
   1.350 +  end) end;
   1.351 +
   1.352 +val init_ref = ref (K (pair I) : class -> Proof.context -> (theory -> theory) * Proof.context);
   1.353 +fun init class = ! init_ref class;
   1.354 +
   1.355 +fun init_default class ctxt =
   1.356 +  let
   1.357 +    val thy = ProofContext.theory_of ctxt;
   1.358 +    val term_check = class_term_check thy class;
   1.359 +  in
   1.360 +    ctxt
   1.361 +    (*|> ProofContext.theory_result (remove_constraints class)*)
   1.362 +    |> Context.proof_map (Syntax.add_term_check term_check)
   1.363 +    (*|>> fold (fn (c, ty) => Sign.add_const_constraint_i (c, SOME ty))*)
   1.364 +    |> pair I
   1.365 +  end;
   1.366 +
   1.367  end;