src/HOL/Tools/Function/termination.ML
author wenzelm
Fri Mar 06 15:58:56 2015 +0100 (2015-03-06)
changeset 59621 291934bac95e
parent 59618 e6939796717e
child 59625 aacdce52b2fc
permissions -rw-r--r--
Thm.cterm_of and Thm.ctyp_of operate on local context;
haftmann@31775
     1
(*  Title:       HOL/Tools/Function/termination.ML
krauss@29125
     2
    Author:      Alexander Krauss, TU Muenchen
krauss@29125
     3
krauss@41114
     4
Context data for termination proofs.
krauss@29125
     5
*)
krauss@29125
     6
krauss@29125
     7
signature TERMINATION =
krauss@29125
     8
sig
krauss@29125
     9
  type data
krauss@39924
    10
  datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm
krauss@29125
    11
krauss@29125
    12
  val mk_sumcases : data -> typ -> term list -> term
krauss@29125
    13
krauss@29125
    14
  val get_num_points : data -> int
krauss@29125
    15
  val get_types      : data -> int -> typ
krauss@29125
    16
  val get_measures   : data -> int -> term list
krauss@29125
    17
krauss@29125
    18
  val get_chain      : data -> term -> term -> thm option option
krauss@29125
    19
  val get_descent    : data -> term -> term -> term -> cell option
krauss@29125
    20
krauss@29125
    21
  val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
krauss@29125
    22
krauss@29125
    23
  val CALLS : (term list * int -> tactic) -> int -> tactic
krauss@29125
    24
krauss@39924
    25
  (* Termination tactics *)
krauss@39924
    26
  type ttac = data -> int -> tactic
krauss@29125
    27
krauss@39924
    28
  val TERMINATION : Proof.context -> tactic -> ttac -> int -> tactic
krauss@29125
    29
wenzelm@30607
    30
  val wf_union_tac : Proof.context -> tactic
krauss@34228
    31
wenzelm@59618
    32
  val decompose_tac : Proof.context -> ttac
krauss@29125
    33
end
krauss@29125
    34
krauss@29125
    35
krauss@29125
    36
krauss@29125
    37
structure Termination : TERMINATION =
krauss@29125
    38
struct
krauss@29125
    39
krauss@33099
    40
open Function_Lib
krauss@29125
    41
wenzelm@35408
    42
val term2_ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord
wenzelm@31971
    43
structure Term2tab = Table(type key = term * term val ord = term2_ord);
wenzelm@35408
    44
structure Term3tab =
wenzelm@35408
    45
  Table(type key = term * (term * term) val ord = prod_ord Term_Ord.fast_term_ord term2_ord);
krauss@29125
    46
krauss@29125
    47
(** Analyzing binary trees **)
krauss@29125
    48
krauss@29125
    49
(* Skeleton of a tree structure *)
krauss@29125
    50
krauss@29125
    51
datatype skel =
krauss@29125
    52
  SLeaf of int (* index *)
krauss@29125
    53
| SBranch of (skel * skel)
krauss@29125
    54
krauss@29125
    55
krauss@29125
    56
(* abstract make and dest functions *)
krauss@29125
    57
fun mk_tree leaf branch =
krauss@29125
    58
  let fun mk (SLeaf i) = leaf i
krauss@29125
    59
        | mk (SBranch (s, t)) = branch (mk s, mk t)
krauss@29125
    60
  in mk end
krauss@29125
    61
krauss@29125
    62
krauss@29125
    63
fun dest_tree split =
krauss@29125
    64
  let fun dest (SLeaf i) x = [(i, x)]
krauss@29125
    65
        | dest (SBranch (s, t)) x =
krauss@29125
    66
          let val (l, r) = split x
krauss@29125
    67
          in dest s l @ dest t r end
krauss@29125
    68
  in dest end
krauss@29125
    69
krauss@29125
    70
krauss@29125
    71
(* concrete versions for sum types *)
haftmann@32602
    72
fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
haftmann@32602
    73
  | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
krauss@29125
    74
  | is_inj _ = false
krauss@29125
    75
haftmann@32602
    76
fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
krauss@29125
    77
  | dest_inl _ = NONE
