src/HOL/Tools/Function/termination.ML
author wenzelm
Fri Mar 06 15:58:56 2015 +0100 (2015-03-06)
changeset 59621 291934bac95e
parent 59618 e6939796717e
child 59625 aacdce52b2fc
permissions -rw-r--r--
Thm.cterm_of and Thm.ctyp_of operate on local context;
     1 (*  Title:       HOL/Tools/Function/termination.ML
     2     Author:      Alexander Krauss, TU Muenchen
     3 
     4 Context data for termination proofs.
     5 *)
     6 
     7 signature TERMINATION =
     8 sig
     9   type data
    10   datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm
    11 
    12   val mk_sumcases : data -> typ -> term list -> term
    13 
    14   val get_num_points : data -> int
    15   val get_types      : data -> int -> typ
    16   val get_measures   : data -> int -> term list
    17 
    18   val get_chain      : data -> term -> term -> thm option option
    19   val get_descent    : data -> term -> term -> term -> cell option
    20 
    21   val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
    22 
    23   val CALLS : (term list * int -> tactic) -> int -> tactic
    24 
    25   (* Termination tactics *)
    26   type ttac = data -> int -> tactic
    27 
    28   val TERMINATION : Proof.context -> tactic -> ttac -> int -> tactic
    29 
    30   val wf_union_tac : Proof.context -> tactic
    31 
    32   val decompose_tac : Proof.context -> ttac
    33 end
    34 
    35 
    36 
    37 structure Termination : TERMINATION =
    38 struct
    39 
    40 open Function_Lib
    41 
    42 val term2_ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord
    43 structure Term2tab = Table(type key = term * term val ord = term2_ord);
    44 structure Term3tab =
    45   Table(type key = term * (term * term) val ord = prod_ord Term_Ord.fast_term_ord term2_ord);
    46 
    47 (** Analyzing binary trees **)
    48 
    49 (* Skeleton of a tree structure *)
    50 
    51 datatype skel =
    52   SLeaf of int (* index *)
    53 | SBranch of (skel * skel)
    54 
    55 
    56 (* abstract make and dest functions *)
    57 fun mk_tree leaf branch =
    58   let fun mk (SLeaf i) = leaf i
    59         | mk (SBranch (s, t)) = branch (mk s, mk t)
    60   in mk end
    61 
    62 
    63 fun dest_tree split =
    64   let fun dest (SLeaf i) x = [(i, x)]
    65         | dest (SBranch (s, t)) x =
    66           let val (l, r) = split x
    67           in dest s l @ dest t r end
    68   in dest end
    69 
    70 
    71 (* concrete versions for sum types *)
    72 fun is_inj (Const (@{const_name Sum_Type.Inl}, _) $ _) = true
    73   | is_inj (Const (@{const_name Sum_Type.Inr}, _) $ _) = true
    74   | is_inj _ = false
    75 
    76 fun dest_inl (Const (@{const_name Sum_Type.Inl}, _) $ t) = SOME t
    77   | dest_inl _ = NONE
    78 
    79 fun dest_inr (Const (@{const_name Sum_Type.Inr}, _) $ t) = SOME t
    80   | dest_inr _ = NONE
    81 
    82 
    83 fun mk_skel ps =
    84   let
    85     fun skel i ps =
    86       if forall is_inj ps andalso not (null ps)
    87       then let
    88           val (j, s) = skel i (map_filter dest_inl ps)
    89           val (k, t) = skel j (map_filter dest_inr ps)
    90         in (k, SBranch (s, t)) end
    91       else (i + 1, SLeaf i)
    92   in
    93     snd (skel 0 ps)
    94   end
    95 
    96 (* compute list of types for nodes *)
    97 fun node_types sk T = dest_tree (fn Type (@{type_name Sum_Type.sum}, [LT, RT]) => (LT, RT)) sk T |> map snd
    98 
    99 (* find index and raw term *)
   100 fun dest_inj (SLeaf i) trm = (i, trm)
   101   | dest_inj (SBranch (s, t)) trm =
   102     case dest_inl trm of
   103       SOME trm' => dest_inj s trm'
   104     | _ => dest_inj t (the (dest_inr trm))
   105 
   106 
   107 
   108 (** Matrix cell datatype **)
   109 
   110 datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm;
   111 
   112 
   113 type data =
   114   skel                            (* structure of the sum type encoding "program points" *)
   115   * (int -> typ)                  (* types of program points *)
   116   * (term list Inttab.table)      (* measures for program points *)
   117   * (term * term -> thm option)   (* which calls form chains? (cached) *)
   118   * (term * (term * term) -> cell)(* local descents (cached) *)
   119 
   120 
   121 (* Build case expression *)
   122 fun mk_sumcases (sk, _, _, _, _) T fs =
   123   mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
   124           (fn ((f, fT), (g, gT)) => (Sum_Tree.mk_sumcase fT gT T f g, Sum_Tree.mk_sumT fT gT))
   125           sk
   126   |> fst
   127 
   128 fun mk_sum_skel rel =
   129   let
   130     val cs = Function_Lib.dest_binop_list @{const_name Lattices.sup} rel
   131     fun collect_pats (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   132       let
   133         val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ _)
   134           = Term.strip_qnt_body @{const_name Ex} c
   135       in cons r o cons l end
   136   in
   137     mk_skel (fold collect_pats cs [])
   138   end
   139 
   140 fun prove_chain ctxt chain_tac (c1, c2) =
   141   let
   142     val goal =
   143       HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.relcomp} (c1, c2),
   144         Const (@{const_abbrev Set.empty}, fastype_of c1))
   145       |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
   146   in
   147     case Function_Lib.try_proof (Thm.cterm_of ctxt goal) chain_tac of
   148       Function_Lib.Solved thm => SOME thm
   149     | _ => NONE
   150   end
   151 
   152 
   153 fun dest_call' sk (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   154   let
   155     val vs = Term.strip_qnt_vars @{const_name Ex} c
   156 
   157     (* FIXME: throw error "dest_call" for malformed terms *)
   158     val (Const (@{const_name HOL.conj}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ (Const (@{const_name Pair}, _) $ r $ l)) $ Gam)
   159       = Term.strip_qnt_body @{const_name Ex} c
   160     val (p, l') = dest_inj sk l
   161     val (q, r') = dest_inj sk r
   162   in
   163     (vs, p, l', q, r', Gam)
   164   end
   165   | dest_call' _ _ = error "dest_call"
   166 
   167 fun dest_call (sk, _, _, _, _) = dest_call' sk
   168 
   169 fun mk_desc ctxt tac vs Gam l r m1 m2 =
   170   let
   171     fun try rel =
   172       try_proof (Thm.cterm_of ctxt
   173         (Logic.list_all (vs,
   174            Logic.mk_implies (HOLogic.mk_Trueprop Gam,
   175              HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
   176                $ (m2 $ r) $ (m1 $ l)))))) tac
   177   in
   178     case try @{const_name Orderings.less} of
   179        Solved thm => Less thm
   180      | Stuck thm =>
   181        (case try @{const_name Orderings.less_eq} of
   182           Solved thm2 => LessEq (thm2, thm)
   183         | Stuck thm2 =>
   184           if Thm.prems_of thm2 = [HOLogic.Trueprop $ @{term False}]
   185           then False thm2 else None (thm2, thm)
   186         | _ => raise Match) (* FIXME *)
   187      | _ => raise Match
   188 end
   189 
   190 fun prove_descent ctxt tac sk (c, (m1, m2)) =
   191   let
   192     val (vs, _, l, _, r, Gam) = dest_call' sk c
   193   in 
   194     mk_desc ctxt tac vs Gam l r m1 m2
   195   end
   196 
   197 fun create ctxt chain_tac descent_tac T rel =
   198   let
   199     val sk = mk_sum_skel rel
   200     val Ts = node_types sk T
   201     val M = Inttab.make (map_index (apsnd (Measure_Functions.get_measure_functions ctxt)) Ts)
   202     val chain_cache =
   203       Cache.create Term2tab.empty Term2tab.lookup Term2tab.update
   204         (prove_chain ctxt chain_tac)
   205     val descent_cache =
   206       Cache.create Term3tab.empty Term3tab.lookup Term3tab.update
   207         (prove_descent ctxt descent_tac sk)
   208   in
   209     (sk, nth Ts, M, chain_cache, descent_cache)
   210   end
   211 
   212 fun get_num_points (sk, _, _, _, _) =
   213   let
   214     fun num (SLeaf i) = i + 1
   215       | num (SBranch (s, t)) = num t
   216   in num sk end
   217 
   218 fun get_types (_, T, _, _, _) = T
   219 fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
   220 
   221 fun get_chain (_, _, _, C, _) c1 c2 =
   222   SOME (C (c1, c2))
   223 
   224 fun get_descent (_, _, _, _, D) c m1 m2 =
   225   SOME (D (c, (m1, m2)))
   226 
   227 fun CALLS tac i st =
   228   if Thm.no_prems st then all_tac st
   229   else case Thm.term_of (Thm.cprem_of st i) of
   230     (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list @{const_name Lattices.sup} rel, i) st
   231   |_ => no_tac st
   232 
   233 type ttac = data -> int -> tactic
   234 
   235 fun TERMINATION ctxt atac tac =
   236   SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
   237   let
   238     val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
   239   in
   240     tac (create ctxt atac atac T rel) i
   241   end)
   242 
   243 
   244 (* A tactic to convert open to closed termination goals *)
   245 local
   246 fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
   247   let
   248     val (vars, prop) = Function_Lib.dest_all_all t
   249     val (prems, concl) = Logic.strip_horn prop
   250     val (lhs, rhs) = concl
   251       |> HOLogic.dest_Trueprop
   252       |> HOLogic.dest_mem |> fst
   253       |> HOLogic.dest_prod
   254   in
   255     (vars, prems, lhs, rhs)
   256   end
   257 
   258 fun mk_pair_compr (T, qs, l, r, conds) =
   259   let
   260     val pT = HOLogic.mk_prodT (T, T)
   261     val n = length qs
   262     val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
   263     val conds' = if null conds then [@{term True}] else conds
   264   in
   265     HOLogic.Collect_const pT $
   266     Abs ("uu_", pT,
   267       (foldr1 HOLogic.mk_conj (peq :: conds')
   268       |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
   269   end
   270 
   271 val Un_aci_simps =
   272   map mk_meta_eq @{thms Un_ac Un_absorb}
   273 
   274 in
   275 
   276 fun wf_union_tac ctxt st = SUBGOAL (fn _ =>
   277   let
   278     val thy = Proof_Context.theory_of ctxt
   279     val ((_ $ (_ $ rel)) :: ineqs) = Thm.prems_of st
   280 
   281     fun mk_compr ineq =
   282       let
   283         val (vars, prems, lhs, rhs) = dest_term ineq
   284       in
   285         mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (Object_Logic.atomize_term thy) prems)
   286       end
   287 
   288     val relation =
   289       if null ineqs
   290       then Const (@{const_abbrev Set.empty}, fastype_of rel)
   291       else map mk_compr ineqs
   292         |> foldr1 (HOLogic.mk_binop @{const_name Lattices.sup})
   293 
   294     fun solve_membership_tac i =
   295       (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
   296       THEN' (fn j => TRY (rtac @{thm UnI1} j))
   297       THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
   298       THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
   299       THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
   300         ORELSE' ((rtac @{thm conjI})
   301           THEN' (rtac @{thm refl})
   302           THEN' (blast_tac ctxt)))    (* Solve rest of context... not very elegant *)
   303       ) i
   304   in
   305     (PRIMITIVE (Drule.cterm_instantiate [apply2 (Thm.cterm_of ctxt) (rel, relation)])
   306      THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i)
   307      THEN rewrite_goal_tac ctxt Un_aci_simps 1)  (* eliminate duplicates *)
   308   end) 1 st
   309 
   310 end
   311 
   312 
   313 
   314 (*** DEPENDENCY GRAPHS ***)
   315 
   316 fun mk_dgraph D cs =
   317   Term_Graph.empty
   318   |> fold (fn c => Term_Graph.new_node (c, ())) cs
   319   |> fold_product (fn c1 => fn c2 =>
   320      if is_none (get_chain D c1 c2 |> the_default NONE)
   321      then Term_Graph.add_edge (c2, c1) else I)
   322      cs cs
   323 
   324 fun ucomp_empty_tac T =
   325   REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR}
   326     ORELSE' rtac @{thm union_comp_emptyL}
   327     ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i))
   328 
   329 fun regroup_calls_tac ctxt cs = CALLS (fn (cs', i) =>
   330  let
   331    val is = map (fn c => find_index (curry op aconv c) cs') cs
   332  in
   333    CONVERSION (Conv.arg_conv (Conv.arg_conv
   334      (Function_Lib.regroup_union_conv is))) i
   335  end)
   336 
   337 
   338 fun solve_trivial_tac D = CALLS (fn ([c], i) =>
   339   (case get_chain D c c of
   340      SOME (SOME thm) => rtac @{thm wf_no_loop} i
   341                         THEN rtac thm i
   342    | _ => no_tac)
   343   | _ => no_tac)
   344 
   345 fun decompose_tac ctxt D = CALLS (fn (cs, i) =>
   346   let
   347     val G = mk_dgraph D cs
   348     val sccs = Term_Graph.strong_conn G
   349 
   350     fun split [SCC] i = TRY (solve_trivial_tac D i)
   351       | split (SCC::rest) i =
   352         regroup_calls_tac ctxt SCC i
   353         THEN rtac @{thm wf_union_compatible} i
   354         THEN rtac @{thm less_by_empty} (i + 2)
   355         THEN ucomp_empty_tac (the o the oo get_chain D) (i + 2)
   356         THEN split rest (i + 1)
   357         THEN TRY (solve_trivial_tac D i)
   358   in
   359     if length sccs > 1 then split sccs i
   360     else solve_trivial_tac D i
   361   end)
   362 
   363 
   364 end