improvements in class and eq handling
authorhaftmann
Mon, 12 Dec 2005 15:37:05 +0100
changeset 18385 d0071d93978e
parent 18384 fa38cca42913
child 18386 e6240d62a7e6
improvements in class and eq handling
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/Tools/codegen_thingol.ML
--- a/src/Pure/Tools/codegen_package.ML	Mon Dec 12 15:36:46 2005 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Mon Dec 12 15:37:05 2005 +0100
@@ -65,10 +65,6 @@
     -> defgen;
   val defgen_datacons: (theory -> string * string -> typ list option)
     -> defgen;
-  val defgen_datatype_eq: (theory -> string -> ((string * sort) list * string list) option)
-    -> defgen;
-  val defgen_datatype_eqinst: (theory -> string -> ((string * sort) list * string list) option)
-    -> defgen;
   val defgen_recfun: (theory -> string * typ -> (term list * term) list * typ)
     -> defgen;
 
@@ -81,7 +77,14 @@
 structure CodegenPackage : CODEGEN_PACKAGE =
 struct
 
-open CodegenThingol;
+open CodegenThingolOp;
+infix 8 `%%;
+infixr 6 `->;
+infixr 6 `-->;
+infix 4 `$;
+infix 4 `$$;
+infixr 5 `|->;
+infixr 5 `|-->;
 
 (* auxiliary *)
 
@@ -125,8 +128,8 @@
 val nsp_dtcon = "dtcon"; (*NOT OPERATIONAL YET*)
 val nsp_mem = "mem";
 val nsp_inst = "inst";
-val nsp_eq_class = "eq_class";
-val nsp_eq = "eq";
+val nsp_eq_inst = "eq_inst";
+val nsp_eq_pred = "eq";
 
 
 (* serializer *)
@@ -135,7 +138,7 @@
   let
     val name_root = "Generated";
     val nsp_conn = [
-      [nsp_class, nsp_type, nsp_eq_class], [nsp_const, nsp_dtcon, nsp_inst, nsp_mem, nsp_eq]
+      [nsp_class, nsp_type], [nsp_const, nsp_dtcon, nsp_inst, nsp_mem, nsp_eq_inst, nsp_eq_pred]
     ];
   in CodegenSerializer.ml_from_thingol nsp_conn name_root end;
 
@@ -143,7 +146,7 @@
   let
     val name_root = "Generated";
     val nsp_conn = [
-      [nsp_class, nsp_eq_class], [nsp_type], [nsp_const, nsp_mem, nsp_eq], [nsp_dtcon], [nsp_inst]
+      [nsp_class], [nsp_type], [nsp_const, nsp_mem, nsp_eq_pred], [nsp_dtcon], [nsp_inst, nsp_eq_inst]
     ];
   in CodegenSerializer.haskell_from_thingol nsp_conn name_root end;
 
@@ -272,6 +275,7 @@
           { serializer = serializer_ml : CodegenSerializer.serializer,
             primitives =
               CodegenSerializer.empty_prims
+              |> CodegenSerializer.add_prim ("Eq", ("type 'a Eq = {eq: 'a -> 'a -> bool};", []))
               |> CodegenSerializer.add_prim ("fst", ("fun fst (x, _) = x;", []))
               |> CodegenSerializer.add_prim ("snd", ("fun snd (_, y) = y;", []))
               |> CodegenSerializer.add_prim ("wfrec", ("fun wfrec f x = f (wfrec f) x;", [])),
@@ -571,10 +575,17 @@
 
 fun fix_nargs thy defs gen (imin, imax) (t, ts) trns =
   if length ts < imin then
-    trns
-    |> debug 10 (fn _ => "eta-expanding")
-    |> gen (strip_comb (Codegen.eta_expand t ts imin))
-    |-> succeed
+    let
+      val d = imin - length ts;
+      val vs = Term.invent_names (add_term_names (t, [])) "x" d;
+      val tys = Library.take (d, ((fst o strip_type o fastype_of) t));
+    in
+      trns
+      |> debug 10 (fn _ => "eta-expanding")
+      |> fold_map (invoke_cg_type thy defs) tys
+      ||>> gen (t, ts @ map2 (curry Free) vs tys)
+      |-> (fn (tys, e) => succeed ((vs ~~ tys) `|--> e))
+    end
   else if length ts > imax then
     trns
     |> debug 10 (fn _ => "splitting arguments (" ^ string_of_int imax ^ ", " ^ string_of_int (length ts) ^ ")")
