corrected type inference of primitive definitions
authorhaftmann
Wed, 11 Mar 2009 15:56:52 +0100
changeset 30451 11e5e8bb28f9
parent 30450 7655e6533209
child 30452 f00b993bda0d
corrected type inference of primitive definitions
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Wed Mar 11 15:56:51 2009 +0100
+++ b/src/HOL/Tools/primrec_package.ML	Wed Mar 11 15:56:52 2009 +0100
@@ -26,7 +26,7 @@
 fun primrec_error msg = raise PrimrecError (msg, NONE);
 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
 
-fun message s = if ! Toplevel.debug then () else writeln s;
+fun message s = if ! Toplevel.debug then tracing s else ();
 
 
 (* preprocessing of equations *)
@@ -187,14 +187,13 @@
 
 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   let
-    val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t))
-                    (map snd ls @ [dummyT])
-                    (list_comb (Const (rec_name, dummyT),
-                                fs @ map Bound (0 :: (length ls downto 1))))
+    val SOME (var, varT) = get_first (fn ((b, T), mx) =>
+      if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
     val def_name = Thm.def_name (Long_Name.base_name fname);
-    val rhs = singleton (Syntax.check_terms ctxt) raw_rhs;
-    val SOME var = get_first (fn ((b, _), mx) =>
-      if Binding.name_of b = fname then SOME (b, mx) else NONE) fixes;
+    val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
+      (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
+    val rhs = singleton (Syntax.check_terms ctxt)
+      (TypeInfer.constrain varT raw_rhs);
   in (var, ((Binding.name def_name, []), rhs)) end;
 
 
@@ -220,12 +219,12 @@
       raw_fixes (map (single o apsnd single) raw_spec) ctxt
   in (fixes, map (apsnd the_single) spec) end;
 
-fun prove_spec ctxt names rec_rewrites defs =
+fun prove_spec ctxt names rec_rewrites defs eqs =
   let
     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
-  in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) end;
+  in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end;
 
 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
   let