src/HOL/Tools/lambda_lifting.ML
author wenzelm
Tue Sep 26 20:54:40 2017 +0200 (23 months ago)
changeset 66695 91500c024c7f
parent 43929 61d432e51aff
child 69593 3dda49e08b9d
permissions -rw-r--r--
tuned;
     1 (*  Title:      HOL/Tools/lambda_lifting.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Lambda-lifting on terms, i.e., replacing (some) lambda-abstractions by
     5 fresh names accompanied with defining equations for these fresh names in
     6 terms of the lambda-abstractions' bodies.
     7 *)
     8 
     9 signature LAMBDA_LIFTING =
    10 sig
    11   type context = (term * term) Termtab.table * Proof.context
    12   val init: Proof.context -> context
    13   val is_quantifier: term -> bool
    14   val lift_lambdas1: (term -> bool) -> string option -> term -> context ->
    15     term * context
    16   val finish: context -> term list * Proof.context
    17   val lift_lambdas: string option -> (term -> bool) -> term list ->
    18     Proof.context -> (term list * term list) * Proof.context
    19 end
    20 
    21 structure Lambda_Lifting: LAMBDA_LIFTING =
    22 struct
    23 
    24 fun mk_def Ts T lhs rhs =
    25   let fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
    26   in fold mk_all Ts (HOLogic.eq_const T $ lhs $ rhs) end
    27 
    28 fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
    29 
    30 fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
    31   | dest_abs Ts t = (Ts, t)
    32 
    33 fun replace_lambda basename Us Ts t (cx as (defs, ctxt)) =
    34   let
    35     val t1 = mk_abs Us t
    36     val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
    37     fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
    38     val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
    39     val t2 = Term.subst_bounds (rs, t1)
    40     val Ts' = map (nth Ts) bs 
    41     val (_, t3) = dest_abs [] t2
    42     val t4 = mk_abs Ts' t2
    43 
    44     val T = Term.fastype_of1 (Us @ Ts, t)
    45     fun app f = Term.list_comb (f, map Bound (rev bs))
    46   in
    47     (case Termtab.lookup defs t4 of
    48       SOME (f, _) => (app f, cx)
    49     | NONE =>
    50         let
    51           val (n, ctxt') = yield_singleton Variable.variant_fixes basename ctxt
    52           val (is, UTs) = split_list (map_index I (Us @ Ts'))
    53           val f = Free (n, rev UTs ---> T)
    54           val lhs = Term.list_comb (f, map Bound (rev is))
    55           val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
    56         in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
    57   end
    58 
    59 type context = (term * term) Termtab.table * Proof.context
    60 
    61 fun init ctxt = (Termtab.empty, ctxt)
    62 
    63 fun is_quantifier (Const (@{const_name All}, _)) = true
    64   | is_quantifier (Const (@{const_name Ex}, _)) = true
    65   | is_quantifier _ = false
    66 
    67 fun lift_lambdas1 is_binder basename =
    68   let
    69     val basename' = the_default Name.uu basename
    70 
    71     fun traverse Ts (t $ (u as Abs (n, T, body))) =
    72           if is_binder t then
    73             traverse Ts t ##>> traverse (T :: Ts) body #>> (fn (t', body') =>
    74             t' $ Abs (n, T, body'))
    75           else traverse Ts t ##>> traverse Ts u #>> (op $)
    76       | traverse Ts (t as Abs _) =
    77           let val (Us, u) = dest_abs [] t
    78           in traverse (Us @ Ts) u #-> replace_lambda basename' Us Ts end
    79       | traverse Ts (t $ u) = traverse Ts t ##>> traverse Ts u #>> (op $)
    80       | traverse _ t = pair t
    81   in traverse [] end
    82 
    83 fun finish (defs, ctxt) = (Termtab.fold (cons o snd o snd) defs [], ctxt)
    84 
    85 fun lift_lambdas basename is_binder ts ctxt =
    86   init ctxt
    87   |> fold_map (lift_lambdas1 is_binder basename) ts
    88   |-> (fn ts' => finish #>> pair ts')
    89 
    90 end