krauss@29125
    78
haftmann@32602
    79
fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
krauss@29125
    80
  | dest_inr _ = NONE
krauss@29125
    81
krauss@29125
    82
krauss@29125
    83
fun mk_skel ps =
krauss@29125
    84
  let
krauss@29125
    85
    fun skel i ps =
krauss@29125
    86
      if forall is_inj ps andalso not (null ps)
krauss@29125
    87
      then let
krauss@29125
    88
          val (j, s) = skel i (map_filter dest_inl ps)
krauss@29125
    89
          val (k, t) = skel j (map_filter dest_inr ps)
krauss@29125
    90
        in (k, SBranch (s, t)) end
krauss@29125
    91
      else (i + 1, SLeaf i)
krauss@29125
    92
  in
krauss@29125
    93
    snd (skel 0 ps)
krauss@29125
    94
  end
krauss@29125
    95
krauss@29125
    96
(* compute list of types for nodes *)
haftmann@37678
    97
fun node_types sk T = dest_tree (fn Type (@{type_name Sum_Type.sum}, [LT, RT]) => (LT, RT)) sk T |> map snd
krauss@29125
    98
krauss@29125
    99
(* find index and raw term *)
krauss@29125
   100
fun dest_inj (SLeaf i) trm = (i, trm)
krauss@29125
   101
  | dest_inj (SBranch (s, t)) trm =
krauss@29125
   102
    case dest_inl trm of
krauss@29125
   103
      SOME trm' => dest_inj s trm'
krauss@29125
   104
    | _ => dest_inj t (the (dest_inr trm))
krauss@29125
   105
krauss@29125
   106
krauss@29125
   107
krauss@29125
   108
(** Matrix cell datatype **)
krauss@29125
   109
krauss@39924
   110
datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm;
krauss@29125
   111
krauss@29125
   112
krauss@29125
   113
type data =
krauss@29125
   114
  skel                            (* structure of the sum type encoding "program points" *)
krauss@29125
   115
  * (int -> typ)                  (* types of program points *)
krauss@29125
   116
  * (term list Inttab.table)      (* measures for program points *)
krauss@39923
   117
  * (term * term -> thm option)   (* which calls form chains? (cached) *)
krauss@39923
   118
  * (term * (term * term) -> cell)(* local descents (cached) *)
krauss@29125
   119
krauss@29125
   120
krauss@29125
   121
(* Build case expression *)
krauss@29125
   122
fun mk_sumcases (sk, _, _, _, _) T fs =
krauss@29125
   123
  mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
blanchet@55968
   124
          (fn ((f, fT), (g, gT)) => (Sum_Tree.mk_sumcase fT gT T f g, Sum_Tree.mk_sumT fT gT))
krauss@29125
   125
          sk
krauss@29125
   126
  |> fst
krauss@29125
   127
krauss@29125
   128
fun mk_sum_skel rel =
krauss@29125
   129
  let
krauss@33099
   130
    val cs = Function_Lib.dest_binop_list @{const_name Lattices.sup} rel
haftmann@32602
   131
    fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
krauss@29125
   132
      let
haftmann@38864
   133
        val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ _)
haftmann@38557
   134
          = Term.strip_qnt_body @{const_name Ex} c
krauss@29125
   135
      in cons r o cons l end
krauss@29125
   136
  in
krauss@29125
   137
    mk_skel (fold collect_pats cs [])
krauss@29125
   138
  end
krauss@29125
   139
wenzelm@59618
   140
fun prove_chain ctxt chain_tac (c1, c2) =
krauss@29125
   141
  let
krauss@39923
   142
    val goal =
griff@47433
   143
      HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.relcomp} (c1, c2),
krauss@39923
   144
        Const (@{const_abbrev Set.empty}, fastype_of c1))
krauss@39923
   145
      |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
krauss@29125
   146
  in
wenzelm@59621
   147
    case Function_Lib.try_proof (Thm.cterm_of ctxt goal) chain_tac of
krauss@39923
   148
      Function_Lib.Solved thm => SOME thm
krauss@39923
   149
    | _ => NONE
krauss@29125
   150
  end
krauss@29125
   151
krauss@39923
   152
