src/HOL/Tools/Function/termination.ML
author haftmann
Sat Sep 19 07:38:03 2009 +0200 (2009-09-19)
changeset 32683 7c1fe854ca6a
parent 32602 f2b741473860
child 33099 b8cdd3d73022
permissions -rw-r--r--
inter and union are mere abbreviations for inf and sup
     1 (*  Title:       HOL/Tools/Function/termination.ML
     2     Author:      Alexander Krauss, TU Muenchen
     3 
     4 Context data for termination proofs
     5 *)
     6 
     7 
     8 signature TERMINATION =
     9 sig
    10 
    11   type data
    12   datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm
    13 
    14   val mk_sumcases : data -> typ -> term list -> term
    15 
    16   val note_measure : int -> term -> data -> data
    17   val note_chain   : term -> term -> thm option -> data -> data
    18   val note_descent : term -> term -> term -> cell -> data -> data
    19 
    20   val get_num_points : data -> int
    21   val get_types      : data -> int -> typ
    22   val get_measures   : data -> int -> term list
    23 
    24   (* read from cache *)
    25   val get_chain      : data -> term -> term -> thm option option
    26   val get_descent    : data -> term -> term -> term -> cell option
    27 
    28   (* writes *)
    29   val derive_descent  : theory -> tactic -> term -> term -> term -> data -> data
    30   val derive_descents : theory -> tactic -> term -> data -> data
    31 
    32   val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
    33 
    34   val CALLS : (term list * int -> tactic) -> int -> tactic
    35 
    36   (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *)
    37   type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
    38 
    39   val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic
    40 
    41   val REPEAT : ttac -> ttac
    42 
    43   val wf_union_tac : Proof.context -> tactic
    44 end
    45 
    46 
    47 
    48 structure Termination : TERMINATION =
    49 struct
    50 
    51 open FundefLib
    52 
    53 val term2_ord = prod_ord TermOrd.fast_term_ord TermOrd.fast_term_ord
    54 structure Term2tab = Table(type key = term * term val ord = term2_ord);
    55 structure Term3tab = Table(type key = term * (term * term) val ord = prod_ord TermOrd.fast_term_ord term2_ord);
    56 
    57 (** Analyzing binary trees **)
    58 
    59 (* Skeleton of a tree structure *)
    60 
    61 datatype skel =
    62   SLeaf of int (* index *)
    63 | SBranch of (skel * skel)
    64 
    65 
    66 (* abstract make and dest functions *)
    67 fun mk_tree leaf branch =
    68   let fun mk (SLeaf i) = leaf i
    69         | mk (SBranch (s, t)) = branch (mk s, mk t)
    70   in mk end
    71 
    72 
    73 fun dest_tree split =
    74   let fun dest (SLeaf i) x = [(i, x)]
    75         | dest (SBranch (s, t)) x =
    76           let val (l, r) = split x
    77           in dest s l @ dest t r end
    78   in dest end
    79 
    80 
    81 (* concrete versions for sum types *)
    82 fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
    83   | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
    84   | is_inj _ = false
    85 
    86 fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
    87   | dest_inl _ = NONE
    88 
    89 fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
    90   | dest_inr _ = NONE
    91 
    92 
    93 fun mk_skel ps =
    94   let
    95     fun skel i ps =
    96       if forall is_inj ps andalso not (null ps)
    97       then let
    98           val (j, s) = skel i (map_filter dest_inl ps)
    99           val (k, t) = skel j (map_filter dest_inr ps)
   100         in (k, SBranch (s, t)) end
   101       else (i + 1, SLeaf i)
   102   in
   103     snd (skel 0 ps)
   104   end
   105 
   106 (* compute list of types for nodes *)
   107 fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd
   108 
   109 (* find index and raw term *)
   110 fun dest_inj (SLeaf i) trm = (i, trm)
   111   | dest_inj (SBranch (s, t)) trm =
   112     case dest_inl trm of
   113       SOME trm' => dest_inj s trm'
   114     | _ => dest_inj t (the (dest_inr trm))
   115 
   116 
   117 
   118 (** Matrix cell datatype **)
   119 
   120 datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm;
   121 
   122 
   123 type data =
   124   skel                            (* structure of the sum type encoding "program points" *)
   125   * (int -> typ)                  (* types of program points *)
   126   * (term list Inttab.table)      (* measures for program points *)
   127   * (thm option Term2tab.table)   (* which calls form chains? *)
   128   * (cell Term3tab.table)         (* local descents *)
   129 
   130 
   131 fun map_measures f (p, T, M, C, D) = (p, T, f M, C, D)
   132 fun map_chains f   (p, T, M, C, D) = (p, T, M, f C, D)
   133 fun map_descent f  (p, T, M, C, D) = (p, T, M, C, f D)
   134 
   135 fun note_measure p m = map_measures (Inttab.insert_list (op aconv) (p, m))
   136 fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res))
   137 fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res))
   138 
   139 (* Build case expression *)
   140 fun mk_sumcases (sk, _, _, _, _) T fs =
   141   mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
   142           (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT))
   143           sk
   144   |> fst
   145 
   146 fun mk_sum_skel rel =
   147   let
   148     val cs = FundefLib.dest_binop_list @{const_name Lattices.sup} rel
   149     fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   150       let
   151         val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
   152           = Term.strip_qnt_body "Ex" c
   153       in cons r o cons l end
   154   in
   155     mk_skel (fold collect_pats cs [])
   156   end
   157 
   158 fun create ctxt T rel =
   159   let
   160     val sk = mk_sum_skel rel
   161     val Ts = node_types sk T
   162     val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
   163   in
   164     (sk, nth Ts, M, Term2tab.empty, Term3tab.empty)
   165   end
   166 
   167 fun get_num_points (sk, _, _, _, _) =
   168   let
   169     fun num (SLeaf i) = i + 1
   170       | num (SBranch (s, t)) = num t
   171   in num sk end
   172 
   173 fun get_types (_, T, _, _, _) = T
   174 fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
   175 
   176 fun get_chain (_, _, _, C, _) c1 c2 =
   177   Term2tab.lookup C (c1, c2)
   178 
   179 fun get_descent (_, _, _, _, D) c m1 m2 =
   180   Term3tab.lookup D (c, (m1, m2))
   181 
   182 fun dest_call D (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   183   let
   184     val n = get_num_points D
   185     val (sk, _, _, _, _) = D
   186     val vs = Term.strip_qnt_vars "Ex" c
   187 
   188     (* FIXME: throw error "dest_call" for malformed terms *)
   189     val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
   190       = Term.strip_qnt_body "Ex" c
   191     val (p, l') = dest_inj sk l
   192     val (q, r') = dest_inj sk r
   193   in
   194     (vs, p, l', q, r', Gam)
   195   end
   196   | dest_call D t = error "dest_call"
   197 
   198 
   199 fun derive_desc_aux thy tac c (vs, p, l', q, r', Gam) m1 m2 D =
   200   case get_descent D c m1 m2 of
   201     SOME _ => D
   202   | NONE => let
   203     fun cgoal rel =
   204       Term.list_all (vs,
   205         Logic.mk_implies (HOLogic.mk_Trueprop Gam,
   206           HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
   207             $ (m2 $ r') $ (m1 $ l'))))
   208       |> cterm_of thy
   209     in
   210       note_descent c m1 m2
   211         (case try_proof (cgoal @{const_name HOL.less}) tac of
   212            Solved thm => Less thm
   213          | Stuck thm =>
   214            (case try_proof (cgoal @{const_name HOL.less_eq}) tac of
   215               Solved thm2 => LessEq (thm2, thm)
   216             | Stuck thm2 =>
   217               if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const]
   218               then False thm2 else None (thm2, thm)
   219             | _ => raise Match) (* FIXME *)
   220          | _ => raise Match) D
   221       end
   222 
   223 fun derive_descent thy tac c m1 m2 D =
   224   derive_desc_aux thy tac c (dest_call D c) m1 m2 D
   225 
   226 (* all descents in one go *)
   227 fun derive_descents thy tac c D =
   228   let val cdesc as (vs, p, l', q, r', Gam) = dest_call D c
   229   in fold_product (derive_desc_aux thy tac c cdesc)
   230        (get_measures D p) (get_measures D q) D
   231   end
   232 
   233 fun CALLS tac i st =
   234   if Thm.no_prems st then all_tac st
   235   else case Thm.term_of (Thm.cprem_of st i) of
   236     (_ $ (_ $ rel)) => tac (FundefLib.dest_binop_list @{const_name Lattices.sup} rel, i) st
   237   |_ => no_tac st
   238 
   239 type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
   240 
   241 fun TERMINATION ctxt tac =
   242   SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
   243   let
   244     val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
   245   in
   246     tac (create ctxt T rel) i
   247   end)
   248 
   249 
   250 (* A tactic to convert open to closed termination goals *)
   251 local
   252 fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
   253     let
   254       val (vars, prop) = FundefLib.dest_all_all t
   255       val (prems, concl) = Logic.strip_horn prop
   256       val (lhs, rhs) = concl
   257                          |> HOLogic.dest_Trueprop
   258                          |> HOLogic.dest_mem |> fst
   259                          |> HOLogic.dest_prod
   260     in
   261       (vars, prems, lhs, rhs)
   262     end
   263 
   264 fun mk_pair_compr (T, qs, l, r, conds) =
   265     let
   266       val pT = HOLogic.mk_prodT (T, T)
   267       val n = length qs
   268       val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
   269       val conds' = if null conds then [HOLogic.true_const] else conds
   270     in
   271       HOLogic.Collect_const pT $
   272       Abs ("uu_", pT,
   273            (foldr1 HOLogic.mk_conj (peq :: conds')
   274             |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
   275     end
   276 
   277 in
   278 
   279 fun wf_union_tac ctxt st =
   280     let
   281       val thy = ProofContext.theory_of ctxt
   282       val cert = cterm_of (theory_of_thm st)
   283       val ((trueprop $ (wf $ rel)) :: ineqs) = prems_of st
   284 
   285       fun mk_compr ineq =
   286           let
   287             val (vars, prems, lhs, rhs) = dest_term ineq
   288           in
   289             mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems)
   290           end
   291 
   292       val relation =
   293           if null ineqs then
   294               Const (@{const_name Set.empty}, fastype_of rel)
   295           else
   296               foldr1 (HOLogic.mk_binop @{const_name Lattices.sup}) (map mk_compr ineqs)
   297 
   298       fun solve_membership_tac i =
   299           (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
   300           THEN' (fn j => TRY (rtac @{thm UnI1} j))
   301           THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
   302           THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
   303           THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
   304                  ORELSE' ((rtac @{thm conjI})
   305                           THEN' (rtac @{thm refl})
   306                           THEN' (blast_tac (claset_of ctxt))))  (* Solve rest of context... not very elegant *)
   307           ) i
   308     in
   309       ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)])
   310       THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st
   311     end
   312 
   313 
   314 end
   315 
   316 
   317 (* continuation passing repeat combinator *)
   318 fun REPEAT ttac cont err_cont =
   319     ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont
   320 
   321 
   322 
   323 
   324 end