tweaked time functions for median-of-medians selection in HOL-Data_Structures
authorManuel Eberl <manuel@pruvisto.org>
Thu, 11 Apr 2024 14:13:43 +0200
changeset 80093 c0d689c4fd15
parent 80092 1a9f0159de5b
child 80094 5af76462e3a5
tweaked time functions for median-of-medians selection in HOL-Data_Structures
src/HOL/Data_Structures/Selection.thy
--- a/src/HOL/Data_Structures/Selection.thy	Wed Apr 10 13:23:00 2024 +0200
+++ b/src/HOL/Data_Structures/Selection.thy	Thu Apr 11 14:13:43 2024 +0200
@@ -634,30 +634,31 @@
 lemma T_partition3_eq: "T_partition3 x xs = length xs + 1"
   by (induction x xs rule: T_partition3.induct) auto
 
-definition T_slow_select :: "nat \<Rightarrow> 'a :: linorder list \<Rightarrow> nat" where
-  "T_slow_select k xs = T_insort xs + T_nth (insort xs) k + 1"
+
+time_definition slow_select
+
+lemmas T_slow_select_def [simp del] = T_slow_select.simps
+
 
 definition T_slow_median :: "'a :: linorder list \<Rightarrow> nat" where
-  "T_slow_median xs = T_slow_select ((length xs - 1) div 2) xs + 1"
+  "T_slow_median xs = T_length xs + T_slow_select ((length xs - 1) div 2) xs"
 
-lemma T_slow_select_le: "T_slow_select k xs \<le> length xs ^ 2 + 3 * length xs + 3"
+lemma T_slow_select_le: "T_slow_select k xs \<le> length xs ^ 2 + 3 * length xs + 2"
 proof -
-  have "T_slow_select k xs \<le> (length xs + 1)\<^sup>2 + (length (insort xs) + 1) + 1"
+  have "T_slow_select k xs \<le> (length xs + 1)\<^sup>2 + (length (insort xs) + 1)"
     unfolding T_slow_select_def
     by (intro add_mono T_insort_length) (auto simp: T_nth_eq)
-  also have "\<dots> = length xs ^ 2 + 3 * length xs + 3"
+  also have "\<dots> = length xs ^ 2 + 3 * length xs + 2"
     by (simp add: insort_correct algebra_simps power2_eq_square)
   finally show ?thesis .
 qed
 
-lemma T_slow_median_le: "T_slow_median xs \<le> length xs ^ 2 + 3 * length xs + 4"
-  unfolding T_slow_median_def using T_slow_select_le[of "(length xs - 1) div 2" xs] by simp
+lemma T_slow_median_le: "T_slow_median xs \<le> length xs ^ 2 + 4 * length xs + 3"
+  unfolding T_slow_median_def using T_slow_select_le[of "(length xs - 1) div 2" xs]
+  by (simp add: algebra_simps T_length_eq)
 
 
-fun T_chop :: "nat \<Rightarrow> 'a list \<Rightarrow> nat" where
-  "T_chop 0 _  = 1"
-| "T_chop _ [] = 1"
-| "T_chop n xs = T_take n xs + T_drop n xs + T_chop n (drop n xs)"
+time_fun chop
 
 lemmas [simp del] = T_chop.simps
 
@@ -668,19 +669,20 @@
   by (auto simp: T_chop.simps)
 
 lemma T_chop_reduce:
-  "n > 0 \<Longrightarrow> xs \<noteq> [] \<Longrightarrow> T_chop n xs = T_take n xs + T_drop n xs + T_chop n (drop n xs)"
+  "n > 0 \<Longrightarrow> xs \<noteq> [] \<Longrightarrow> T_chop n xs = T_take n xs + T_drop n xs + T_chop n (drop n xs) + 1"
   by (cases n; cases xs) (auto simp: T_chop.simps)
 
 lemma T_chop_le: "T_chop d xs \<le> 5 * length xs + 1"
   by (induction d xs rule: T_chop.induct) (auto simp: T_chop_reduce T_take_eq T_drop_eq)
 
+
 text \<open>
   The option \<open>domintros\<close> here allows us to explicitly reason about where the function does and
   does not terminate. With this, we can skip the termination proof this time because we can
   reuse the one for \<^const>\<open>mom_select\<close>.
 \<close>
 function (domintros) T_mom_select :: "nat \<Rightarrow> 'a :: linorder list \<Rightarrow> nat" where
-  "T_mom_select k xs = (
+  "T_mom_select k xs = T_length xs + (
      if length xs \<le> 20 then
        T_slow_select k xs
      else
