--- a/src/Pure/Tools/ROOT.ML Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/ROOT.ML Fri Mar 17 14:19:24 2006 +0100
@@ -7,6 +7,9 @@
(*class package*)
use "class_package.ML";
+(*code generator theorems*)
+use "codegen_theorems.ML";
+
(*code generator, 1st generation*)
use "../codegen.ML";
--- a/src/Pure/Tools/class_package.ML Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/class_package.ML Fri Mar 17 14:19:24 2006 +0100
@@ -65,7 +65,8 @@
name_axclass: string,
intro: thm option,
var: string,
- consts: (string * typ) list
+ consts: (string * (string * typ)) list
+ (*locale parameter ~> toplevel const*)
};
structure ClassData = TheoryDataFun (
@@ -95,7 +96,7 @@
Pretty.str ("class variable: " ^ var),
(Pretty.block o Pretty.fbreaks) (
Pretty.str "constants: "
- :: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
+ :: map (fn (_, (c, ty)) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
)
]
in
@@ -120,7 +121,9 @@
of NONE => error ("undeclared operational class " ^ quote class)
| SOME data => data;
-fun is_class thy cls =
+val is_class = is_some oo lookup_class_data;
+
+fun is_operational_class thy cls =
lookup_class_data thy cls
|> Option.map (not o null o #consts)
|> the_default false;
@@ -129,7 +132,7 @@
let
val classes = Sign.classes_of thy;
fun get_sort class =
- if is_class thy class
+ if is_operational_class thy class
then [class]
else operational_sort_of thy (Sorts.superclasses classes class);
in
@@ -144,14 +147,14 @@
Sorts.superclasses (Sign.classes_of thy) class
|> operational_sort_of thy
else
- error ("no syntactic class: " ^ class);
+ error ("no class: " ^ class);
fun get_superclass_derivation thy (subclass, superclass) =
if subclass = superclass
then SOME [subclass]
else case Graph.find_paths ((fst o fst o ClassData.get) thy) (subclass, superclass)
of [] => NONE
- | (p::_) => (SOME o filter (is_class thy)) p;
+ | (p::_) => (SOME o filter (is_operational_class thy)) p;
fun the_ancestry thy classes =
let
@@ -170,14 +173,19 @@
map_type_tfree (fn u as (w, _) =>
if w = v then ty_subst else TFree u);
+fun the_parm_map thy class =
+ let
+ val data = the_class_data thy class
+ in (#consts data) end;
+
fun the_consts_sign thy class =
let
val data = the_class_data thy class
- in (#var data, #consts data) end;
+ in (#var data, (map snd o #consts) data) end;
fun the_inst_sign thy (class, tyco) =
let
- val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
+ val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class);
val arity =
Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class];
val clsvar = (#var o the_class_data thy) class;
@@ -216,13 +224,16 @@
tab
|> Symtab.update (class, [])),
consttab
- |> fold (fn (c, _) => Symtab.update (c, class)) consts
+ |> fold (fn (_, (c, _)) => Symtab.update (c, class)) consts
));
fun add_inst_data (class, inst) =
ClassData.map (fn ((gr, tab), consttab) =>
- ((gr, tab |>
- (Symtab.map_entry class (AList.update (op =) inst))), consttab));
+ let
+ val undef_supclasses = class :: (filter (Symtab.defined tab) (Graph.all_succs gr [class]));
+ in
+ ((gr, tab |> fold (fn class => Symtab.map_entry class (AList.update (op =) inst)) undef_supclasses), consttab)
+ end);
(* name handling *)
@@ -234,7 +245,7 @@
map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort);
fun intern_class thy =
- certify_class thy o Sign.intern_class thy;
+certify_class thy o Sign.intern_class thy;
fun intern_sort thy =
certify_sort thy o Sign.intern_sort thy;
@@ -356,18 +367,19 @@
|> map (#name_axclass o the_class_data thy)
|> Sorts.certify_sort (Sign.classes_of thy)
|> null ? K (Sign.defaultS thy);
- val supcs = (Library.flat o map (snd o the_consts_sign thy) o the_ancestry thy)
- supclasses;
val expr = if null supclasses
then Locale.empty
else
(Locale.Merge o map (Locale.Locale o #name_locale o the_class_data thy)) supclasses;
+ val mapp_sup = AList.make
+ (the o AList.lookup (op =) ((Library.flat o map (the_parm_map thy) o the_ancestry thy) supclasses))
+ ((map (fst o fst) o Locale.parameters_of_expr thy) expr);
fun extract_tyvar_consts thy name_locale =
let
fun extract_tyvar_name thy tys =
fold (curry add_typ_tfrees) tys []
|> (fn [(v, sort)] =>
- if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
+ if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
then v
else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
| [] => error ("no class type variable")
@@ -377,10 +389,9 @@
|> map (apsnd Syntax.unlocalize_mixfix)
val v = (extract_tyvar_name thy o map (snd o fst)) consts1;
val consts2 = map ((apfst o apsnd) (subst_clsvar v (TFree (v, [])))) consts1;
- in (v, chop (length supcs) consts2) end;
+ in (v, chop (length mapp_sup) consts2) end;
fun add_consts v raw_cs_sup raw_cs_this thy =
let
- val mapp_sub = map2 (fn ((c, _), _) => pair c) raw_cs_sup supcs
fun add_global_const ((c, ty), syn) thy =
thy
|> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
@@ -388,7 +399,6 @@
in
thy
|> fold_map add_global_const raw_cs_this
- |-> (fn mapp_this => pair (mapp_sub @ mapp_this, map snd mapp_this))
end;
fun extract_assumes thy name_locale cs_mapp =
let
@@ -400,7 +410,7 @@
in
(map prep_asm o Locale.local_asms_of thy) name_locale
end;
- fun add_global_constraint v class (c, ty) thy =
+ fun add_global_constraint v class (_, (c, ty)) thy =
thy
|> Sign.add_const_constraint_i (c, SOME (subst_clsvar v (TFree (v, [class])) ty));
fun mk_const thy class v (c, ty) =
@@ -412,15 +422,15 @@
`(fn thy => extract_tyvar_consts thy name_locale)
#-> (fn (v, (raw_cs_sup, raw_cs_this)) =>
add_consts v raw_cs_sup raw_cs_this
- #-> (fn (cs_map, cs_this) =>
- `(fn thy => extract_assumes thy name_locale cs_map)
+ #-> (fn mapp_this =>
+ `(fn thy => extract_assumes thy name_locale (mapp_sup @ mapp_this))
#-> (fn loc_axioms =>
add_axclass_i (bname, supsort) loc_axioms
#-> (fn (name_axclass, (_, ax_axioms)) =>
- fold (add_global_constraint v name_axclass) cs_this
- #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, cs_this))
+ fold (add_global_constraint v name_axclass) mapp_this
+ #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, mapp_this))
#> prove_interpretation_i (NameSpace.base name_locale, [])
- (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (supcs @ cs_this))
+ (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (map snd (mapp_sup @ mapp_this)))
((ALLGOALS o resolve_tac) ax_axioms)
#> pair ctxt
)))))
@@ -490,7 +500,7 @@
val data = the_class_data theory class;
val subst_ty = map_type_tfree (fn (var as (v, _)) =>
if #var data = v then ty_inst else TFree var)
- in (map (apsnd subst_ty) o #consts) data end;
+ in (map (apsnd subst_ty o snd) o #consts) data end;
val cs = (Library.flat o map get_consts) classes;
fun get_remove_contraint c thy =
let
@@ -570,7 +580,7 @@
val _ = writeln ("sub " ^ name)
val suplocales = (fn Locale.Merge es => map (fn Locale.Locale n => n) es) expr;
val _ = writeln ("super " ^ commas suplocales)
- fun get_c name =
+ fun get_c name =
(map (NameSpace.base o fst o fst) o Locale.parameters_of thy) name;
fun get_a name =
(map (NameSpace.base o fst o fst) o Locale.local_asms_of thy) name;
@@ -663,7 +673,7 @@
fun mk_lookup (sort_def, (Type (tyco, tys))) =
map (fn class => Instance ((class, tyco),
map2 (curry mk_lookup)
- ((fst o the o AList.lookup (op =) (the_instances thy class)) tyco)
+ (map (operational_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]))
tys)
) sort_def
| mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
@@ -673,7 +683,7 @@
in Lookup (deriv, (vname, classindex)) end;
in map mk_look sort_def end;
in
- sortctxt
+ sortctxt
|> map (tab_lookup o fst)
|> map (apfst (operational_sort_of thy))
|> filter (not o null o fst)
@@ -690,7 +700,7 @@
| SOME class =>
let
val data = the_class_data thy class;
- val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
+ val sign = (Type.varifyT o the o AList.lookup (op =) ((map snd o #consts) data)) c;
val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
val v : string = case Vartab.lookup match_tab (#var data, 0)
of SOME (_, TVar ((v, _), _)) => v;
@@ -751,13 +761,18 @@
Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
Scan.repeat1 P.context_element >> pair Locale.empty);
+val class_subP = P.name -- Scan.repeat (P.$$$ "+" |-- P.name) >> (op ::);
+val class_bodyP = P.!!! (Scan.repeat1 P.context_element);
+
val classP =
OuterSyntax.command classK "operational type classes" K.thy_decl (
P.name --| P.$$$ "="
- -- Scan.optional (Scan.repeat1 (P.name --| P.$$$ "+")) []
- -- Scan.optional (P.!!! (Scan.repeat1 P.context_element)) []
- >> (Toplevel.theory_context
- o (fn ((bname, supclasses), elems) => class bname supclasses elems)));
+ -- (
+ class_subP --| P.$$$ "+" -- class_bodyP
+ || class_subP >> rpair []
+ || class_bodyP >> pair []
+ ) >> (Toplevel.theory_context
+ o (fn (bname, (supclasses, elems)) => class bname supclasses elems)));
val instanceP =
OuterSyntax.command instanceK "prove type arity or subclass relation" K.thy_goal ((
--- a/src/Pure/Tools/codegen_package.ML Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML Fri Mar 17 14:19:24 2006 +0100
@@ -136,7 +136,7 @@
fun eq_typ thy (ty1, ty2) =
Sign.typ_instance thy (ty1, ty2)
- andalso Sign.typ_instance thy (ty2, ty1);
+ andalso Sign.typ_instance thy (ty2, ty1);
fun is_overloaded thy c = case Defs.specifications_of (Theory.defs_of thy) c
of [] => true
@@ -754,7 +754,8 @@
fun gen_membr (m, ty) trns =
trns
|> mk_fun thy tabs true (m, ty)
- |-> (fn NONE => error ("could not derive definition for member " ^ quote m)
+ |-> (fn NONE => error ("could not derive definition for member "
+ ^ quote m ^ " :: " ^ Sign.string_of_typ thy ty)
| SOME (funn, ty_use) =>
(fold_map o fold_map) (exprgen_classlookup thy tabs)
(ClassPackage.extract_classlookup_member thy (ty, ty_use))
@@ -913,7 +914,7 @@
fun eqextr_defs thy (deftab, _) (c, ty) =
Option.mapPartial (get_first (fn (ty', (thm, _)) =>
- if eq_typ thy (ty, ty')
+ if Sign.typ_instance thy (ty, ty')
then SOME ([thm], ty')
else NONE
)) (Symtab.lookup deftab c);
--- a/src/Pure/Tools/codegen_serializer.ML Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML Fri Mar 17 14:19:24 2006 +0100
@@ -413,7 +413,7 @@
str ")"
]
end;
- fun ml_from_sortlookup fxy ls =
+ fun ml_from_sortlookup fxy lss =
let
fun from_label l =
Pretty.block [str "#", ml_from_label l];
@@ -437,10 +437,10 @@
from_lookup BR classes (str v)
| from_classlookup fxy (Lookup (classes, (v, i))) =
from_lookup BR (string_of_int (i+1) :: classes) (str v)
- in case ls
+ in case lss
of [] => str "()"
- | [l] => from_classlookup fxy l
- | ls => (Pretty.list "(" ")" o map (from_classlookup NOBR)) ls
+ | [ls] => from_classlookup fxy ls
+ | lss => (Pretty.list "(" ")" o map (from_classlookup NOBR)) lss
end;
fun ml_from_tycoexpr fxy (tyco, tys) =
let
@@ -738,7 +738,9 @@
ml_from_label supclass
:: str "="
:: (str o resolv) supinst
- :: map (ml_from_sortlookup NOBR) lss
+ :: (if null lss andalso (not o null) arity
+ then [str "()"]
+ else map (ml_from_sortlookup NOBR) lss)
);
fun from_memdef (m, ((m', def), lss)) =
(ml_from_funs [(m', def)], (Pretty.block o Pretty.breaks) (
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/Tools/codegen_theorems.ML Fri Mar 17 14:19:24 2006 +0100
@@ -0,0 +1,148 @@
+(* Title: Pure/Tools/CODEGEN_THEOREMS.ML
+ ID: $Id$
+ Author: Florian Haftmann, TU Muenchen
+
+Theorems used for code generation.
+*)
+
+signature CODEGEN_THEOREMS =
+sig
+ val add_notify: (string option -> theory -> theory) -> theory -> theory;
+ val add_preproc: (theory -> thm list -> thm list) -> theory -> theory;
+ val add_funn: thm -> theory -> theory;
+ val add_pred: thm -> theory -> theory;
+ val add_unfold: thm -> theory -> theory;
+ val preprocess: theory -> thm list -> thm list;
+ val preprocess_term: theory -> term -> term;
+end;
+
+structure CodegenTheorems: CODEGEN_THEOREMS =
+struct
+
+(** auxiliary **)
+
+fun dest_funn thm =
+ case try (fst o dest_Const o fst o strip_comb o fst o Logic.dest_equals o prop_of) thm
+ of SOME c => SOME (c, thm)
+ | NONE => NONE;
+
+fun dest_pred thm =
+ case try (fst o dest_Const o fst o strip_comb o snd o Logic.dest_implies o prop_of) thm
+ of SOME c => SOME (c, thm)
+ | NONE => NONE;
+
+
+(** theory data **)
+
+datatype procs = Procs of {
+ preprocs: (serial * (theory -> thm list -> thm list)) list,
+ notify: (serial * (string option -> theory -> theory)) list
+};
+
+fun mk_procs (preprocs, notify) = Procs { preprocs = preprocs, notify = notify };
+fun map_procs f (Procs { preprocs, notify }) = mk_procs (f (preprocs, notify));
+fun merge_procs _ (Procs { preprocs = preprocs1, notify = notify1 },
+ Procs { preprocs = preprocs2, notify = notify2 }) =
+ mk_procs (AList.merge (op =) (K true) (preprocs1, preprocs2),
+ AList.merge (op =) (K true) (notify1, notify2));
+
+datatype codethms = Codethms of {
+ funns: thm list Symtab.table,
+ preds: thm list Symtab.table,
+ unfolds: thm list
+};
+
+fun mk_codethms ((funns, preds), unfolds) =
+ Codethms { funns = funns, preds = preds, unfolds = unfolds };
+fun map_codethms f (Codethms { funns, preds, unfolds }) =
+ mk_codethms (f ((funns, preds), unfolds));
+fun merge_codethms _ (Codethms { funns = funns1, preds = preds1, unfolds = unfolds1 },
+ Codethms { funns = funns2, preds = preds2, unfolds = unfolds2 }) =
+ mk_codethms ((Symtab.join (K (uncurry (fold (insert eq_thm)))) (funns1, funns2),
+ Symtab.join (K (uncurry (fold (insert eq_thm)))) (preds1, preds2)),
+ fold (insert eq_thm) unfolds1 unfolds2);
+
+datatype T = T of {
+ procs: procs,
+ codethms: codethms
+};
+
+fun mk_T (procs, codethms) = T { procs = procs, codethms = codethms };
+fun map_T f (T { procs, codethms }) = mk_T (f (procs, codethms));
+fun merge_T pp (T { procs = procs1, codethms = codethms1 },
+ T { procs = procs2, codethms = codethms2 }) =
+ mk_T (merge_procs pp (procs1, procs2), merge_codethms pp (codethms1, codethms2));
+
+structure CodegenTheorems = TheoryDataFun
+(struct
+ val name = "Pure/CodegenTheorems";
+ type T = T;
+ val empty = mk_T (mk_procs ([], []),
+ mk_codethms ((Symtab.empty, Symtab.empty), []));
+ val copy = I;
+ val extend = I;
+ val merge = merge_T;
+ fun print _ _ = ();
+end);
+
+val _ = Context.add_setup CodegenTheorems.init;
+
+
+(** notifiers and preprocessors **)
+
+fun add_notify f =
+ CodegenTheorems.map (map_T (fn (procs, codethms) =>
+ (procs |> map_procs (fn (preprocs, notify) =>
+ (preprocs, (serial (), f) :: notify)), codethms)));
+
+fun notify_all c thy =
+ fold (fn f => f c) (((fn Procs { notify, ... } => map snd notify)
+ o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy) thy;
+
+fun add_preproc f =
+ CodegenTheorems.map (map_T (fn (procs, codethms) =>
+ (procs |> map_procs (fn (preprocs, notify) =>
+ ((serial (), f) :: preprocs, notify)), codethms)))
+ #> notify_all NONE;
+
+fun preprocess thy =
+ fold (fn f => f thy) (((fn Procs { preprocs, ... } => map snd preprocs)
+ o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy);
+
+fun preprocess_term thy t =
+ let
+ val x = Free (variant (add_term_names (t, [])) "x", fastype_of t);
+ (* fake definition *)
+ val eq = setmp quick_and_dirty true (SkipProof.make_thm thy)
+ (Logic.mk_equals (x, t));
+ fun err () = error "preprocess_term: bad preprocessor"
+ in case map prop_of (preprocess thy [eq]) of
+ [Const ("==", _) $ x' $ t'] => if x = x' then t' else err ()
+ | _ => err ()
+ end;
+
+fun add_unfold thm =
+ CodegenTheorems.map (map_T (fn (procs, codethms) =>
+ (procs, codethms |> map_codethms (fn (defs, unfolds) =>
+ (defs, thm :: unfolds)))))
+
+fun add_funn thm =
+ case dest_funn thm
+ of SOME (c, thm) =>
+ CodegenTheorems.map (map_T (fn (procs, codethms) =>
+ (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
+ ((funns |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm]), preds), unfolds)))))
+ | NONE => error ("not a function equation: " ^ string_of_thm thm);
+
+fun add_pred thm =
+ case dest_pred thm
+ of SOME (c, thm) =>
+ CodegenTheorems.map (map_T (fn (procs, codethms) =>
+ (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
+ ((funns, preds |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm])), unfolds)))))
+ | NONE => error ("not a predicate clause: " ^ string_of_thm thm);
+
+
+(** isar **)
+
+end; (* struct *)