generalized lambda-lifting such that it is less specifically tailored for SMT (it does not anymore dependent on any SMT-specific code)
authorboehmes
Wed, 20 Jul 2011 12:23:20 +0200
changeset 43929 61d432e51aff
parent 43928 24d6e759753f
child 43930 cb7914f6e9b3
generalized lambda-lifting such that it is less specifically tailored for SMT (it does not anymore dependent on any SMT-specific code)
src/HOL/SMT.thy
src/HOL/Tools/SMT/smt_translate.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
src/HOL/Tools/lambda_lifting.ML
--- a/src/HOL/SMT.thy	Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/SMT.thy	Wed Jul 20 12:23:20 2011 +0200
@@ -7,13 +7,13 @@
 theory SMT
 imports Record
 uses
+  "Tools/lambda_lifting.ML"
   "Tools/SMT/smt_utils.ML"
   "Tools/SMT/smt_failure.ML"
   "Tools/SMT/smt_config.ML"
   ("Tools/SMT/smt_builtin.ML")
   ("Tools/SMT/smt_datatypes.ML")
   ("Tools/SMT/smt_normalize.ML")
-  ("Tools/lambda_lifting.ML")
   ("Tools/SMT/smt_translate.ML")
   ("Tools/SMT/smt_solver.ML")
   ("Tools/SMT/smtlib_interface.ML")
@@ -138,7 +138,6 @@
 use "Tools/SMT/smt_builtin.ML"
 use "Tools/SMT/smt_datatypes.ML"
 use "Tools/SMT/smt_normalize.ML"
-use "Tools/lambda_lifting.ML"
 use "Tools/SMT/smt_translate.ML"
 use "Tools/SMT/smt_solver.ML"
 use "Tools/SMT/smtlib_interface.ML"
--- a/src/HOL/Tools/SMT/smt_translate.ML	Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Wed Jul 20 12:23:20 2011 +0200
@@ -295,7 +295,7 @@
           q $ Abs (x, T, in_trigger (T :: Ts) u)
       | (q as Const (@{const_name Ex}, _), [Abs (x, T, u)]) =>
           q $ Abs (x, T, in_trigger (T :: Ts) u)
