src/HOL/Tools/Function/termination.ML
author wenzelm
Sat Feb 27 21:56:55 2010 +0100 (2010-02-27)
changeset 35404 de56579ae229
parent 35402 115a5a95710a
child 35408 b48ab741683b
permissions -rw-r--r--
just one copy of structure Term_Graph (in Pure);
     1 (*  Title:       HOL/Tools/Function/termination.ML
     2     Author:      Alexander Krauss, TU Muenchen
     3 
     4 Context data for termination proofs
     5 *)
     6 
     7 
     8 signature TERMINATION =
     9 sig
    10 
    11   type data
    12   datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm
    13 
    14   val mk_sumcases : data -> typ -> term list -> term
    15 
    16   val get_num_points : data -> int
    17   val get_types      : data -> int -> typ
    18   val get_measures   : data -> int -> term list
    19 
    20   (* read from cache *)
    21   val get_chain      : data -> term -> term -> thm option option
    22   val get_descent    : data -> term -> term -> term -> cell option
    23 
    24   val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
    25 
    26   val CALLS : (term list * int -> tactic) -> int -> tactic
    27 
    28   (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *)
    29   type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
    30 
    31   val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic
    32 
    33   val REPEAT : ttac -> ttac
    34 
    35   val wf_union_tac : Proof.context -> tactic
    36 
    37   val decompose_tac : Proof.context -> tactic -> ttac
    38 
    39   val derive_diag : Proof.context -> tactic -> 
    40     (data -> int -> tactic) -> data -> int -> tactic
    41 
    42   val derive_all  : Proof.context -> tactic ->
    43     (data -> int -> tactic) -> data -> int -> tactic
    44 
    45 end
    46 
    47 
    48 
    49 structure Termination : TERMINATION =
    50 struct
    51 
    52 open Function_Lib
    53 
    54 val term2_ord = prod_ord TermOrd.fast_term_ord TermOrd.fast_term_ord
    55 structure Term2tab = Table(type key = term * term val ord = term2_ord);
    56 structure Term3tab = Table(type key = term * (term * term) val ord = prod_ord TermOrd.fast_term_ord term2_ord);
    57 
    58 (** Analyzing binary trees **)
    59 
    60 (* Skeleton of a tree structure *)
    61 
    62 datatype skel =
    63   SLeaf of int (* index *)
    64 | SBranch of (skel * skel)
    65 
    66 
    67 (* abstract make and dest functions *)
    68 fun mk_tree leaf branch =
    69   let fun mk (SLeaf i) = leaf i
    70         | mk (SBranch (s, t)) = branch (mk s, mk t)
    71   in mk end
    72 
    73 
    74 fun dest_tree split =
    75   let fun dest (SLeaf i) x = [(i, x)]
    76         | dest (SBranch (s, t)) x =
    77           let val (l, r) = split x
    78           in dest s l @ dest t r end
    79   in dest end
    80 
    81 
    82 (* concrete versions for sum types *)
    83 fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
    84   | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
    85   | is_inj _ = false
    86 
    87 fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
    88   | dest_inl _ = NONE
    89 
    90 fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
    91   | dest_inr _ = NONE
    92 
    93 
    94 fun mk_skel ps =
    95   let
    96     fun skel i ps =
    97       if forall is_inj ps andalso not (null ps)
    98       then let
    99           val (j, s) = skel i (map_filter dest_inl ps)
   100           val (k, t) = skel j (map_filter dest_inr ps)
   101         in (k, SBranch (s, t)) end
   102       else (i + 1, SLeaf i)
   103   in
   104     snd (skel 0 ps)
   105   end
   106 
   107 (* compute list of types for nodes *)
   108 fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd
   109 
   110 (* find index and raw term *)
   111 fun dest_inj (SLeaf i) trm = (i, trm)
   112   | dest_inj (SBranch (s, t)) trm =
   113     case dest_inl trm of
   114       SOME trm' => dest_inj s trm'
   115     | _ => dest_inj t (the (dest_inr trm))
   116 
   117 
   118 
   119 (** Matrix cell datatype **)
   120 
   121 datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm;
   122 
   123 
   124 type data =
   125   skel                            (* structure of the sum type encoding "program points" *)
   126   * (int -> typ)                  (* types of program points *)
   127   * (term list Inttab.table)      (* measures for program points *)
   128   * (thm option Term2tab.table)   (* which calls form chains? *)
   129   * (cell Term3tab.table)         (* local descents *)
   130 
   131 
   132 fun map_chains f (p, T, M, C, D) = (p, T, M, f C, D)
   133 fun map_descent f (p, T, M, C, D) = (p, T, M, C, f D)
   134 
   135 fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res))
   136 fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res))
   137 
   138 (* Build case expression *)
   139 fun mk_sumcases (sk, _, _, _, _) T fs =
   140   mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
   141           (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT))
   142           sk
   143   |> fst
   144 
   145 fun mk_sum_skel rel =
   146   let
   147     val cs = Function_Lib.dest_binop_list @{const_name Lattices.sup} rel
   148     fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   149       let
   150         val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ _)
   151           = Term.strip_qnt_body "Ex" c
   152       in cons r o cons l end
   153   in
   154     mk_skel (fold collect_pats cs [])
   155   end
   156 
   157 fun create ctxt T rel =
   158   let
   159     val sk = mk_sum_skel rel
   160     val Ts = node_types sk T
   161     val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
   162   in
   163     (sk, nth Ts, M, Term2tab.empty, Term3tab.empty)
   164   end
   165 
   166 fun get_num_points (sk, _, _, _, _) =
   167   let
   168     fun num (SLeaf i) = i + 1
   169       | num (SBranch (s, t)) = num t
   170   in num sk end
   171 
   172 fun get_types (_, T, _, _, _) = T
   173 fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
   174 
   175 fun get_chain (_, _, _, C, _) c1 c2 =
   176   Term2tab.lookup C (c1, c2)
   177 
   178 fun get_descent (_, _, _, _, D) c m1 m2 =
   179   Term3tab.lookup D (c, (m1, m2))
   180 
   181 fun dest_call D (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   182   let
   183     val (sk, _, _, _, _) = D
   184     val vs = Term.strip_qnt_vars "Ex" c
   185 
   186     (* FIXME: throw error "dest_call" for malformed terms *)
   187     val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
   188       = Term.strip_qnt_body "Ex" c
   189     val (p, l') = dest_inj sk l
   190     val (q, r') = dest_inj sk r
   191   in
   192     (vs, p, l', q, r', Gam)
   193   end
   194   | dest_call D t = error "dest_call"
   195 
   196 
   197 fun mk_desc thy tac vs Gam l r m1 m2 =
   198   let
   199     fun try rel =
   200       try_proof (cterm_of thy
   201         (Term.list_all (vs,
   202            Logic.mk_implies (HOLogic.mk_Trueprop Gam,
   203              HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
   204                $ (m2 $ r) $ (m1 $ l)))))) tac
   205   in
   206     case try @{const_name Orderings.less} of
   207        Solved thm => Less thm
   208      | Stuck thm =>
   209        (case try @{const_name Orderings.less_eq} of
   210           Solved thm2 => LessEq (thm2, thm)
   211         | Stuck thm2 =>
   212           if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const]
   213           then False thm2 else None (thm2, thm)
   214         | _ => raise Match) (* FIXME *)
   215      | _ => raise Match
   216 end
   217 
   218 fun derive_descent thy tac c m1 m2 D =
   219   case get_descent D c m1 m2 of
   220     SOME _ => D
   221   | NONE => 
   222     let
   223       val (vs, _, l, _, r, Gam) = dest_call D c
   224     in 
   225       note_descent c m1 m2 (mk_desc thy tac vs Gam l r m1 m2) D
   226     end
   227 
   228 fun CALLS tac i st =
   229   if Thm.no_prems st then all_tac st
   230   else case Thm.term_of (Thm.cprem_of st i) of
   231     (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list @{const_name Lattices.sup} rel, i) st
   232   |_ => no_tac st
   233 
   234 type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
   235 
   236 fun TERMINATION ctxt tac =
   237   SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
   238   let
   239     val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
   240   in
   241     tac (create ctxt T rel) i
   242   end)
   243 
   244 
   245 (* A tactic to convert open to closed termination goals *)
   246 local
   247 fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
   248   let
   249     val (vars, prop) = Function_Lib.dest_all_all t
   250     val (prems, concl) = Logic.strip_horn prop
   251     val (lhs, rhs) = concl
   252       |> HOLogic.dest_Trueprop
   253       |> HOLogic.dest_mem |> fst
   254       |> HOLogic.dest_prod
   255   in
   256     (vars, prems, lhs, rhs)
   257   end
   258 
   259 fun mk_pair_compr (T, qs, l, r, conds) =
   260   let
   261     val pT = HOLogic.mk_prodT (T, T)
   262     val n = length qs
   263     val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
   264     val conds' = if null conds then [HOLogic.true_const] else conds
   265   in
   266     HOLogic.Collect_const pT $
   267     Abs ("uu_", pT,
   268       (foldr1 HOLogic.mk_conj (peq :: conds')
   269       |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
   270   end
   271 
   272 in
   273 
   274 fun wf_union_tac ctxt st =
   275   let
   276     val thy = ProofContext.theory_of ctxt
   277     val cert = cterm_of (theory_of_thm st)
   278     val ((_ $ (_ $ rel)) :: ineqs) = prems_of st
   279 
   280     fun mk_compr ineq =
   281       let
   282         val (vars, prems, lhs, rhs) = dest_term ineq
   283       in
   284         mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems)
   285       end
   286 
   287     val relation =
   288       if null ineqs
   289       then Const (@{const_abbrev Set.empty}, fastype_of rel)
   290       else map mk_compr ineqs
   291         |> foldr1 (HOLogic.mk_binop @{const_name Lattices.sup})
   292 
   293     fun solve_membership_tac i =
   294       (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
   295       THEN' (fn j => TRY (rtac @{thm UnI1} j))
   296       THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
   297       THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
   298       THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
   299         ORELSE' ((rtac @{thm conjI})
   300           THEN' (rtac @{thm refl})
   301           THEN' (blast_tac (claset_of ctxt))))  (* Solve rest of context... not very elegant *)
   302       ) i
   303   in
   304     ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)])
   305      THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st
   306   end
   307 
   308 end
   309 
   310 
   311 (* continuation passing repeat combinator *)
   312 fun REPEAT ttac cont err_cont =
   313     ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont
   314 
   315 (*** DEPENDENCY GRAPHS ***)
   316 
   317 fun prove_chain thy chain_tac c1 c2 =
   318   let
   319     val goal =
   320       HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.rel_comp} (c1, c2),
   321         Const (@{const_abbrev Set.empty}, fastype_of c1))
   322       |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
   323   in
   324     case Function_Lib.try_proof (cterm_of thy goal) chain_tac of
   325       Function_Lib.Solved thm => SOME thm
   326     | _ => NONE
   327   end
   328 
   329 fun derive_chains ctxt chain_tac cont D = CALLS (fn (cs, i) =>
   330   let
   331     val thy = ProofContext.theory_of ctxt
   332 
   333     fun derive_chain c1 c2 D =
   334       if is_some (get_chain D c1 c2) then D else
   335       note_chain c1 c2 (prove_chain thy chain_tac c1 c2) D
   336   in
   337     cont (fold_product derive_chain cs cs D) i
   338   end)
   339 
   340 
   341 fun mk_dgraph D cs =
   342   Term_Graph.empty
   343   |> fold (fn c => Term_Graph.new_node (c, ())) cs
   344   |> fold_product (fn c1 => fn c2 =>
   345      if is_none (get_chain D c1 c2 |> the_default NONE)
   346      then Term_Graph.add_edge (c1, c2) else I)
   347      cs cs
   348 
   349 fun ucomp_empty_tac T =
   350   REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR}
   351     ORELSE' rtac @{thm union_comp_emptyL}
   352     ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i))
   353 
   354 fun regroup_calls_tac cs = CALLS (fn (cs', i) =>
   355  let
   356    val is = map (fn c => find_index (curry op aconv c) cs') cs
   357  in
   358    CONVERSION (Conv.arg_conv (Conv.arg_conv
   359      (Function_Lib.regroup_union_conv is))) i
   360  end)
   361 
   362 
   363 fun solve_trivial_tac D = CALLS (fn ([c], i) =>
   364   (case get_chain D c c of
   365      SOME (SOME thm) => rtac @{thm wf_no_loop} i
   366                         THEN rtac thm i
   367    | _ => no_tac)
   368   | _ => no_tac)
   369 
   370 fun decompose_tac' cont err_cont D = CALLS (fn (cs, i) =>
   371   let
   372     val G = mk_dgraph D cs
   373     val sccs = Term_Graph.strong_conn G
   374 
   375     fun split [SCC] i = (solve_trivial_tac D i ORELSE cont D i)
   376       | split (SCC::rest) i =
   377         regroup_calls_tac SCC i
   378         THEN rtac @{thm wf_union_compatible} i
   379         THEN rtac @{thm less_by_empty} (i + 2)
   380         THEN ucomp_empty_tac (the o the oo get_chain D) (i + 2)
   381         THEN split rest (i + 1)
   382         THEN (solve_trivial_tac D i ORELSE cont D i)
   383   in
   384     if length sccs > 1 then split sccs i
   385     else solve_trivial_tac D i ORELSE err_cont D i
   386   end)
   387 
   388 fun decompose_tac ctxt chain_tac cont err_cont =
   389   derive_chains ctxt chain_tac (decompose_tac' cont err_cont)
   390 
   391 
   392 (*** Local Descent Proofs ***)
   393 
   394 fun gen_descent diag ctxt tac cont D = CALLS (fn (cs, i) =>
   395   let
   396     val thy = ProofContext.theory_of ctxt
   397     val measures_of = get_measures D
   398 
   399     fun derive c D =
   400       let
   401         val (_, p, _, q, _, _) = dest_call D c
   402       in
   403         if diag andalso p = q
   404         then fold (fn m => derive_descent thy tac c m m) (measures_of p) D
   405         else fold_product (derive_descent thy tac c)
   406                (measures_of p) (measures_of q) D
   407       end
   408   in
   409     cont (Function_Common.PROFILE "deriving descents" (fold derive cs) D) i
   410   end)
   411 
   412 fun derive_diag ctxt = gen_descent true ctxt
   413 fun derive_all ctxt = gen_descent false ctxt
   414 
   415 
   416 end