moved lambda-lifting on terms into a separate structure (for better re-use in tools other than SMT)
authorboehmes
Wed Jul 20 09:23:12 2011 +0200 (2011-07-20)
changeset 4392824d6e759753f
parent 43927 3a87cb597832
child 43929 61d432e51aff
moved lambda-lifting on terms into a separate structure (for better re-use in tools other than SMT)
src/HOL/SMT.thy
src/HOL/Tools/SMT/smt_translate.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
src/HOL/Tools/lambda_lifting.ML
     1.1 --- a/src/HOL/SMT.thy	Wed Jul 20 09:23:09 2011 +0200
     1.2 +++ b/src/HOL/SMT.thy	Wed Jul 20 09:23:12 2011 +0200
     1.3 @@ -13,6 +13,7 @@
     1.4    ("Tools/SMT/smt_builtin.ML")
     1.5    ("Tools/SMT/smt_datatypes.ML")
     1.6    ("Tools/SMT/smt_normalize.ML")
     1.7 +  ("Tools/lambda_lifting.ML")
     1.8    ("Tools/SMT/smt_translate.ML")
     1.9    ("Tools/SMT/smt_solver.ML")
    1.10    ("Tools/SMT/smtlib_interface.ML")
    1.11 @@ -137,6 +138,7 @@
    1.12  use "Tools/SMT/smt_builtin.ML"
    1.13  use "Tools/SMT/smt_datatypes.ML"
    1.14  use "Tools/SMT/smt_normalize.ML"
    1.15 +use "Tools/lambda_lifting.ML"
    1.16  use "Tools/SMT/smt_translate.ML"
    1.17  use "Tools/SMT/smt_solver.ML"
    1.18  use "Tools/SMT/smtlib_interface.ML"
     2.1 --- a/src/HOL/Tools/SMT/smt_translate.ML	Wed Jul 20 09:23:09 2011 +0200
     2.2 +++ b/src/HOL/Tools/SMT/smt_translate.ML	Wed Jul 20 09:23:12 2011 +0200
     2.3 @@ -38,8 +38,6 @@
     2.4    (*translation*)
     2.5    val add_config: SMT_Utils.class * (Proof.context -> config) ->
     2.6      Context.generic -> Context.generic 
     2.7 -  val lift_lambdas: Proof.context -> bool -> term list ->
     2.8 -    Proof.context * (term list * term list)
     2.9    val translate: Proof.context -> string list -> (int * thm) list ->
    2.10      string * recon
    2.11  end
    2.12 @@ -243,82 +241,6 @@
    2.13  end
    2.14  
    2.15  
    2.16 -(** lambda-lifting **)
    2.17 -
    2.18 -local
    2.19 -  fun mk_def triggers Ts T lhs rhs =
    2.20 -    let
    2.21 -      val eq = HOLogic.eq_const T $ lhs $ rhs
    2.22 -      fun trigger () =
    2.23 -        [[Const (@{const_name SMT.pat}, T --> @{typ SMT.pattern}) $ lhs]]
    2.24 -        |> map (HOLogic.mk_list @{typ SMT.pattern})
    2.25 -        |> HOLogic.mk_list @{typ "SMT.pattern list"}
    2.26 -      fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
    2.27 -    in
    2.28 -      fold mk_all Ts (if triggers then @{const SMT.trigger} $ trigger () $ eq
    2.29 -        else eq)
    2.30 -    end
    2.31 -
    2.32 -  fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
    2.33 -
    2.34 -  fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
    2.35 -    | dest_abs Ts t = (Ts, t)
    2.36 -
    2.37 -  fun replace_lambda triggers Us Ts t (cx as (defs, ctxt)) =
    2.38 -    let
    2.39 -      val t1 = mk_abs Us t
    2.40 -      val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
    2.41 -      fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
    2.42 -      val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
    2.43 -      val t2 = Term.subst_bounds (rs, t1)
    2.44 -      val Ts' = map (nth Ts) bs 
    2.45 -      val (_, t3) = dest_abs [] t2
    2.46 -      val t4 = mk_abs Ts' t2
    2.47 -
    2.48 -      val T = Term.fastype_of1 (Us @ Ts, t)
    2.49 -      fun app f = Term.list_comb (f, map Bound (rev bs))
    2.50 -    in
    2.51 -      (case Termtab.lookup defs t4 of
    2.52 -        SOME (f, _) => (app f, cx)
    2.53 -      | NONE =>
    2.54 -          let
    2.55 -            val (n, ctxt') =
    2.56 -              yield_singleton Variable.variant_fixes Name.uu ctxt
    2.57 -            val (is, UTs) = split_list (map_index I (Us @ Ts'))
    2.58 -            val f = Free (n, rev UTs ---> T)
    2.59 -            val lhs = Term.list_comb (f, map Bound (rev is))
    2.60 -            val def = mk_def triggers UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
    2.61 -          in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
    2.62 -    end
    2.63 -
    2.64 -  fun traverse triggers Ts t =
    2.65 -    (case t of
    2.66 -      (q as Const (@{const_name All}, _)) $ Abs a =>
    2.67 -        abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
    2.68 -    | (q as Const (@{const_name Ex}, _)) $ Abs a =>
    2.69 -        abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
    2.70 -    | (l as Const (@{const_name Let}, _)) $ u $ Abs a =>
    2.71 -        traverse triggers Ts u ##>> abs_traverse triggers Ts a #>>
    2.72 -        (fn (u', a') => l $ u' $ Abs a')
    2.73 -    | Abs _ =>
    2.74 -        let val (Us, u) = dest_abs [] t
    2.75 -        in traverse triggers (Us @ Ts) u #-> replace_lambda triggers Us Ts end
    2.76 -    | u1 $ u2 => traverse triggers Ts u1 ##>> traverse triggers Ts u2 #>> (op $)
    2.77 -    | _ => pair t)
    2.78 -
    2.79 -  and abs_traverse triggers Ts (n, T, t) =
    2.80 -    traverse triggers (T::Ts) t #>> (fn t' => (n, T, t'))
    2.81 -in
    2.82 -
    2.83 -fun lift_lambdas ctxt triggers ts =
    2.84 -  (Termtab.empty, ctxt)
    2.85 -  |> fold_map (traverse triggers []) ts
    2.86 -  |> (fn (us, (defs, ctxt')) =>
    2.87 -       (ctxt', (Termtab.fold (cons o snd o snd) defs [], us)))
    2.88 -
    2.89 -end
    2.90 -
    2.91 -
    2.92  (** introduce explicit applications **)
    2.93  
    2.94  local
    2.95 @@ -618,7 +540,7 @@
    2.96      val (ctxt2, ts3) =
    2.97        ts2
    2.98        |> eta_expand ctxt1 is_fol funcs
    2.99 -      |> lift_lambdas ctxt1 true
   2.100 +      |> Lambda_Lifting.lift_lambdas ctxt1 true
   2.101        ||> (op @)
   2.102        |-> (fn ctxt1' => pair ctxt1' o intro_explicit_application ctxt1 funcs)
   2.103  
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 20 09:23:09 2011 +0200
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 20 09:23:12 2011 +0200
     3.3 @@ -541,7 +541,7 @@
     3.4             rpair [] o map (conceal_lambdas ctxt)
     3.5           else if trans = liftingN then
     3.6             map (close_form o Envir.eta_contract)
     3.7 -           #> SMT_Translate.lift_lambdas ctxt false #> snd #> swap
     3.8 +           #> Lambda_Lifting.lift_lambdas ctxt false #> snd #> swap
     3.9           else if trans = combinatorsN then
    3.10             rpair [] o map (introduce_combinators ctxt)
    3.11           else if trans = lambdasN then
     4.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     4.2 +++ b/src/HOL/Tools/lambda_lifting.ML	Wed Jul 20 09:23:12 2011 +0200
     4.3 @@ -0,0 +1,87 @@
     4.4 +(*  Title:      HOL/Tools/lambda_lifting.ML
     4.5 +    Author:     Sascha Boehme, TU Muenchen
     4.6 +
     4.7 +Lambda-lifting on terms, i.e., replacing (some) lambda-abstractions by
     4.8 +fresh names accompanied with defining equations for these fresh names in
     4.9 +terms of the lambda-abstractions' bodies.
    4.10 +*)
    4.11 +
    4.12 +signature LAMBDA_LIFTING =
    4.13 +sig
    4.14 +  val lift_lambdas: Proof.context -> bool -> term list ->
    4.15 +    Proof.context * (term list * term list)
    4.16 +end
    4.17 +
    4.18 +structure Lambda_Lifting: LAMBDA_LIFTING =
    4.19 +struct
    4.20 +
    4.21 +fun mk_def triggers Ts T lhs rhs =
    4.22 +  let
    4.23 +    val eq = HOLogic.eq_const T $ lhs $ rhs
    4.24 +    fun trigger () =
    4.25 +      [[Const (@{const_name SMT.pat}, T --> @{typ SMT.pattern}) $ lhs]]
    4.26 +      |> map (HOLogic.mk_list @{typ SMT.pattern})
    4.27 +      |> HOLogic.mk_list @{typ "SMT.pattern list"}
    4.28 +    fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
    4.29 +  in
    4.30 +    fold mk_all Ts (if triggers then @{const SMT.trigger} $ trigger () $ eq
    4.31 +      else eq)
    4.32 +  end
    4.33 +
    4.34 +fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
    4.35 +
    4.36 +fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
    4.37 +  | dest_abs Ts t = (Ts, t)
    4.38 +
    4.39 +fun replace_lambda triggers Us Ts t (cx as (defs, ctxt)) =
    4.40 +  let
    4.41 +    val t1 = mk_abs Us t
    4.42 +    val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
    4.43 +    fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
    4.44 +    val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
    4.45 +    val t2 = Term.subst_bounds (rs, t1)
    4.46 +    val Ts' = map (nth Ts) bs 
    4.47 +    val (_, t3) = dest_abs [] t2
    4.48 +    val t4 = mk_abs Ts' t2
    4.49 +
    4.50 +    val T = Term.fastype_of1 (Us @ Ts, t)
    4.51 +    fun app f = Term.list_comb (f, map Bound (rev bs))
    4.52 +  in
    4.53 +    (case Termtab.lookup defs t4 of
    4.54 +      SOME (f, _) => (app f, cx)
    4.55 +    | NONE =>
    4.56 +        let
    4.57 +          val (n, ctxt') =
    4.58 +            yield_singleton Variable.variant_fixes Name.uu ctxt
    4.59 +          val (is, UTs) = split_list (map_index I (Us @ Ts'))
    4.60 +          val f = Free (n, rev UTs ---> T)
    4.61 +          val lhs = Term.list_comb (f, map Bound (rev is))
    4.62 +          val def = mk_def triggers UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
    4.63 +        in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
    4.64 +  end
    4.65 +
    4.66 +fun traverse triggers Ts t =
    4.67 +  (case t of
    4.68 +    (q as Const (@{const_name All}, _)) $ Abs a =>
    4.69 +      abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
    4.70 +  | (q as Const (@{const_name Ex}, _)) $ Abs a =>
    4.71 +      abs_traverse triggers Ts a #>> (fn a' => q $ Abs a')
    4.72 +  | (l as Const (@{const_name Let}, _)) $ u $ Abs a =>
    4.73 +      traverse triggers Ts u ##>> abs_traverse triggers Ts a #>>
    4.74 +      (fn (u', a') => l $ u' $ Abs a')
    4.75 +  | Abs _ =>
    4.76 +      let val (Us, u) = dest_abs [] t
    4.77 +      in traverse triggers (Us @ Ts) u #-> replace_lambda triggers Us Ts end
    4.78 +  | u1 $ u2 => traverse triggers Ts u1 ##>> traverse triggers Ts u2 #>> (op $)
    4.79 +  | _ => pair t)
    4.80 +
    4.81 +and abs_traverse triggers Ts (n, T, t) =
    4.82 +  traverse triggers (T::Ts) t #>> (fn t' => (n, T, t'))
    4.83 +
    4.84 +fun lift_lambdas ctxt triggers ts =
    4.85 +  (Termtab.empty, ctxt)
    4.86 +  |> fold_map (traverse triggers []) ts
    4.87 +  |> (fn (us, (defs, ctxt')) =>
    4.88 +       (ctxt', (Termtab.fold (cons o snd o snd) defs [], us)))
    4.89 +
    4.90 +end