heutistics for type annotations in Haskell
authorhaftmann
Thu, 13 Dec 2007 07:09:09 +0100
changeset 25621 97ebdbdb0299
parent 25620 a6cb8f60cff7
child 25622 6067d838041a
heutistics for type annotations in Haskell
src/Tools/code/code_target.ML
src/Tools/code/code_thingol.ML
--- a/src/Tools/code/code_target.ML	Thu Dec 13 07:09:08 2007 +0100
+++ b/src/Tools/code/code_target.ML	Thu Dec 13 07:09:09 2007 +0100
@@ -1130,7 +1130,7 @@
 in
 
 fun pr_haskell class_syntax tyco_syntax const_syntax labelled_name
-    init_syms deresolv_here deresolv is_cons deriving_show def =
+    init_syms deresolv_here deresolv is_cons contr_classparam_typs deriving_show def =
   let
     fun class_name class = case class_syntax class
      of NONE => deresolv class
@@ -1139,17 +1139,17 @@
      of NONE => deresolv_here classparam
       | SOME (_, classparam_syntax) => case classparam_syntax classparam
          of NONE => (snd o dest_name) classparam
-          | SOME classparam => classparam
-    fun pr_typparms tyvars vs =
-      case maps (fn (v, sort) => map (pair v) sort) vs
-       of [] => str ""
-        | xs => Pretty.block [
-            Pretty.enum "," "(" ")" (
-              map (fn (v, class) => str
-                (class_name class ^ " " ^ CodeName.lookup_var tyvars v)) xs
-            ),
-            str " => "
-          ];
+          | SOME classparam => classparam;
+    fun pr_typcontext tyvars vs = case maps (fn (v, sort) => map (pair v) sort) vs
+     of [] => []
+      | classbinds => Pretty.enum "," "(" ")" (
+          map (fn (v, class) =>
+            str (class_name class ^ " " ^ CodeName.lookup_var tyvars v)) classbinds)
+          @@ str " => ";
+    fun pr_typforall tyvars vs = case map fst vs
+     of [] => []
+      | vnames => str "forall " :: Pretty.breaks
+          (map (str o CodeName.lookup_var tyvars) vnames) @ str "." @@ Pretty.brk 1;
     fun pr_tycoexpr tyvars fxy (tyco, tys) =
       brackify fxy (str tyco :: map (pr_typ tyvars BR) tys)
     and pr_typ tyvars fxy (tycoexpr as tyco `%% tys) =
@@ -1164,66 +1164,78 @@
                 else pr (pr_typ tyvars) fxy tys)
       | pr_typ tyvars fxy (ITyVar v) =
           (str o CodeName.lookup_var tyvars) v;
-    fun pr_typscheme_expr tyvars (vs, tycoexpr) =
-      Pretty.block (pr_typparms tyvars vs @@ pr_tycoexpr tyvars NOBR tycoexpr);
+    fun pr_typdecl tyvars (vs, tycoexpr) =
+      Pretty.block (pr_typcontext tyvars vs @| pr_tycoexpr tyvars NOBR tycoexpr);
     fun pr_typscheme tyvars (vs, ty) =
-      Pretty.block (pr_typparms tyvars vs @@ pr_typ tyvars NOBR ty);
-    fun pr_term lhs vars fxy (IConst c) =
-          pr_app lhs vars fxy (c, [])
-      | pr_term lhs vars fxy (t as (t1 `$ t2)) =
+      Pretty.block (pr_typforall tyvars vs @ pr_typcontext tyvars vs @| pr_typ tyvars NOBR ty);
+    fun pr_term tyvars lhs vars fxy (IConst c) =
+          pr_app tyvars lhs vars fxy (c, [])
+      | pr_term tyvars lhs vars fxy (t as (t1 `$ t2)) =
           (case CodeThingol.unfold_const_app t
-           of SOME app => pr_app lhs vars fxy app
+           of SOME app => pr_app tyvars lhs vars fxy app
             | _ =>
                 brackify fxy [
-                  pr_term lhs vars NOBR t1,
-                  pr_term lhs vars BR t2
+                  pr_term tyvars lhs vars NOBR t1,
+                  pr_term tyvars lhs vars BR t2
                 ])
-      | pr_term lhs vars fxy (IVar v) =
+      | pr_term tyvars lhs vars fxy (IVar v) =
           (str o CodeName.lookup_var vars) v
-      | pr_term lhs vars fxy (t as _ `|-> _) =
+      | pr_term tyvars lhs vars fxy (t as _ `|-> _) =
           let
             val (binds, t') = CodeThingol.unfold_abs t;
-            fun pr ((v, pat), ty) = pr_bind BR ((SOME v, pat), ty);
+            fun pr ((v, pat), ty) = pr_bind tyvars BR ((SOME v, pat), ty);
             val (ps, vars') = fold_map pr binds vars;
-          in brackets (str "\\" :: ps @ str "->" @@ pr_term lhs vars' NOBR t') end
-      | pr_term lhs vars fxy (ICase (cases as (_, t0))) =
+          in brackets (str "\\" :: ps @ str "->" @@ pr_term tyvars lhs vars' NOBR t') end
+      | pr_term tyvars lhs vars fxy (ICase (cases as (_, t0))) =
           (case CodeThingol.unfold_const_app t0
            of SOME (c_ts as ((c, _), _)) => if is_none (const_syntax c)
-                then pr_case vars fxy cases
-                else pr_app lhs vars fxy c_ts
-            | NONE => pr_case vars fxy cases)
-    and pr_app' lhs vars ((c, _), ts) =
-      (str o deresolv) c :: map (pr_term lhs vars BR) ts
-    and pr_app lhs vars = gen_pr_app pr_app' pr_term const_syntax
+                then pr_case tyvars vars fxy cases
+                else pr_app tyvars lhs vars fxy c_ts
+            | NONE => pr_case tyvars vars fxy cases)
+    and pr_app' tyvars lhs vars ((c, (_, tys)), ts) = case contr_classparam_typs c
+     of [] => (str o deresolv) c :: map (pr_term tyvars lhs vars BR) ts
+      | fingerprint => let
+          val ts_fingerprint = ts ~~ curry Library.take (length ts) fingerprint;
+          val needs_annotation = forall (fn (_, NONE) => true | (t, SOME _) =>
+            (not o CodeThingol.locally_monomorphic) t) ts_fingerprint;
+          fun pr_term_anno (t, NONE) _ = pr_term tyvars lhs vars BR t
+            | pr_term_anno (t, SOME _) ty =
+                brackets [pr_term tyvars lhs vars NOBR t, str "::", pr_typ tyvars NOBR ty];
+        in
+          if needs_annotation then
+            (str o deresolv) c :: map2 pr_term_anno ts_fingerprint (curry Library.take (length ts) tys)
+          else (str o deresolv) c :: map (pr_term tyvars lhs vars BR) ts
+        end
+    and pr_app tyvars lhs vars  = gen_pr_app (pr_app' tyvars) (pr_term tyvars) const_syntax
       labelled_name is_cons lhs vars
-    and pr_bind fxy = pr_bind_haskell pr_term fxy
-    and pr_case vars fxy (cases as ((_, [_]), _)) =
+    and pr_bind tyvars = pr_bind_haskell (pr_term tyvars)
+    and pr_case tyvars vars fxy (cases as ((_, [_]), _)) =
           let
             val (binds, t) = CodeThingol.unfold_let (ICase cases);
             fun pr ((pat, ty), t) vars =
               vars
-              |> pr_bind BR ((NONE, SOME pat), ty)
-              |>> (fn p => semicolon [p, str "=", pr_term false vars NOBR t])
+              |> pr_bind tyvars BR ((NONE, SOME pat), ty)
+              |>> (fn p => semicolon [p, str "=", pr_term tyvars false vars NOBR t])
             val (ps, vars') = fold_map pr binds vars;
           in
             Pretty.block_enclose (
               str "let {",
-              concat [str "}", str "in", pr_term false vars' NOBR t]
+              concat [str "}", str "in", pr_term tyvars false vars' NOBR t]
             ) ps
           end
-      | pr_case vars fxy (((td, ty), bs as _ :: _), _) =
+      | pr_case tyvars vars fxy (((td, ty), bs as _ :: _), _) =
           let
             fun pr (pat, t) =
               let
-                val (p, vars') = pr_bind NOBR ((NONE, SOME pat), ty) vars;
-              in semicolon [p, str "->", pr_term false vars' NOBR t] end;
+                val (p, vars') = pr_bind tyvars NOBR ((NONE, SOME pat), ty) vars;
+              in semicolon [p, str "->", pr_term tyvars false vars' NOBR t] end;
           in
             Pretty.block_enclose (
-              concat [str "(case", pr_term false vars NOBR td, str "of", str "{"],
+              concat [str "(case", pr_term tyvars false vars NOBR td, str "of", str "{"],
               str "})"
             ) (map pr bs)
           end
-      | pr_case vars fxy ((_, []), _) = str "error \"empty case\"";
+      | pr_case tyvars vars fxy ((_, []), _) = str "error \"empty case\"";
     fun pr_def (name, CodeThingol.Fun ((vs, ty), [])) =
           let
             val tyvars = CodeName.intro_vars (map fst vs) init_syms;
@@ -1262,9 +1274,9 @@
               in
                 semicolon (
                   (str o deresolv_here) name
-                  :: map (pr_term true vars BR) ts
+                  :: map (pr_term tyvars true vars BR) ts
                   @ str "="
-                  @@ pr_term false vars NOBR t
+                  @@ pr_term tyvars false vars NOBR t
                 )
               end;
           in
@@ -1284,7 +1296,7 @@
           in
             semicolon [
               str "data",
-              pr_typscheme_expr tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
+              pr_typdecl tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
             ]
           end
       | pr_def (name, CodeThingol.Datatype (vs, [(co, [ty])])) =
@@ -1293,7 +1305,7 @@
           in
             semicolon (
               str "newtype"
-              :: pr_typscheme_expr tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
+              :: pr_typdecl tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
               :: str "="
               :: (str o deresolv_here) co
               :: pr_typ tyvars BR ty
@@ -1311,7 +1323,7 @@
           in
             semicolon (
               str "data"
-              :: pr_typscheme_expr tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
+              :: pr_typdecl tyvars (vs, (deresolv_here name, map (ITyVar o fst) vs))
               :: str "="
               :: pr_co co
               :: map ((fn p => Pretty.block [str "| ", p]) o pr_co) cos
@@ -1331,7 +1343,7 @@
             Pretty.block_enclose (
               Pretty.block [
                 str "class ",
-                pr_typparms tyvars [(v, map fst superclasses)],
+                Pretty.block (pr_typcontext tyvars [(v, map fst superclasses)]),
                 str (deresolv_here name ^ " " ^ CodeName.lookup_var tyvars v),
                 str " where {"
               ],
@@ -1345,13 +1357,13 @@
               semicolon [
                 (str o classparam_name class) classparam,
                 str "=",
-                pr_app false init_syms NOBR (c_inst, [])
+                pr_app tyvars false init_syms NOBR (c_inst, [])
               ];
           in
             Pretty.block_enclose (
               Pretty.block [
                 str "instance ",
-                pr_typparms tyvars vs,
+                Pretty.block (pr_typcontext tyvars vs),
                 str (class_name class ^ " "),
                 pr_typ tyvars BR (tyco `%% map (ITyVar o fst) vs),
                 str " where {"
@@ -1388,6 +1400,7 @@
   let
     val _ = Option.map File.check destination;
     val is_cons = CodeThingol.is_cons code;
+    val contr_classparam_typs = CodeThingol.contr_classparam_typs code;
     val module_alias = if is_some module then K module else raw_module_alias;
     val init_names = Name.make_context reserved_syms;
     val name_modl = mk_modl_name_tab init_names module_prefix module_alias code;
@@ -1460,6 +1473,7 @@
     fun seri_def qualified = pr_haskell class_syntax tyco_syntax
       const_syntax labelled_name init_syms
       deresolv_here (if qualified then deresolv else deresolv_here) is_cons
+      contr_classparam_typs
       (if string_classes then deriving_show else K false);
     fun write_modulefile (SOME destination) modlname =
           let
@@ -1541,7 +1555,7 @@
           ])
       | pr_fun _ = NONE
     val pr = pr_haskell (K NONE) pr_fun (K NONE) labelled_name init_names
