src/HOL/Tools/Function/termination.ML
author griff
Tue Apr 03 17:26:30 2012 +0900 (2012-04-03)
changeset 47433 07f4bf913230
parent 46218 ecf6375e2abb
child 47835 2d48bf79b725
permissions -rw-r--r--
renamed "rel_comp" to "relcomp" (to be consistent with, e.g., "relpow")
     1 (*  Title:       HOL/Tools/Function/termination.ML
     2     Author:      Alexander Krauss, TU Muenchen
     3 
     4 Context data for termination proofs.
     5 *)
     6 
     7 signature TERMINATION =
     8 sig
     9 
    10   type data
    11   datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm
    12 
    13   val mk_sumcases : data -> typ -> term list -> term
    14 
    15   val get_num_points : data -> int
    16   val get_types      : data -> int -> typ
    17   val get_measures   : data -> int -> term list
    18 
    19   val get_chain      : data -> term -> term -> thm option option
    20   val get_descent    : data -> term -> term -> term -> cell option
    21 
    22   val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
    23 
    24   val CALLS : (term list * int -> tactic) -> int -> tactic
    25 
    26   (* Termination tactics *)
    27   type ttac = data -> int -> tactic
    28 
    29   val TERMINATION : Proof.context -> tactic -> ttac -> int -> tactic
    30 
    31   val wf_union_tac : Proof.context -> tactic
    32 
    33   val decompose_tac : ttac
    34 end
    35 
    36 
    37 
    38 structure Termination : TERMINATION =
    39 struct
    40 
    41 open Function_Lib
    42 
    43 val term2_ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord
    44 structure Term2tab = Table(type key = term * term val ord = term2_ord);
    45 structure Term3tab =
    46   Table(type key = term * (term * term) val ord = prod_ord Term_Ord.fast_term_ord term2_ord);
    47 
    48 (** Analyzing binary trees **)
    49 
    50 (* Skeleton of a tree structure *)
    51 
    52 datatype skel =
    53   SLeaf of int (* index *)
    54 | SBranch of (skel * skel)
    55 
    56 
    57 (* abstract make and dest functions *)
    58 fun mk_tree leaf branch =
    59   let fun mk (SLeaf i) = leaf i
    60         | mk (SBranch (s, t)) = branch (mk s, mk t)
    61   in mk end
    62 
    63 
    64 fun dest_tree split =
    65   let fun dest (SLeaf i) x = [(i, x)]
    66         | dest (SBranch (s, t)) x =
    67           let val (l, r) = split x
    68           in dest s l @ dest t r end
    69   in dest end
    70 
    71 
    72 (* concrete versions for sum types *)
    73 fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
    74   | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
    75   | is_inj _ = false
    76 
    77 fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
    78   | dest_inl _ = NONE
    79 
    80 fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
    81   | dest_inr _ = NONE
    82 
    83 
    84 fun mk_skel ps =
    85   let
    86     fun skel i ps =
    87       if forall is_inj ps andalso not (null ps)
    88       then let
    89           val (j, s) = skel i (map_filter dest_inl ps)
    90           val (k, t) = skel j (map_filter dest_inr ps)
    91         in (k, SBranch (s, t)) end
    92       else (i + 1, SLeaf i)
    93   in
    94     snd (skel 0 ps)
    95   end
    96 
    97 (* compute list of types for nodes *)
    98 fun node_types sk T = dest_tree (fn Type (@{type_name Sum_Type.sum}, [LT, RT]) => (LT, RT)) sk T |> map snd
    99 
   100 (* find index and raw term *)
   101 fun dest_inj (SLeaf i) trm = (i, trm)
   102   | dest_inj (SBranch (s, t)) trm =
   103     case dest_inl trm of
   104       SOME trm' => dest_inj s trm'
   105     | _ => dest_inj t (the (dest_inr trm))
   106 
   107 
   108 
   109 (** Matrix cell datatype **)
   110 
   111 datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm;
   112 
   113 
   114 type data =
   115   skel                            (* structure of the sum type encoding "program points" *)
   116   * (int -> typ)                  (* types of program points *)
   117   * (term list Inttab.table)      (* measures for program points *)
   118   * (term * term -> thm option)   (* which calls form chains? (cached) *)
   119   * (term * (term * term) -> cell)(* local descents (cached) *)
   120 
   121 
   122 (* Build case expression *)
   123 fun mk_sumcases (sk, _, _, _, _) T fs =
   124   mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
   125           (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT))
   126           sk
   127   |> fst
   128 
   129 fun mk_sum_skel rel =
   130   let
   131     val cs = Function_Lib.dest_binop_list @{const_name Lattices.sup} rel
   132     fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   133       let
   134         val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ _)
   135           = Term.strip_qnt_body @{const_name Ex} c
   136       in cons r o cons l end
   137   in
   138     mk_skel (fold collect_pats cs [])
   139   end
   140 
   141 fun prove_chain thy chain_tac (c1, c2) =
   142   let
   143     val goal =
   144       HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.relcomp} (c1, c2),
   145         Const (@{const_abbrev Set.empty}, fastype_of c1))
   146       |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
   147   in
   148     case Function_Lib.try_proof (cterm_of thy goal) chain_tac of
   149       Function_Lib.Solved thm => SOME thm
   150     | _ => NONE
   151   end
   152 
   153 
   154 fun dest_call' sk (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   155   let
   156     val vs = Term.strip_qnt_vars @{const_name Ex} c
   157 
   158     (* FIXME: throw error "dest_call" for malformed terms *)
   159     val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ Gam)
   160       = Term.strip_qnt_body @{const_name Ex} c
   161     val (p, l') = dest_inj sk l
   162     val (q, r') = dest_inj sk r
   163   in
   164     (vs, p, l', q, r', Gam)
   165   end
   166   | dest_call' _ _ = error "dest_call"
   167 
   168 fun dest_call (sk, _, _, _, _) = dest_call' sk
   169 
   170 fun mk_desc thy tac vs Gam l r m1 m2 =
   171   let
   172     fun try rel =
   173       try_proof (cterm_of thy
   174         (Logic.list_all (vs,
   175            Logic.mk_implies (HOLogic.mk_Trueprop Gam,
   176              HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
   177                $ (m2 $ r) $ (m1 $ l)))))) tac
   178   in
   179     case try @{const_name Orderings.less} of
   180        Solved thm => Less thm
   181      | Stuck thm =>
   182        (case try @{const_name Orderings.less_eq} of
   183           Solved thm2 => LessEq (thm2, thm)
   184         | Stuck thm2 =>
   185           if prems_of thm2 = [HOLogic.Trueprop $ @{term False}]
   186           then False thm2 else None (thm2, thm)
   187         | _ => raise Match) (* FIXME *)
   188      | _ => raise Match
   189 end
   190 
   191 fun prove_descent thy tac sk (c, (m1, m2)) =
   192   let
   193     val (vs, _, l, _, r, Gam) = dest_call' sk c
   194   in 
   195     mk_desc thy tac vs Gam l r m1 m2
   196   end
   197 
   198 fun create ctxt chain_tac descent_tac T rel =
   199   let
   200     val thy = Proof_Context.theory_of ctxt
   201     val sk = mk_sum_skel rel
   202     val Ts = node_types sk T
   203     val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
   204     val chain_cache = Cache.create Term2tab.empty Term2tab.lookup Term2tab.update
   205       (prove_chain thy chain_tac)
   206     val descent_cache = Cache.create Term3tab.empty Term3tab.lookup Term3tab.update
   207       (prove_descent thy descent_tac sk)
   208   in
   209     (sk, nth Ts, M, chain_cache, descent_cache)
   210   end
   211 
   212 fun get_num_points (sk, _, _, _, _) =
   213   let
   214     fun num (SLeaf i) = i + 1
   215       | num (SBranch (s, t)) = num t
   216   in num sk end
   217 
   218 fun get_types (_, T, _, _, _) = T
   219 fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
   220 
   221 fun get_chain (_, _, _, C, _) c1 c2 =
   222   SOME (C (c1, c2))
   223 
   224 fun get_descent (_, _, _, _, D) c m1 m2 =
   225   SOME (D (c, (m1, m2)))
   226 
   227 fun CALLS tac i st =
   228   if Thm.no_prems st then all_tac st
   229   else case Thm.term_of (Thm.cprem_of st i) of
   230     (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list @{const_name Lattices.sup} rel, i) st
   231   |_ => no_tac st
   232 
   233 type ttac = data -> int -> tactic
   234 
   235 fun TERMINATION ctxt atac tac =
   236   SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
   237   let
   238     val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
   239   in
   240     tac (create ctxt atac atac T rel) i
   241   end)
   242 
   243 
   244 (* A tactic to convert open to closed termination goals *)
   245 local
   246 fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
   247   let
   248     val (vars, prop) = Function_Lib.dest_all_all t
   249     val (prems, concl) = Logic.strip_horn prop
   250     val (lhs, rhs) = concl
   251       |> HOLogic.dest_Trueprop
   252       |> HOLogic.dest_mem |> fst
   253       |> HOLogic.dest_prod
   254   in
   255     (vars, prems, lhs, rhs)
   256   end
   257 
   258 fun mk_pair_compr (T, qs, l, r, conds) =
   259   let
   260     val pT = HOLogic.mk_prodT (T, T)
   261     val n = length qs
   262     val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
   263     val conds' = if null conds then [@{term True}] else conds
   264   in
   265     HOLogic.Collect_const pT $
   266     Abs ("uu_", pT,
   267       (foldr1 HOLogic.mk_conj (peq :: conds')
   268       |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
   269   end
   270 
   271 in
   272 
   273 fun wf_union_tac ctxt st =
   274   let
   275     val thy = Proof_Context.theory_of ctxt
   276     val cert = cterm_of (theory_of_thm st)
   277     val ((_ $ (_ $ rel)) :: ineqs) = prems_of st
   278 
   279     fun mk_compr ineq =
   280       let
   281         val (vars, prems, lhs, rhs) = dest_term ineq
   282       in
   283         mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (Object_Logic.atomize_term thy) prems)
   284       end
   285 
   286     val relation =
   287       if null ineqs
   288       then Const (@{const_abbrev Set.empty}, fastype_of rel)
   289       else map mk_compr ineqs
   290         |> foldr1 (HOLogic.mk_binop @{const_name Lattices.sup})
   291 
   292     fun solve_membership_tac i =
   293       (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
   294       THEN' (fn j => TRY (rtac @{thm UnI1} j))
   295       THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
   296       THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
   297       THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
   298         ORELSE' ((rtac @{thm conjI})
   299           THEN' (rtac @{thm refl})
   300           THEN' (blast_tac ctxt)))    (* Solve rest of context... not very elegant *)
   301       ) i
   302   in
   303     ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)])
   304      THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st
   305   end
   306 
   307 end
   308 
   309 
   310 
   311 (*** DEPENDENCY GRAPHS ***)
   312 
   313 fun mk_dgraph D cs =
   314   Term_Graph.empty
   315   |> fold (fn c => Term_Graph.new_node (c, ())) cs
   316   |> fold_product (fn c1 => fn c2 =>
   317      if is_none (get_chain D c1 c2 |> the_default NONE)
   318      then Term_Graph.add_edge (c1, c2) else I)
   319      cs cs
   320 
   321 fun ucomp_empty_tac T =
   322   REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR}
   323     ORELSE' rtac @{thm union_comp_emptyL}
   324     ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i))
   325 
   326 fun regroup_calls_tac cs = CALLS (fn (cs', i) =>
   327  let
   328    val is = map (fn c => find_index (curry op aconv c) cs') cs
   329  in
   330    CONVERSION (Conv.arg_conv (Conv.arg_conv
   331      (Function_Lib.regroup_union_conv is))) i
   332  end)
   333 
   334 
   335 fun solve_trivial_tac D = CALLS (fn ([c], i) =>
   336   (case get_chain D c c of
   337      SOME (SOME thm) => rtac @{thm wf_no_loop} i
   338                         THEN rtac thm i
   339    | _ => no_tac)
   340   | _ => no_tac)
   341 
   342 fun decompose_tac D = CALLS (fn (cs, i) =>
   343   let
   344     val G = mk_dgraph D cs
   345     val sccs = Term_Graph.strong_conn G
   346 
   347     fun split [SCC] i = TRY (solve_trivial_tac D i)
   348       | split (SCC::rest) i =
   349         regroup_calls_tac SCC i
   350         THEN rtac @{thm wf_union_compatible} i
   351         THEN rtac @{thm less_by_empty} (i + 2)
   352         THEN ucomp_empty_tac (the o the oo get_chain D) (i + 2)
   353         THEN split rest (i + 1)
   354         THEN TRY (solve_trivial_tac D i)
   355   in
   356     if length sccs > 1 then split sccs i
   357     else solve_trivial_tac D i
   358   end)
   359 
   360 
   361 end