@@ -587,16 +598,6 @@
     |> gen (t, ts)
     |-> succeed;
 
-local
-  open CodegenThingolOp;
-  infix 8 `%%;
-  infixr 6 `->;
-  infixr 6 `-->;
-  infix 4 `$;
-  infix 4 `$$;
-  infixr 5 `|->;
-  infixr 5 `|-->;
-in
 
 (* code generators *)
 
@@ -701,25 +702,42 @@
       trns
       |> fail ("not a negation: " ^ quote f);
 
-fun exprgen_term_eq thy defs (Const ("op =", Type ("fun", [ty, _]))) trns =
-  trns
+fun appgen_eq thy defs (f as ("op =", Type ("fun", [ty, _])), ts) trns =
+      let
+        fun mk_eq_pred_inst (dtco, (eqpred, arity)) trns =
+          let
+            val name_dtco = (the oo tname_of_idf) thy dtco;
+            val idf_eqinst = idf_of_name thy nsp_eq_inst name_dtco;
+            val idf_eqpred = idf_of_name thy nsp_eq_pred name_dtco;
+            fun mk_eq_pred _ trns =
+              trns
+              |> succeed (eqpred, [])
+            fun mk_eq_inst _ trns =
+              trns
+              |> gen_ensure_def [("eqpred", mk_eq_pred)] ("generating equality predicate for " ^ quote dtco) idf_eqpred
+              |> succeed (Classinst (class_eq, (dtco, arity), [(fun_eq, idf_eqpred)]), [])
+          in
+            trns
+            |> gen_ensure_def [("eqinst", mk_eq_inst)] ("generating equality instance for " ^ quote dtco) idf_eqinst
+          end;
+        fun mk_eq_expr (_, [t1, t2]) trns =
+          trns
+          |> invoke_eq (invoke_cg_type thy defs) mk_eq_pred_inst ty
+          |-> (fn false => error ("could not derive equality for " ^ Sign.string_of_typ thy ty)
+                | true => fn trns => trns
+          |> invoke_cg_expr thy defs t1
+          ||>> invoke_cg_expr thy defs t2
+          |-> (fn (e1, e2) => pair (Fun_eq `$ e1 `$ e2)))
+      in
+        trns
+        |> fix_nargs thy defs mk_eq_expr (2, 2) (Const f, ts)
+      end
+  | appgen_eq thy defs ((f, _), _) trns =
+      trns
+      |> fail ("not an equality: " ^ quote f);
 
-(*fun codegen_eq thy defs t trns =
- let
-   fun cg_eq (Const ("op =", _), [t, u]) =
-         trns
-         |> invoke_cg_type thy defs (type_of t)
-         |-> (fn ty => invoke_ensure_eqinst nsp_eq_class nsp_eq ty #> pair ty)
-         ||>> invoke_cg_expr thy defs t
-         ||>> invoke_cg_expr thy defs u
-         |-> (fn ((ty, t'), u') => succeed (
-               IConst (fun_eq, ty `-> ty `-> Type_bool)
-                 `$ t' `$ u'))
-     | cg_eq _ =
-         trns
-         |> fail ("no equality: " ^ Sign.string_of_term thy t)
-  in cg_eq (strip_comb t) end;*)
 
+(* invoke_eq: ((string * def) -> transact -> transact) -> itype -> transact -> bool * transact;  *)
 
 (* definition generators *)
 