-      I I (K false) (K false);
+      I I (K false) (K []) (K false);
   in
     []
     |> Graph.fold (fn (name, (def, _)) =>
--- a/src/Tools/code/code_thingol.ML	Thu Dec 13 07:09:08 2007 +0100
+++ b/src/Tools/code/code_thingol.ML	Thu Dec 13 07:09:09 2007 +0100
@@ -55,6 +55,7 @@
     -> (iterm * itype) * (iterm * iterm) list;
   val eta_expand: (string * (dict list list * itype list)) * iterm list -> int -> iterm;
   val contains_dictvar: iterm -> bool;
+  val locally_monomorphic: iterm -> bool;
   val fold_constnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a;
   val fold_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a;
   val fold_unbound_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a;
@@ -78,6 +79,7 @@
     -> code -> code;
   val empty_funs: code -> string list;
   val is_cons: code -> string -> bool;
+  val contr_classparam_typs: code -> string -> itype option list;
 
   type transact;
   val ensure_const: theory -> ((sort -> sort) * Sorts.algebra) * Consts.T
@@ -244,6 +246,13 @@
       (fn IConst (_, (dss, _)) => (fold o fold) contains dss | _ => I) t false
   end;
   
+fun locally_monomorphic (IConst _) = false
+  | locally_monomorphic (IVar _) = true
+  | locally_monomorphic (t `$ _) = locally_monomorphic t
+  | locally_monomorphic (_ `|-> t) = locally_monomorphic t
+  | locally_monomorphic (ICase ((_, ds), _)) = exists (locally_monomorphic o snd) ds;
+
+
 
 (** definitions, transactions **)
 
@@ -313,6 +322,19 @@
  of Datatypecons _ => true
   | _ => false;
 
+fun contr_classparam_typs code name = case Graph.get_node code name
+ of Classparam class => let
+        val Class (_, (_, params)) = Graph.get_node code class;
+        val SOME ty = AList.lookup (op =) params name;
+        val (tys, res_ty) = unfold_fun ty;
+        fun no_tyvar (_ `%% tys) = forall no_tyvar tys
+          | no_tyvar (ITyVar _) = false;
+      in if no_tyvar res_ty
+        then map (fn ty => if no_tyvar ty then NONE else SOME ty) tys
+        else []
+      end
+  | _ => [];
+
 
 (* transaction protocol *)
 
@@ -486,10 +508,10 @@
     val c' = CodeName.const thy c;
     fun stmt_datatypecons tyco =
       ensure_tyco thy algbr funcgr tyco
-      #>> K (Datatypecons c');
+      #>> Datatypecons;
     fun stmt_classparam class =
       ensure_class thy algbr funcgr class
-      #>> K (Classparam c');
+      #>> Classparam;
     fun stmt_fun trns =
       let
         val raw_thms = CodeFuncgr.funcs funcgr c;