src/HOL/Tools/Function/termination.ML
author krauss
Fri Oct 23 16:22:10 2009 +0200 (2009-10-23)
changeset 33099 b8cdd3d73022
parent 32683 7c1fe854ca6a
child 33855 cd8acf137c9c
permissions -rw-r--r--
function package: more standard names for structures and files
haftmann@31775
     1
(*  Title:       HOL/Tools/Function/termination.ML
krauss@29125
     2
    Author:      Alexander Krauss, TU Muenchen
krauss@29125
     3
krauss@29125
     4
Context data for termination proofs
krauss@29125
     5
*)
krauss@29125
     6
krauss@29125
     7
krauss@29125
     8
signature TERMINATION =
krauss@29125
     9
sig
krauss@29125
    10
krauss@29125
    11
  type data
krauss@29125
    12
  datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm
krauss@29125
    13
krauss@29125
    14
  val mk_sumcases : data -> typ -> term list -> term
krauss@29125
    15
krauss@29125
    16
  val note_measure : int -> term -> data -> data
krauss@29125
    17
  val note_chain   : term -> term -> thm option -> data -> data
krauss@29125
    18
  val note_descent : term -> term -> term -> cell -> data -> data
krauss@29125
    19
krauss@29125
    20
  val get_num_points : data -> int
krauss@29125
    21
  val get_types      : data -> int -> typ
krauss@29125
    22
  val get_measures   : data -> int -> term list
krauss@29125
    23
krauss@29125
    24
  (* read from cache *)
krauss@29125
    25
  val get_chain      : data -> term -> term -> thm option option
krauss@29125
    26
  val get_descent    : data -> term -> term -> term -> cell option
krauss@29125
    27
krauss@29125
    28
  (* writes *)
krauss@29125
    29
  val derive_descent  : theory -> tactic -> term -> term -> term -> data -> data
krauss@29125
    30
  val derive_descents : theory -> tactic -> term -> data -> data
krauss@29125
    31
krauss@29125
    32
  val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
krauss@29125
    33
krauss@29125
    34
  val CALLS : (term list * int -> tactic) -> int -> tactic
krauss@29125
    35
krauss@29125
    36
  (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *)
krauss@29125
    37
  type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
krauss@29125
    38
krauss@29125
    39
  val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic
krauss@29125
    40
krauss@29125
    41
  val REPEAT : ttac -> ttac
krauss@29125
    42
wenzelm@30607
    43
  val wf_union_tac : Proof.context -> tactic
krauss@29125
    44
end
krauss@29125
    45
krauss@29125
    46
krauss@29125
    47
krauss@29125
    48
structure Termination : TERMINATION =
krauss@29125
    49
struct
krauss@29125
    50
krauss@33099
    51
open Function_Lib
krauss@29125
    52
wenzelm@29269
    53
val term2_ord = prod_ord TermOrd.fast_term_ord TermOrd.fast_term_ord
wenzelm@31971
    54
structure Term2tab = Table(type key = term * term val ord = term2_ord);
wenzelm@31971
    55
structure Term3tab = Table(type key = term * (term * term) val ord = prod_ord TermOrd.fast_term_ord term2_ord);
krauss@29125
    56
krauss@29125
    57
(** Analyzing binary trees **)
krauss@29125
    58
krauss@29125
    59
(* Skeleton of a tree structure *)
krauss@29125
    60
krauss@29125
    61
datatype skel =
krauss@29125
    62
  SLeaf of int (* index *)
krauss@29125
    63
| SBranch of (skel * skel)
krauss@29125
    64
krauss@29125
    65
krauss@29125
    66
(* abstract make and dest functions *)
krauss@29125
    67
fun mk_tree leaf branch =
krauss@29125
    68
  let fun mk (SLeaf i) = leaf i
krauss@29125
    69
        | mk (SBranch (s, t)) = branch (mk s, mk t)
krauss@29125
    70
  in mk end
krauss@29125
    71
krauss@29125
    72
krauss@29125
    73
fun dest_tree split =
krauss@29125
    74
  let fun dest (SLeaf i) x = [(i, x)]
krauss@29125
    75
        | dest (SBranch (s, t)) x =
krauss@29125
    76
          let val (l, r) = split x
krauss@29125
    77
          in dest s l @ dest t r end
krauss@29125
    78
  in dest end
