--- a/src/Pure/Tools/codegen_thingol.ML Mon Dec 12 15:36:46 2005 +0100
+++ b/src/Pure/Tools/codegen_thingol.ML Mon Dec 12 15:37:05 2005 +0100
@@ -38,6 +38,7 @@
val unfold_abs: iexpr -> (vname * itype) list * iexpr;
val unfold_let: iexpr -> (ipat * iexpr) list * iexpr;
val itype_of_iexpr: iexpr -> itype;
+ val itype_of_ipat: ipat -> itype;
val ipat_of_iexpr: iexpr -> ipat;
val eq_itype: itype * itype -> bool;
val tvars_of_itypes: itype list -> string list;
@@ -105,9 +106,9 @@
val Fun_wfrec: iexpr;
val prims: string list;
- val get_eqpred: module -> string -> string option;
- val is_eqtype: module -> itype -> bool;
- val build_eqpred: module -> string -> def;
+ val invoke_eq: ('a -> transact -> itype * transact)
+ -> (string * (def * (string * sort) list) -> transact -> transact)
+ -> 'a -> transact -> bool * transact;
val extract_defs: iexpr -> string list;
val eta_expand: (string -> int) -> module -> module;
val eta_expand_poly: module -> module;
@@ -943,7 +944,6 @@
val cons_pair = "Pair";
val cons_nil = "Nil";
val cons_cons = "Cons";
-val fun_primeq = "primeq"; (*defined for all primitive types*)
val fun_eq = "eq"; (*to class eq*)
val fun_not = "not";
val fun_and = "and";
@@ -1009,63 +1009,61 @@
end; (* local *)
val prims = [class_eq, type_bool, type_integer, type_float, type_pair, type_list,
- cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_primeq, fun_eq, fun_not, fun_and,
+ cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_eq, fun_not, fun_and,
fun_or, fun_if, fun_fst, fun_snd, fun_add, fun_mult, fun_minus, fun_lt, fun_le, fun_wfrec];
(** equality handling **)
-fun get_eqpred modl tyco =
- if NameSpace.is_qualified tyco
- then
- case get_def modl tyco
- of Datatype (_, _, insts) =>
- get_first (fn inst =>
- case get_def modl inst
- of Classinst (cls, _, memdefs) =>
- if cls = class_eq
- then (SOME o snd o hd) memdefs
- else NONE) insts
- else SOME fun_primeq;
-
-fun is_eqtype modl (IType (tyco, tys)) =
- forall (is_eqtype modl) tys
- andalso (
- NameSpace.is_qualified tyco
- orelse
- case get_def modl tyco
- of Typesyn (vs, ty) => is_eqtype modl ty
- | Datatype (_, _, insts) =>
- exists (fn inst => case get_def modl inst of Classinst (cls, _, _) => cls = class_eq) insts
- )
- | is_eqtype modl (IFun _) =
- false
- | is_eqtype modl (IVarT (_, sort)) =
- member (op =) sort class_eq;
-
-fun build_eqpred modl dtname =
+fun invoke_eq gen_ty gen_eq x (trns as (_ , modl)) =
let
- val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
- val sortctxt = map (rpair [class_eq] o fst) vs
- val ty = IType (dtname, map IVarT sortctxt);
- fun mk_eq (c, []) =
- ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
- | mk_eq (c, tys) =
- let
- val vars1 = Term.invent_names [] "a" (length tys);
- val vars2 = Term.invent_names vars1 "b" (length tys);
- fun mk_eq_cons ty' (v1, v2) =
- IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
- fun mk_conj (e1, e2) =
- Fun_and `$ e1 `$ e2;
- in
- ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
- ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
- foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
- end;
- val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+ fun mk_eqpred dtname =
+ let
+ val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
+ val arity = map (rpair [class_eq] o fst) vs
+ val ty = IType (dtname, map IVarT arity);
+ fun mk_eq (c, []) =
+ ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
+ | mk_eq (c, tys) =
+ let
+ val vars1 = Term.invent_names [] "a" (length tys);
+ val vars2 = Term.invent_names vars1 "b" (length tys);
+ fun mk_eq_cons ty' (v1, v2) =
+ IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
+ fun mk_conj (e1, e2) =
+ Fun_and `$ e1 `$ e2;
+ in
+ ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
+ ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
+ foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
+ end;
+ val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+ in
+ (Fun (eqs, (arity, ty `-> ty `-> Type_bool)), arity)
+ end;
+ fun invoke' (IType (tyco, tys)) trns =
+ trns
+ |> fold_map invoke' tys
+ |-> (fn is_eq =>
+ if forall I is_eq
+ then if NameSpace.is_qualified tyco
+ then
+ gen_eq (tyco, mk_eqpred tyco)
+ #> pair true
+ else
+ pair true
+ else
+ pair false)
+ | invoke' (IFun _) trns =
+ trns
+ |> pair false
+ | invoke' (IVarT (_, sort)) trns =
+ trns
+ |> pair (member (op =) sort class_eq)
in
- Fun (eqs, (sortctxt, ty `-> ty `-> Type_bool))
+ trns
+ |> gen_ty x
+ |-> (fn ty => invoke' ty)
end;
@@ -1209,7 +1207,7 @@
fun introduce_dicts (Class (supcls, v, membrs, insts)) =
let
val _ = writeln "TRANSFORMING CLASS";
- val _ = PolyML.print (Class (supcls, v, membrs, insts));
+ val _ = print (Class (supcls, v, membrs, insts));
val varname_cls = Term.invent_names (tvars_of_itypes (map (snd o snd) membrs)) "a" 1 |> hd
in
Typesyn ([(varname_cls, supcls)], IDictT (mk_cls_typ_map v membrs (IVarT (varname_cls, []))))
@@ -1217,8 +1215,13 @@
| introduce_dicts (Classinst (clsname, (tyco, arity), memdefs)) =
let
val _ = writeln "TRANSFORMING CLASSINST";
- val _ = PolyML.print (Classinst (clsname, (tyco, arity), memdefs));
- val Class (_, v, members, _) = get_def module clsname;
+ val _ = print (Classinst (clsname, (tyco, arity), memdefs));
+ val Class (_, v, members, _) =
+ if clsname = class_eq
+ then
+ Class ([], "a", [(fun_eq, ([], IVarT ("a", []) `-> IVarT ("a", []) `-> Type_bool))], [])
+ else
+ get_def module clsname;
val ty = tyco `%% map IVarT arity;
val inst_typ_map = mk_cls_typ_map v members ty;
val memdefs_ty = map (fn (memname, memprim) =>
@@ -1242,7 +1245,7 @@
(map snd sortctxt);
val _ = writeln "TRANSFORMING FUN (2)";
val vname_alist = map2 (fn (vt, sort) => fn vs => (vt, vs ~~ sort))
- sortctxt varnames_ctxt |> PolyML.print;
+ sortctxt varnames_ctxt |> print;
val _ = writeln "TRANSFORMING FUN (3)";
val ty' = map (op ** o (fn (vt, vss) => map (fn (_, cls) =>
cls `%% [IVarT (vt, [])]) vss)) vname_alist