generalized lambda-lifting such that it is less specifically tailored for SMT (it does not anymore dependent on any SMT-specific code)
--- a/src/HOL/SMT.thy Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/SMT.thy Wed Jul 20 12:23:20 2011 +0200
@@ -7,13 +7,13 @@
theory SMT
imports Record
uses
+ "Tools/lambda_lifting.ML"
"Tools/SMT/smt_utils.ML"
"Tools/SMT/smt_failure.ML"
"Tools/SMT/smt_config.ML"
("Tools/SMT/smt_builtin.ML")
("Tools/SMT/smt_datatypes.ML")
("Tools/SMT/smt_normalize.ML")
- ("Tools/lambda_lifting.ML")
("Tools/SMT/smt_translate.ML")
("Tools/SMT/smt_solver.ML")
("Tools/SMT/smtlib_interface.ML")
@@ -138,7 +138,6 @@
use "Tools/SMT/smt_builtin.ML"
use "Tools/SMT/smt_datatypes.ML"
use "Tools/SMT/smt_normalize.ML"
-use "Tools/lambda_lifting.ML"
use "Tools/SMT/smt_translate.ML"
use "Tools/SMT/smt_solver.ML"
use "Tools/SMT/smtlib_interface.ML"
--- a/src/HOL/Tools/SMT/smt_translate.ML Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML Wed Jul 20 12:23:20 2011 +0200
@@ -295,7 +295,7 @@
q $ Abs (x, T, in_trigger (T :: Ts) u)
| (q as Const (@{const_name Ex}, _), [Abs (x, T, u)]) =>
q $ Abs (x, T, in_trigger (T :: Ts) u)
- | (q as Const (@{const_name Let}, _), [u1 as Abs _, u2]) =>
+ | (q as Const (@{const_name Let}, _), [u1, u2 as Abs _]) =>
q $ traverse Ts u1 $ traverse Ts u2
| (u as Const (c as (_, T)), ts) =>
(case SMT_Builtin.dest_builtin ctxt c ts of
@@ -537,12 +537,30 @@
((make_tr_context prefixes, ctxt), ts1)
|-> (if with_datatypes then collect_datatypes_and_records else no_dtyps)
+ fun is_binder (Const (@{const_name Let}, _) $ _) = true
+ | is_binder t = Lambda_Lifting.is_quantifier t
+
+ fun mk_trigger ((q as Const (@{const_name All}, _)) $ Abs (n, T, t)) =
+ q $ Abs (n, T, mk_trigger t)
+ | mk_trigger (eq as (Const (@{const_name HOL.eq}, T) $ lhs $ _)) =
+ Term.domain_type T --> @{typ SMT.pattern}
+ |> (fn T => Const (@{const_name SMT.pat}, T) $ lhs)
+ |> HOLogic.mk_list @{typ SMT.pattern} o single
+ |> HOLogic.mk_list @{typ "SMT.pattern list"} o single
+ |> (fn t => @{const SMT.trigger} $ t $ eq)
+ | mk_trigger t = t
+
val (ctxt2, ts3) =
ts2
|> eta_expand ctxt1 is_fol funcs
- |> Lambda_Lifting.lift_lambdas ctxt1 true
- ||> (op @)
- |-> (fn ctxt1' => pair ctxt1' o intro_explicit_application ctxt1 funcs)
+ |> rpair ctxt1
+ |>> tap (map (tracing o PolyML.makestring))
+ |-> Lambda_Lifting.lift_lambdas NONE is_binder
+ |-> (fn (ts', defs) => fn ctxt' =>
+ map mk_trigger defs @ ts'
+ |> tap (map (tracing o PolyML.makestring))
+ |> intro_explicit_application ctxt' funcs
+ |> pair ctxt')
val ((rewrite_rules, extra_thms, builtin), ts4) =
(if is_fol then folify ctxt2 else pair ([], [], I)) ts3
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Wed Jul 20 12:23:20 2011 +0200
@@ -540,8 +540,9 @@
if trans = concealedN then
rpair [] o map (conceal_lambdas ctxt)
else if trans = liftingN then
- map (close_form o Envir.eta_contract)
- #> Lambda_Lifting.lift_lambdas ctxt false #> snd #> swap
+ map (close_form o Envir.eta_contract) #> rpair ctxt
+ #-> Lambda_Lifting.lift_lambdas NONE Lambda_Lifting.is_quantifier
+ #> fst
else if trans = combinatorsN then
rpair [] o map (introduce_combinators ctxt)
else if trans = lambdasN then
--- a/src/HOL/Tools/lambda_lifting.ML Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/lambda_lifting.ML Wed Jul 20 12:23:20 2011 +0200
@@ -8,32 +8,29 @@
signature LAMBDA_LIFTING =
sig
- val lift_lambdas: Proof.context -> bool -> term list ->
- Proof.context * (term list * term list)
+ type context = (term * term) Termtab.table * Proof.context
+ val init: Proof.context -> context
+ val is_quantifier: term -> bool
+ val lift_lambdas1: (term -> bool) -> string option -> term -> context ->
+ term * context
+ val finish: context -> term list * Proof.context
+ val lift_lambdas: string option -> (term -> bool) -> term list ->
+ Proof.context -> (term list * term list) * Proof.context
end
structure Lambda_Lifting: LAMBDA_LIFTING =
struct
-fun mk_def triggers Ts T lhs rhs =
- let
- val eq = HOLogic.eq_const T $ lhs $ rhs
- fun trigger () =
- [[Const (@{const_name SMT.pat}, T --> @{typ SMT.pattern}) $ lhs]]
- |> map (HOLogic.mk_list @{typ SMT.pattern})
- |> HOLogic.mk_list @{typ "SMT.pattern list"}
- fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
- in
- fold mk_all Ts (if triggers then @{const SMT.trigger} $ trigger () $ eq
- else eq)
- end
+fun mk_def Ts T lhs rhs =
+ let fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
+ in fold mk_all Ts (HOLogic.eq_const T $ lhs $ rhs) end
fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
| dest_abs Ts t = (Ts, t)
-fun replace_lambda triggers Us Ts t (cx as (defs, ctxt)) =
+fun replace_lambda basename Us Ts t (cx as (defs, ctxt)) =
let
val t1 = mk_abs Us t
val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
@@ -51,37 +48,43 @@
SOME (f, _) => (app f, cx)
| NONE =>
let
- val (n, ctxt') =
- yield_singleton Variable.variant_fixes Name.uu ctxt
+ val (n, ctxt') = yield_singleton Variable.variant_fixes basename ctxt
val (is, UTs) = split_list (map_index I (Us @ Ts'))
val f = Free (n, rev UTs ---> T)
val lhs = Term.list_comb (f, map Bound (rev is))
- val def = mk_def triggers UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
+ val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
end
-fun traverse triggers Ts t =
- (case t of
- (q as Const (@{const_name All}, _)) $ Abs a =>
- abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
- | (q as Const (@{const_name Ex}, _)) $ Abs a =>
- abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
- | (l as Const (@{const_name Let}, _)) $ u $ Abs a =>
- traverse triggers Ts u ##>> abs_traverse triggers Ts a #>>
- (fn (u', a') => l $ u' $ Abs a')
- | Abs _ =>
- let val (Us, u) = dest_abs [] t
- in traverse triggers (Us @ Ts) u #-> replace_lambda triggers Us Ts end
- | u1 $ u2 => traverse triggers Ts u1 ##>> traverse triggers Ts u2 #>> (op $)
- | _ => pair t)
+type context = (term * term) Termtab.table * Proof.context
+
+fun init ctxt = (Termtab.empty, ctxt)
+
+fun is_quantifier (Const (@{const_name All}, _)) = true
+ | is_quantifier (Const (@{const_name Ex}, _)) = true
+ | is_quantifier _ = false
+
+fun lift_lambdas1 is_binder basename =
+ let
+ val basename' = the_default Name.uu basename
-and abs_traverse triggers Ts (n, T, t) =
- traverse triggers (T::Ts) t #>> (fn t' => (n, T, t'))
+ fun traverse Ts (t $ (u as Abs (n, T, body))) =
+ if is_binder t then
+ traverse Ts t ##>> traverse (T :: Ts) body #>> (fn (t', body') =>
+ t' $ Abs (n, T, body'))
+ else traverse Ts t ##>> traverse Ts u #>> (op $)
+ | traverse Ts (t as Abs _) =
+ let val (Us, u) = dest_abs [] t
+ in traverse (Us @ Ts) u #-> replace_lambda basename' Us Ts end
+ | traverse Ts (t $ u) = traverse Ts t ##>> traverse Ts u #>> (op $)
+ | traverse _ t = pair t
+ in traverse [] end
-fun lift_lambdas ctxt triggers ts =
- (Termtab.empty, ctxt)
- |> fold_map (traverse triggers []) ts
- |> (fn (us, (defs, ctxt')) =>
- (ctxt', (Termtab.fold (cons o snd o snd) defs [], us)))
+fun finish (defs, ctxt) = (Termtab.fold (cons o snd o snd) defs [], ctxt)
+
+fun lift_lambdas basename is_binder ts ctxt =
+ init ctxt
+ |> fold_map (lift_lambdas1 is_binder basename) ts
+ |-> (fn ts' => finish #>> pair ts')
end