krauss@29125
    79
krauss@29125
    80
krauss@29125
    81
(* concrete versions for sum types *)
haftmann@32602
    82
fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
haftmann@32602
    83
  | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
krauss@29125
    84
  | is_inj _ = false
krauss@29125
    85
haftmann@32602
    86
fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
krauss@29125
    87
  | dest_inl _ = NONE
krauss@29125
    88
haftmann@32602
    89
fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
krauss@29125
    90
  | dest_inr _ = NONE
krauss@29125
    91
krauss@29125
    92
krauss@29125
    93
fun mk_skel ps =
krauss@29125
    94
  let
krauss@29125
    95
    fun skel i ps =
krauss@29125
    96
      if forall is_inj ps andalso not (null ps)
krauss@29125
    97
      then let
krauss@29125
    98
          val (j, s) = skel i (map_filter dest_inl ps)
krauss@29125
    99
          val (k, t) = skel j (map_filter dest_inr ps)
krauss@29125
   100
        in (k, SBranch (s, t)) end
krauss@29125
   101
      else (i + 1, SLeaf i)
krauss@29125
   102
  in
krauss@29125
   103
    snd (skel 0 ps)
krauss@29125
   104
  end
krauss@29125
   105
krauss@29125
   106
(* compute list of types for nodes *)
krauss@29125
   107
fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd
krauss@29125
   108
krauss@29125
   109
(* find index and raw term *)
krauss@29125
   110
fun dest_inj (SLeaf i) trm = (i, trm)
krauss@29125
   111
  | dest_inj (SBranch (s, t)) trm =
krauss@29125
   112
    case dest_inl trm of
krauss@29125
   113
      SOME trm' => dest_inj s trm'
krauss@29125
   114
    | _ => dest_inj t (the (dest_inr trm))
krauss@29125
   115
krauss@29125
   116
krauss@29125
   117
krauss@29125
   118
(** Matrix cell datatype **)
krauss@29125
   119
krauss@29125
   120
datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm;
krauss@29125
   121
krauss@29125
   122
krauss@29125
   123
type data =
krauss@29125
   124
  skel                            (* structure of the sum type encoding "program points" *)
krauss@29125
   125
  * (int -> typ)                  (* types of program points *)
krauss@29125
   126
  * (term list Inttab.table)      (* measures for program points *)
krauss@29125
   127
  * (thm option Term2tab.table)   (* which calls form chains? *)
krauss@29125
   128
  * (cell Term3tab.table)         (* local descents *)
krauss@29125
   129
krauss@29125
   130
krauss@29125
   131
fun map_measures f (p, T, M, C, D) = (p, T, f M, C, D)
krauss@29125
   132
fun map_chains f   (p, T, M, C, D) = (p, T, M, f C, D)
krauss@29125
   133
fun map_descent f  (p, T, M, C, D) = (p, T, M, C, f D)
krauss@29125
   134
krauss@29125
   135
fun note_measure p m = map_measures (Inttab.insert_list (op aconv) (p, m))
krauss@29125
   136
fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res))
krauss@29125
   137
fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res))
krauss@29125
   138
krauss@29125
   139
(* Build case expression *)
krauss@29125
   140
fun mk_sumcases (sk, _, _, _, _) T fs =
krauss@29125
   141
  mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
krauss@29125
   142
          (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT))
krauss@29125
   143
          sk
krauss@29125
   144
  |> fst
krauss@29125
   145
krauss@29125
   146
fun mk_sum_skel rel =
krauss@29125
   147
  let
krauss@33099
   148
    val cs = Function_Lib.dest_binop_list @{const_name Lattices.sup} rel
haftmann@32602
   149
    fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
krauss@29125
   150
      let
krauss@29125
   151
        val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
krauss@29125
   152
          = Term.strip_qnt_body "Ex" c
krauss@29125
   153
      in cons r o cons l end
krauss@29125
   154
  in
krauss@29125
   155
    mk_skel (fold collect_pats cs [])
krauss@29125
   156
  end
krauss@29125
   157
krauss@29125
   158
fun create ctxt T rel =
krauss@29125
   159
  let
krauss@29125
   160
    val sk = mk_sum_skel rel
krauss@29125
   161
    val Ts = node_types sk T
krauss@29125
   162
    val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
krauss@29125
   163
  in
