--- a/src/Tools/Code/code_thingol.ML Wed Sep 07 13:51:32 2011 +0200
+++ b/src/Tools/Code/code_thingol.ML Wed Sep 07 13:51:34 2011 +0200
@@ -609,6 +609,43 @@
(err_typ ^ "\n" ^ err_class)
end;
+(* inference of type annotations for disambiguation with type classes *)
+
+
+fun annotate_term (Const (c', T'), Const (c, T)) tvar_names =
+ let
+ val tvar_names' = Term.add_tvar_namesT T' tvar_names
+ in
+ (Const (c, if eq_set (op =) (tvar_names, tvar_names') then T else Type("", [T])), tvar_names')
+ end
+ | annotate_term (t1 $ u1, t $ u) tvar_names =
+ let
+ val (u', tvar_names') = annotate_term (u1, u) tvar_names
+ val (t', tvar_names'') = annotate_term (t1, t) tvar_names'
+ in
+ (t' $ u', tvar_names'')
+ end
+ | annotate_term (Abs (_, _, t1) , Abs (x, T, t)) tvar_names =
+ apfst (fn t => Abs (x, T, t)) (annotate_term (t1, t) tvar_names)
+ | annotate_term (_, t) tvar_names = (t, tvar_names)
+
+fun annotate_eqns thy eqns =
+ let
+ val ctxt = ProofContext.init_global thy
+ val erase = map_types (fn _ => Type_Infer.anyT [])
+ val reinfer = singleton (Type_Infer_Context.infer_types ctxt)
+ fun add_annotations ((args, (rhs, some_abs)), (SOME th, proper)) =
+ let
+ val (lhs, drhs) = Logic.dest_equals (prop_of (Thm.unvarify_global th))
+ val drhs' = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase drhs))))
+ val (rhs', _) = annotate_term (drhs', rhs) []
+ in
+ ((args, (rhs', some_abs)), (SOME th, proper))
+ end
+ | add_annotations eqn = eqn
+ in
+ map add_annotations eqns
+ end;
(* translation *)