# HG changeset patch # User haftmann # Date 1151498169 -7200 # Node ID 91ba241a1678cb0741e1feca3662a733b582b804 # Parent f992e507020e6c0efa4c5b5fac721a2a89e54310 reduced code, better instance command diff -r f992e507020e -r 91ba241a1678 src/Pure/Tools/class_package.ML --- a/src/Pure/Tools/class_package.ML Wed Jun 28 14:35:51 2006 +0200 +++ b/src/Pure/Tools/class_package.ML Wed Jun 28 14:36:09 2006 +0200 @@ -36,10 +36,7 @@ val operational_sort_of: theory -> sort -> sort val the_superclasses: theory -> class -> class list val the_consts_sign: theory -> class -> string * (string * typ) list - val lookup_const_class: theory -> string -> class option - val the_instances: theory -> class -> (string * ((sort list) * string)) list val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list - val get_classtab: theory -> (string * string) list Symtab.table val print_classes: theory -> unit val intro_classes_tac: thm list -> tactic @@ -57,6 +54,22 @@ struct +(* auxiliary *) + +fun instantiations_of thy (ty, ty') = + let + val vartab = typ_tvars ty; + fun prep_vartab (v, (_, ty)) = + case (the o AList.lookup (op =) vartab) v + of [] => NONE + | sort => SOME ((v, sort), ty); + in case try (Sign.typ_match thy (ty, ty')) Vartab.empty + of NONE => NONE + | SOME vartab => + SOME ((map_filter prep_vartab o Vartab.dest) vartab) + end; + + (* theory data *) datatype class_data = ClassData of { @@ -64,7 +77,7 @@ name_axclass: string, var: string, consts: (string * (string * typ)) list - (*locale parameter ~> toplevel const*) + (*locale parameter ~> toplevel constant*) }; fun rep_classdata (ClassData c) = c; @@ -72,17 +85,12 @@ structure ClassData = TheoryDataFun ( struct val name = "Pure/classes"; - type T = (class_data Graph.T - * (string * (sort list * string)) list Symtab.table) - (*class ~> tyco ~> (arity, thyname)*) - * class Symtab.table; - val empty = ((Graph.empty, Symtab.empty), Symtab.empty); + type T = class_data Graph.T; + val empty = Graph.empty; val copy = I; val extend = I; - fun merge _ (((g1, c1), f1) : T, ((g2, c2), f2)) = - ((Graph.merge (K true) (g1, g2), Symtab.join (fn _ => AList.merge (op =) (op =)) (c1, c2)), - Symtab.merge (op =) (f1, f2)); - fun print thy ((gr, _), _) = + fun merge _ = Graph.merge (K true); + fun print thy gr = let fun pretty_class gr (name, ClassData {name_locale, name_axclass, var, consts}) = (Pretty.block o Pretty.fbreaks) [ @@ -112,9 +120,7 @@ (* queries *) -val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node o fst o fst o ClassData.get; -val the_instances = these oo Symtab.lookup o snd o fst o ClassData.get; -val lookup_const_class = Symtab.lookup o snd o ClassData.get; +val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node o ClassData.get; fun the_class_data thy class = case lookup_class_data thy class @@ -144,13 +150,6 @@ else error ("no class: " ^ class); -fun get_superclass_derivation thy (subclass, superclass) = - if subclass = superclass - then SOME [subclass] - else case Graph.irreducible_paths ((fst o fst o ClassData.get) thy) (subclass, superclass) - of [] => NONE - | (p::_) => (SOME o filter (is_operational_class thy)) p; - fun the_ancestry thy classes = let fun ancestry class anc = @@ -176,7 +175,7 @@ fun the_inst_sign thy (class, tyco) = let val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class); - val arity = Sign.arity_sorts thy tyco [class]; + val asorts = Sign.arity_sorts thy tyco [class]; val clsvar = (#var o the_class_data thy) class; val const_sign = (snd o the_consts_sign thy) class; fun add_var sort used = @@ -187,41 +186,23 @@ [clsvar] |> fold (fn (_, ty) => curry (gen_union (op =)) ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign - |> fold_map add_var arity; + |> fold_map add_var asorts; val ty_inst = Type (tyco, map TFree vsorts); val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign; in (vsorts, inst_signs) end; -fun get_classtab thy = - (Symtab.map o map) - (fn (tyco, (_, thyname)) => (tyco, thyname)) ((snd o fst o ClassData.get) thy); - (* updaters *) fun add_class_data (class, (superclasses, name_locale, name_axclass, var, consts)) = - ClassData.map (fn ((gr, tab), consttab) => (( - gr - |> Graph.new_node (class, ClassData { - name_locale = name_locale, - name_axclass = name_axclass, - var = var, - consts = consts - }) - |> fold (curry Graph.add_edge_acyclic class) superclasses, - tab - |> Symtab.update (class, [])), - consttab - |> fold (fn (_, (c, _)) => Symtab.update (c, class)) consts - )); - -fun add_inst_data (class, inst) = - ClassData.map (fn ((gr, tab), consttab) => - let - val undef_supclasses = class :: (filter (Symtab.defined tab) (Graph.all_succs gr [class])); - in - ((gr, tab |> fold (fn class => Symtab.map_entry class (AList.update (op =) inst)) undef_supclasses), consttab) - end); + ClassData.map ( + Graph.new_node (class, ClassData { + name_locale = name_locale, + name_axclass = name_axclass, + var = var, + consts = consts }) + #> fold (curry Graph.add_edge_acyclic class) superclasses + ); (* name handling *) @@ -233,7 +214,7 @@ map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort); fun intern_class thy = -certify_class thy o Sign.intern_class thy; + certify_class thy o Sign.intern_class thy; fun intern_sort thy = certify_sort thy o Sign.intern_sort thy; @@ -295,10 +276,10 @@ local -fun add_axclass_i (name, supsort) axs thy = +fun add_axclass_i (name, supsort) params axs thy = let val (c, thy') = thy - |> AxClass.define_class_i (name, supsort) [] axs; + |> AxClass.define_class_i (name, supsort) params axs; val {intro, axioms, ...} = AxClass.get_definition thy' c; in ((c, (intro, axioms)), thy') end; @@ -309,13 +290,11 @@ let val p = setmp show_types true (setmp show_sorts true (setmp print_mode [] (Sign.pretty_typ thy))) ty; val s = c ^ "::" ^ Pretty.output p; - val _ = writeln s; in SOME s end | ad_hoc_term (SOME t) = let val p = setmp show_types true (setmp show_sorts true (setmp print_mode [] (Sign.pretty_term thy))) t; val s = Pretty.output p; - val _ = writeln s; in SOME s end; in thy @@ -368,7 +347,7 @@ map_aterms (fn Free (c, ty) => Const ((fst o the o AList.lookup (op =) cs_mapp) c, ty) | t => t) fun prep_asm ((name, atts), ts) = - ((NameSpace.base name |> print, map (Attrib.attribute thy) atts), map subst_assume ts) + ((NameSpace.base name, map (Attrib.attribute thy) atts), map subst_assume ts) in (map prep_asm o Locale.local_asms_of thy) name_locale end; @@ -387,7 +366,7 @@ #-> (fn mapp_this => `(fn thy => extract_assumes thy name_locale (mapp_sup @ mapp_this)) #-> (fn loc_axioms => - add_axclass_i (bname, supsort) loc_axioms + add_axclass_i (bname, supsort) (map (fst o snd) mapp_this) loc_axioms #-> (fn (name_axclass, (_, ax_axioms)) => fold (add_global_constraint v name_axclass) mapp_this #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, v, mapp_this)) @@ -407,30 +386,20 @@ local -fun gen_add_defs_overloaded prep_att tap_def add_defs tyco raw_defs thy = +fun gen_read_def thy prep_att read_def tyco ((raw_name, raw_atts), raw_t) = let - fun invent_name raw_t = - let - val t = tap_def thy raw_t; - val c = (fst o dest_Const o fst o strip_comb o fst o Logic.dest_equals) t; - in - Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco) - end; - fun prep_def (_, (("", a), t)) = - let - val n = invent_name t - in ((n, t), map (prep_att thy) a) end - | prep_def (_, ((n, a), t)) = - ((n, t), map (prep_att thy) a); - in - thy - |> add_defs true (map prep_def raw_defs) - end; + val (_, t) = read_def thy (raw_name, raw_t); + val ((c, ty), _) = Sign.cert_def (Sign.pp thy) t; + val atts = map (prep_att thy) raw_atts; + val name = case raw_name + of "" => Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco) + | _ => raw_name; + in (c, (Logic.varifyT ty, ((name, t), atts))) end; -val add_defs_overloaded = gen_add_defs_overloaded Attrib.attribute Sign.read_term PureThy.add_defs; -val add_defs_overloaded_i = gen_add_defs_overloaded (K I) (K I) PureThy.add_defs_i; +fun read_def thy = gen_read_def thy Attrib.attribute read_axm; +fun read_def_i thy = gen_read_def thy (K I) (K I); -fun gen_instance_arity prep_arity prep_att add_defs tap_def do_proof raw_arity (raw_name, raw_atts) raw_defs theory = +fun gen_instance_arity prep_arity prep_att read_def do_proof raw_arity (raw_name, raw_atts) raw_defs theory = let val pp = Sign.pp theory; val arity as (tyco, asorts, sort) = prep_arity theory ((fn ((x, y), z) => (x, y, z)) raw_arity); @@ -439,24 +408,37 @@ of "" => Thm.def_name ((space_implode "_" o map NameSpace.base) sort ^ "_" ^ NameSpace.base tyco) | _ => raw_name; val atts = map (prep_att theory) raw_atts; - fun get_classes thy tyco sort = - let - fun get class classes = - if AList.defined (op =) ((the_instances thy) class) tyco - then classes - else classes - |> cons class - |> fold get (the_superclasses thy class) - in fold get sort [] end; - val classes = get_classes theory tyco sort; - val _ = if null classes then error ("already instantiated") else (); fun get_consts class = let val data = the_class_data theory class; - val subst_ty = map_type_tfree (fn (var as (v, _)) => - if #var data = v then ty_inst else TFree var) - in (map (apsnd subst_ty o snd) o #consts) data end; - val cs = (flat o map get_consts) classes; + fun defined c = + is_some (find_first (fn (_, { lhs = [ty], ...}) => + Sign.typ_instance theory (ty, ty_inst) orelse Sign.typ_instance theory (ty_inst, ty)) + (Defs.specifications_of (Theory.defs_of theory) c)) + val subst_ty = map_type_tfree (fn (v, sort) => + if #var data = v then ty_inst else TVar ((v, 0), sort)); + in + (map_filter (fn (_, (c, ty)) => + if defined c then NONE else SOME (c, subst_ty ty)) o #consts) data + end; + val cs = (maps get_consts o the_ancestry theory) sort; + fun read_defs defs cs = + let + val thy_read = (Sign.primitive_arity (tyco, asorts, sort) o Theory.copy) theory; + fun read raw_def cs = + let + val (c, (ty, def)) = read_def thy_read tyco raw_def; + val def' = case AList.lookup (op =) cs c + of NONE => error ("superfluous definition for constant " ^ quote c) + | SOME ty' => case instantiations_of thy_read (ty, ty') + of NONE => error ("superfluous definition for constant " ^ + quote c ^ "::" ^ Sign.string_of_typ thy_read ty) + | SOME insttab => + (apfst o apsnd o map_term_types) + (Logic.unvarifyT o Term.instantiateT insttab o Logic.varifyT) def + in (def', AList.delete (op =) c cs) end; + in fold_map read defs cs end; + val (defs, _) = read_defs raw_defs cs; fun get_remove_contraint c thy = let val ty = Sign.the_const_constraint thy c; @@ -465,65 +447,31 @@ |> Sign.add_const_constraint_i (c, NONE) |> pair (c, Logic.legacy_unvarifyT ty) end; - fun check_defs0 thy raw_defs c_req = - let - fun get_c raw_def = - (fst o Sign.cert_def pp o tap_def thy o snd) raw_def; - val c_given = map get_c raw_defs; - fun eq_c ((c1 : string, ty1), (c2, ty2)) = - let - val ty1' = Logic.legacy_varifyT ty1; - val ty2' = Logic.legacy_varifyT ty2; - in - c1 = c2 - andalso Sign.typ_instance thy (ty1', ty2') - andalso Sign.typ_instance thy (ty2', ty1') - end; - val _ = case subtract eq_c c_req c_given - of [] => () - | cs => error ("superfluous definition(s) given for " - ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs); - (*val _ = case subtract eq_c c_given c_req - of [] => () - | cs => error ("no definition(s) given for " - ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);*) - in () end; - fun check_defs1 raw_defs c_req thy = - let - val thy' = (Sign.primitive_arity (tyco, asorts, sort) o Theory.copy) thy - in (check_defs0 thy' raw_defs c_req; thy) end; - fun mangle_alldef_name tyco sort = - Thm.def_name ((space_implode "_" o map NameSpace.base) sort ^ "_" ^ NameSpace.base tyco); - fun note_all tyco sort thms thy = + fun note_all thms thy = thy |> PureThy.note_thmss_i PureThy.internalK [((name, atts), [(thms, [])])] |> snd; fun after_qed cs thy = thy - |> fold (fn class => - add_inst_data (class, (tyco, - (map (operational_sort_of thy) asorts, Context.theory_name thy)))) sort |> fold Sign.add_const_constraint_i (map (apsnd SOME) cs); in theory - |> check_defs1 raw_defs cs |> fold_map get_remove_contraint (map fst cs) - ||>> add_defs tyco (map (pair NONE) raw_defs) - |-> (fn (cs, defnames) => note_all tyco sort defnames #> pair cs) + ||>> PureThy.add_defs_i true defs + |-> (fn (cs, thms) => note_all thms #> pair cs) |-> (fn cs => do_proof (after_qed cs) arity) end; -fun instance_arity' do_proof = gen_instance_arity Sign.read_arity Attrib.attribute add_defs_overloaded - (fn thy => fn t => (snd o read_axm thy) ("", t)) do_proof; -fun instance_arity_i' do_proof = gen_instance_arity Sign.cert_arity (K I) add_defs_overloaded_i - (K I) do_proof; -val setup_proof = axclass_instance_arity_i; +fun instance_arity' do_proof = gen_instance_arity Sign.read_arity Attrib.attribute + read_def do_proof; +fun instance_arity_i' do_proof = gen_instance_arity Sign.cert_arity (K I) + read_def_i do_proof; fun tactic_proof tac after_qed arity = AxClass.prove_arity arity tac #> after_qed; in -val instance_arity = instance_arity' setup_proof; -val instance_arity_i = instance_arity_i' setup_proof; +val instance_arity = instance_arity' axclass_instance_arity_i; +val instance_arity_i = instance_arity_i' axclass_instance_arity_i; val prove_instance_arity = instance_arity_i' o tactic_proof; end; (* local *) @@ -590,6 +538,7 @@ end; (* local *) + (* extracting dictionary obligations from types *) type sortcontext = (string * sort) list; @@ -635,22 +584,15 @@ fun sortlookups_const thy (c, typ_ctxt) = let - val typ_decl = case lookup_const_class thy c + val typ_decl = case AxClass.class_of thy c of NONE => Sign.the_const_type thy c | SOME class => case the_consts_sign thy class of (v, cs) => (Logic.legacy_varifyT o subst_clsvar v (TFree (v, [class]))) ((the o AList.lookup (op =) cs) c) - val vartab = typ_tvars typ_decl; - fun prep_vartab (v, (_, ty)) = - case (the o AList.lookup (op =) vartab) v - of [] => NONE - | sort => SOME (sort, ty); in - Vartab.empty - |> Sign.typ_match thy (typ_decl, typ_ctxt) - |> Vartab.dest - |> map_filter prep_vartab - |> map (sortlookup thy) + instantiations_of thy (typ_decl, typ_ctxt) + |> the + |> map (fn ((_, sort), ty) => sortlookup thy (sort, ty)) |> filter_out null end;