krauss@39923
   153
fun dest_call' sk (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
krauss@29125
   154
  let
haftmann@38557
   155
    val vs = Term.strip_qnt_vars @{const_name Ex} c
krauss@29125
   156
krauss@29125
   157
    (* FIXME: throw error "dest_call" for malformed terms *)
haftmann@38864
   158
    val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ Gam)
haftmann@38557
   159
      = Term.strip_qnt_body @{const_name Ex} c
krauss@29125
   160
    val (p, l') = dest_inj sk l
krauss@29125
   161
    val (q, r') = dest_inj sk r
krauss@29125
   162
  in
krauss@29125
   163
    (vs, p, l', q, r', Gam)
krauss@29125
   164
  end
krauss@39923
   165
  | dest_call' _ _ = error "dest_call"
krauss@29125
   166
krauss@39923
   167
fun dest_call (sk, _, _, _, _) = dest_call' sk
krauss@29125
   168
wenzelm@59618
   169
fun mk_desc ctxt tac vs Gam l r m1 m2 =
krauss@34229
   170
  let
krauss@34229
   171
    fun try rel =
wenzelm@59621
   172
      try_proof (Thm.cterm_of ctxt
wenzelm@46218
   173
        (Logic.list_all (vs,
krauss@34229
   174
           Logic.mk_implies (HOLogic.mk_Trueprop Gam,
krauss@34229
   175
             HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
krauss@34229
   176
               $ (m2 $ r) $ (m1 $ l)))))) tac
krauss@34229
   177
  in
haftmann@35092
   178
    case try @{const_name Orderings.less} of
krauss@34229
   179
       Solved thm => Less thm
krauss@34229
   180
     | Stuck thm =>
haftmann@35092
   181
       (case try @{const_name Orderings.less_eq} of
krauss@34229
   182
          Solved thm2 => LessEq (thm2, thm)
krauss@34229
   183
        | Stuck thm2 =>
wenzelm@59582
   184
          if Thm.prems_of thm2 = [HOLogic.Trueprop $ @{term False}]
krauss@34229
   185
          then False thm2 else None (thm2, thm)
krauss@34229
   186
        | _ => raise Match) (* FIXME *)
krauss@34229
   187
     | _ => raise Match
krauss@34229
   188
end
krauss@34229
   189
wenzelm@59618
   190
fun prove_descent ctxt tac sk (c, (m1, m2)) =
krauss@39923
   191
  let
krauss@39923
   192
    val (vs, _, l, _, r, Gam) = dest_call' sk c
krauss@39923
   193
  in 
wenzelm@59618
   194
    mk_desc ctxt tac vs Gam l r m1 m2
krauss@39923
   195
  end
krauss@39923
   196
krauss@39923
   197
fun create ctxt chain_tac descent_tac T rel =
krauss@39923
   198
  let
krauss@39923
   199
    val sk = mk_sum_skel rel
krauss@39923
   200
    val Ts = node_types sk T
wenzelm@57959
   201
    val M = Inttab.make (map_index (apsnd (Measure_Functions.get_measure_functions ctxt)) Ts)
wenzelm@59618
   202
    val chain_cache =
wenzelm@59618
   203
      Cache.create Term2tab.empty Term2tab.lookup Term2tab.update
wenzelm@59618
   204
        (prove_chain ctxt chain_tac)
wenzelm@59618
   205
    val descent_cache =
wenzelm@59618
   206
      Cache.create Term3tab.empty Term3tab.lookup Term3tab.update
wenzelm@59618
   207
        (prove_descent ctxt descent_tac sk)
krauss@39923
   208
  in
krauss@39923
   209
    (sk, nth Ts, M, chain_cache, descent_cache)
krauss@39923
   210
  end
krauss@39923
   211
krauss@39923
   212
fun get_num_points (sk, _, _, _, _) =
krauss@39923
   213
  let
krauss@39923
   214
    fun num (SLeaf i) = i + 1
krauss@39923
   215
      | num (SBranch (s, t)) = num t
krauss@39923
   216
  in num sk end
krauss@39923
   217
krauss@39923
   218
fun get_types (_, T, _, _, _) = T
krauss@39923
   219
fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
krauss@39923
   220
krauss@39923
   221
fun get_chain (_, _, _, C, _) c1 c2 =
krauss@39923
   222
  SOME (C (c1, c2))
krauss@39923
   223
krauss@39923
   224
fun get_descent (_, _, _, _, D) c m1 m2 =
krauss@39923
   225
  SOME (D (c, (m1, m2)))
krauss@39923
   226
krauss@29125
   227
fun CALLS tac i st =
krauss@29125
   228
  if Thm.no_prems st then all_tac st
krauss@29125
   229
  else case Thm.term_of (Thm.cprem_of st i) of
krauss@33099
   230
    (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list @{const_name Lattices.sup} rel, i) st
krauss@29125
   231
  |_ => no_tac st
krauss@29125
   232
krauss@39924
   233
type ttac = data -> int -> tactic
krauss@29125
   234
krauss@39923
   235
fun TERMINATION ctxt atac tac =
haftmann@32602
   236
  SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
krauss@29125
   237
  let
krauss@29125
   238
    val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
krauss@29125
   239
  in
krauss@39923
   240
    tac (create ctxt atac atac T rel) i
krauss@29125
   241
  end)
