--- a/src/HOL/SMT/Tools/smt_monomorph.ML	Tue Dec 01 22:29:46 2009 +0000
+++ b/src/HOL/SMT/Tools/smt_monomorph.ML	Thu Dec 03 15:56:06 2009 +0100
@@ -74,16 +74,18 @@
 
 fun incr_tvar_indices i t =
   let
-    val incrT = Logic.incr_tvar i
+    val incrT = Logic.incr_tvar_same i
 
     fun incr t =
       (case t of
         Const (n, T) => Const (n, incrT T)
       | Free (n, T) => Free (n, incrT T)
-      | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
-      | t1 $ t2 => incr t1 $ incr t2
-      | _ => t)
-  in incr t end
+      | Abs (n, T, t1) => (Abs (n, incrT T, incr t1 handle Same.SAME => t1)
+          handle Same.SAME => Abs (n, T, incr t1))
+      | t1 $ t2 => (incr t1 $ (incr t2 handle Same.SAME => t2)
+          handle Same.SAME => t1 $ incr t2)
+      | _ => Same.same t)
+  in incr t handle Same.SAME => t end
 
 
 val monomorph_limit = 10
@@ -93,18 +95,17 @@
    create copies of terms containing those constants.
    To prevent non-termination, there is an upper limit for the number of
    recursions involved in the fixpoint construction. *)
-fun monomorph thy ts =
+fun monomorph thy =
   let
-    val (ps, ms) = List.partition term_has_tvars ts
+    fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
+    fun incr_indices ts = fst (fold_map incr ts 0)
 
     fun with_tvar (n, Ts) =
       let val Ts' = filter typ_has_tvars Ts
       in if null Ts' then NONE else SOME (n, Ts') end
-    fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
-    val rps = fst (fold_map incr ps 0)
-      |> map (fn r => (r, map_filter with_tvar (consts_of [r])))
+    fun extract_consts_with_tvar t = (t, map_filter with_tvar (consts_of [t]))
 
-    fun mono count is ces cs ts =
+    fun mono rps count is ces cs ts =
       let
         val spec = specialize thy cs is
         val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
@@ -113,8 +114,15 @@
         if null is' then ts'
         else if count > monomorph_limit then
           (warning "monomorphization limit reached"; ts')
-        else mono (count + 1) is' ces' cs' ts'
+        else mono rps (count + 1) is' ces' cs' ts'
       end
-  in mono 0 (consts_of ms) (map (K []) rps) [] ms end
+    fun mono_all rps ms = if null rps then ms
+      else mono rps 0 (consts_of ms) (map (K []) rps) [] ms
+  in
+    List.partition term_has_tvars
+    #>> incr_indices
+    #>> map extract_consts_with_tvar
+    #-> mono_all
+  end
 
 end