krauss@29125
   164
    (sk, nth Ts, M, Term2tab.empty, Term3tab.empty)
krauss@29125
   165
  end
krauss@29125
   166
krauss@29125
   167
fun get_num_points (sk, _, _, _, _) =
krauss@29125
   168
  let
krauss@29125
   169
    fun num (SLeaf i) = i + 1
krauss@29125
   170
      | num (SBranch (s, t)) = num t
krauss@29125
   171
  in num sk end
krauss@29125
   172
krauss@29125
   173
fun get_types (_, T, _, _, _) = T
krauss@29125
   174
fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
krauss@29125
   175
krauss@29125
   176
fun get_chain (_, _, _, C, _) c1 c2 =
krauss@29125
   177
  Term2tab.lookup C (c1, c2)
krauss@29125
   178
krauss@29125
   179
fun get_descent (_, _, _, _, D) c m1 m2 =
krauss@29125
   180
  Term3tab.lookup D (c, (m1, m2))
krauss@29125
   181
haftmann@32602
   182
fun dest_call D (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
krauss@29125
   183
  let
krauss@29125
   184
    val n = get_num_points D
krauss@29125
   185
    val (sk, _, _, _, _) = D
krauss@29125
   186
    val vs = Term.strip_qnt_vars "Ex" c
krauss@29125
   187
krauss@29125
   188
    (* FIXME: throw error "dest_call" for malformed terms *)
krauss@29125
   189
    val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
krauss@29125
   190
      = Term.strip_qnt_body "Ex" c
krauss@29125
   191
    val (p, l') = dest_inj sk l
krauss@29125
   192
    val (q, r') = dest_inj sk r
krauss@29125
   193
  in
krauss@29125
   194
    (vs, p, l', q, r', Gam)
krauss@29125
   195
  end
krauss@29125
   196
  | dest_call D t = error "dest_call"
krauss@29125
   197
krauss@29125
   198
krauss@29125
   199
fun derive_desc_aux thy tac c (vs, p, l', q, r', Gam) m1 m2 D =
krauss@29125
   200
  case get_descent D c m1 m2 of
krauss@29125
   201
    SOME _ => D
krauss@29125
   202
  | NONE => let
krauss@29125
   203
    fun cgoal rel =
krauss@29125
   204
      Term.list_all (vs,
krauss@29125
   205
        Logic.mk_implies (HOLogic.mk_Trueprop Gam,
krauss@29125
   206
          HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
krauss@29125
   207
            $ (m2 $ r') $ (m1 $ l'))))
krauss@29125
   208
      |> cterm_of thy
krauss@29125
   209
    in
krauss@29125
   210
      note_descent c m1 m2
krauss@29125
   211
        (case try_proof (cgoal @{const_name HOL.less}) tac of
krauss@29125
   212
           Solved thm => Less thm
krauss@29125
   213
         | Stuck thm =>
krauss@29125
   214
           (case try_proof (cgoal @{const_name HOL.less_eq}) tac of
krauss@29125
   215
              Solved thm2 => LessEq (thm2, thm)
krauss@29125
   216
            | Stuck thm2 =>
krauss@29125
   217
              if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const]
krauss@29125
   218
              then False thm2 else None (thm2, thm)
krauss@29125
   219
            | _ => raise Match) (* FIXME *)
krauss@29125
   220
         | _ => raise Match) D
krauss@29125
   221
      end
krauss@29125
   222
krauss@29125
   223
fun derive_descent thy tac c m1 m2 D =
krauss@29125
   224
  derive_desc_aux thy tac c (dest_call D c) m1 m2 D
krauss@29125
   225
krauss@29125
   226
(* all descents in one go *)
krauss@29125
   227
fun derive_descents thy tac c D =
krauss@29125
   228
  let val cdesc as (vs, p, l', q, r', Gam) = dest_call D c
krauss@29125
   229
  in fold_product (derive_desc_aux thy tac c cdesc)
krauss@29125
   230
       (get_measures D p) (get_measures D q) D
krauss@29125
   231
  end
krauss@29125
   232
krauss@29125
   233
fun CALLS tac i st =
krauss@29125
   234
  if Thm.no_prems st then all_tac st
krauss@29125
   235
  else case Thm.term_of (Thm.cprem_of st i) of
krauss@33099
   236
    (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list @{const_name Lattices.sup} rel, i) st
krauss@29125
   237
  |_ => no_tac st
krauss@29125
   238
krauss@29125
   239
type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
krauss@29125
   240
krauss@29125
   241
fun TERMINATION ctxt tac =
haftmann@32602
   242
  SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
krauss@29125
   243
  let
krauss@29125
   244
    val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
krauss@29125
   245
  in
krauss@29125
   246
    tac (create ctxt T rel) i
krauss@29125
   247
  end)
