src/Pure/Tools/codegen_thingol.ML
changeset 18385 d0071d93978e
parent 18380 9668764224a7
child 18441 7488d8ea61bc
--- a/src/Pure/Tools/codegen_thingol.ML	Mon Dec 12 15:36:46 2005 +0100
+++ b/src/Pure/Tools/codegen_thingol.ML	Mon Dec 12 15:37:05 2005 +0100
@@ -38,6 +38,7 @@
   val unfold_abs: iexpr -> (vname * itype) list * iexpr;
   val unfold_let: iexpr -> (ipat * iexpr) list * iexpr;
   val itype_of_iexpr: iexpr -> itype;
+  val itype_of_ipat: ipat -> itype;
   val ipat_of_iexpr: iexpr -> ipat;
   val eq_itype: itype * itype -> bool;
   val tvars_of_itypes: itype list -> string list;
@@ -105,9 +106,9 @@
   val Fun_wfrec: iexpr;
 
   val prims: string list;
-  val get_eqpred: module -> string -> string option;
-  val is_eqtype: module -> itype -> bool;
-  val build_eqpred: module -> string -> def;
+  val invoke_eq: ('a -> transact -> itype * transact)
+    -> (string * (def * (string * sort) list) -> transact -> transact)
+    -> 'a -> transact -> bool * transact;
   val extract_defs: iexpr -> string list;
   val eta_expand: (string -> int) -> module -> module;
   val eta_expand_poly: module -> module;
@@ -943,7 +944,6 @@
 val cons_pair = "Pair";
 val cons_nil = "Nil";
 val cons_cons = "Cons";
-val fun_primeq = "primeq"; (*defined for all primitive types*)
 val fun_eq = "eq"; (*to class eq*)
 val fun_not = "not";
 val fun_and = "and";
@@ -1009,63 +1009,61 @@
 end; (* local *)
 
 val prims = [class_eq, type_bool, type_integer, type_float, type_pair, type_list,
-  cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_primeq, fun_eq, fun_not, fun_and,
+  cons_true, cons_false, cons_pair, cons_nil, cons_cons, fun_eq, fun_not, fun_and,
   fun_or, fun_if, fun_fst, fun_snd, fun_add, fun_mult, fun_minus, fun_lt, fun_le, fun_wfrec];
 
 
 (** equality handling **)
 
-fun get_eqpred modl tyco =
-  if NameSpace.is_qualified tyco
-  then
-    case get_def modl tyco
-     of Datatype (_, _, insts) =>
-          get_first (fn inst =>
-            case get_def modl inst
-             of Classinst (cls, _, memdefs) =>
-                if cls = class_eq
-                then (SOME o snd o hd) memdefs
-                else NONE) insts
-  else SOME fun_primeq;
-
-fun is_eqtype modl (IType (tyco, tys)) =
-      forall (is_eqtype modl) tys
-      andalso (
-        NameSpace.is_qualified tyco
-        orelse
-          case get_def modl tyco
-           of Typesyn (vs, ty) => is_eqtype modl ty
-            | Datatype (_, _, insts) =>
-                exists (fn inst => case get_def modl inst of Classinst (cls, _, _) => cls = class_eq) insts
-      )
-  | is_eqtype modl (IFun _) =
-      false
-  | is_eqtype modl (IVarT (_, sort)) =
-      member (op =) sort class_eq;
-
-fun build_eqpred modl dtname =
+fun invoke_eq gen_ty gen_eq x (trns as (_ , modl)) =
   let
-    val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
-    val sortctxt = map (rpair [class_eq] o fst) vs
-    val ty = IType (dtname, map IVarT sortctxt);
-    fun mk_eq (c, []) =
-          ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
-      | mk_eq (c, tys) =
-          let
-            val vars1 = Term.invent_names [] "a" (length tys);
-            val vars2 = Term.invent_names vars1 "b" (length tys);
-            fun mk_eq_cons ty' (v1, v2) =
-              IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
-            fun mk_conj (e1, e2) =
-              Fun_and `$ e1 `$ e2;
-          in
-            ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
-              ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
-              foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
-          end;
-    val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+    fun mk_eqpred dtname =
+      let
+        val (vs, cons, _) = case get_def modl dtname of Datatype info => info;
+        val arity = map (rpair [class_eq] o fst) vs
+        val ty = IType (dtname, map IVarT arity);
+        fun mk_eq (c, []) =
+              ([ICons ((c, []), ty), ICons ((c, []), ty)], Cons_true)
+          | mk_eq (c, tys) =
+              let
+                val vars1 = Term.invent_names [] "a" (length tys);
+                val vars2 = Term.invent_names vars1 "b" (length tys);
+                fun mk_eq_cons ty' (v1, v2) =
+                  IConst (fun_eq, ty' `-> ty' `-> Type_bool) `$ IVarE (v1, ty) `$ IVarE (v2, ty)
+                fun mk_conj (e1, e2) =
+                  Fun_and `$ e1 `$ e2;
+              in
+                ([ICons ((c, map2 (curry IVarP) vars1 tys), ty),
+                  ICons ((c, map2 (curry IVarP) vars2 tys), ty)],
+                  foldr1 mk_conj (map2 mk_eq_cons tys (vars1 ~~ vars2)))
+              end;
+        val eqs = map mk_eq cons @ [([IVarP ("_", ty), IVarP ("_", ty)], Cons_false)];
+      in
+        (Fun (eqs, (arity, ty `-> ty `-> Type_bool)), arity)
+      end;
+    fun invoke' (IType (tyco, tys)) trns =
+          trns
+          |> fold_map invoke' tys
+          |-> (fn is_eq =>
+                if forall I is_eq
+                  then if NameSpace.is_qualified tyco
+                  then
+                    gen_eq (tyco, mk_eqpred tyco)
+                    #> pair true
+                  else
+                    pair true
+                else
+                  pair false)
+      | invoke' (IFun _) trns =
+          trns 
+          |> pair false
+      | invoke' (IVarT (_, sort)) trns =
+          trns 
+          |> pair (member (op =) sort class_eq)
   in
-    Fun (eqs, (sortctxt, ty `-> ty `-> Type_bool))
+    trns
+    |> gen_ty x
+    |-> (fn ty => invoke' ty)
   end;
 
 
@@ -1209,7 +1207,7 @@
     fun introduce_dicts (Class (supcls, v, membrs, insts)) =
           let
             val _ = writeln "TRANSFORMING CLASS";
-            val _ = PolyML.print (Class (supcls, v, membrs, insts));
+            val _ = print (Class (supcls, v, membrs, insts));
             val varname_cls = Term.invent_names (tvars_of_itypes (map (snd o snd) membrs)) "a" 1 |> hd
           in
             Typesyn ([(varname_cls, supcls)], IDictT (mk_cls_typ_map v membrs (IVarT (varname_cls, []))))
@@ -1217,8 +1215,13 @@
       | introduce_dicts (Classinst (clsname, (tyco, arity), memdefs)) =
           let
             val _ = writeln "TRANSFORMING CLASSINST";
-            val _ = PolyML.print (Classinst (clsname, (tyco, arity), memdefs));
-            val Class (_, v, members, _) = get_def module clsname;
+            val _ = print (Classinst (clsname, (tyco, arity), memdefs));
+            val Class (_, v, members, _) =
+              if clsname = class_eq
+              then
+                Class ([], "a", [(fun_eq, ([], IVarT ("a", []) `-> IVarT ("a", []) `-> Type_bool))], [])
+              else
+                get_def module clsname;
             val ty = tyco `%% map IVarT arity;
             val inst_typ_map = mk_cls_typ_map v members ty;
             val memdefs_ty = map (fn (memname, memprim) =>
@@ -1242,7 +1245,7 @@
                 (map snd sortctxt);
             val _ = writeln "TRANSFORMING FUN (2)";
             val vname_alist = map2 (fn (vt, sort) => fn vs => (vt, vs ~~ sort))
-              sortctxt varnames_ctxt |> PolyML.print;
+              sortctxt varnames_ctxt |> print;
             val _ = writeln "TRANSFORMING FUN (3)";
             val ty' = map (op ** o (fn (vt, vss) => map (fn (_, cls) =>
               cls `%% [IVarT (vt, [])]) vss)) vname_alist