src/Tools/Code/code_thingol.ML
changeset 44997 e534939f880d
parent 44996 410eea28b0f7
child 44998 f12ef61ea76e
--- a/src/Tools/Code/code_thingol.ML	Mon Sep 19 16:18:18 2011 +0200
+++ b/src/Tools/Code/code_thingol.ML	Mon Sep 19 16:18:19 2011 +0200
@@ -581,26 +581,35 @@
 
 (* inference of type annotations for disambiguation with type classes *)
 
-fun annotate_term (Const (c', T'), Const (c, T)) tvar_names =
-    let
-      val tvar_names' = Term.add_tvar_namesT T' tvar_names
-    in
-      (Const (c, if eq_set (op =) (tvar_names, tvar_names') then T else Type("", [T])), tvar_names')
-    end
-  | annotate_term (t1 $ u1, t $ u) tvar_names =
-    let
-      val (u', tvar_names') = annotate_term (u1, u) tvar_names
-      val (t', tvar_names'') = annotate_term (t1, t) tvar_names'    
-    in
-      (t' $ u', tvar_names'')
-    end
-  | annotate_term (Abs (_, _, t1) , Abs (x, T, t)) tvar_names =
-    apfst (fn t => Abs (x, T, t)) (annotate_term (t1, t) tvar_names)
-  | annotate_term (Free _, t as Free _) tvar_names = (t, tvar_names)
-  | annotate_term (Var _, t as Var _) tvar_names = (t, tvar_names)
-  | annotate_term (Bound   _, t as Bound _) tvar_names = (t, tvar_names)
+fun annotate_term (proj_sort, _) eqngr =
+  let
+    val has_sort_constraints = exists (not o null) o map proj_sort o Code_Preproc.sortargs eqngr
+    fun annotate (Const (c', T'), Const (c, T)) tvar_names =
+      let
+        val tvar_names' = Term.add_tvar_namesT T' tvar_names
+      in
+        if not (eq_set (op =) (tvar_names, tvar_names')) andalso has_sort_constraints c then
+          (Const (c, Type ("", [T])), tvar_names')
+        else
+          (Const (c, T), tvar_names)
+      end
+      | annotate (t1 $ u1, t $ u) tvar_names =
+      let
+        val (u', tvar_names') = annotate (u1, u) tvar_names
+        val (t', tvar_names'') = annotate (t1, t) tvar_names'    
+      in
+        (t' $ u', tvar_names'')
+      end
+      | annotate (Abs (_, _, t1) , Abs (x, T, t)) tvar_names =
+        apfst (fn t => Abs (x, T, t)) (annotate (t1, t) tvar_names)
+      | annotate (Free _, t as Free _) tvar_names = (t, tvar_names)
+      | annotate (Var _, t as Var _) tvar_names = (t, tvar_names)
+      | annotate (Bound   _, t as Bound _) tvar_names = (t, tvar_names)
+  in
+    annotate
+  end
 
-fun annotate thy (c, ty) args rhs =
+fun annotate thy algbr eqngr (c, ty) args rhs =
   let
     val ctxt = ProofContext.init_global thy |> Config.put Type_Infer_Context.const_sorts false
     val erase = map_types (fn _ => Type_Infer.anyT [])
@@ -608,11 +617,12 @@
     val lhs = list_comb (Const (c, ty), map (map_types Type.strip_sorts o fst) args)
     val reinferred_rhs = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase rhs))))
   in
-    fst (annotate_term (reinferred_rhs, rhs) [])
+    fst (annotate_term algbr eqngr (reinferred_rhs, rhs) [])
   end
 
-fun annotate_eqns thy (c, ty) eqns = 
-  map (apfst (fn (args, (rhs, some_abs)) => (args, (annotate thy (c, ty) args rhs, some_abs)))) eqns
+fun annotate_eqns thy algbr eqngr (c, ty) eqns = 
+  map (apfst (fn (args, (rhs, some_abs)) => (args,
+    (annotate thy algbr eqngr (c, ty) args rhs, some_abs)))) eqns
 
 (* translation *)
 
@@ -638,7 +648,7 @@
     fun stmt_fun cert =
       let
         val ((vs, ty), eqns) = Code.equations_of_cert thy cert;
-        val eqns' = annotate_eqns thy (c, ty) eqns
+        val eqns' = annotate_eqns thy algbr eqngr (c, ty) eqns
         val some_case_cong = Code.get_case_cong thy c;
       in
         fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs
@@ -919,7 +929,7 @@
     val ty = fastype_of t;
     val vs = fold_term_types (K (fold_atyps (insert (eq_fst op =)
       o dest_TFree))) t [];
-    val t' = annotate thy (Term.dummy_patternN, ty) [] (Code.subst_signatures thy t) 
+    val t' = annotate thy algbr eqngr (Term.dummy_patternN, ty) [] (Code.subst_signatures thy t) 
     val stmt_value =
       fold_map (translate_tyvar_sort thy algbr eqngr false) vs
       ##>> translate_typ thy algbr eqngr false ty