krauss@29125
   242
krauss@29125
   243
krauss@29125
   244
(* A tactic to convert open to closed termination goals *)
krauss@29125
   245
local
krauss@29125
   246
fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
krauss@34232
   247
  let
krauss@34232
   248
    val (vars, prop) = Function_Lib.dest_all_all t
krauss@34232
   249
    val (prems, concl) = Logic.strip_horn prop
krauss@34232
   250
    val (lhs, rhs) = concl
krauss@34232
   251
      |> HOLogic.dest_Trueprop
krauss@34232
   252
      |> HOLogic.dest_mem |> fst
krauss@34232
   253
      |> HOLogic.dest_prod
krauss@34232
   254
  in
krauss@34232
   255
    (vars, prems, lhs, rhs)
krauss@34232
   256
  end
krauss@29125
   257
krauss@29125
   258
fun mk_pair_compr (T, qs, l, r, conds) =
krauss@34232
   259
  let
krauss@34232
   260
    val pT = HOLogic.mk_prodT (T, T)
krauss@34232
   261
    val n = length qs
krauss@34232
   262
    val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
wenzelm@45740
   263
    val conds' = if null conds then [@{term True}] else conds
krauss@34232
   264
  in
krauss@34232
   265
    HOLogic.Collect_const pT $
krauss@34232
   266
    Abs ("uu_", pT,
krauss@34232
   267
      (foldr1 HOLogic.mk_conj (peq :: conds')
krauss@34232
   268
      |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
krauss@34232
   269
  end
krauss@29125
   270
krauss@52723
   271
val Un_aci_simps =
krauss@52723
   272
  map mk_meta_eq @{thms Un_ac Un_absorb}
krauss@52723
   273
krauss@29125
   274
in
krauss@29125
   275
wenzelm@56231
   276
fun wf_union_tac ctxt st = SUBGOAL (fn _ =>
krauss@34232
   277
  let
wenzelm@42361
   278
    val thy = Proof_Context.theory_of ctxt
wenzelm@59582
   279
    val ((_ $ (_ $ rel)) :: ineqs) = Thm.prems_of st
krauss@29125
   280
krauss@34232
   281
    fun mk_compr ineq =
krauss@34232
   282
      let
krauss@34232
   283
        val (vars, prems, lhs, rhs) = dest_term ineq
krauss@34232
   284
      in
wenzelm@35625
   285
        mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (Object_Logic.atomize_term thy) prems)
krauss@34232
   286
      end
krauss@29125
   287
krauss@34232
   288
    val relation =
krauss@34232
   289
      if null ineqs
wenzelm@35402
   290
      then Const (@{const_abbrev Set.empty}, fastype_of rel)
krauss@34232
   291
      else map mk_compr ineqs
krauss@34232
   292
        |> foldr1 (HOLogic.mk_binop @{const_name Lattices.sup})
krauss@29125
   293
krauss@34232
   294
    fun solve_membership_tac i =
krauss@34232
   295
      (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
krauss@34232
   296
      THEN' (fn j => TRY (rtac @{thm UnI1} j))
krauss@34232
   297
      THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
krauss@34232
   298
      THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
krauss@34232
   299
      THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
krauss@34232
   300
        ORELSE' ((rtac @{thm conjI})
krauss@34232
   301
          THEN' (rtac @{thm refl})
wenzelm@42793
   302
          THEN' (blast_tac ctxt)))    (* Solve rest of context... not very elegant *)
krauss@34232
   303
      ) i
krauss@34232
   304
  in
wenzelm@59621
   305
    (PRIMITIVE (Drule.cterm_instantiate [apply2 (Thm.cterm_of ctxt) (rel, relation)])
krauss@52723
   306
     THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i)
wenzelm@56231
   307
     THEN rewrite_goal_tac ctxt Un_aci_simps 1)  (* eliminate duplicates *)
wenzelm@56231
   308
  end) 1 st
krauss@29125
   309
krauss@29125
   310
end
krauss@29125
   311
krauss@29125
   312
krauss@29125
   313
krauss@34228
   314
(*** DEPENDENCY GRAPHS ***)
krauss@29125
   315
krauss@34228
   316
fun mk_dgraph D cs =
wenzelm@35404
   317
  Term_Graph.empty
wenzelm@35404
   318
  |> fold (fn c => Term_Graph.new_node (c, ())) cs
krauss@34232
   319
  |> fold_product (fn c1 => fn c2 =>
krauss@34232
   320
     if is_none (get_chain D c1 c2 |> the_default NONE)
krauss@47835
   321
     then Term_Graph.add_edge (c2, c1) else I)
krauss@34232
   322
     cs cs
krauss@34228
   323
krauss@34228
   324
fun ucomp_empty_tac T =
krauss@34232
   325
  REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR}
