--- a/src/HOL/Data_Structures/Sorting.thy Sun Apr 15 13:57:00 2018 +0100
+++ b/src/HOL/Data_Structures/Sorting.thy Sun Apr 15 17:48:07 2018 +0200
@@ -20,6 +20,9 @@
"sorted (xs@ys) = (sorted xs & sorted ys & (\<forall>x \<in> set xs. \<forall>y \<in> set ys. x\<le>y))"
by (induct xs) (auto)
+lemma sorted01: "length xs \<le> 1 \<Longrightarrow> sorted xs"
+by(auto simp: le_Suc_eq length_Suc_conv)
+
subsection "Insertion Sort"
@@ -134,7 +137,57 @@
declare msort.simps [simp del]
-(* We count the number of comparisons between list elements only *)
+subsubsection "Functional Correctness"
+
+lemma mset_merge: "mset(merge xs ys) = mset xs + mset ys"
+by(induction xs ys rule: merge.induct) auto
+
+lemma "mset (msort xs) = mset xs"
+proof(induction xs rule: msort.induct)
+ case (1 xs)
+ let ?n = "length xs"
+ let ?xs1 = "take (?n div 2) xs"
+ let ?xs2 = "drop (?n div 2) xs"
+ show ?case
+ proof cases
+ assume "?n \<le> 1"
+ thus ?thesis by(simp add: msort.simps[of xs])
+ next
+ assume "\<not> ?n \<le> 1"
+ hence "mset (msort xs) = mset (msort ?xs1) + mset (msort ?xs2)"
+ by(simp add: msort.simps[of xs] mset_merge)
+ also have "\<dots> = mset ?xs1 + mset ?xs2"
+ using \<open>\<not> ?n \<le> 1\<close> by(simp add: "1.IH")
+ also have "\<dots> = mset (?xs1 @ ?xs2)" by (simp del: append_take_drop_id)
+ also have "\<dots> = mset xs" by simp
+ finally show ?thesis .
+ qed
+qed
+
+lemma set_merge: "set(merge xs ys) = set xs \<union> set ys"
+by(induction xs ys rule: merge.induct) (auto)
+
+lemma sorted_merge: "sorted (merge xs ys) \<longleftrightarrow> (sorted xs \<and> sorted ys)"
+by(induction xs ys rule: merge.induct) (auto simp: set_merge)
+
+lemma "sorted (msort xs)"
+proof(induction xs rule: msort.induct)
+ case (1 xs)
+ let ?n = "length xs"
+ show ?case
+ proof cases
+ assume "?n \<le> 1"
+ thus ?thesis by(simp add: msort.simps[of xs] sorted01)
+ next
+ assume "\<not> ?n \<le> 1"
+ thus ?thesis using "1.IH"
+ by(simp add: sorted_merge msort.simps[of xs] mset_merge)
+ qed
+qed
+
+subsection "Time Complexity"
+
+text \<open>We only count the number of comparisons between list elements.\<close>
fun c_merge :: "'a::linorder list \<Rightarrow> 'a list \<Rightarrow> nat" where
"c_merge [] ys = 0" |
@@ -204,4 +257,110 @@
using c_msort_le[of xs k] apply (simp add: log_nat_power algebra_simps)
by (metis (mono_tags) numeral_power_eq_of_nat_cancel_iff of_nat_le_iff of_nat_mult)
+
+subsection "Bottom-Up Merge Sort"
+
+(* Exercise: make tail recursive *)
+fun merge_adj :: "('a::linorder) list list \<Rightarrow> 'a list list" where
+"merge_adj [] = []" |
+"merge_adj [xs] = [xs]" |
+"merge_adj (xs # ys # zss) = merge xs ys # merge_adj zss"
+
+text \<open>For the termination proof of \<open>merge_all\<close> below.\<close>
+lemma length_merge_adjacent[simp]: "length (merge_adj xs) = (length xs + 1) div 2"
+by (induction xs rule: merge_adj.induct) auto
+
+fun merge_all :: "('a::linorder) list list \<Rightarrow> 'a list" where
+"merge_all [] = undefined" |
+"merge_all [xs] = xs" |
+"merge_all xss = merge_all (merge_adj xss)"
+
+definition msort_bu :: "('a::linorder) list \<Rightarrow> 'a list" where
+"msort_bu xs = (if xs = [] then [] else merge_all (map (\<lambda>x. [x]) xs))"
+
+subsubsection "Functional Correctness"
+
+lemma mset_merge_adj:
+ "\<Union># image_mset mset (mset (merge_adj xss)) = \<Union># image_mset mset (mset xss)"
+by(induction xss rule: merge_adj.induct) (auto simp: mset_merge)
+
+lemma msec_merge_all:
+ "xss \<noteq> [] \<Longrightarrow> mset (merge_all xss) = (\<Union># (mset (map mset xss)))"
+by(induction xss rule: merge_all.induct) (auto simp: mset_merge mset_merge_adj)
+
+lemma sorted_merge_adj:
+ "\<forall>xs \<in> set xss. sorted xs \<Longrightarrow> \<forall>xs \<in> set (merge_adj xss). sorted xs"
+by(induction xss rule: merge_adj.induct) (auto simp: sorted_merge)
+
+lemma sorted_merge_all:
+ "\<forall>xs \<in> set xss. sorted xs \<Longrightarrow> xss \<noteq> [] \<Longrightarrow> sorted (merge_all xss)"
+apply(induction xss rule: merge_all.induct)
+using [[simp_depth_limit=3]] by (auto simp add: sorted_merge_adj)
+
+lemma sorted_msort_bu: "sorted (msort_bu xs)"
+by(simp add: msort_bu_def sorted_merge_all)
+
+lemma mset_msort: "mset (msort_bu xs) = mset xs"
+by(simp add: msort_bu_def msec_merge_all comp_def)
+
+subsection "Time Complexity"
+
+fun c_merge_adj :: "('a::linorder) list list \<Rightarrow> real" where
+"c_merge_adj [] = 0" |
+"c_merge_adj [x] = 0" |
+"c_merge_adj (x # y # zs) = c_merge x y + c_merge_adj zs"
+
+fun c_merge_all :: "('a::linorder) list list \<Rightarrow> real" where
+"c_merge_all [] = 0" |
+"c_merge_all [x] = 0" |
+"c_merge_all xs = c_merge_adj xs + c_merge_all (merge_adj xs)"
+
+definition c_msort_bu :: "('a::linorder) list \<Rightarrow> real" where
+"c_msort_bu xs = (if xs = [] then 0 else c_merge_all (map (\<lambda>x. [x]) xs))"
+
+lemma length_merge_adj:
+ "\<lbrakk> even(length xs); \<forall>x \<in> set xs. length x = m \<rbrakk> \<Longrightarrow> \<forall>x \<in> set (merge_adj xs). length x = 2*m"
+by(induction xs rule: merge_adj.induct) (auto simp: length_merge)
+
+lemma c_merge_adj: "\<forall>x \<in> set xs. length x = m \<Longrightarrow> c_merge_adj xs \<le> m * length xs"
+proof(induction xs rule: c_merge_adj.induct)
+ case 1 thus ?case by simp
+next
+ case 2 thus ?case by simp
+next
+ case (3 x y) thus ?case using c_merge_ub[of x y] by (simp add: algebra_simps)
+qed
+
+lemma c_merge_all: "\<lbrakk> \<forall>x \<in> set xs. length x = m; length xs = 2^k \<rbrakk>
+ \<Longrightarrow> c_merge_all xs \<le> m * k * 2^k"
+proof (induction xs arbitrary: k m rule: c_merge_all.induct)
+ case 1 thus ?case by simp
+next
+ case (2 x)
+ then show ?case by (simp)
+next
+ case (3 x y xs)
+ let ?xs = "x # y # xs"
+ let ?xs2 = "merge_adj ?xs"
+ obtain k' where k': "k = Suc k'" using "3.prems"(2)
+ by (metis length_Cons nat.inject nat_power_eq_Suc_0_iff nat.exhaust)
+ have "even (length xs)" using "3.prems"(2) even_Suc_Suc_iff by fastforce
+ from "3.prems"(1) length_merge_adj[OF this]
+ have 2: "\<forall>x \<in> set(merge_adj ?xs). length x = 2*m" by(auto simp: length_merge)
+ have 3: "length ?xs2 = 2 ^ k'" using "3.prems"(2) k' by auto
+ have 4: "length ?xs div 2 = 2 ^ k'"
+ using "3.prems"(2) k' by(simp add: power_eq_if[of 2 k] split: if_splits)
+ have "c_merge_all ?xs = c_merge_adj ?xs + c_merge_all ?xs2" by simp
+ also have "\<dots> \<le> m * 2^k + c_merge_all ?xs2"
+ using "3.prems"(2) c_merge_adj[OF "3.prems"(1)] by (auto simp: algebra_simps)
+ also have "\<dots> \<le> m * 2^k + (2*m) * k' * 2^k'"
+ using "3.IH"[OF 2 3] by simp
+ also have "\<dots> = m * k * 2^k"
+ using k' by (simp add: algebra_simps)
+ finally show ?case .
+qed
+
+corollary c_msort_bu: "length xs = 2 ^ k \<Longrightarrow> c_msort_bu xs \<le> k * 2 ^ k"
+using c_merge_all[of "map (\<lambda>x. [x]) xs" 1] by (simp add: c_msort_bu_def)
+
end