src/HOL/Tools/lambda_lifting.ML
author wenzelm
Tue Sep 26 20:54:40 2017 +0200 (23 months ago)
changeset 66695 91500c024c7f
parent 43929 61d432e51aff
child 69593 3dda49e08b9d
permissions -rw-r--r--
tuned;
boehmes@43928
     1
(*  Title:      HOL/Tools/lambda_lifting.ML
boehmes@43928
     2
    Author:     Sascha Boehme, TU Muenchen
boehmes@43928
     3
boehmes@43928
     4
Lambda-lifting on terms, i.e., replacing (some) lambda-abstractions by
boehmes@43928
     5
fresh names accompanied with defining equations for these fresh names in
boehmes@43928
     6
terms of the lambda-abstractions' bodies.
boehmes@43928
     7
*)
boehmes@43928
     8
boehmes@43928
     9
signature LAMBDA_LIFTING =
boehmes@43928
    10
sig
boehmes@43929
    11
  type context = (term * term) Termtab.table * Proof.context
boehmes@43929
    12
  val init: Proof.context -> context
boehmes@43929
    13
  val is_quantifier: term -> bool
boehmes@43929
    14
  val lift_lambdas1: (term -> bool) -> string option -> term -> context ->
boehmes@43929
    15
    term * context
boehmes@43929
    16
  val finish: context -> term list * Proof.context
boehmes@43929
    17
  val lift_lambdas: string option -> (term -> bool) -> term list ->
boehmes@43929
    18
    Proof.context -> (term list * term list) * Proof.context
boehmes@43928
    19
end
boehmes@43928
    20
boehmes@43928
    21
structure Lambda_Lifting: LAMBDA_LIFTING =
boehmes@43928
    22
struct
boehmes@43928
    23
boehmes@43929
    24
fun mk_def Ts T lhs rhs =
boehmes@43929
    25
  let fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
boehmes@43929
    26
  in fold mk_all Ts (HOLogic.eq_const T $ lhs $ rhs) end
boehmes@43928
    27
boehmes@43928
    28
fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
boehmes@43928
    29
boehmes@43928
    30
fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
boehmes@43928
    31
  | dest_abs Ts t = (Ts, t)
boehmes@43928
    32
boehmes@43929
    33
fun replace_lambda basename Us Ts t (cx as (defs, ctxt)) =
boehmes@43928
    34
  let
boehmes@43928
    35
    val t1 = mk_abs Us t
boehmes@43928
    36
    val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
boehmes@43928
    37
    fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
boehmes@43928
    38
    val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
boehmes@43928
    39
    val t2 = Term.subst_bounds (rs, t1)
boehmes@43928
    40
    val Ts' = map (nth Ts) bs 
boehmes@43928
    41
    val (_, t3) = dest_abs [] t2
boehmes@43928
    42
    val t4 = mk_abs Ts' t2
boehmes@43928
    43
boehmes@43928
    44
    val T = Term.fastype_of1 (Us @ Ts, t)
boehmes@43928
    45
    fun app f = Term.list_comb (f, map Bound (rev bs))
boehmes@43928
    46
  in
boehmes@43928
    47
    (case Termtab.lookup defs t4 of
boehmes@43928
    48
      SOME (f, _) => (app f, cx)
boehmes@43928
    49
    | NONE =>
boehmes@43928
    50
        let
boehmes@43929
    51
          val (n, ctxt') = yield_singleton Variable.variant_fixes basename ctxt
boehmes@43928
    52
          val (is, UTs) = split_list (map_index I (Us @ Ts'))
boehmes@43928
    53
          val f = Free (n, rev UTs ---> T)
boehmes@43928
    54
          val lhs = Term.list_comb (f, map Bound (rev is))
boehmes@43929
    55
          val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
boehmes@43928
    56
        in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
boehmes@43928
    57
  end
boehmes@43928
    58
boehmes@43929
    59
type context = (term * term) Termtab.table * Proof.context
boehmes@43929
    60
boehmes@43929
    61
fun init ctxt = (Termtab.empty, ctxt)
boehmes@43929
    62
boehmes@43929
    63
fun is_quantifier (Const (@{const_name All}, _)) = true
boehmes@43929
    64
  | is_quantifier (Const (@{const_name Ex}, _)) = true
boehmes@43929
    65
  | is_quantifier _ = false
boehmes@43929
    66
boehmes@43929
    67
fun lift_lambdas1 is_binder basename =
boehmes@43929
    68
  let
boehmes@43929
    69
    val basename' = the_default Name.uu basename
boehmes@43928
    70
boehmes@43929
    71
    fun traverse Ts (t $ (u as Abs (n, T, body))) =
boehmes@43929
    72
          if is_binder t then
boehmes@43929
    73
            traverse Ts t ##>> traverse (T :: Ts) body #>> (fn (t', body') =>
boehmes@43929
    74
            t' $ Abs (n, T, body'))
boehmes@43929
    75
          else traverse Ts t ##>> traverse Ts u #>> (op $)
boehmes@43929
    76
      | traverse Ts (t as Abs _) =
boehmes@43929
    77
          let val (Us, u) = dest_abs [] t
boehmes@43929
    78
          in traverse (Us @ Ts) u #-> replace_lambda basename' Us Ts end
boehmes@43929
    79
      | traverse Ts (t $ u) = traverse Ts t ##>> traverse Ts u #>> (op $)
boehmes@43929
    80
      | traverse _ t = pair t
boehmes@43929
    81
  in traverse [] end
boehmes@43928
    82
boehmes@43929
    83
fun finish (defs, ctxt) = (Termtab.fold (cons o snd o snd) defs [], ctxt)
boehmes@43929
    84
boehmes@43929
    85
fun lift_lambdas basename is_binder ts ctxt =
boehmes@43929
    86
  init ctxt
boehmes@43929
    87
  |> fold_map (lift_lambdas1 is_binder basename) ts
boehmes@43929
    88
  |-> (fn ts' => finish #>> pair ts')
boehmes@43928
    89
boehmes@43928
    90
end