--- a/src/Pure/Tools/codegen_package.ML Fri Feb 10 02:22:59 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML Fri Feb 10 09:09:07 2006 +0100
@@ -10,7 +10,9 @@
sig
type auxtab;
type eqextr = theory -> auxtab
- -> (string * typ) -> (thm list * typ) option;
+ -> string * typ -> (thm list * typ) option;
+ type eqextr_default = theory -> auxtab
+ -> string * typ -> ((thm list * term option) * typ) option;
type defgen;
type appgen = theory -> auxtab
-> (string * typ) * term list -> CodegenThingol.transact
@@ -19,6 +21,7 @@
val add_appconst: string * ((int * int) * appgen) -> theory -> theory;
val add_appconst_i: xstring * ((int * int) * appgen) -> theory -> theory;
val add_eqextr: string * eqextr -> theory -> theory;
+ val add_eqextr_default: string * eqextr_default -> theory -> theory;
val add_prim_class: xstring -> (string * string)
-> theory -> theory;
val add_prim_tyco: xstring -> (string * string)
@@ -43,12 +46,14 @@
-> appgen;
val appgen_number_of: (term -> term) -> (theory -> term -> IntInf.int) -> string -> string
-> appgen;
+ val eqextr_eq: (theory -> string -> thm list) -> term
+ -> eqextr_default;
val add_case_const: (theory -> string -> (string * int) list option) -> xstring
-> theory -> theory;
val add_case_const_i: (theory -> string -> (string * int) list option) -> string
-> theory -> theory;
- val print_codegen_generated: theory -> unit;
+ val print_code: theory -> unit;
val rename_inconsistent: theory -> theory;
val ensure_datatype_case_consts: (theory -> string list)
-> (theory -> string -> (string * int) list option)
@@ -91,12 +96,15 @@
(* code generator basics *)
+val alias_ref = ref (fn thy : theory => fn s : string => s);
+fun alias_get name = ! alias_ref name;
+
structure InstNameMangler = NameManglerFun (
type ctxt = theory;
type src = string * (class * string);
val ord = prod_ord string_ord (prod_ord string_ord string_ord);
fun mk thy ((thyname, (cls, tyco)), i) =
- NameSpace.base cls ^ "_" ^ NameSpace.base tyco ^ implode (replicate i "'")
+ (NameSpace.base o alias_get thy) cls ^ "_" ^ (NameSpace.base o alias_get thy) tyco ^ implode (replicate i "'")
|> NameSpace.append thyname;
fun is_valid _ _ = true;
fun maybe_unique _ _ = NONE;
@@ -110,7 +118,7 @@
fun mk thy ((c, (ty_decl, ty)), i) =
let
fun mangle (Type (tyco, tys)) =
- NameSpace.base tyco :: Library.flat (List.mapPartial mangle tys) |> SOME
+ (NameSpace.base o alias_get thy) tyco :: Library.flat (List.mapPartial mangle tys) |> SOME
| mangle _ =
NONE
in
@@ -158,7 +166,9 @@
* (InstNameMangler.T * ((typ * typ list) Symtab.table * ConstNameMangler.T)
* DatatypeconsNameMangler.T);
type eqextr = theory -> auxtab
- -> (string * typ) -> (thm list * typ) option;
+ -> string * typ -> (thm list * typ) option;
+type eqextr_default = theory -> auxtab
+ -> string * typ -> ((thm list * term option) * typ) option;
type defgen = theory -> auxtab -> gen_defgen;
type appgen = theory -> auxtab
-> (string * typ) * term list -> transact -> iexpr * transact;
@@ -191,7 +201,7 @@
type gens = {
appconst: ((int * int) * (appgen * stamp)) Symtab.table,
- eqextrs: (string * (eqextr * stamp)) list
+ eqextrs: (string * (eqextr_default * stamp)) list
};
fun map_gens f { appconst, eqextrs } =
@@ -310,11 +320,11 @@
in CodegenData.put { modl = modl, gens = gens,
target_data = target_data, logic_data = logic_data } thy end;
-fun print_codegen_generated thy =
+fun print_code thy =
let
val module = (#modl o CodegenData.get) thy;
in
- (writeln o Pretty.output o Pretty.chunks) [pretty_module module, pretty_deps module]
+ (Pretty.writeln o Pretty.chunks) [pretty_module module, pretty_deps module]
end;
@@ -329,7 +339,7 @@
(tab |> Symtab.update (src, dst),
tab_rev |> Symtab.update (dst, src))))));
-val alias_get = perhaps o Symtab.lookup o fst o #alias o #logic_data o CodegenData.get;
+val _ = alias_ref := (perhaps o Symtab.lookup o fst o #alias o #logic_data o CodegenData.get);
val alias_rev = perhaps o Symtab.lookup o snd o #alias o #logic_data o CodegenData.get;
fun add_nsp shallow name =
@@ -347,7 +357,7 @@
val (modl, shallow) = split_last idf'';
in
if nsp = shallow
- then (SOME o NameSpace.pack) (modl @ [idf_base])
+ then (SOME o NameSpace.pack) (modl @ [idf_base])
else NONE
end;
@@ -427,11 +437,22 @@
(fn (appconst, eqextrs) =>
(appconst, eqextrs
|> Output.update_warn (op =) ("overwriting existing equation extractor " ^ name)
+ (name, ((Option.map o apfst o rpair) NONE ooo eqx , stamp ())))),
+ target_data, logic_data));
+
+fun add_eqextr_default (name, eqx) =
+ map_codegen_data
+ (fn (modl, gens, target_data, logic_data) =>
+ (modl,
+ gens |> map_gens
+ (fn (appconst, eqextrs) =>
+ (appconst, eqextrs
+ |> Output.update_warn (op =) ("overwriting existing equation extractor " ^ name)
(name, (eqx, stamp ())))),
target_data, logic_data));
fun get_eqextrs thy tabs =
- (map (fn (_, (eqx, _)) => eqx thy tabs) o #eqextrs o #gens o CodegenData.get) thy;
+ (map (fn (name, (eqx, _)) => (name, eqx thy tabs)) o #eqextrs o #gens o CodegenData.get) thy;
fun set_get_all_datatype_cons f =
map_codegen_data
@@ -465,7 +486,7 @@
|> Symtab.update (
#ml CodegenSerializer.serializers
|> apsnd (fn seri => seri
- (nsp_dtcon, nsp_class, fn tyco' => tyco' = idf_of_name thy nsp_tyco tyco )
+ (nsp_dtcon, nsp_class, fn tyco' => tyco' = idf_of_name thy nsp_tyco tyco )
[[nsp_module], [nsp_class, nsp_tyco], [nsp_const, nsp_overl, nsp_dtcon, nsp_mem, nsp_inst]]
)
)
@@ -474,27 +495,28 @@
(* sophisticated devarification *)
-fun assert f msg x =
- if f x then x
- else error msg;
-
-val _ : ('a -> bool) -> string -> 'a -> 'a = assert;
+fun eq_typ thy (ty1, ty2) =
+ Sign.typ_instance thy (ty1, ty2)
+ andalso Sign.typ_instance thy (ty2, ty1);
fun devarify_typs tys =
let
- fun add_rename (var as ((v, _), sort)) used =
+ fun add_rename (vi as (v, _), sorts) used =
let
val v' = "'" ^ variant used (unprefix "'" v)
- in (((var, TFree (v', sort)), (v', TVar var)), v' :: used) end;
+ in (map (fn sort => (((vi, sort), TFree (v', sort)), (v', TVar (vi, sort)))) sorts, v' :: used) end;
fun typ_names (Type (tyco, tys)) (vars, names) =
(vars, names |> insert (op =) (NameSpace.base tyco))
|> fold typ_names tys
| typ_names (TFree (v, _)) (vars, names) =
(vars, names |> insert (op =) (unprefix "'" v))
| typ_names (TVar (vi, sort)) (vars, names) =
- (vars |> AList.update (op =) (vi, sort), names);
+ (vars
+ |> AList.default (op =) (vi, [])
+ |> AList.map_entry (op =) vi (cons sort),
+ names);
val (vars, used) = fold typ_names tys ([], []);
- val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
+ val (renames, reverse) = fold_map add_rename vars used |> fst |> Library.flat |> split_list;
in
(reverse, map (Term.instantiateT renames) tys)
end;
@@ -513,16 +535,19 @@
fun devarify_terms ts =
let
- fun add_rename (var as ((v, _), ty)) used =
+ fun add_rename (vi as (v, _), tys) used =
let
val v' = variant used v
- in (((var, Free (v', ty)), (v', Var var)), v' :: used) end;
+ in (map (fn ty => (((vi, ty), Free (v', ty)), (v', Var (vi, ty)))) tys, v' :: used) end;
fun term_names (Const (c, _)) (vars, names) =
(vars, names |> insert (op =) (NameSpace.base c))
| term_names (Free (v, _)) (vars, names) =
(vars, names |> insert (op =) v)
- | term_names (Var (v, sort)) (vars, names) =
- (vars |> AList.update (op =) (v, sort), names)
+ | term_names (Var (vi, ty)) (vars, names) =
+ (vars
+ |> AList.default (op =) (vi, [])
+ |> AList.map_entry (op =) vi (cons ty),
+ names)
| term_names (Bound _) vars_names =
vars_names
| term_names (Abs (v, _, _)) (vars, names) =
@@ -530,7 +555,7 @@
| term_names (t1 $ t2) vars_names =
vars_names |> term_names t1 |> term_names t2
val (vars, used) = fold term_names ts ([], []);
- val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
+ val (renames, reverse) = fold_map add_rename vars used |> fst |> Library.flat |> split_list;
in
(reverse, map (Term.instantiate ([], renames)) ts)
end;
@@ -576,7 +601,7 @@
fun defgen_datatype thy (tabs as (_, (_, _, dtcontab))) dtco trns =
case name_of_idf thy nsp_tyco dtco
of SOME dtco =>
- (case get_datatype thy dtco
+ (case get_datatype thy dtco
of SOME (vars, cos) =>
let
val cos' = map (fn (co, tys) => (DatatypeconsNameMangler.get thy dtcontab (co, dtco) |>
@@ -635,8 +660,8 @@
|> fold_map (ensure_def_class thy tabs) clss
|-> (fn clss => pair (Lookup (clss, (v |> unprefix "'", i))))
and mk_fun thy tabs (c, ty) trns =
- case get_first (fn eqx => eqx (c, ty)) (get_eqextrs thy tabs)
- of SOME (eq_thms, ty) =>
+ case get_first (fn (name, eqx) => (eqx (c, ty))) (get_eqextrs thy tabs)
+ of SOME ((eq_thms, default), ty) =>
let
val sortctxt = ClassPackage.extract_sortctxt thy ty;
fun dest_eqthm eq_thm =
@@ -649,12 +674,22 @@
^ ", actually defining " ^ quote c')
| _ => error ("illegal function equation for " ^ quote c)
end;
+ fun mk_default t =
+ let
+ val (tys, ty') = strip_type ty;
+ val vs = Term.invent_names (add_term_names (t, [])) "x" (length tys);
+ in
+ if (not o eq_typ thy) (type_of t, ty')
+ then error ("inconsistent type for default rule")
+ else (map2 (curry Free) vs tys, t)
+ end;
in
trns
|> (codegen_eqs thy tabs o map dest_eqthm) eq_thms
+ ||>> (codegen_eqs thy tabs o the_list o Option.map mk_default) default
||>> codegen_type thy tabs [ty]
||>> fold_map (exprgen_tyvar_sort thy tabs) sortctxt
- |-> (fn ((eqs, [ty]), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty)))
+ |-> (fn (((eqs, eq_default), [ty]), sortctxt) => (pair o SOME) (eqs @ eq_default, (sortctxt, ty)))
end
| NONE => (NONE, trns)
and ensure_def_inst thy (tabs as (_, (insttab, _, _))) (cls, tyco) trns =
@@ -825,49 +860,14 @@
trns
|> appgen_default thy tabs ((f, ty), ts);
-(* fun ensure_def_eq thy tabs (dtco, (eqpred, arity)) trns =
- let
- val name_dtco = (the ooo name_of_idf) thy nsp_tyco dtco;
- val idf_eqinst = idf_of_name thy nsp_eq_inst name_dtco;
- val idf_eqpred = idf_of_name thy nsp_eq_pred name_dtco;
- val inst_sortlookup = map (fn (v, _) => [ClassPackage.Lookup ([], (v, 0))]) arity;
- fun mk_eq_pred _ trns =
- trns
- |> succeed (eqpred)
- fun mk_eq_inst _ trns =
- trns
- |> gen_ensure_def [("eqpred", mk_eq_pred)] ("generating equality predicate for " ^ quote dtco) idf_eqpred
- |> succeed (Classinst ((class_eq, (dtco, arity)), ([], [(fun_eq, (idf_eqpred, inst_sortlookup))])));
- in
- trns
- |> gen_ensure_def [("eqinst", mk_eq_inst)] ("generating equality instance for " ^ quote dtco) idf_eqinst
- end; *)
-
-(* expression generators *)
-
-(* fun appgen_eq thy tabs (("op =", Type ("fun", [ty, _])), [t1, t2]) trns =
- trns
- |> invoke_eq (exprgen_type thy tabs) (ensure_def_eq thy tabs) ty
- |-> (fn false => error ("could not derive equality for " ^ Sign.string_of_typ thy ty)
- | true => fn trns => trns
- |> exprgen_term thy tabs t1
- ||>> exprgen_term thy tabs t2
- |-> (fn (e1, e2) => pair (Fun_eq `$ e1 `$ e2))); *)
-
(* function extractors *)
fun eqextr_defs thy ((deftab, _), _) (c, ty) =
- let
- fun eq_typ (ty1, ty2) =
- Sign.typ_instance thy (ty1, ty2)
- andalso Sign.typ_instance thy (ty2, ty1)
- in
- Option.mapPartial (get_first (fn (ty', thm) => if eq_typ (ty, ty')
- then SOME ([thm], ty')
- else NONE
- )) (Symtab.lookup deftab c)
- end;
+ Option.mapPartial (get_first (fn (ty', thm) => if eq_typ thy (ty, ty')
+ then SOME ([thm], ty')
+ else NONE
+ )) (Symtab.lookup deftab c);
(* parametrized generators, for instantiation in HOL *)
@@ -916,6 +916,17 @@
|> exprgen_term thy tabs (mk_int_to_nat bin)
else error ("invalid type constructor for numeral: " ^ quote tyco);
+fun eqextr_eq f fals thy tabs ("op =", ty) =
+ (case ty
+ of Type ("fun", [Type (dtco, _), _]) =>
+ (case f thy dtco
+ of [] => NONE
+ | [eq] => SOME ((Codegen.preprocess thy [eq], NONE), ty)
+ | eqs => SOME ((Codegen.preprocess thy eqs, SOME fals), ty))
+ | _ => NONE)
+ | eqextr_eq f fals thy tabs _ =
+ NONE;
+
fun appgen_datatype_case cos thy tabs ((_, ty), ts) trns =
let
val (ts', t) = split_last ts;
@@ -972,7 +983,7 @@
in if forall is_Var args then SOME ((c, ty), tm) else NONE
end handle TERM _ => NONE;
fun prep_def def = (case Codegen.preprocess thy [def] of
- [def'] => def' | _ => error "mk_auxtab: bad preprocessor");
+ [def'] => def' | _ => error "mk_tabs: bad preprocessor");
fun add_def (name, _) =
case (dest o prep_def o Thm.get_axiom thy) name
of SOME ((c, ty), tm) =>
@@ -990,6 +1001,22 @@
(fn (tyco, thyname) => InstNameMangler.declare thy (thyname, (cls, tyco))) clsinsts)
(ClassPackage.get_classtab thy)
|-> (fn _ => I);
+ fun add_monoeq thy (overltab1, overltab2) =
+ let
+ val c = "op =";
+ val ty = Sign.the_const_type thy c;
+ fun inst dtco =
+ map_atyps (fn _ => Type (dtco,
+ (map (fn (v, sort) => TVar ((v, 0), sort)) o fst o the o get_datatype thy) dtco)) ty
+ val dtcos = fold (insert (op =) o snd) (get_all_datatype_cons thy) [];
+ val tys = map inst dtcos;
+ in
+ (overltab1
+ |> Symtab.update_new (c, (ty, tys)),
+ overltab2
+ |> fold (fn ty' => ConstNameMangler.declare thy
+ (idf_of_name thy nsp_overl c, (ty, ty')) #> snd) tys)
+ end;
fun mk_overltabs thy deftab =
(Symtab.empty, ConstNameMangler.empty)
|> Symtab.fold
@@ -998,19 +1025,20 @@
if (is_none o ClassPackage.lookup_const_class thy) c
then (fn (overltab1, overltab2) => (
overltab1
- |> Symtab.update_new (c, (Sign.the_const_constraint thy c, map fst tytab)),
+ |> Symtab.update_new (c, (Sign.the_const_type thy c, map fst tytab)),
overltab2
|> fold (fn (ty, _) => ConstNameMangler.declare thy
- (idf_of_name thy nsp_overl c, (Sign.the_const_constraint thy c, ty)) #> snd) tytab))
+ (idf_of_name thy nsp_overl c, (Sign.the_const_type thy c, ty)) #> snd) tytab))
else I
- ) deftab;
+ ) deftab
+ |> add_monoeq thy;
fun mk_dtcontab thy =
DatatypeconsNameMangler.empty
|> fold_map
(fn (_, co_dtco) => DatatypeconsNameMangler.declare_multi thy co_dtco)
(fold (fn (co, dtco) =>
let
- val key = ((NameSpace.drop_base o NameSpace.drop_base) co, NameSpace.base co)
+ val key = ((NameSpace.drop_base o NameSpace.drop_base) co, NameSpace.base co);
in AList.default (op =) (key, []) #> AList.map_entry (op =) key (cons (co, dtco)) end
) (get_all_datatype_cons thy) [])
|-> (fn _ => I);
@@ -1030,7 +1058,7 @@
fun get_serializer target =
case Symtab.lookup (!serializers) target
of SOME seri => seri
- | NONE => error ("unknown code target language: " ^ quote target);
+ | NONE => Scan.fail_with (fn _ => "unknown code target language: " ^ quote target) ();
fun map_module f =
map_codegen_data (fn (modl, gens, target_data, logic_data) =>
@@ -1094,7 +1122,7 @@
then ()
else error ("no such constant: " ^ quote c);
val ty = case raw_ty
- of NONE => Sign.the_const_constraint thy c
+ of NONE => Sign.the_const_type thy c
| SOME raw_ty => read_typ thy raw_ty;
in (c, ty) end;
@@ -1127,7 +1155,7 @@
(fn thy => fn tabs => idf_of_const thy tabs o read_const thy)
CodegenSerializer.parse_targetdef;
-val ensure_prim = (map_module oo CodegenThingol.ensure_prim);
+val ensure_prim = map_module oo CodegenThingol.ensure_prim;
(* syntax *)
@@ -1235,7 +1263,7 @@
(** toplevel interface **)
local
-
+
fun generate_code (SOME raw_consts) thy =
let
val consts = map (read_const thy) raw_consts;
@@ -1363,7 +1391,7 @@
P.name -- parse_syntax_const raw_const
))
)
- >> (Toplevel.theory oo fold o fold)
+ >> (Toplevel.theory oo fold o fold)
(fn (target, modifier) => modifier target)
);
@@ -1376,7 +1404,6 @@
val _ = Context.add_setup (
add_eqextr ("defs", eqextr_defs)
-(* add_appconst_i ("op =", ((2, 2), appgen_eq)) *)
);
end; (* local *)