-      | (q as Const (@{const_name Let}, _), [u1 as Abs _, u2]) =>
+      | (q as Const (@{const_name Let}, _), [u1, u2 as Abs _]) =>
           q $ traverse Ts u1 $ traverse Ts u2
       | (u as Const (c as (_, T)), ts) =>
           (case SMT_Builtin.dest_builtin ctxt c ts of
@@ -537,12 +537,30 @@
       ((make_tr_context prefixes, ctxt), ts1)
       |-> (if with_datatypes then collect_datatypes_and_records else no_dtyps)
 
+    fun is_binder (Const (@{const_name Let}, _) $ _) = true
+      | is_binder t = Lambda_Lifting.is_quantifier t
+
+    fun mk_trigger ((q as Const (@{const_name All}, _)) $ Abs (n, T, t)) =
+          q $ Abs (n, T, mk_trigger t)
+      | mk_trigger (eq as (Const (@{const_name HOL.eq}, T) $ lhs $ _)) =
+          Term.domain_type T --> @{typ SMT.pattern}
+          |> (fn T => Const (@{const_name SMT.pat}, T) $ lhs)
+          |> HOLogic.mk_list @{typ SMT.pattern} o single
+          |> HOLogic.mk_list @{typ "SMT.pattern list"} o single
+          |> (fn t => @{const SMT.trigger} $ t $ eq)
+      | mk_trigger t = t
+
     val (ctxt2, ts3) =
       ts2
       |> eta_expand ctxt1 is_fol funcs
-      |> Lambda_Lifting.lift_lambdas ctxt1 true
-      ||> (op @)
-      |-> (fn ctxt1' => pair ctxt1' o intro_explicit_application ctxt1 funcs)
+      |> rpair ctxt1
+      |>> tap (map (tracing o PolyML.makestring))
+      |-> Lambda_Lifting.lift_lambdas NONE is_binder
+      |-> (fn (ts', defs) => fn ctxt' =>
+          map mk_trigger defs @ ts'
+          |> tap (map (tracing o PolyML.makestring))
+          |> intro_explicit_application ctxt' funcs 
+          |> pair ctxt')
 
     val ((rewrite_rules, extra_thms, builtin), ts4) =
       (if is_fol then folify ctxt2 else pair ([], [], I)) ts3
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 20 12:23:20 2011 +0200
@@ -540,8 +540,9 @@
          if trans = concealedN then
            rpair [] o map (conceal_lambdas ctxt)
          else if trans = liftingN then
-           map (close_form o Envir.eta_contract)
-           #> Lambda_Lifting.lift_lambdas ctxt false #> snd #> swap
+           map (close_form o Envir.eta_contract) #> rpair ctxt
+           #-> Lambda_Lifting.lift_lambdas NONE Lambda_Lifting.is_quantifier
+           #> fst
          else if trans = combinatorsN then
            rpair [] o map (introduce_combinators ctxt)
          else if trans = lambdasN then
--- a/src/HOL/Tools/lambda_lifting.ML	Wed Jul 20 09:23:12 2011 +0200
+++ b/src/HOL/Tools/lambda_lifting.ML	Wed Jul 20 12:23:20 2011 +0200
@@ -8,32 +8,29 @@
 
 signature LAMBDA_LIFTING =
 sig
-  val lift_lambdas: Proof.context -> bool -> term list ->
-    Proof.context * (term list * term list)
+  type context = (term * term) Termtab.table * Proof.context
+  val init: Proof.context -> context
+  val is_quantifier: term -> bool
+  val lift_lambdas1: (term -> bool) -> string option -> term -> context ->
+    term * context
+  val finish: context -> term list * Proof.context
+  val lift_lambdas: string option -> (term -> bool) -> term list ->
+    Proof.context -> (term list * term list) * Proof.context
 end
 
 structure Lambda_Lifting: LAMBDA_LIFTING =
 struct
 
-fun mk_def triggers Ts T lhs rhs =
-  let
-    val eq = HOLogic.eq_const T $ lhs $ rhs
-    fun trigger () =
-      [[Const (@{const_name SMT.pat}, T --> @{typ SMT.pattern}) $ lhs]]
-      |> map (HOLogic.mk_list @{typ SMT.pattern})
-      |> HOLogic.mk_list @{typ "SMT.pattern list"}
-    fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
-  in
-    fold mk_all Ts (if triggers then @{const SMT.trigger} $ trigger () $ eq
-      else eq)
-  end
+fun mk_def Ts T lhs rhs =
+  let fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
+  in fold mk_all Ts (HOLogic.eq_const T $ lhs $ rhs) end
 
 fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
 
 fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
   | dest_abs Ts t = (Ts, t)
 
-fun replace_lambda triggers Us Ts t (cx as (defs, ctxt)) =
+fun replace_lambda basename Us Ts t (cx as (defs, ctxt)) =
   let
     val t1 = mk_abs Us t
     val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
@@ -51,37 +48,43 @@
       SOME (f, _) => (app f, cx)
     | NONE =>
         let
-          val (n, ctxt') =
-            yield_singleton Variable.variant_fixes Name.uu ctxt
+          val (n, ctxt') = yield_singleton Variable.variant_fixes basename ctxt
           val (is, UTs) = split_list (map_index I (Us @ Ts'))
           val f = Free (n, rev UTs ---> T)
           val lhs = Term.list_comb (f, map Bound (rev is))
-          val def = mk_def triggers UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
+          val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
         in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
   end
 
-fun traverse triggers Ts t =
-  (case t of
-    (q as Const (@{const_name All}, _)) $ Abs a =>
-      abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
-  | (q as Const (@{const_name Ex}, _)) $ Abs a =>
-      abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
-  | (l as Const (@{const_name Let}, _)) $ u $ Abs a =>
-      traverse triggers Ts u ##>> abs_traverse triggers Ts a #>>
-      (fn (u', a') => l $ u' $ Abs a')
-  | Abs _ =>
-      let val (Us, u) = dest_abs [] t
-      in traverse triggers (Us @ Ts) u #-> replace_lambda triggers Us Ts end
-  | u1 $ u2 => traverse triggers Ts u1 ##>> traverse triggers Ts u2 #>> (op $)
-  | _ => pair t)
+type context = (term * term) Termtab.table * Proof.context
+
+fun init ctxt = (Termtab.empty, ctxt)
+
+fun is_quantifier (Const (@{const_name All}, _)) = true
+  | is_quantifier (Const (@{const_name Ex}, _)) = true
+  | is_quantifier _ = false
+
+fun lift_lambdas1 is_binder basename =
+  let
+    val basename' = the_default Name.uu basename
 
-and abs_traverse triggers Ts (n, T, t) =
-  traverse triggers (T::Ts) t #>> (fn t' => (n, T, t'))
+    fun traverse Ts (t $ (u as Abs (n, T, body))) =
+          if is_binder t then
+            traverse Ts t ##>> traverse (T :: Ts) body #>> (fn (t', body') =>
+            t' $ Abs (n, T, body'))
+          else traverse Ts t ##>> traverse Ts u #>> (op $)
+      | traverse Ts (t as Abs _) =
+          let val (Us, u) = dest_abs [] t
+          in traverse (Us @ Ts) u #-> replace_lambda basename' Us Ts end
+      | traverse Ts (t $ u) = traverse Ts t ##>> traverse Ts u #>> (op $)
+      | traverse _ t = pair t
+  in traverse [] end
 
-fun lift_lambdas ctxt triggers ts =
-  (Termtab.empty, ctxt)
-  |> fold_map (traverse triggers []) ts
-  |> (fn (us, (defs, ctxt')) =>
-       (ctxt', (Termtab.fold (cons o snd o snd) defs [], us)))
+fun finish (defs, ctxt) = (Termtab.fold (cons o snd o snd) defs [], ctxt)
+
+fun lift_lambdas basename is_binder ts ctxt =
+  init ctxt
+  |> fold_map (lift_lambdas1 is_binder basename) ts
+  |-> (fn ts' => finish #>> pair ts')
 
 end