revisiting type annotations for Haskell: necessary type annotations are not inferred on the provided theorems but using the arguments and right hand sides, as these might differ in the case of constants with abstract code types
authorbulwahn
Fri, 09 Sep 2011 12:33:09 +0200
changeset 44854 0b3d3570ab31
parent 44852 8ac91e7b6024
child 44855 f4a6786057d9
revisiting type annotations for Haskell: necessary type annotations are not inferred on the provided theorems but using the arguments and right hand sides, as these might differ in the case of constants with abstract code types
src/Tools/Code/code_thingol.ML
--- a/src/Tools/Code/code_thingol.ML	Thu Sep 08 12:23:11 2011 +0200
+++ b/src/Tools/Code/code_thingol.ML	Fri Sep 09 12:33:09 2011 +0200
@@ -598,22 +598,21 @@
     apfst (fn t => Abs (x, T, t)) (annotate_term (t1, t) tvar_names)
   | annotate_term (_, t) tvar_names = (t, tvar_names)
 
-fun annotate_eqns thy eqns = 
+fun annotate_eqns thy (c, ty) eqns = 
   let
     val ctxt = ProofContext.init_global thy |> Config.put Type_Infer_Context.const_sorts false
     val erase = map_types (fn _ => Type_Infer.anyT [])
     val reinfer = singleton (Type_Infer_Context.infer_types ctxt)
-    fun add_annotations ((args, (rhs, some_abs)), (SOME th, proper)) =
+    fun add_annotations (args, (rhs, some_abs)) =
       let
-        val (lhs, drhs) = Logic.dest_equals (prop_of (Thm.unvarify_global th))
-        val drhs' = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase drhs))))
-        val (rhs', _) = annotate_term (drhs', rhs) []
+        val lhs = list_comb (Const (c, ty), map (map_types Type.strip_sorts o fst) args)
+        val reinferred_rhs = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase rhs))))
+        val (rhs', _) = annotate_term (reinferred_rhs, rhs) []
      in
-        ((args, (rhs', some_abs)), (SOME th, proper))
+        (args, (rhs', some_abs))
      end
-     | add_annotations eqn = eqn
   in
-    map add_annotations eqns
+    map (apfst add_annotations) eqns
   end;
 
 (* translation *)
@@ -640,7 +639,7 @@
     fun stmt_fun cert =
       let
         val ((vs, ty), eqns) = Code.equations_of_cert thy cert;
-        val eqns' = annotate_eqns thy eqns
+        val eqns' = annotate_eqns thy (c, ty) eqns
         val some_case_cong = Code.get_case_cong thy c;
       in
         fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs