Theory Set2_Join_RBT

theory Set2_Join_RBT
imports Set2_Join RBT_Set
(* Author: Tobias Nipkow *)

section "Join-Based Implementation of Sets via RBTs"

theory Set2_Join_RBT
imports
  Set2_Join
  RBT_Set
begin

subsection "Code"

text ‹
Function ‹joinL› joins two trees (and an element).
Precondition: @{prop "bheight l ≤ bheight r"}.
Method:
Descend along the left spine of ‹r›
until you find a subtree with the same ‹bheight› as ‹l›,
then combine them into a new red node.
›
fun joinL :: "'a rbt ⇒ 'a ⇒ 'a rbt ⇒ 'a rbt" where
"joinL l x r =
  (if bheight l = bheight r then R l x r
   else case r of
     B l' x' r' ⇒ baliL (joinL l x l') x' r' |
     R l' x' r' ⇒ R (joinL l x l') x' r')"

fun joinR :: "'a rbt ⇒ 'a ⇒ 'a rbt ⇒ 'a rbt" where
"joinR l x r =
  (if bheight l ≤ bheight r then R l x r
   else case l of
     B l' x' r' ⇒ baliR l' x' (joinR r' x r) |
     R l' x' r' ⇒ R l' x' (joinR r' x r))"

fun join :: "'a rbt ⇒ 'a ⇒ 'a rbt ⇒ 'a rbt" where
"join l x r =
  (if bheight l > bheight r
   then paint Black (joinR l x r)
   else if bheight l < bheight r
   then paint Black (joinL l x r)
   else B l x r)"

declare joinL.simps[simp del]
declare joinR.simps[simp del]

text ‹
One would expect @{const joinR} to be be completely dual to @{const joinL}.
Thus the condition should be @{prop"bheight l = bheight r"}. What we have done
is totalize the function. On the intended domain (@{prop "bheight l ≥ bheight r"})
the two versions behave exactly the same, including complexity. Thus from a programmer's
perspective they are equivalent. However, not from a verifier's perspective:
the total version of @{const joinR} is easier
to reason about because lemmas about it may not require preconditions. In particular
@{prop"set_tree (joinR l x r) = set_tree l ∪ {x} ∪ set_tree r"}
is provable outright and hence also
@{prop"set_tree (join l x r) = set_tree l ∪ {x} ∪ set_tree r"}.
This is necessary because locale @{locale Set2_Join} unconditionally assumes
exactly that. Adding preconditions to this assumptions significantly complicates
the proofs within @{locale Set2_Join}, which we want to avoid.

Why not work with the partial version of @{const joinR} and add the precondition
@{prop "bheight l ≥ bheight r"} to lemmas about @{const joinR}? After all, that is how
we worked with @{const joinL}, and @{const join} ensures that @{const joinL} and @{const joinR}
are only called under the respective precondition. But function @{const bheight}
makes the difference: it descends along the left spine, just like @{const joinL}.
Function @{const joinR}, however, descends along the right spine and thus @{const bheight}
may change all the time. Thus we would need the further precondition @{prop "invh l"}.
This is what we really wanted to avoid in order to satisfy the unconditional assumption
in @{locale Set2_Join}.
›

subsection "Properties"

subsubsection "Color and height invariants"

lemma invc2_joinL:
 "⟦ invc l; invc r; bheight l ≤ bheight r ⟧ ⟹
  invc2 (joinL l x r)
  ∧ (bheight l ≠ bheight r ∧ color r = Black ⟶ invc(joinL l x r))"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: invc_baliL invc2I joinL.simps[of l x r] split!: tree.splits if_splits)
qed

lemma invc2_joinR:
  "⟦ invc l; invh l; invc r; invh r; bheight l ≥ bheight r ⟧ ⟹
  invc2 (joinR l x r)
  ∧ (bheight l ≠ bheight r ∧ color l = Black ⟶ invc(joinR l x r))"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: invc_baliR invc2I joinR.simps[of l x r] split!: tree.splits if_splits)
qed

lemma bheight_joinL:
  "⟦ invh l; invh r; bheight l ≤ bheight r ⟧ ⟹ bheight (joinL l x r) = bheight r"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: bheight_baliL joinL.simps[of l x r] split!: tree.split)
qed

lemma invh_joinL:
  "⟦ invh l;  invh r;  bheight l ≤ bheight r ⟧ ⟹ invh (joinL l x r)"
proof (induct l x r rule: joinL.induct)
  case (1 l x r) thus ?case
    by(auto simp: invh_baliL bheight_joinL joinL.simps[of l x r] split!: tree.split color.split)
qed

lemma bheight_baliR:
  "bheight l = bheight r ⟹ bheight (baliR l a r) = Suc (bheight l)"
by (cases "(l,a,r)" rule: baliR.cases) auto

lemma bheight_joinR:
  "⟦ invh l;  invh r;  bheight l ≥ bheight r ⟧ ⟹ bheight (joinR l x r) = bheight l"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: bheight_baliR joinR.simps[of l x r] split!: tree.split)
qed

lemma invh_joinR:
  "⟦ invh l; invh r; bheight l ≥ bheight r ⟧ ⟹ invh (joinR l x r)"
proof (induct l x r rule: joinR.induct)
  case (1 l x r) thus ?case
    by(fastforce simp: invh_baliR bheight_joinR joinR.simps[of l x r]
        split!: tree.split color.split)
qed

(* unused *)
lemma rbt_join: "⟦ invc l; invh l; invc r; invh r ⟧ ⟹ rbt(join l x r)"
by(simp add: invc2_joinL invc2_joinR invc_paint_Black invh_joinL invh_joinR invh_paint rbt_def
    color_paint_Black)

