src/HOL/Tools/function_package/termination.ML
changeset 29125 d41182a8135c
child 29269 5c25a2012975
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/function_package/termination.ML	Tue Dec 16 08:46:07 2008 +0100
@@ -0,0 +1,324 @@
+(*  Title:       HOL/Tools/function_package/termination_data.ML
+    Author:      Alexander Krauss, TU Muenchen
+
+Context data for termination proofs
+*)
+
+
+signature TERMINATION =
+sig
+
+  type data
+  datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm
+
+  val mk_sumcases : data -> typ -> term list -> term
+
+  val note_measure : int -> term -> data -> data
+  val note_chain   : term -> term -> thm option -> data -> data
+  val note_descent : term -> term -> term -> cell -> data -> data
+
+  val get_num_points : data -> int
+  val get_types      : data -> int -> typ
+  val get_measures   : data -> int -> term list
+
+  (* read from cache *)
+  val get_chain      : data -> term -> term -> thm option option
+  val get_descent    : data -> term -> term -> term -> cell option
+
+  (* writes *)
+  val derive_descent  : theory -> tactic -> term -> term -> term -> data -> data
+  val derive_descents : theory -> tactic -> term -> data -> data
+
+  val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
+
+  val CALLS : (term list * int -> tactic) -> int -> tactic
+
+  (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *)
+  type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
+
+  val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic
+
+  val REPEAT : ttac -> ttac
+
+  val wf_union_tac : tactic
+end
+
+
+
+structure Termination : TERMINATION =
+struct
+
+open FundefLib
+
+val term2_ord = prod_ord Term.fast_term_ord Term.fast_term_ord
+structure Term2tab = TableFun(type key = term * term val ord = term2_ord);
+structure Term3tab = TableFun(type key = term * (term * term) val ord = prod_ord Term.fast_term_ord term2_ord);
+
+(** Analyzing binary trees **)
+
+(* Skeleton of a tree structure *)
+
+datatype skel =
+  SLeaf of int (* index *)
+| SBranch of (skel * skel)
+
+
+(* abstract make and dest functions *)
+fun mk_tree leaf branch =
+  let fun mk (SLeaf i) = leaf i
+        | mk (SBranch (s, t)) = branch (mk s, mk t)
+  in mk end
+
+
+fun dest_tree split =
+  let fun dest (SLeaf i) x = [(i, x)]
+        | dest (SBranch (s, t)) x =
+          let val (l, r) = split x
+          in dest s l @ dest t r end
+  in dest end
+
+
+(* concrete versions for sum types *)
+fun is_inj (Const ("Sum_Type.Inl", _) $ _) = true
+  | is_inj (Const ("Sum_Type.Inr", _) $ _) = true
+  | is_inj _ = false
+
+fun dest_inl (Const ("Sum_Type.Inl", _) $ t) = SOME t
+  | dest_inl _ = NONE
+
+fun dest_inr (Const ("Sum_Type.Inr", _) $ t) = SOME t
+  | dest_inr _ = NONE
+
+
+fun mk_skel ps =
+  let
+    fun skel i ps =
+      if forall is_inj ps andalso not (null ps)
+      then let
+          val (j, s) = skel i (map_filter dest_inl ps)
+          val (k, t) = skel j (map_filter dest_inr ps)
+        in (k, SBranch (s, t)) end
+      else (i + 1, SLeaf i)
+  in
+    snd (skel 0 ps)
+  end
+
+(* compute list of types for nodes *)
+fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd
+
+(* find index and raw term *)
+fun dest_inj (SLeaf i) trm = (i, trm)
+  | dest_inj (SBranch (s, t)) trm =
+    case dest_inl trm of
+      SOME trm' => dest_inj s trm'
+    | _ => dest_inj t (the (dest_inr trm))
+
+
+
+(** Matrix cell datatype **)
+
+datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm;
+
+
+type data =
+  skel                            (* structure of the sum type encoding "program points" *)
+  * (int -> typ)                  (* types of program points *)
+  * (term list Inttab.table)      (* measures for program points *)
+  * (thm option Term2tab.table)   (* which calls form chains? *)
+  * (cell Term3tab.table)         (* local descents *)
+
+
+fun map_measures f (p, T, M, C, D) = (p, T, f M, C, D)
+fun map_chains f   (p, T, M, C, D) = (p, T, M, f C, D)
+fun map_descent f  (p, T, M, C, D) = (p, T, M, C, f D)
+
+fun note_measure p m = map_measures (Inttab.insert_list (op aconv) (p, m))
+fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res))
+fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res))
+
+(* Build case expression *)
+fun mk_sumcases (sk, _, _, _, _) T fs =
+  mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
+          (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT))
+          sk
+  |> fst
+
+fun mk_sum_skel rel =
+  let
+    val cs = FundefLib.dest_binop_list @{const_name "op Un"} rel
+    fun collect_pats (Const ("Collect", _) $ Abs (_, _, c)) =
+      let
+        val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
+          = Term.strip_qnt_body "Ex" c
+      in cons r o cons l end
+  in
+    mk_skel (fold collect_pats cs [])
+  end
+
+fun create ctxt T rel =
+  let
+    val sk = mk_sum_skel rel
+    val Ts = node_types sk T
+    val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
+  in
+    (sk, nth Ts, M, Term2tab.empty, Term3tab.empty)
+  end
+
+fun get_num_points (sk, _, _, _, _) =
+  let
+    fun num (SLeaf i) = i + 1
+      | num (SBranch (s, t)) = num t
+  in num sk end
+
+fun get_types (_, T, _, _, _) = T
+fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
+
+fun get_chain (_, _, _, C, _) c1 c2 =
+  Term2tab.lookup C (c1, c2)
+
+fun get_descent (_, _, _, _, D) c m1 m2 =
+  Term3tab.lookup D (c, (m1, m2))
+
+fun dest_call D (Const ("Collect", _) $ Abs (_, _, c)) =
+  let
+    val n = get_num_points D
+    val (sk, _, _, _, _) = D
+    val vs = Term.strip_qnt_vars "Ex" c
+
+    (* FIXME: throw error "dest_call" for malformed terms *)
+    val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam)
+      = Term.strip_qnt_body "Ex" c
+    val (p, l') = dest_inj sk l
+    val (q, r') = dest_inj sk r
+  in
+    (vs, p, l', q, r', Gam)
+  end
+  | dest_call D t = error "dest_call"
+
+
+fun derive_desc_aux thy tac c (vs, p, l', q, r', Gam) m1 m2 D =
+  case get_descent D c m1 m2 of
+    SOME _ => D
+  | NONE => let
+    fun cgoal rel =
+      Term.list_all (vs,
+        Logic.mk_implies (HOLogic.mk_Trueprop Gam,
+          HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"})
+            $ (m2 $ r') $ (m1 $ l'))))
+      |> cterm_of thy
+    in
+      note_descent c m1 m2
+        (case try_proof (cgoal @{const_name HOL.less}) tac of
+           Solved thm => Less thm
+         | Stuck thm =>
+           (case try_proof (cgoal @{const_name HOL.less_eq}) tac of
+              Solved thm2 => LessEq (thm2, thm)
+            | Stuck thm2 =>
+              if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const]
+              then False thm2 else None (thm2, thm)
+            | _ => raise Match) (* FIXME *)
+         | _ => raise Match) D
+      end
+
+fun derive_descent thy tac c m1 m2 D =
+  derive_desc_aux thy tac c (dest_call D c) m1 m2 D
+
+(* all descents in one go *)
+fun derive_descents thy tac c D =
+  let val cdesc as (vs, p, l', q, r', Gam) = dest_call D c
+  in fold_product (derive_desc_aux thy tac c cdesc)
+       (get_measures D p) (get_measures D q) D
+  end
+
+fun CALLS tac i st =
+  if Thm.no_prems st then all_tac st
+  else case Thm.term_of (Thm.cprem_of st i) of
+    (_ $ (_ $ rel)) => tac (FundefLib.dest_binop_list @{const_name "op Un"} rel, i) st
+  |_ => no_tac st
+
+type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
+
+fun TERMINATION ctxt tac =
+  SUBGOAL (fn (_ $ (Const (@{const_name "wf"}, wfT) $ rel), i) =>
+  let
+    val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
+  in
+    tac (create ctxt T rel) i
+  end)
+
+
+(* A tactic to convert open to closed termination goals *)
+local
+fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
+    let
+      val (vars, prop) = FundefLib.dest_all_all t
+      val (prems, concl) = Logic.strip_horn prop
+      val (lhs, rhs) = concl
+                         |> HOLogic.dest_Trueprop
+                         |> HOLogic.dest_mem |> fst
+                         |> HOLogic.dest_prod
+    in
+      (vars, prems, lhs, rhs)
+    end
+
+fun mk_pair_compr (T, qs, l, r, conds) =
+    let
+      val pT = HOLogic.mk_prodT (T, T)
+      val n = length qs
+      val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
+      val conds' = if null conds then [HOLogic.true_const] else conds
+    in
+      HOLogic.Collect_const pT $
+      Abs ("uu_", pT,
+           (foldr1 HOLogic.mk_conj (peq :: conds')
+            |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
+    end
+
+in
+
+fun wf_union_tac st =
+    let
+      val thy = theory_of_thm st
+      val cert = cterm_of (theory_of_thm st)
+      val ((trueprop $ (wf $ rel)) :: ineqs) = prems_of st
+
+      fun mk_compr ineq =
+          let
+            val (vars, prems, lhs, rhs) = dest_term ineq
+          in
+            mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems)
+          end
+
+      val relation =
+          if null ineqs then
+              Const (@{const_name "{}"}, fastype_of rel)
+          else
+              foldr1 (HOLogic.mk_binop @{const_name "op Un"}) (map mk_compr ineqs)
+
+      fun solve_membership_tac i =
+          (EVERY' (replicate (i - 2) (rtac @{thm UnI2}))  (* pick the right component of the union *)
+          THEN' (fn j => TRY (rtac @{thm UnI1} j))
+          THEN' (rtac @{thm CollectI})                    (* unfold comprehension *)
+          THEN' (fn i => REPEAT (rtac @{thm exI} i))      (* Turn existentials into schematic Vars *)
+          THEN' ((rtac @{thm refl})                       (* unification instantiates all Vars *)
+                 ORELSE' ((rtac @{thm conjI})
+                          THEN' (rtac @{thm refl})
+                          THEN' (CLASET' blast_tac)))     (* Solve rest of context... not very elegant *)
+          ) i
+    in
+      ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)])
+      THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st
+    end
+
+
+end
+
+
+(* continuation passing repeat combinator *)
+fun REPEAT ttac cont err_cont =
+    ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont
+
+
+
+
+end