added bottom-up merge sort
authornipkow
Sun, 15 Apr 2018 17:48:07 +0200
changeset 67983 487685540a51
parent 67982 7643b005b29a
child 67985 7811748de271
added bottom-up merge sort
src/HOL/Data_Structures/Sorting.thy
--- a/src/HOL/Data_Structures/Sorting.thy	Sun Apr 15 13:57:00 2018 +0100
+++ b/src/HOL/Data_Structures/Sorting.thy	Sun Apr 15 17:48:07 2018 +0200
@@ -20,6 +20,9 @@
   "sorted (xs@ys) = (sorted xs & sorted ys & (\<forall>x \<in> set xs. \<forall>y \<in> set ys. x\<le>y))"
 by (induct xs) (auto)
 
+lemma sorted01: "length xs \<le> 1 \<Longrightarrow> sorted xs"
+by(auto simp: le_Suc_eq length_Suc_conv)
+
 
 subsection "Insertion Sort"
 
@@ -134,7 +137,57 @@
 
 declare msort.simps [simp del]
 
-(* We count the number of comparisons between list elements only *)
+subsubsection "Functional Correctness"
+
+lemma mset_merge: "mset(merge xs ys) = mset xs + mset ys"
+by(induction xs ys rule: merge.induct) auto
+
+lemma "mset (msort xs) = mset xs"
+proof(induction xs rule: msort.induct)
+  case (1 xs)
+  let ?n = "length xs"
+  let ?xs1 = "take (?n div 2) xs"
+  let ?xs2 = "drop (?n div 2) xs"
+  show ?case
+  proof cases
+    assume "?n \<le> 1"
+    thus ?thesis by(simp add: msort.simps[of xs])
+  next
+    assume "\<not> ?n \<le> 1"
+    hence "mset (msort xs) = mset (msort ?xs1) + mset (msort ?xs2)"
+      by(simp add: msort.simps[of xs] mset_merge)
+    also have "\<dots> = mset ?xs1 + mset ?xs2"
+      using \<open>\<not> ?n \<le> 1\<close> by(simp add: "1.IH")
+    also have "\<dots> = mset (?xs1 @ ?xs2)" by (simp del: append_take_drop_id)
+    also have "\<dots> = mset xs" by simp
+    finally show ?thesis .
+  qed
+qed
+
+lemma set_merge: "set(merge xs ys) = set xs \<union> set ys"
+by(induction xs ys rule: merge.induct) (auto)
+
+lemma sorted_merge: "sorted (merge xs ys) \<longleftrightarrow> (sorted xs \<and> sorted ys)"
+by(induction xs ys rule: merge.induct) (auto simp: set_merge)
+
+lemma "sorted (msort xs)"
+proof(induction xs rule: msort.induct)
+  case (1 xs)
+  let ?n = "length xs"
+  show ?case
+  proof cases
+    assume "?n \<le> 1"
+    thus ?thesis by(simp add: msort.simps[of xs] sorted01)
+  next
+    assume "\<not> ?n \<le> 1"
+    thus ?thesis using "1.IH"
+      by(simp add: sorted_merge  msort.simps[of xs] mset_merge)
+  qed
+qed
+
+subsection "Time Complexity"
+
+text \<open>We only count the number of comparisons between list elements.\<close>
 
 fun c_merge :: "'a::linorder list \<Rightarrow> 'a list \<Rightarrow> nat" where
 "c_merge [] ys = 0" |
@@ -204,4 +257,110 @@
 using c_msort_le[of xs k] apply (simp add: log_nat_power algebra_simps)
 by (metis (mono_tags) numeral_power_eq_of_nat_cancel_iff of_nat_le_iff of_nat_mult)
 