text ‹To make sure the the black height is not increased unnecessarily:›

lemma bheight_paint_Black: "bheight(paint Black t) ≤ bheight t + 1"
by(cases t) auto

lemma "⟦ rbt l; rbt r ⟧ ⟹ bheight(join l x r) ≤ max (bheight l) (bheight r) + 1"
using bheight_paint_Black[of "joinL l x r"] bheight_paint_Black[of "joinR l x r"]
  bheight_joinL[of l r x] bheight_joinR[of l r x]
by(auto simp: max_def rbt_def)


subsubsection "Inorder properties"

text "Currently unused. Instead @{const set_tree} and @{const bst} properties are proved directly."

lemma inorder_joinL: "bheight l ≤ bheight r ⟹ inorder(joinL l x r) = inorder l @ x # inorder r"
proof(induction l x r rule: joinL.induct)
  case (1 l x r)
  thus ?case by(auto simp: inorder_baliL joinL.simps[of l x r] split!: tree.splits color.splits)
qed

lemma inorder_joinR:
  "inorder(joinR l x r) = inorder l @ x # inorder r"
proof(induction l x r rule: joinR.induct)
  case (1 l x r)
  thus ?case by (force simp: inorder_baliR joinR.simps[of l x r] split!: tree.splits color.splits)
qed

lemma "inorder(join l x r) = inorder l @ x # inorder r"
by(auto simp: inorder_joinL inorder_joinR inorder_paint split!: tree.splits color.splits if_splits
      dest!: arg_cong[where f = inorder])


subsubsection "Set and bst properties"

lemma set_baliL:
  "set_tree(baliL l a r) = set_tree l ∪ {a} ∪ set_tree r"
by(cases "(l,a,r)" rule: baliL.cases) (auto)

lemma set_joinL:
  "bheight l ≤ bheight r ⟹ set_tree (joinL l x r) = set_tree l ∪ {x} ∪ set_tree r"
proof(induction l x r rule: joinL.induct)
  case (1 l x r)
  thus ?case by(auto simp: set_baliL joinL.simps[of l x r] split!: tree.splits color.splits)
qed

lemma set_baliR:
  "set_tree(baliR l a r) = set_tree l ∪ {a} ∪ set_tree r"
by(cases "(l,a,r)" rule: baliR.cases) (auto)

lemma set_joinR:
  "set_tree (joinR l x r) = set_tree l ∪ {x} ∪ set_tree r"
proof(induction l x r rule: joinR.induct)
  case (1 l x r)
  thus ?case by(force simp: set_baliR joinR.simps[of l x r] split!: tree.splits color.splits)
qed

lemma set_paint: "set_tree (paint c t) = set_tree t"
by (cases t) auto

lemma set_join: "set_tree (join l x r) = set_tree l ∪ {x} ∪ set_tree r"
by(simp add: set_joinL set_joinR set_paint)

lemma bst_baliL:
  "⟦bst l; bst r; ∀x∈set_tree l. x < k; ∀x∈set_tree r. k < x⟧
   ⟹ bst (baliL l k r)"
by(cases "(l,k,r)" rule: baliL.cases) (auto simp: ball_Un)

lemma bst_baliR:
  "⟦bst l; bst r; ∀x∈set_tree l. x < k; ∀x∈set_tree r. k < x⟧
   ⟹ bst (baliR l k r)"
by(cases "(l,k,r)" rule: baliR.cases) (auto simp: ball_Un)

lemma bst_joinL:
  "⟦bst l; bst r; ∀x∈set_tree l. x < k; ∀y∈set_tree r. k < y; bheight l ≤ bheight r⟧
  ⟹ bst (joinL l k r)"
proof(induction l k r rule: joinL.induct)
  case (1 l x r)
  thus ?case
    by(auto simp: set_baliL joinL.simps[of l x r] set_joinL ball_Un intro!: bst_baliL
        split!: tree.splits color.splits)
qed

lemma bst_joinR:
  "⟦bst l; bst r; ∀x∈set_tree l. x < k; ∀y∈set_tree r. k < y ⟧
  ⟹ bst (joinR l k r)"
proof(induction l k r rule: joinR.induct)
  case (1 l x r)
  thus ?case
    by(auto simp: set_baliR joinR.simps[of l x r] set_joinR ball_Un intro!: bst_baliR
        split!: tree.splits color.splits)
qed

lemma bst_paint: "bst (paint c t) = bst t"
by(cases t) auto

lemma bst_join:
  "⟦bst l; bst r; ∀x∈set_tree l. x < k; ∀y∈set_tree r. k < y ⟧
  ⟹ bst (join l k r)"
by(auto simp: bst_paint bst_joinL bst_joinR)


subsubsection "Interpretation of @{locale Set2_Join} with Red-Black Tree"

global_interpretation RBT: Set2_Join
where join = join and inv = "λt. invc t ∧ invh t"
defines insert_rbt = RBT.insert and delete_rbt = RBT.delete and split_rbt = RBT.split
and join2_rbt = RBT.join2 and split_min_rbt = RBT.split_min
proof (standard, goal_cases)
  case 1 show ?case by (rule set_join)
next
  case 2 thus ?case by (rule bst_join)
next
  case 3 show ?case by simp
next
  case 4 thus ?case
    by (simp add: invc2_joinL invc2_joinR invc_paint_Black invh_joinL invh_joinR invh_paint)
next
  case 5 thus ?case by simp
qed

text ‹The invariant does not guarantee that the root node is black. This is not required
to guarantee that the height is logarithmic in the size --- Exercise.›

end