krauss@34232
   326
    ORELSE' rtac @{thm union_comp_emptyL}
krauss@34232
   327
    ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i))
krauss@34228
   328
wenzelm@59618
   329
fun regroup_calls_tac ctxt cs = CALLS (fn (cs', i) =>
krauss@34232
   330
 let
krauss@34232
   331
   val is = map (fn c => find_index (curry op aconv c) cs') cs
krauss@34232
   332
 in
krauss@34232
   333
   CONVERSION (Conv.arg_conv (Conv.arg_conv
krauss@34232
   334
     (Function_Lib.regroup_union_conv is))) i
krauss@34232
   335
 end)
krauss@34228
   336
krauss@34228
   337
krauss@34232
   338
fun solve_trivial_tac D = CALLS (fn ([c], i) =>
krauss@34232
   339
  (case get_chain D c c of
krauss@34232
   340
     SOME (SOME thm) => rtac @{thm wf_no_loop} i
krauss@34232
   341
                        THEN rtac thm i
krauss@34232
   342
   | _ => no_tac)
krauss@34228
   343
  | _ => no_tac)
krauss@34228
   344
wenzelm@59618
   345
fun decompose_tac ctxt D = CALLS (fn (cs, i) =>
krauss@34232
   346
  let
krauss@34232
   347
    val G = mk_dgraph D cs
wenzelm@35404
   348
    val sccs = Term_Graph.strong_conn G
krauss@34228
   349
krauss@39924
   350
    fun split [SCC] i = TRY (solve_trivial_tac D i)
krauss@34232
   351
      | split (SCC::rest) i =
wenzelm@59618
   352
        regroup_calls_tac ctxt SCC i
krauss@34232
   353
        THEN rtac @{thm wf_union_compatible} i
krauss@34232
   354
        THEN rtac @{thm less_by_empty} (i + 2)
krauss@34232
   355
        THEN ucomp_empty_tac (the o the oo get_chain D) (i + 2)
krauss@34232
   356
        THEN split rest (i + 1)
krauss@39924
   357
        THEN TRY (solve_trivial_tac D i)
krauss@34232
   358
  in
krauss@34232
   359
    if length sccs > 1 then split sccs i
krauss@39924
   360
    else solve_trivial_tac D i
krauss@34232
   361
  end)
krauss@34228
   362
krauss@29125
   363
krauss@29125
   364
end