krauss@29125
   248
krauss@29125
   249
krauss@29125
   250
(* A tactic to convert open to closed termination goals *)
krauss@29125
   251
local
krauss@29125
   252
fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
krauss@29125
   253
    let
krauss@33099
   254
      val (vars, prop) = Function_Lib.dest_all_all t
krauss@29125
   255
      val (prems, concl) = Logic.strip_horn prop
krauss@29125
   256
      val (lhs, rhs) = concl
krauss@29125
   257
                         |> HOLogic.dest_Trueprop
krauss@29125
   258
                         |> HOLogic.dest_mem |> fst
krauss@29125
   259
                         |> HOLogic.dest_prod
krauss@29125
   260
    in
krauss@29125
   261
      (vars, prems, lhs, rhs)
krauss@29125
   262
    end
krauss@29125
   263
krauss@29125
   264
fun mk_pair_compr (T, qs, l, r, conds) =
krauss@29125
   265
    let
krauss@29125
   266
      val pT = HOLogic.mk_prodT (T, T)
krauss@29125
   267
      val n = length qs
krauss@29125
   268
      val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
krauss@29125
   269
      val conds' = if null conds then [HOLogic.true_const] else conds
krauss@29125
   270
    in
krauss@29125
   271
      HOLogic.Collect_const pT $
krauss@29125
   272
      Abs ("uu_", pT,
krauss@29125
   273
           (foldr1 HOLogic.mk_conj (peq :: conds')
krauss@29125
   274
            |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
krauss@29125
   275
    end
krauss@29125
   276
krauss@29125
   277
in
krauss@29125
   278
wenzelm@30607
   279
fun wf_union_tac ctxt st =
krauss@29125
   280
    let
wenzelm@30607
   281
      val thy = ProofContext.theory_of ctxt
krauss@29125
   282
      val cert = cterm_of (theory_of_thm st)
krauss@29125
   283
      val ((trueprop $ (wf $ rel)) :: ineqs) = prems_of st
krauss@29125
   284
krauss@29125
   285
      fun mk_compr ineq =
krauss@29125
   286
          let
krauss@29125
   287
            val (vars, prems, lhs, rhs) = dest_term ineq
krauss@29125
   288
          in
krauss@29125
   289
            mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems)
krauss@29125
   290
          end
krauss@29125
   291
krauss@29125
   292
      val relation =
krauss@29125
   293
          if null ineqs then
haftmann@30304
   294
              Const (@{const_name Set.empty}, fastype_of rel)
krauss@29125
   295
          else
haftmann@32683
   296
              foldr1 (HOLogic.mk_binop @{const_name Lattices.sup}) (map mk_compr ineqs)
krauss@29125
   297
krauss@29125
   298
      fun solve_membership_tac i =
krauss@29125
   299
          (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
krauss@29125
   300
          THEN' (fn j => TRY (rtac @{thm UnI1} j))
krauss@29125
   301
          THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
krauss@29125
   302
          THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
krauss@29125
   303
          THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
krauss@29125
   304
                 ORELSE' ((rtac @{thm conjI})
krauss@29125
   305
                          THEN' (rtac @{thm refl})
wenzelm@32149
   306
                          THEN' (blast_tac (claset_of ctxt))))  (* Solve rest of context... not very elegant *)
krauss@29125
   307
          ) i
krauss@29125
   308
    in
krauss@29125
   309
      ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)])
krauss@29125
   310
      THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st
krauss@29125
   311
    end
krauss@29125
   312
krauss@29125
   313
krauss@29125
   314
end
krauss@29125
   315
krauss@29125
   316
krauss@29125
   317
(* continuation passing repeat combinator *)
krauss@29125
   318
fun REPEAT ttac cont err_cont =
krauss@29125
   319
    ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont
krauss@29125
   320
krauss@29125
   321
krauss@29125
   322
krauss@29125
   323
krauss@29125
   324
end