+
+subsection "Bottom-Up Merge Sort"
+
+(* Exercise: make tail recursive *)
+fun merge_adj :: "('a::linorder) list list \<Rightarrow> 'a list list" where
+"merge_adj [] = []" |
+"merge_adj [xs] = [xs]" |
+"merge_adj (xs # ys # zss) = merge xs ys # merge_adj zss"
+
+text \<open>For the termination proof of \<open>merge_all\<close> below.\<close>
+lemma length_merge_adjacent[simp]: "length (merge_adj xs) = (length xs + 1) div 2"
+by (induction xs rule: merge_adj.induct) auto
+
+fun merge_all :: "('a::linorder) list list \<Rightarrow> 'a list" where
+"merge_all [] = undefined" |
+"merge_all [xs] = xs" |
+"merge_all xss = merge_all (merge_adj xss)"
+
+definition msort_bu :: "('a::linorder) list \<Rightarrow> 'a list" where
+"msort_bu xs = (if xs = [] then [] else merge_all (map (\<lambda>x. [x]) xs))"
+
+subsubsection "Functional Correctness"
+
+lemma mset_merge_adj:
+  "\<Union># image_mset mset (mset (merge_adj xss)) = \<Union># image_mset mset (mset xss)"
+by(induction xss rule: merge_adj.induct) (auto simp: mset_merge)
+
+lemma msec_merge_all:
+  "xss \<noteq> [] \<Longrightarrow> mset (merge_all xss) = (\<Union># (mset (map mset xss)))"
+by(induction xss rule: merge_all.induct) (auto simp: mset_merge mset_merge_adj)
+
+lemma sorted_merge_adj:
+  "\<forall>xs \<in> set xss. sorted xs \<Longrightarrow> \<forall>xs \<in> set (merge_adj xss). sorted xs"
+by(induction xss rule: merge_adj.induct) (auto simp: sorted_merge)
+
+lemma sorted_merge_all:
+  "\<forall>xs \<in> set xss. sorted xs \<Longrightarrow> xss \<noteq> [] \<Longrightarrow> sorted (merge_all xss)"
+apply(induction xss rule: merge_all.induct)
+using [[simp_depth_limit=3]] by (auto simp add: sorted_merge_adj)
+
+lemma sorted_msort_bu: "sorted (msort_bu xs)"
+by(simp add: msort_bu_def sorted_merge_all)
+
+lemma mset_msort: "mset (msort_bu xs) = mset xs"
+by(simp add: msort_bu_def msec_merge_all comp_def)
+
+subsection "Time Complexity"
+
+fun c_merge_adj :: "('a::linorder) list list \<Rightarrow> real" where
+"c_merge_adj [] = 0" |
+"c_merge_adj [x] = 0" |
+"c_merge_adj (x # y # zs) = c_merge x y + c_merge_adj zs"
+
+fun c_merge_all :: "('a::linorder) list list \<Rightarrow> real" where
+"c_merge_all [] = 0" |
+"c_merge_all [x] = 0" |
+"c_merge_all xs = c_merge_adj xs + c_merge_all (merge_adj xs)"
+
+definition c_msort_bu :: "('a::linorder) list \<Rightarrow> real" where
+"c_msort_bu xs = (if xs = [] then 0 else c_merge_all (map (\<lambda>x. [x]) xs))"
+
+lemma length_merge_adj:
+  "\<lbrakk> even(length xs); \<forall>x \<in> set xs. length x = m \<rbrakk> \<Longrightarrow> \<forall>x \<in> set (merge_adj xs). length x = 2*m"
+by(induction xs rule: merge_adj.induct) (auto simp: length_merge)
+
+lemma c_merge_adj: "\<forall>x \<in> set xs. length x = m \<Longrightarrow> c_merge_adj xs \<le> m * length xs"
+proof(induction xs rule: c_merge_adj.induct)
+  case 1 thus ?case by simp
+next
+  case 2 thus ?case by simp
+next
+  case (3 x y) thus ?case using c_merge_ub[of x y] by (simp add: algebra_simps)
+qed
+
+lemma c_merge_all: "\<lbrakk> \<forall>x \<in> set xs. length x = m; length xs = 2^k \<rbrakk>
+  \<Longrightarrow> c_merge_all xs \<le> m * k * 2^k"
+proof (induction xs arbitrary: k m rule: c_merge_all.induct)
+  case 1 thus ?case by simp
+next
+  case (2 x)
+  then show ?case by (simp)
+next
+  case (3 x y xs)
+  let ?xs = "x # y # xs"
+  let ?xs2 = "merge_adj ?xs"
+  obtain k' where k': "k = Suc k'" using "3.prems"(2)
+    by (metis length_Cons nat.inject nat_power_eq_Suc_0_iff nat.exhaust)
+  have "even (length xs)" using "3.prems"(2) even_Suc_Suc_iff by fastforce
+  from "3.prems"(1) length_merge_adj[OF this]
+  have 2: "\<forall>x \<in> set(merge_adj ?xs). length x = 2*m" by(auto simp: length_merge)
+  have 3: "length ?xs2 = 2 ^ k'" using "3.prems"(2) k' by auto
+  have 4: "length ?xs div 2 = 2 ^ k'"
+    using "3.prems"(2) k' by(simp add: power_eq_if[of 2 k] split: if_splits)
+  have "c_merge_all ?xs = c_merge_adj ?xs + c_merge_all ?xs2" by simp
+  also have "\<dots> \<le> m * 2^k + c_merge_all ?xs2"
+    using "3.prems"(2) c_merge_adj[OF "3.prems"(1)] by (auto simp: algebra_simps)
+  also have "\<dots> \<le> m * 2^k + (2*m) * k' * 2^k'"
+    using "3.IH"[OF 2 3] by simp
+  also have "\<dots> = m * k * 2^k"
+    using k' by (simp add: algebra_simps)
+  finally show ?case .
+qed
+
+corollary c_msort_bu: "length xs = 2 ^ k \<Longrightarrow> c_msort_bu xs \<le> k * 2 ^ k"
+using c_merge_all[of "map (\<lambda>x. [x]) xs" 1] by (simp add: c_msort_bu_def)
+
 end