src/HOL/Data_Structures/AVL_Set.thy
changeset 61232 c46faf9762f7
child 61428 5e1938107371
equal deleted inserted replaced
61231:cc6969542f8d 61232:c46faf9762f7
       
     1 (*
       
     2 Author:     Tobias Nipkow
       
     3 Derived from AFP entry AVL.
       
     4 *)
       
     5 
       
     6 section "AVL Tree Implementation of Sets"
       
     7 
       
     8 theory AVL_Set
       
     9 imports Isin2
       
    10 begin
       
    11 
       
    12 type_synonym 'a avl_tree = "('a,nat) tree"
       
    13 
       
    14 text {* Invariant: *}
       
    15 
       
    16 fun avl :: "'a avl_tree \<Rightarrow> bool" where
       
    17 "avl Leaf = True" |
       
    18 "avl (Node h l a r) =
       
    19  ((height l = height r \<or> height l = height r + 1 \<or> height r = height l + 1) \<and> 
       
    20   h = max (height l) (height r) + 1 \<and> avl l \<and> avl r)"
       
    21 
       
    22 fun ht :: "'a avl_tree \<Rightarrow> nat" where
       
    23 "ht Leaf = 0" |
       
    24 "ht (Node h l a r) = h"
       
    25 
       
    26 definition node :: "'a avl_tree \<Rightarrow> 'a \<Rightarrow> 'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    27 "node l a r = Node (max (ht l) (ht r) + 1) l a r"
       
    28 
       
    29 definition node_bal_l :: "'a avl_tree \<Rightarrow> 'a \<Rightarrow> 'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    30 "node_bal_l l a r = (
       
    31   if ht l = ht r + 2 then (case l of 
       
    32     Node _ bl b br \<Rightarrow> (if ht bl < ht br
       
    33     then case br of
       
    34       Node _ cl c cr \<Rightarrow> node (node bl b cl) c (node cr a r)
       
    35     else node bl b (node br a r)))
       
    36   else node l a r)"
       
    37 
       
    38 definition node_bal_r :: "'a avl_tree \<Rightarrow> 'a \<Rightarrow> 'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    39 "node_bal_r l a r = (
       
    40   if ht r = ht l + 2 then (case r of
       
    41     Node _ bl b br \<Rightarrow> (if ht bl > ht br
       
    42     then case bl of
       
    43       Node _ cl c cr \<Rightarrow> node (node l a cl) c (node cr b br)
       
    44     else node (node l a bl) b br))
       
    45   else node l a r)"
       
    46 
       
    47 fun insert :: "'a::order \<Rightarrow> 'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    48 "insert x Leaf = Node 1 Leaf x Leaf" |
       
    49 "insert x (Node h l a r) = 
       
    50    (if x=a then Node h l a r
       
    51     else if x<a
       
    52       then node_bal_l (insert x l) a r
       
    53       else node_bal_r l a (insert x r))"
       
    54 
       
    55 fun delete_max :: "'a avl_tree \<Rightarrow> 'a avl_tree * 'a" where
       
    56 "delete_max (Node _ l a Leaf) = (l,a)" |
       
    57 "delete_max (Node _ l a r) = (
       
    58   let (r',a') = delete_max r in
       
    59   (node_bal_l l a r', a'))"
       
    60 
       
    61 lemmas delete_max_induct = delete_max.induct[case_names Leaf Node]
       
    62 
       
    63 fun delete_root :: "'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    64 "delete_root (Node h Leaf a r) = r" |
       
    65 "delete_root (Node h l a Leaf) = l" |
       
    66 "delete_root (Node h l a r) =
       
    67   (let (l', a') = delete_max l in node_bal_r l' a' r)"
       
    68 
       
    69 lemmas delete_root_cases = delete_root.cases[case_names Leaf_t Node_Leaf Node_Node]
       
    70 
       
    71 fun delete :: "'a::order \<Rightarrow> 'a avl_tree \<Rightarrow> 'a avl_tree" where
       
    72 "delete _ Leaf = Leaf" |
       
    73 "delete x (Node h l a r) = (
       
    74    if x = a then delete_root (Node h l a r)
       
    75    else if x < a then node_bal_r (delete x l) a r
       
    76    else node_bal_l l a (delete x r))"
       
    77 
       
    78 
       
    79 subsection {* Functional Correctness Proofs *}
       
    80 
       
    81 text{* Very different from the AFP/AVL proofs *}
       
    82 
       
    83 
       
    84 subsubsection "Proofs for insert"
       
    85 
       
    86 lemma inorder_node_bal_l:
       
    87   "inorder (node_bal_l l a r) = inorder l @ a # inorder r"
       
    88 by (auto simp: node_def node_bal_l_def split:tree.splits)
       
    89 
       
    90 lemma inorder_node_bal_r:
       
    91   "inorder (node_bal_r l a r) = inorder l @ a # inorder r"
       
    92 by (auto simp: node_def node_bal_r_def split:tree.splits)
       
    93 
       
    94 theorem inorder_insert:
       
    95   "sorted(inorder t) \<Longrightarrow> inorder(insert x t) = ins_list x (inorder t)"
       
    96 by (induct t) 
       
    97    (auto simp: ins_list_simps inorder_node_bal_l inorder_node_bal_r)
       
    98 
       
    99 
       
   100 subsubsection "Proofs for delete"
       
   101 
       
   102 lemma inorder_delete_maxD:
       
   103   "\<lbrakk> delete_max t = (t',a); t \<noteq> Leaf \<rbrakk> \<Longrightarrow>
       
   104    inorder t' @ [a] = inorder t"
       
   105 by(induction t arbitrary: t' rule: delete_max.induct)
       
   106   (auto simp: inorder_node_bal_l split: prod.splits tree.split)
       
   107 
       
   108 lemma inorder_delete_root:
       
   109   "inorder (delete_root (Node h l a r)) = inorder l @ inorder r"
       
   110 by(induction "Node h l a r" arbitrary: l a r h rule: delete_root.induct)
       
   111   (auto simp: inorder_node_bal_r inorder_delete_maxD split: prod.splits)
       
   112 
       
   113 theorem inorder_delete:
       
   114   "sorted(inorder t) \<Longrightarrow> inorder (delete x t) = del_list x (inorder t)"
       
   115 by(induction t)
       
   116   (auto simp: del_list_simps inorder_node_bal_l inorder_node_bal_r
       
   117     inorder_delete_root inorder_delete_maxD split: prod.splits)
       
   118 
       
   119 
       
   120 subsubsection "Overall functional correctness"
       
   121 
       
   122 interpretation Set_by_Ordered
       
   123 where empty = Leaf and isin = isin and insert = insert and delete = delete
       
   124 and inorder = inorder and wf = "\<lambda>_. True"
       
   125 proof (standard, goal_cases)
       
   126   case 1 show ?case by simp
       
   127 next
       
   128   case 2 thus ?case by(simp add: isin_set)
       
   129 next
       
   130   case 3 thus ?case by(simp add: inorder_insert)
       
   131 next
       
   132   case 4 thus ?case by(simp add: inorder_delete)
       
   133 next
       
   134   case 5 thus ?case ..
       
   135 qed
       
   136 
       
   137 
       
   138 subsection {* AVL invariants *}
       
   139 
       
   140 text{* Essentially the AFP/AVL proofs *}
       
   141 
       
   142 
       
   143 subsubsection {* Insertion maintains AVL balance *}
       
   144 
       
   145 declare Let_def [simp]
       
   146 
       
   147 lemma [simp]: "avl t \<Longrightarrow> ht t = height t"
       
   148 by (induct t) simp_all
       
   149 
       
   150 lemma height_node_bal_l:
       
   151   "\<lbrakk> height l = height r + 2; avl l; avl r \<rbrakk> \<Longrightarrow>
       
   152    height (node_bal_l l a r) = height r + 2 \<or>
       
   153    height (node_bal_l l a r) = height r + 3"
       
   154 by (cases l) (auto simp:node_def node_bal_l_def split:tree.split)
       
   155        
       
   156 lemma height_node_bal_r:
       
   157   "\<lbrakk> height r = height l + 2; avl l; avl r \<rbrakk> \<Longrightarrow>
       
   158    height (node_bal_r l a r) = height l + 2 \<or>
       
   159    height (node_bal_r l a r) = height l + 3"
       
   160 by (cases r) (auto simp add:node_def node_bal_r_def split:tree.split)
       
   161 
       
   162 lemma [simp]: "height(node l a r) = max (height l) (height r) + 1"
       
   163 by (simp add: node_def)
       
   164 
       
   165 lemma avl_node:
       
   166   "\<lbrakk> avl l; avl r;
       
   167      height l = height r \<or> height l = height r + 1 \<or> height r = height l + 1
       
   168    \<rbrakk> \<Longrightarrow> avl(node l a r)"
       
   169 by (auto simp add:max_def node_def)
       
   170 
       
   171 lemma height_node_bal_l2:
       
   172   "\<lbrakk> avl l; avl r; height l \<noteq> height r + 2 \<rbrakk> \<Longrightarrow>
       
   173    height (node_bal_l l a r) = (1 + max (height l) (height r))"
       
   174 by (cases l, cases r) (simp_all add: node_bal_l_def)
       
   175 
       
   176 lemma height_node_bal_r2:
       
   177   "\<lbrakk> avl l;  avl r;  height r \<noteq> height l + 2 \<rbrakk> \<Longrightarrow>
       
   178    height (node_bal_r l a r) = (1 + max (height l) (height r))"
       
   179 by (cases l, cases r) (simp_all add: node_bal_r_def)
       
   180 
       
   181 lemma avl_node_bal_l: 
       
   182   assumes "avl l" "avl r" and "height l = height r \<or> height l = height r + 1
       
   183     \<or> height r = height l + 1 \<or> height l = height r + 2" 
       
   184   shows "avl(node_bal_l l a r)"
       
   185 proof(cases l)
       
   186   case Leaf
       
   187   with assms show ?thesis by (simp add: node_def node_bal_l_def)
       
   188 next
       
   189   case (Node ln ll lr lh)
       
   190   with assms show ?thesis
       
   191   proof(cases "height l = height r + 2")
       
   192     case True
       
   193     from True Node assms show ?thesis
       
   194       by (auto simp: node_bal_l_def intro!: avl_node split: tree.split) arith+
       
   195   next
       
   196     case False
       
   197     with assms show ?thesis by (simp add: avl_node node_bal_l_def)
       
   198   qed
       
   199 qed
       
   200 
       
   201 lemma avl_node_bal_r: 
       
   202   assumes "avl l" and "avl r" and "height l = height r \<or> height l = height r + 1
       
   203     \<or> height r = height l + 1 \<or> height r = height l + 2" 
       
   204   shows "avl(node_bal_r l a r)"
       
   205 proof(cases r)
       
   206   case Leaf
       
   207   with assms show ?thesis by (simp add: node_def node_bal_r_def)
       
   208 next
       
   209   case (Node rn rl rr rh)
       
   210   with assms show ?thesis
       
   211   proof(cases "height r = height l + 2")
       
   212     case True
       
   213       from True Node assms show ?thesis
       
   214         by (auto simp: node_bal_r_def intro!: avl_node split: tree.split) arith+
       
   215   next
       
   216     case False
       
   217     with assms show ?thesis by (simp add: node_bal_r_def avl_node)
       
   218   qed
       
   219 qed
       
   220 
       
   221 (* It appears that these two properties need to be proved simultaneously: *)
       
   222 
       
   223 text{* Insertion maintains the AVL property: *}
       
   224 
       
   225 theorem avl_insert_aux:
       
   226   assumes "avl t"
       
   227   shows "avl(insert x t)"
       
   228         "(height (insert x t) = height t \<or> height (insert x t) = height t + 1)"
       
   229 using assms
       
   230 proof (induction t)
       
   231   case (Node h l a r)
       
   232   case 1
       
   233   with Node show ?case
       
   234   proof(cases "x = a")
       
   235     case True
       
   236     with Node 1 show ?thesis by simp
       
   237   next
       
   238     case False
       
   239     with Node 1 show ?thesis 
       
   240     proof(cases "x<a")
       
   241       case True
       
   242       with Node 1 show ?thesis by (auto simp add:avl_node_bal_l)
       
   243     next
       
   244       case False
       
   245       with Node 1 `x\<noteq>a` show ?thesis by (auto simp add:avl_node_bal_r)
       
   246     qed
       
   247   qed
       
   248   case 2
       
   249   from 2 Node show ?case
       
   250   proof(cases "x = a")
       
   251     case True
       
   252     with Node 1 show ?thesis by simp
       
   253   next
       
   254     case False
       
   255     with Node 1 show ?thesis 
       
   256      proof(cases "x<a")
       
   257       case True
       
   258       with Node 2 show ?thesis
       
   259       proof(cases "height (insert x l) = height r + 2")
       
   260         case False with Node 2 `x < a` show ?thesis by (auto simp: height_node_bal_l2)
       
   261       next
       
   262         case True 
       
   263         hence "(height (node_bal_l (insert x l) a r) = height r + 2) \<or>
       
   264           (height (node_bal_l (insert x l) a r) = height r + 3)" (is "?A \<or> ?B")
       
   265           using Node 2 by (intro height_node_bal_l) simp_all
       
   266         thus ?thesis
       
   267         proof
       
   268           assume ?A
       
   269           with 2 `x < a` show ?thesis by (auto)
       
   270         next
       
   271           assume ?B
       
   272           with True 1 Node(2) `x < a` show ?thesis by (simp) arith
       
   273         qed
       
   274       qed
       
   275     next
       
   276       case False
       
   277       with Node 2 show ?thesis 
       
   278       proof(cases "height (insert x r) = height l + 2")
       
   279         case False
       
   280         with Node 2 `\<not>x < a` show ?thesis by (auto simp: height_node_bal_r2)
       
   281       next
       
   282         case True 
       
   283         hence "(height (node_bal_r l a (insert x r)) = height l + 2) \<or>
       
   284           (height (node_bal_r l a (insert x r)) = height l + 3)"  (is "?A \<or> ?B")
       
   285           using Node 2 by (intro height_node_bal_r) simp_all
       
   286         thus ?thesis 
       
   287         proof
       
   288           assume ?A
       
   289           with 2 `\<not>x < a` show ?thesis by (auto)
       
   290         next
       
   291           assume ?B
       
   292           with True 1 Node(4) `\<not>x < a` show ?thesis by (simp) arith
       
   293         qed
       
   294       qed
       
   295     qed
       
   296   qed
       
   297 qed simp_all
       
   298 
       
   299 
       
   300 subsubsection {* Deletion maintains AVL balance *}
       
   301 
       
   302 lemma avl_delete_max:
       
   303   assumes "avl x" and "x \<noteq> Leaf"
       
   304   shows "avl (fst (delete_max x))" "height x = height(fst (delete_max x)) \<or>
       
   305          height x = height(fst (delete_max x)) + 1"
       
   306 using assms
       
   307 proof (induct x rule: delete_max_induct)
       
   308   case (Node h l a rh rl b rr)
       
   309   case 1
       
   310   with Node have "avl l" "avl (fst (delete_max (Node rh rl b rr)))" by auto
       
   311   with 1 Node have "avl (node_bal_l l a (fst (delete_max (Node rh rl b rr))))"
       
   312     by (intro avl_node_bal_l) fastforce+
       
   313   thus ?case 
       
   314     by (auto simp: height_node_bal_l height_node_bal_l2
       
   315       linorder_class.max.absorb1 linorder_class.max.absorb2
       
   316       split:prod.split)
       
   317 next
       
   318   case (Node h l a rh rl b rr)
       
   319   case 2
       
   320   let ?r = "Node rh rl b rr"
       
   321   let ?r' = "fst (delete_max ?r)"
       
   322   from `avl x` Node 2 have "avl l" and "avl ?r" by simp_all
       
   323   thus ?case using Node 2 height_node_bal_l[of l ?r' a] height_node_bal_l2[of l ?r' a]
       
   324     apply (auto split:prod.splits simp del:avl.simps) by arith+
       
   325 qed auto
       
   326 
       
   327 lemma avl_delete_root:
       
   328   assumes "avl t" and "t \<noteq> Leaf"
       
   329   shows "avl(delete_root t)" 
       
   330 using assms
       
   331 proof (cases t rule:delete_root_cases)
       
   332   case (Node_Node h lh ll ln lr n rh rl rn rr)
       
   333   let ?l = "Node lh ll ln lr"
       
   334   let ?r = "Node rh rl rn rr"
       
   335   let ?l' = "fst (delete_max ?l)"
       
   336   from `avl t` and Node_Node have "avl ?r" by simp
       
   337   from `avl t` and Node_Node have "avl ?l" by simp
       
   338   hence "avl(?l')" "height ?l = height(?l') \<or>
       
   339          height ?l = height(?l') + 1" by (rule avl_delete_max,simp)+
       
   340   with `avl t` Node_Node have "height ?l' = height ?r \<or> height ?l' = height ?r + 1
       
   341             \<or> height ?r = height ?l' + 1 \<or> height ?r = height ?l' + 2" by fastforce
       
   342   with `avl ?l'` `avl ?r` have "avl(node_bal_r ?l' (snd(delete_max ?l)) ?r)"
       
   343     by (rule avl_node_bal_r)
       
   344   with Node_Node show ?thesis by (auto split:prod.splits)
       
   345 qed simp_all
       
   346 
       
   347 lemma height_delete_root:
       
   348   assumes "avl t" and "t \<noteq> Leaf" 
       
   349   shows "height t = height(delete_root t) \<or> height t = height(delete_root t) + 1"
       
   350 using assms
       
   351 proof (cases t rule: delete_root_cases)
       
   352   case (Node_Node h lh ll ln lr n rh rl rn rr)
       
   353   let ?l = "Node lh ll ln lr"
       
   354   let ?r = "Node rh rl rn rr"
       
   355   let ?l' = "fst (delete_max ?l)"
       
   356   let ?t' = "node_bal_r ?l' (snd(delete_max ?l)) ?r"
       
   357   from `avl t` and Node_Node have "avl ?r" by simp
       
   358   from `avl t` and Node_Node have "avl ?l" by simp
       
   359   hence "avl(?l')"  by (rule avl_delete_max,simp)
       
   360   have l'_height: "height ?l = height ?l' \<or> height ?l = height ?l' + 1" using `avl ?l` by (intro avl_delete_max) auto
       
   361   have t_height: "height t = 1 + max (height ?l) (height ?r)" using `avl t` Node_Node by simp
       
   362   have "height t = height ?t' \<or> height t = height ?t' + 1" using  `avl t` Node_Node
       
   363   proof(cases "height ?r = height ?l' + 2")
       
   364     case False
       
   365     show ?thesis using l'_height t_height False by (subst  height_node_bal_r2[OF `avl ?l'` `avl ?r` False])+ arith
       
   366   next
       
   367     case True
       
   368     show ?thesis
       
   369     proof(cases rule: disjE[OF height_node_bal_r[OF True `avl ?l'` `avl ?r`, of "snd (delete_max ?l)"]])
       
   370       case 1
       
   371       thus ?thesis using l'_height t_height True by arith
       
   372     next
       
   373       case 2
       
   374       thus ?thesis using l'_height t_height True by arith
       
   375     qed
       
   376   qed
       
   377   thus ?thesis using Node_Node by (auto split:prod.splits)
       
   378 qed simp_all
       
   379 
       
   380 text{* Deletion maintains the AVL property: *}
       
   381 
       
   382 theorem avl_delete_aux:
       
   383   assumes "avl t" 
       
   384   shows "avl(delete x t)" and "height t = (height (delete x t)) \<or> height t = height (delete x t) + 1"
       
   385 using assms
       
   386 proof (induct t)
       
   387   case (Node h l n r)
       
   388   case 1
       
   389   with Node show ?case
       
   390   proof(cases "x = n")
       
   391     case True
       
   392     with Node 1 show ?thesis by (auto simp:avl_delete_root)
       
   393   next
       
   394     case False
       
   395     with Node 1 show ?thesis 
       
   396     proof(cases "x<n")
       
   397       case True
       
   398       with Node 1 show ?thesis by (auto simp add:avl_node_bal_r)
       
   399     next
       
   400       case False
       
   401       with Node 1 `x\<noteq>n` show ?thesis by (auto simp add:avl_node_bal_l)
       
   402     qed
       
   403   qed
       
   404   case 2
       
   405   with Node show ?case
       
   406   proof(cases "x = n")
       
   407     case True
       
   408     with 1 have "height (Node h l n r) = height(delete_root (Node h l n r))
       
   409       \<or> height (Node h l n r) = height(delete_root (Node h l n r)) + 1"
       
   410       by (subst height_delete_root,simp_all)
       
   411     with True show ?thesis by simp
       
   412   next
       
   413     case False
       
   414     with Node 1 show ?thesis 
       
   415      proof(cases "x<n")
       
   416       case True
       
   417       show ?thesis
       
   418       proof(cases "height r = height (delete x l) + 2")
       
   419         case False with Node 1 `x < n` show ?thesis by(auto simp: node_bal_r_def)
       
   420       next
       
   421         case True 
       
   422         hence "(height (node_bal_r (delete x l) n r) = height (delete x l) + 2) \<or>
       
   423           height (node_bal_r (delete x l) n r) = height (delete x l) + 3" (is "?A \<or> ?B")
       
   424           using Node 2 by (intro height_node_bal_r) auto
       
   425         thus ?thesis 
       
   426         proof
       
   427           assume ?A
       
   428           with `x < n` Node 2 show ?thesis by(auto simp: node_bal_r_def)
       
   429         next
       
   430           assume ?B
       
   431           with `x < n` Node 2 show ?thesis by(auto simp: node_bal_r_def)
       
   432         qed
       
   433       qed
       
   434     next
       
   435       case False
       
   436       show ?thesis
       
   437       proof(cases "height l = height (delete x r) + 2")
       
   438         case False with Node 1 `\<not>x < n` `x \<noteq> n` show ?thesis by(auto simp: node_bal_l_def)
       
   439       next
       
   440         case True 
       
   441         hence "(height (node_bal_l l n (delete x r)) = height (delete x r) + 2) \<or>
       
   442           height (node_bal_l l n (delete x r)) = height (delete x r) + 3" (is "?A \<or> ?B")
       
   443           using Node 2 by (intro height_node_bal_l) auto
       
   444         thus ?thesis 
       
   445         proof
       
   446           assume ?A
       
   447           with `\<not>x < n` `x \<noteq> n` Node 2 show ?thesis by(auto simp: node_bal_l_def)
       
   448         next
       
   449           assume ?B
       
   450           with `\<not>x < n` `x \<noteq> n` Node 2 show ?thesis by(auto simp: node_bal_l_def)
       
   451         qed
       
   452       qed
       
   453     qed
       
   454   qed
       
   455 qed simp_all
       
   456 
       
   457 end