@@ -693,10 +695,9 @@
            ne = length es
        in
          (if k < nl then T_mom_select k ls 
-          else if k < nl + ne then 0
-          else T_mom_select (k - nl - ne) gs) +
+          else T_length es + (if k < nl + ne then 0 else T_mom_select (k - nl - ne) gs)) +
          T_mom_select idx ms + T_chop 5 xs + T_map T_slow_median xss +
-         T_partition3 x xs + T_length ls + T_length es + 1
+         T_partition3 x xs + T_length ls + 1
       )"
   by auto
 
@@ -716,16 +717,16 @@
 function T'_mom_select :: "nat \<Rightarrow> nat" where
   "T'_mom_select n =
      (if n \<le> 20 then
-        463
+        483
       else
-        T'_mom_select (nat \<lceil>0.2*n\<rceil>) + T'_mom_select (nat \<lceil>0.7*n+3\<rceil>) + 17 * n + 50)"
+        T'_mom_select (nat \<lceil>0.2*n\<rceil>) + T'_mom_select (nat \<lceil>0.7*n+3\<rceil>) + 19 * n + 55)"
   by force+
 termination by (relation "measure id"; simp; linarith)
 
 lemmas [simp del] = T'_mom_select.simps
 
 
-lemma T'_mom_select_ge: "T'_mom_select n \<ge> 463"
+lemma T'_mom_select_ge: "T'_mom_select n \<ge> 483"
   by (induction n rule: T'_mom_select.induct; subst T'_mom_select.simps) auto
 
 lemma T'_mom_select_mono:
@@ -735,7 +736,7 @@
   show ?case
   proof (cases "m \<le> 20")
     case True
