diff -r b5ca587d0885 -r ac78f5cdc430 src/HOL/SMT/Tools/smt_monomorph.ML --- a/src/HOL/SMT/Tools/smt_monomorph.ML Tue Dec 01 22:29:46 2009 +0000 +++ b/src/HOL/SMT/Tools/smt_monomorph.ML Thu Dec 03 15:56:06 2009 +0100 @@ -74,16 +74,18 @@ fun incr_tvar_indices i t = let - val incrT = Logic.incr_tvar i + val incrT = Logic.incr_tvar_same i fun incr t = (case t of Const (n, T) => Const (n, incrT T) | Free (n, T) => Free (n, incrT T) - | Abs (n, T, t1) => Abs (n, incrT T, incr t1) - | t1 $ t2 => incr t1 $ incr t2 - | _ => t) - in incr t end + | Abs (n, T, t1) => (Abs (n, incrT T, incr t1 handle Same.SAME => t1) + handle Same.SAME => Abs (n, T, incr t1)) + | t1 $ t2 => (incr t1 $ (incr t2 handle Same.SAME => t2) + handle Same.SAME => t1 $ incr t2) + | _ => Same.same t) + in incr t handle Same.SAME => t end val monomorph_limit = 10 @@ -93,18 +95,17 @@ create copies of terms containing those constants. To prevent non-termination, there is an upper limit for the number of recursions involved in the fixpoint construction. *) -fun monomorph thy ts = +fun monomorph thy = let - val (ps, ms) = List.partition term_has_tvars ts + fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1) + fun incr_indices ts = fst (fold_map incr ts 0) fun with_tvar (n, Ts) = let val Ts' = filter typ_has_tvars Ts in if null Ts' then NONE else SOME (n, Ts') end - fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1) - val rps = fst (fold_map incr ps 0) - |> map (fn r => (r, map_filter with_tvar (consts_of [r]))) + fun extract_consts_with_tvar t = (t, map_filter with_tvar (consts_of [t])) - fun mono count is ces cs ts = + fun mono rps count is ces cs ts = let val spec = specialize thy cs is val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, []) @@ -113,8 +114,15 @@ if null is' then ts' else if count > monomorph_limit then (warning "monomorphization limit reached"; ts') - else mono (count + 1) is' ces' cs' ts' + else mono rps (count + 1) is' ces' cs' ts' end - in mono 0 (consts_of ms) (map (K []) rps) [] ms end + fun mono_all rps ms = if null rps then ms + else mono rps 0 (consts_of ms) (map (K []) rps) [] ms + in + List.partition term_has_tvars + #>> incr_indices + #>> map extract_consts_with_tvar + #-> mono_all + end end