generic merge sort
authorhaftmann
Wed Nov 07 11:08:10 2018 +0000 (6 months ago)
changeset 692501011f0b46af7
parent 69249 27423819534c
child 69251 d240598e8637
generic merge sort
src/HOL/Library/Sorting_Algorithms.thy
     1.1 --- a/src/HOL/Library/Sorting_Algorithms.thy	Tue Nov 06 15:06:30 2018 +0100
     1.2 +++ b/src/HOL/Library/Sorting_Algorithms.thy	Wed Nov 07 11:08:10 2018 +0000
     1.3 @@ -392,9 +392,9 @@
     1.4  lemma sort_by_quicksort_rec:
     1.5    "sort cmp xs = sort cmp [x\<leftarrow>xs. compare cmp x (xs ! (length xs div 2)) = Less]
     1.6      @ stable_segment cmp (xs ! (length xs div 2)) xs
     1.7 -    @ sort cmp [x\<leftarrow>xs. compare cmp x (xs ! (length xs div 2)) = Greater]" (is "sort cmp ?lhs = ?rhs")
     1.8 +    @ sort cmp [x\<leftarrow>xs. compare cmp x (xs ! (length xs div 2)) = Greater]" (is "_ = ?rhs")
     1.9  proof (rule sort_eqI)
    1.10 -  show "mset ?lhs = mset ?rhs"
    1.11 +  show "mset xs = mset ?rhs"
    1.12      by (rule multiset_eqI) (auto simp add: compare.sym intro: comp.exhaust)
    1.13  next
    1.14    show "sorted cmp ?rhs"
    1.15 @@ -402,7 +402,6 @@
    1.16  next
    1.17    let ?pivot = "xs ! (length xs div 2)"
    1.18    fix l
    1.19 -  assume "l \<in> set xs"
    1.20    have "compare cmp x ?pivot = comp \<and> compare cmp l x = Equiv
    1.21      \<longleftrightarrow> compare cmp l ?pivot = comp \<and> compare cmp l x = Equiv" for x comp
    1.22    proof -
    1.23 @@ -411,7 +410,7 @@
    1.24        using that by (simp add: compare.equiv_subst_left compare.sym)
    1.25      then show ?thesis by blast
    1.26    qed
    1.27 -  then show "stable_segment cmp l ?lhs = stable_segment cmp l ?rhs"
    1.28 +  then show "stable_segment cmp l xs = stable_segment cmp l ?rhs"
    1.29      by (simp add: stable_sort compare.sym [of _ ?pivot])
    1.30        (cases "compare cmp l ?pivot", simp_all)
    1.31  qed
    1.32 @@ -444,10 +443,8 @@
    1.33          in quicksort cmp lts @ eqs @ quicksort cmp gts)"
    1.34  proof (cases "length xs \<ge> 3")
    1.35    case False
    1.36 -  then have "length xs \<le> 2"
    1.37 -    by simp
    1.38 -  then have "length xs = 0 \<or> length xs = 1 \<or> length xs = 2"
    1.39 -    using le_neq_trans less_2_cases by auto
    1.40 +  then have "length xs \<in> {0, 1, 2}"
    1.41 +    by (auto simp add: not_le le_less less_antisym)
    1.42    then consider "xs = []" | x where "xs = [x]" | x y where "xs = [x, y]"
    1.43      by (auto simp add: length_Suc_conv numeral_2_eq_2)
    1.44    then show ?thesis
    1.45 @@ -466,9 +463,208 @@
    1.46  
    1.47  end
    1.48  
    1.49 -text \<open>Evaluation example\<close>
    1.50 +
    1.51 +subsection \<open>Mergesort\<close>
    1.52 +
    1.53 +definition mergesort :: "'a comparator \<Rightarrow> 'a list \<Rightarrow> 'a list"
    1.54 +  where mergesort_is_sort [simp]: "mergesort = sort"
    1.55 +
    1.56 +lemma sort_by_mergesort:
    1.57 +  "sort = mergesort"
    1.58 +  by simp
    1.59 +
    1.60 +context
    1.61 +  fixes cmp :: "'a comparator"
    1.62 +begin
    1.63 +
    1.64 +qualified function merge :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list"
    1.65 +  where "merge [] ys = ys"
    1.66 +  | "merge xs [] = xs"
    1.67 +  | "merge (x # xs) (y # ys) = (if compare cmp x y = Greater
    1.68 +      then y # merge (x # xs) ys else x # merge xs (y # ys))"
    1.69 +  by pat_completeness auto
    1.70 +
    1.71 +qualified termination by lexicographic_order
    1.72 +
    1.73 +lemma mset_merge:
    1.74 +  "mset (merge xs ys) = mset xs + mset ys"
    1.75 +  by (induction xs ys rule: merge.induct) simp_all
    1.76 +
    1.77 +lemma merge_eq_Cons_imp:
    1.78 +  "xs \<noteq> [] \<and> z = hd xs \<or> ys \<noteq> [] \<and> z = hd ys"
    1.79 +    if "merge xs ys = z # zs"
    1.80 +  using that by (induction xs ys rule: merge.induct) (auto split: if_splits)
    1.81 +
    1.82 +lemma filter_merge:
    1.83 +  "filter P (merge xs ys) = merge (filter P xs) (filter P ys)"
    1.84 +    if "sorted cmp xs" and "sorted cmp ys"
    1.85 +using that proof (induction xs ys rule: merge.induct)
    1.86 +  case (1 ys)
    1.87 +  then show ?case
    1.88 +    by simp
    1.89 +next
    1.90 +  case (2 xs)
    1.91 +  then show ?case
    1.92 +    by simp
    1.93 +next
    1.94 +  case (3 x xs y ys)
    1.95 +  show ?case
    1.96 +  proof (cases "compare cmp x y = Greater")
    1.97 +    case True
    1.98 +    with 3 have hyp: "filter P (merge (x # xs) ys) =
    1.99 +      merge (filter P (x # xs)) (filter P ys)"
   1.100 +      by (simp add: sorted_Cons_imp_sorted)
   1.101 +    show ?thesis
   1.102 +    proof (cases "\<not> P x \<and> P y")
   1.103 +      case False
   1.104 +      with \<open>compare cmp x y = Greater\<close> show ?thesis
   1.105 +        by (auto simp add: hyp)
   1.106 +    next
   1.107 +      case True
   1.108 +      from \<open>compare cmp x y = Greater\<close> "3.prems"
   1.109 +      have *: "compare cmp z y = Greater" if "z \<in> set (filter P xs)" for z
   1.110 +        using that by (auto dest: compare.trans_not_greater sorted_Cons_imp_not_less)
   1.111 +      from \<open>compare cmp x y = Greater\<close> show ?thesis
   1.112 +        by (cases "filter P xs") (simp_all add: hyp *)
   1.113 +    qed
   1.114 +  next
   1.115 +    case False
   1.116 +    with 3 have hyp: "filter P (merge xs (y # ys)) =
   1.117 +      merge (filter P xs) (filter P (y # ys))"
   1.118 +      by (simp add: sorted_Cons_imp_sorted)
   1.119 +    show ?thesis
   1.120 +    proof (cases "P x \<and> \<not> P y")
   1.121 +      case False
   1.122 +      with \<open>compare cmp x y \<noteq> Greater\<close> show ?thesis
   1.123 +        by (auto simp add: hyp)
   1.124 +    next
   1.125 +      case True
   1.126 +      from \<open>compare cmp x y \<noteq> Greater\<close> "3.prems"
   1.127 +      have *: "compare cmp x z \<noteq> Greater" if "z \<in> set (filter P ys)" for z
   1.128 +        using that by (auto dest: compare.trans_not_greater sorted_Cons_imp_not_less)
   1.129 +      from \<open>compare cmp x y \<noteq> Greater\<close> show ?thesis
   1.130 +        by (cases "filter P ys") (simp_all add: hyp *)
   1.131 +    qed
   1.132 +  qed
   1.133 +qed
   1.134  
   1.135 -value "let cmp = key abs (reversed default)
   1.136 -  in quicksort cmp [65, 1705, -2322, 734, 4, (-17::int)]"
   1.137 +lemma sorted_merge:
   1.138 +  "sorted cmp (merge xs ys)" if "sorted cmp xs" and "sorted cmp ys"
   1.139 +using that proof (induction xs ys rule: merge.induct)
   1.140 +  case (1 ys)
   1.141 +  then show ?case
   1.142 +    by simp
   1.143 +next
   1.144 +  case (2 xs)
   1.145 +  then show ?case
   1.146 +    by simp
   1.147 +next
   1.148 +  case (3 x xs y ys)
   1.149 +  show ?case
   1.150 +  proof (cases "compare cmp x y = Greater")
   1.151 +    case True
   1.152 +    with 3 have "sorted cmp (merge (x # xs) ys)"
   1.153 +      by (simp add: sorted_Cons_imp_sorted)
   1.154 +    then have "sorted cmp (y # merge (x # xs) ys)"
   1.155 +    proof (rule sorted_ConsI)
   1.156 +      fix z zs
   1.157 +      assume "merge (x # xs) ys = z # zs"
   1.158 +      with 3(4) True show "compare cmp y z \<noteq> Greater"
   1.159 +        by (clarsimp simp add: sorted_Cons_imp_sorted dest!: merge_eq_Cons_imp)
   1.160 +          (auto simp add: compare.asym_greater sorted_Cons_imp_not_less)
   1.161 +    qed
   1.162 +    with True show ?thesis
   1.163 +      by simp
   1.164 +  next
   1.165 +    case False
   1.166 +    with 3 have "sorted cmp (merge xs (y # ys))"
   1.167 +      by (simp add: sorted_Cons_imp_sorted)
   1.168 +    then have "sorted cmp (x # merge xs (y # ys))"
   1.169 +    proof (rule sorted_ConsI)
   1.170 +      fix z zs
   1.171 +      assume "merge xs (y # ys) = z # zs"
   1.172 +      with 3(3) False show "compare cmp x z \<noteq> Greater"
   1.173 +        by (clarsimp simp add: sorted_Cons_imp_sorted dest!: merge_eq_Cons_imp)
   1.174 +          (auto simp add: compare.asym_greater sorted_Cons_imp_not_less)
   1.175 +    qed
   1.176 +    with False show ?thesis
   1.177 +      by simp
   1.178 +  qed
   1.179 +qed
   1.180 +
   1.181 +lemma merge_eq_appendI:
   1.182 +  "merge xs ys = xs @ ys"
   1.183 +    if "\<And>x y. x \<in> set xs \<Longrightarrow> y \<in> set ys \<Longrightarrow> compare cmp x y \<noteq> Greater"
   1.184 +  using that by (induction xs ys rule: merge.induct) simp_all
   1.185 +
   1.186 +lemma merge_stable_segments:
   1.187 +  "merge (stable_segment cmp l xs) (stable_segment cmp l ys) =
   1.188 +     stable_segment cmp l xs @ stable_segment cmp l ys"
   1.189 +  by (rule merge_eq_appendI) (auto dest: compare.trans_equiv_greater)
   1.190 +
   1.191 +lemma sort_by_mergesort_rec:
   1.192 +  "sort cmp xs =
   1.193 +    merge (sort cmp (take (length xs div 2) xs))
   1.194 +      (sort cmp (drop (length xs div 2) xs))" (is "_ = ?rhs")
   1.195 +proof (rule sort_eqI)
   1.196 +  have "mset (take (length xs div 2) xs) + mset (drop (length xs div 2) xs) =
   1.197 +    mset (take (length xs div 2) xs @ drop (length xs div 2) xs)"
   1.198 +    by (simp only: mset_append)
   1.199 +  then show "mset xs = mset ?rhs"
   1.200 +    by (simp add: mset_merge)
   1.201 +next
   1.202 +  show "sorted cmp ?rhs"
   1.203 +    by (simp add: sorted_merge)
   1.204 +next
   1.205 +  fix l
   1.206 +  have "stable_segment cmp l (take (length xs div 2) xs) @ stable_segment cmp l (drop (length xs div 2) xs)
   1.207 +    = stable_segment cmp l xs"
   1.208 +    by (simp only: filter_append [symmetric] append_take_drop_id)
   1.209 +  have "merge (stable_segment cmp l (take (length xs div 2) xs))
   1.210 +    (stable_segment cmp l (drop (length xs div 2) xs)) =
   1.211 +    stable_segment cmp l (take (length xs div 2) xs) @ stable_segment cmp l (drop (length xs div 2) xs)"
   1.212 +    by (rule merge_eq_appendI) (auto simp add: compare.trans_equiv_greater)
   1.213 +  also have "\<dots> = stable_segment cmp l xs"
   1.214 +    by (simp only: filter_append [symmetric] append_take_drop_id)
   1.215 +  finally show "stable_segment cmp l xs = stable_segment cmp l ?rhs"
   1.216 +    by (simp add: stable_sort filter_merge)
   1.217 +qed
   1.218 +
   1.219 +lemma mergesort_code [code]:
   1.220 +  "mergesort cmp xs =
   1.221 +    (case xs of
   1.222 +      [] \<Rightarrow> []
   1.223 +    | [x] \<Rightarrow> xs
   1.224 +    | [x, y] \<Rightarrow> (if compare cmp x y \<noteq> Greater then xs else [y, x])
   1.225 +    | _ \<Rightarrow>
   1.226 +        let
   1.227 +          half = length xs div 2;
   1.228 +          ys = take half xs;
   1.229 +          zs = drop half xs
   1.230 +        in merge (mergesort cmp ys) (mergesort cmp zs))"
   1.231 +proof (cases "length xs \<ge> 3")
   1.232 +  case False
   1.233 +  then have "length xs \<in> {0, 1, 2}"
   1.234 +    by (auto simp add: not_le le_less less_antisym)
   1.235 +  then consider "xs = []" | x where "xs = [x]" | x y where "xs = [x, y]"
   1.236 +    by (auto simp add: length_Suc_conv numeral_2_eq_2)
   1.237 +  then show ?thesis
   1.238 +    by cases simp_all
   1.239 +next
   1.240 +  case True
   1.241 +  then obtain x y z zs where "xs = x # y # z # zs"
   1.242 +    by (metis le_0_eq length_0_conv length_Cons list.exhaust not_less_eq_eq numeral_3_eq_3)
   1.243 +  moreover have "mergesort cmp xs =
   1.244 +    (let
   1.245 +       half = length xs div 2;
   1.246 +       ys = take half xs;
   1.247 +       zs = drop half xs
   1.248 +     in merge (mergesort cmp ys) (mergesort cmp zs))"
   1.249 +    using sort_by_mergesort_rec [of xs] by (simp add: Let_def)
   1.250 +  ultimately show ?thesis
   1.251 +    by simp
   1.252 +qed
   1.253  
   1.254  end
   1.255 +
   1.256 +end