@@ -946,16 +964,6 @@
           end
   end;
 
-local
-
-fun add_eqinst get_datacons thy modl dtname cnames =
-  if forall (is_eqtype modl)
-    (Library.flat (map (fn cname => get_datacons thy (cname, dtname)) cnames))
-  then append [idf_of_name thy nsp_eq_class dtname]
-  else I
-
-in
-
 fun defgen_datatype get_datatype get_datacons thy defs idf trns =
   case tname_of_idf thy idf
    of SOME dtco =>
@@ -983,9 +991,6 @@
     | NONE =>
         trns
         |> fail ("not a type constructor: " ^ quote idf)
-  end;
-
-end; (* local *)
 
 fun defgen_datacons get_datacons thy defs f trns =
   let
@@ -1016,43 +1021,6 @@
           |> fail ("not a constant: " ^ quote f)
   end;
 
-fun defgen_datatype_eq get_datatype thy defs f trns =
-  case name_of_idf thy nsp_eq f
-   of SOME dtname =>
-        (case get_datatype thy dtname
-         of SOME (_, cnames) =>
-              trns
-              |> debug 5 (fn _ => "trying defgen datatype_eq for " ^ quote dtname)
-              |> ensure_def_tyco thy defs (idf_of_tname thy dtname)
-              ||>> fold_map (ensure_def_const thy defs) (cnames
-                   |> map (idf_of_name thy nsp_const)
-                   |> map (fn "0" => "const.Zero" | c => c))
-              ||>> `(fn (_, modl) => build_eqpred modl dtname)
-              |-> (fn (_, eqpred) => succeed (eqpred, []))
-          | NONE =>
-              trns
-              |> fail ("no datatype found for " ^ quote dtname))
-    | NONE =>
-        trns
-        |> fail ("not an equality predicate: " ^ quote f)
-
-fun defgen_datatype_eqinst get_datatype thy defs f trns =
-  case name_of_idf thy nsp_eq_class f
-   of SOME dtname =>
-        (case get_datatype thy dtname
-         of SOME (vars, _) =>
-              trns
-              |> debug 5 (fn _ => "trying defgen datatype_eqinst for " ^ quote dtname)
-              |> ensure_def_const thy defs (idf_of_name thy nsp_eq dtname)
-              |-> (fn pred_eq => succeed (Classinst (class_eq, (dtname,
-                    map (fn (v, _) => (v, [class_eq])) vars), [(fun_eq, pred_eq)]), []))
-          | NONE =>
-              trns
-              |> fail ("no datatype found for " ^ quote dtname))
-    | NONE =>
-        trns
-        |> fail ("not an equality instance: " ^ quote f)
-
 fun defgen_recfun get_equations thy defs f trns =
   case cname_of_idf thy defs f
    of SOME (f, ty) =>
@@ -1455,7 +1423,7 @@
     add_codegen_type ("default", exprgen_type_default),
     add_codegen_expr ("default", exprgen_term_default),
     add_appgen ("default", appgen_default),
-(*     add_codegen_expr ("eq", codegen_eq),  *)
+    add_appgen ("eq", appgen_eq),
     add_appgen ("neg", appgen_neg),
     add_defgen ("clsdecl", defgen_clsdecl),
     add_defgen ("tyco_fallback", defgen_tyco_fallback),
--- a/src/Pure/Tools/codegen_serializer.ML	Mon Dec 12 15:36:46 2005 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML	Mon Dec 12 15:37:05 2005 +0100
@@ -143,7 +143,7 @@
 
 local
 
-fun ml_from_defs tyco_syntax const_syntax resolv ds =
+fun ml_from_defs tyco_syntax const_syntax is_dicttype resolv ds =
   let
     fun chunk_defs ps =
       let
@@ -340,14 +340,6 @@
           Pretty.str "true"
       | ml_from_app br ("False", []) =
           Pretty.str "false"
-      | ml_from_app br ("primeq", [e1, e2]) =
-          brackify (eval_br br (INFX (4, L))) [
-            ml_from_expr (INFX (4, L)) e1,
-            Pretty.str "=",
-            ml_from_expr (INFX (4, X)) e2,
-            Pretty.str ":",
-            ml_from_type NOBR (itype_of_iexpr e2)
-          ]
       | ml_from_app br ("Pair", [e1, e2]) =
           Pretty.list "(" ")" [
             ml_from_expr NOBR e1,
@@ -417,6 +409,7 @@
                 in mk_app_p br (pr (map (ml_from_expr BR) es1) (ml_from_expr BR)) es2 end;
     fun ml_from_funs (ds as d::ds_tl) =
       let
+        val _ = debug 15 (fn _ => "(1) FUN") ();
         fun mk_definer [] = "val"
           | mk_definer _ = "fun"
         fun check_args (_, Fun ((pats, _)::_, _)) NONE =
@@ -427,19 +420,39 @@
               else error ("mixing simultaneous vals and funs not implemented")
           | check_args _ _ =
               error ("function definition block containing other definitions than functions")
+        val _ = debug 15 (fn _ => "(2) FUN") ();
         val definer = the (fold check_args ds NONE);
+        val _ = debug 15 (fn _ => "(3) FUN") ();
         fun mk_eq definer f ty (pats, expr) =
           let
+            val _ = debug 15 (fn _ => "(5) FUN") ();
+            fun mk_pat_arg p =
+              case itype_of_ipat p
+               of ty as IType (tyco, _) =>
+                    if is_dicttype tyco
+                    then Pretty.block [
+                        Pretty.str "(",
+                        ml_from_pat NOBR p,
+                        Pretty.str ":",
+                        ml_from_type NOBR ty,
+                        Pretty.str ")"
+                      ]
+                    else ml_from_pat BR p
+                | _ => ml_from_pat BR p;
+            val _ = debug 15 (fn _ => "(6) FUN") ();
             val lhs = [Pretty.str (definer ^ " " ^ f)]
                        @ (if null pats
                           then [Pretty.str ":", ml_from_type NOBR ty]
-                          else map (ml_from_pat BR) pats)
+                          else map mk_pat_arg pats)
+            val _ = debug 15 (fn _ => "(7) FUN") ();
             val rhs = [Pretty.str "=", ml_from_expr NOBR expr]
+            val _ = debug 15 (fn _ => "(8) FUN") ();
           in
             Pretty.block (separate (Pretty.brk 1) (lhs @ rhs))
           end
         fun mk_fun definer (f, Fun (eqs as eq::eq_tl, (_, ty))) =
           let
+            val _ = debug 15 (fn _ => "(4) FUN") ();
             val (pats_hd::pats_tl) = (fst o split_list) eqs;
             val shift = if null eq_tl then I else map (Pretty.block o single);
           in (Pretty.block o Pretty.fbreaks o shift) (
@@ -507,7 +520,7 @@
           NONE
       | ml_from_def (name, Classinst _) =
           error ("can't serialize instance declaration " ^ quote name ^ " to ML")
-  in (debug 10 (fn _ => "*** defs " ^ commas (map fst ds)) ();
+  in (debug 10 (fn _ => "*** defs " ^ commas (map (fn (n, d) => n ^ " = " ^ (Pretty.output o pretty_def) d) ds)) ();
   case ds
    of (_, Fun _)::_ => ml_from_funs ds
     | (_, Datatypecons _)::_ => ml_from_datatypes ds
@@ -547,9 +560,12 @@
         Pretty.str "",
         Pretty.str ("end; (* struct " ^ name ^ " *)")
       ]);
+    fun is_dicttype tyco =
+      case get_def module tyco
+       of Typesyn (_, IDictT _) => true
+        | _ => false;
     fun eta_expander "Pair" = 2
       | eta_expander "Cons" = 2
-      | eta_expander "primeq" = 2
       | eta_expander "and" = 2
       | eta_expander "or" = 2
       | eta_expander "if" = 3
@@ -569,7 +585,7 @@
                 const_syntax s
                 |> Option.map fst
                 |> the_default 0
-          else 0
+          else 0;
   in
     module
     |> debug 12 (Pretty.output o pretty_module)
@@ -584,8 +600,8 @@
     |> debug 3 (fn _ => "eliminating classes...")
     |> eliminate_classes
     |> debug 12 (Pretty.output o pretty_module)
-    |> debug 3 (fn _ => "generating...")
-    |> serialize (ml_from_defs tyco_syntax const_syntax) ml_from_module ml_validator nspgrp name_root
+    |> debug 3 (fn _ => "serializing...")
+    |> serialize (ml_from_defs tyco_syntax const_syntax is_dicttype) ml_from_module ml_validator nspgrp name_root
     |> (fn p => Pretty.chunks [setmp print_mode [] (Pretty.str o mk_prims) prims, p])
   end;
 
@@ -807,12 +823,6 @@
           Pretty.str "[]"
       | haskell_from_app br ("Cons", es) =
           mk_app_p br (Pretty.str "(:)") es
-      | haskell_from_app br ("primeq", [e1, e2]) =
-          brackify (eval_br br (INFX (4, L))) [
-            haskell_from_expr (INFX (4, L)) e1,
-            Pretty.str "==",
-            haskell_from_expr (INFX (4, X)) e2
-          ]
       | haskell_from_app br ("eq", [e1, e2]) =
           brackify (eval_br br (INFX (4, L))) [
             haskell_from_expr (INFX (4, L)) e1,
@@ -884,7 +894,9 @@
           else error ("empty statement during serialization: " ^ quote name)
       | haskell_from_def (name, Fun (eqs, (_, ty))) =
           let
+            val _ = print "(1) FUN";
             fun from_eq name (args, rhs) =
+              (print args; print rhs;
               Pretty.block [
                 Pretty.str (lower_first name),
                 Pretty.block (map (fn p => Pretty.block [Pretty.brk 1, haskell_from_pat BR p]) args),
@@ -892,7 +904,8 @@
                 Pretty.str ("="),
                 Pretty.brk 1,
                 haskell_from_expr NOBR rhs
-              ]
+              ])
+            val _ = print "(2) FUN";
           in
             Pretty.chunks [
               Pretty.block [
@@ -953,6 +966,17 @@
           end
       | haskell_from_def (name, Classmember _) =
           NONE
+      | haskell_from_def (_, Classinst ("Eq", (tyco, arity), [(_, eqpred)])) = 
+          Pretty.block [
+            Pretty.str "instance ",
+            haskell_from_sctxt arity,
+            Pretty.str "Eq",
+            Pretty.str " ",
+            haskell_from_type NOBR (IType (tyco, (map (IVarT o rpair [] o fst)) arity)),
+            Pretty.str " where",
+            Pretty.fbrk,
+            Pretty.str ("(==) = " ^ (lower_first o resolv) eqpred)
+          ] |> SOME
       | haskell_from_def (_, Classinst (clsname, (tyco, arity), instmems)) = 
           Pretty.block [
             Pretty.str "instance ",
@@ -967,7 +991,7 @@
             ) instmems)
           ] |> SOME
   in
-    case List.mapPartial (fn (name, def) => haskell_from_def (name, def)) defs
+    case List.mapPartial (fn (name, def) => (print ("serializing " ^ name); haskell_from_def (name, def))) defs
      of [] => NONE
       | l => (SOME o Pretty.block) l
   end;
@@ -1028,7 +1052,7 @@
     |> (if is_some select then (partof o the) select else I)
     |> debug 3 (fn _ => "eta-expanding...")
     |> eta_expand eta_expander
-    |> debug 3 (fn _ => "generating...")
+    |> debug 3 (fn _ => "serializing...")
     |> serialize (haskell_from_defs tyco_syntax const_syntax is_cons) haskell_from_module haskell_validator nspgrp name_root
     |> (fn p => Pretty.chunks [setmp print_mode [] (Pretty.str o mk_prims) prims, p])
   end;
--- a/src/Pure/Tools/codegen_thingol.ML	Mon Dec 12 15:36:46 2005 +0100
+++ b/src/Pure/Tools/codegen_thingol.ML	Mon Dec 12 15:37:05 2005 +0100
@@ -38,6 +38,7 @@
   val unfold_abs: iexpr -> (vname * itype) list * iexpr;
   val unfold_let: iexpr -> (ipat * iexpr) list * iexpr;
   val itype_of_iexpr: iexpr -> itype;
+  val itype_of_ipat: ipat -> itype;
   val ipat_of_iexpr: iexpr -> ipat;
   val eq_itype: itype * itype -> bool;
   val tvars_of_itypes: itype list -> string list;
@@ -105,9 +106,9 @@
   val Fun_wfrec: iexpr;
 
   val prims: string list;
-  val get_eqpred: module -> string -> string option;
-  val is_eqtype: module -> itype -> bool;
-  val build_eqpred: module -> string -> def;
+  val invoke_eq: ('a -> transact -> itype * transact)
+    -> (string * (def * (string * sort) list) -> transact -> transact)
+    -> 'a -> transact -> bool * transact;
   val extract_defs: iexpr -> string list;
   val eta_expand: (string -> int) -> module -> module;
   val eta_expand_poly: module -> module;
@@ -943,7 +944,6 @@
 val cons_pair = "Pair";
 val cons_nil = "Nil";
 val cons_cons = "Cons";
-val fun_primeq = "primeq"; (*defined for all primitive types*)
 val fun_eq = "eq"; (*to class eq*)
 val fun_not = "not";
 val fun_and = "and";
@@ -1009,63 +1009,61 @@
 end; (* local *)
 
 val prims = [class_eq, type_bool, type_integer, type_float, type_pair, type_list,
-  cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_primeq, fun_eq, fun_not, fun_and,
+  cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_eq, fun_not, fun_and,
   fun_or, fun_if, fun_fst, fun_snd, fun_add, fun_mult, fun_minus, fun_lt, fun_le, fun_wfrec];
 
 
 (** equality handling **)
 
-fun get_eqpred modl tyco =
-  if NameSpace.is_qualified tyco
-  then
-    case get_def modl tyco
-     of Datatype (_, _, insts) =>
-          get_first (fn inst =>
-            case get_def modl inst
-             of Classinst (cls, _, memdefs) =>
-                if cls = class_eq
-                then (SOME o snd o hd) memdefs
-                else NONE) insts
-  else SOME fun_primeq;
-
-fun is_eqtype modl (IType (tyco, tys)) =
-      forall (is_eqtype modl) tys
-      andalso (
-        NameSpace.is_qualified tyco
-        orelse
-          case get_def modl tyco
-           of Typesyn (vs, ty) => is_eqtype modl ty
-            | Datatype (_, _, insts) =>
-                exists (fn inst => case get_def modl inst of Classinst (cls, _, _) => cls = class_eq) insts
-      )
-  | is_eqtype modl (IFun _) =
-      false
-  | is_eqtype modl (IVarT (_, sort)) =
-      member (op =) sort class_eq;
-
-fun build_eqpred modl dtname =
+fun invoke_eq gen_ty gen_eq x (trns as (_ , modl)) =
   let
-    val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
-    val sortctxt = map (rpair [class_eq] o fst) vs
-    val ty = IType (dtname, map IVarT sortctxt);
-    fun mk_eq (c, []) =
-          ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
-      | mk_eq (c, tys) =
-          let
-            val vars1 = Term.invent_names [] "a" (length tys);
-            val vars2 = Term.invent_names vars1 "b" (length tys);
-            fun mk_eq_cons ty' (v1, v2) =
-              IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
-            fun mk_conj (e1, e2) =
-              Fun_and `$ e1 `$ e2;
-          in
-            ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
-              ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
-              foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
-          end;
-    val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+    fun mk_eqpred dtname =
+      let
+        val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
+        val arity = map (rpair [class_eq] o fst) vs
+        val ty = IType (dtname, map IVarT arity);
+        fun mk_eq (c, []) =
+              ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
+          | mk_eq (c, tys) =
+              let
+                val vars1 = Term.invent_names [] "a" (length tys);
+                val vars2 = Term.invent_names vars1 "b" (length tys);
+                fun mk_eq_cons ty' (v1, v2) =
+                  IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
+                fun mk_conj (e1, e2) =
+                  Fun_and `$ e1 `$ e2;
+              in
+                ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
+                  ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
+                  foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
+              end;
+        val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+      in
+        (Fun (eqs, (arity, ty `-> ty `-> Type_bool)), arity)
+      end;
+    fun invoke' (IType (tyco, tys)) trns =
+          trns
+          |> fold_map invoke' tys
+          |-> (fn is_eq =>
+                if forall I is_eq
+                  then if NameSpace.is_qualified tyco
+                  then
+                    gen_eq (tyco, mk_eqpred tyco)
+                    #> pair true
+                  else
+                    pair true
+                else
+                  pair false)
+      | invoke' (IFun _) trns =
+          trns 
+          |> pair false
+      | invoke' (IVarT (_, sort)) trns =
+          trns 
+          |> pair (member (op =) sort class_eq)
   in
-    Fun (eqs, (sortctxt, ty `-> ty `-> Type_bool))
+    trns
+    |> gen_ty x
+    |-> (fn ty => invoke' ty)
   end;
 
 
@@ -1209,7 +1207,7 @@
     fun introduce_dicts (Class (supcls, v, membrs, insts)) =
           let
             val _ = writeln "TRANSFORMING CLASS";
-            val _ = PolyML.print (Class (supcls, v, membrs, insts));
+            val _ = print (Class (supcls, v, membrs, insts));
             val varname_cls = Term.invent_names (tvars_of_itypes (map (snd o snd) membrs)) "a" 1 |> hd
           in
             Typesyn ([(varname_cls, supcls)], IDictT (mk_cls_typ_map v membrs (IVarT (varname_cls, []))))
@@ -1217,8 +1215,13 @@
       | introduce_dicts (Classinst (clsname, (tyco, arity), memdefs)) =
           let
             val _ = writeln "TRANSFORMING CLASSINST";
-            val _ = PolyML.print (Classinst (clsname, (tyco, arity), memdefs));
-            val Class (_, v, members, _) = get_def module clsname;
+            val _ = print (Classinst (clsname, (tyco, arity), memdefs));
+            val Class (_, v, members, _) =
+              if clsname = class_eq
+              then
+                Class ([], "a", [(fun_eq, ([], IVarT ("a", []) `-> IVarT ("a", []) `-> Type_bool))], [])
+              else
+                get_def module clsname;
             val ty = tyco `%% map IVarT arity;
             val inst_typ_map = mk_cls_typ_map v members ty;
             val memdefs_ty = map (fn (memname, memprim) =>
@@ -1242,7 +1245,7 @@
                 (map snd sortctxt);
             val _ = writeln "TRANSFORMING FUN (2)";
             val vname_alist = map2 (fn (vt, sort) => fn vs => (vt, vs ~~ sort))
-              sortctxt varnames_ctxt |> PolyML.print;
+              sortctxt varnames_ctxt |> print;
             val _ = writeln "TRANSFORMING FUN (3)";
             val ty' = map (op ** o (fn (vt, vss) => map (fn (_, cls) =>
               cls `%% [IVarT (vt, [])]) vss)) vname_alist