-    hence "T'_mom_select m = 463"
+    hence "T'_mom_select m = 483"
       by (subst T'_mom_select.simps) auto
     also have "\<dots> \<le> T'_mom_select n"
       by (rule T'_mom_select_ge)
@@ -743,9 +744,9 @@
   next
     case False
     hence "T'_mom_select m =
-             T'_mom_select (nat \<lceil>0.2*m\<rceil>) + T'_mom_select (nat \<lceil>0.7*m + 3\<rceil>) + 17 * m + 50"
+             T'_mom_select (nat \<lceil>0.2*m\<rceil>) + T'_mom_select (nat \<lceil>0.7*m + 3\<rceil>) + 19 * m + 55"
       by (subst T'_mom_select.simps) auto
-    also have "\<dots> \<le> T'_mom_select (nat \<lceil>0.2*n\<rceil>) + T'_mom_select (nat \<lceil>0.7*n + 3\<rceil>) + 17 * n + 50"
+    also have "\<dots> \<le> T'_mom_select (nat \<lceil>0.2*n\<rceil>) + T'_mom_select (nat \<lceil>0.7*n + 3\<rceil>) + 19 * n + 55"
       using \<open>m \<le> n\<close> and False by (intro add_mono less.IH; linarith)
     also have "\<dots> = T'_mom_select n"
       using \<open>m \<le> n\<close> and False by (subst T'_mom_select.simps) auto
@@ -770,11 +771,11 @@
   show ?case
   proof (cases "length xs \<le> 20")
     case True \<comment> \<open>base case\<close>
-    hence "T_mom_select k xs \<le> (length xs)\<^sup>2 + 3 * length xs + 3"
-      using T_slow_select_le[of k xs] by (subst T_mom_select.simps) auto
-    also have "\<dots> \<le> 20\<^sup>2 + 3 * 20 + 3"
+    hence "T_mom_select k xs \<le> (length xs)\<^sup>2 + 4 * length xs + 3"
+      using T_slow_select_le[of k xs] by (subst T_mom_select.simps) (auto simp: T_length_eq)
+    also have "\<dots> \<le> 20\<^sup>2 + 4 * 20 + 3"
       using True by (intro add_mono power_mono) auto
-    also have "\<dots> \<le> 463"
+    also have "\<dots> = 483"
       by simp
     also have "\<dots> = T'_mom_select (length xs)"
       using True by (simp add: T'_mom_select.simps)
@@ -793,27 +794,27 @@
 
     text \<open>The cost of computing the medians of all the subgroups:\<close>
     define T_ms where "T_ms = T_map T_slow_median (chop 5 xs)"
-    have "T_ms \<le> 9 * n + 45"
+    have "T_ms \<le> 10 * n + 49"
     proof -
       have "T_ms = (\<Sum>ys\<leftarrow>chop 5 xs. T_slow_median ys) + length (chop 5 xs) + 1"
         by (simp add: T_ms_def T_map_eq)
-      also have "(\<Sum>ys\<leftarrow>chop 5 xs. T_slow_median ys) \<le> (\<Sum>ys\<leftarrow>chop 5 xs. 44)"
+      also have "(\<Sum>ys\<leftarrow>chop 5 xs. T_slow_median ys) \<le> (\<Sum>ys\<leftarrow>chop 5 xs. 48)"
       proof (intro sum_list_mono)
         fix ys assume "ys \<in> set (chop 5 xs)"
         hence "length ys \<le> 5"
           using length_chop_part_le by blast
-        have "T_slow_median ys \<le> (length ys) ^ 2 + 3 * length ys + 4"
+        have "T_slow_median ys \<le> (length ys) ^ 2 + 4 * length ys + 3"
           by (rule T_slow_median_le)
-        also have "\<dots> \<le> 5 ^ 2 + 3 * 5 + 4"
+        also have "\<dots> \<le> 5 ^ 2 + 4 * 5 + 3"
           using \<open>length ys \<le> 5\<close> by (intro add_mono power_mono) auto
-        finally show "T_slow_median ys \<le> 44" by simp
+        finally show "T_slow_median ys \<le> 48" by simp
       qed
-      also have "(\<Sum>ys\<leftarrow>chop 5 xs. 44) + length (chop 5 xs) + 1 =
-                   45 * nat \<lceil>real n / 5\<rceil> + 1"
+      also have "(\<Sum>ys\<leftarrow>chop 5 xs. 48) + length (chop 5 xs) + 1 =
+                   49 * nat \<lceil>real n / 5\<rceil> + 1"
         by (simp add: map_replicate_const length_chop)
-      also have "\<dots> \<le> 9 * n + 45"
+      also have "\<dots> \<le> 10 * n + 49"
         by linarith
-      finally show "T_ms \<le> 9 * n + 45" by simp
+      finally show "T_ms \<le> 10 * n + 49" by simp
     qed
 
     text \<open>The cost of the first recursive call (to compute the median of medians):\<close>
@@ -864,19 +865,19 @@
     qed
 
     text \<open>Now for the final inequality chain:\<close>
-    have "T_mom_select k xs = T_rec2 + T_rec1 + T_ms + n + nl + ne + T_chop 5 xs + 4" using False
+    have "T_mom_select k xs \<le> T_rec2 + T_rec1 + T_ms + 2 * n + nl + ne + T_chop 5 xs + 5" using False
       by (subst T_mom_select.simps, unfold Let_def tw [symmetric] defs [symmetric])
          (simp_all add: nl_def ne_def T_rec1_def T_rec2_def T_partition3_eq
                         T_length_eq T_ms_def)
     also have "nl \<le> n" by (simp add: nl_def ls_def)
     also have "ne \<le> n" by (simp add: ne_def es_def)
-    also note \<open>T_ms \<le> 9 * n + 45\<close>
+    also note \<open>T_ms \<le> 10 * n + 49\<close>
     also have "T_chop 5 xs \<le> 5 * n + 1"
       using T_chop_le[of 5 xs] by simp 
     also note \<open>T_rec1 \<le> T'_mom_select (nat \<lceil>0.2*n\<rceil>)\<close>
     also note \<open>T_rec2 \<le> T'_mom_select (nat \<lceil>0.7*n + 3\<rceil>)\<close>
     finally have "T_mom_select k xs \<le>
-                    T'_mom_select (nat \<lceil>0.7*n + 3\<rceil>) + T'_mom_select (nat \<lceil>0.2*n\<rceil>) + 17 * n + 50"
+                    T'_mom_select (nat \<lceil>0.7*n + 3\<rceil>) + T'_mom_select (nat \<lceil>0.2*n\<rceil>) + 19 * n + 55"
       by simp
     also have "\<dots> = T'_mom_select n"
       using False by (subst T'_mom_select.simps) auto
@@ -1033,7 +1034,7 @@
 lemma T'_mom_select_le': "\<exists>C\<^sub>1 C\<^sub>2. \<forall>n. T'_mom_select n \<le> C\<^sub>1 * n + C\<^sub>2"
 proof (rule akra_bazzi_light_nat)
   show "\<forall>n>20. T'_mom_select n = T'_mom_select (nat \<lceil>0.2 * n + 0\<rceil>) +
-                 T'_mom_select (nat \<lceil>0.7 * n + 3\<rceil>) + 17 * n + 50"
+                 T'_mom_select (nat \<lceil>0.7 * n + 3\<rceil>) + 19 * n + 55"
     using T'_mom_select.simps by auto
 qed auto