diff -r 9786f64d8285 -r d6daa049c1db src/HOL/Data_Structures/Selection.thy --- a/src/HOL/Data_Structures/Selection.thy Thu Oct 31 15:46:53 2024 +0100 +++ b/src/HOL/Data_Structures/Selection.thy Thu Oct 31 18:43:32 2024 +0100 @@ -525,15 +525,17 @@ \ function mom_select where "mom_select k xs = ( - if length xs \ 20 then + let n = length xs + in if n \ 20 then slow_select k xs else - let M = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs)); - (ls, es, gs) = partition3 M xs + let M = mom_select (((n + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs)); + (ls, es, gs) = partition3 M xs; + nl = length ls in - if k < length ls then mom_select k ls - else if k < length ls + length es then M - else mom_select (k - length ls - length es) gs + if k < nl then mom_select k ls + else let ne = length es in if k < nl + ne then M + else mom_select (k - nl - ne) gs )" by auto @@ -564,7 +566,8 @@ have length_eq: "length xs = nl + ne + length gs" unfolding nl_def ne_def ls_def es_def gs_def using [[simp_depth_limit = 1]] by (induction xs) auto - note IH = "1.IH"(2,3)[OF False x_def tw refl refl] + note IH = "1.IH"(2)[OF refl False x_def tw refl refl refl] + "1.IH"(3)[OF refl False x_def tw refl refl refl _ refl] have "mom_select k xs = (if k < nl then mom_select k ls else if k < nl + ne then x else mom_select (k - nl - ne) gs)" using "1.hyps" False @@ -572,6 +575,7 @@ also have "\ = (if k < nl then select k ls else if k < nl + ne then x else select (k - nl - ne) gs)" using IH length_eq "1.prems" by (simp add: ls_def es_def gs_def nl_def ne_def) + try0 also have "\ = select k xs" using \k < length xs\ by (subst (3) select_rec_partition[of _ _ x]) (simp_all add: nl_def ne_def flip: tw) finally show "mom_select k xs = select k xs" . @@ -697,17 +701,13 @@ lemma T_chop_le: "T_chop d xs \ 5 * length xs + 1" by (induction d xs rule: T_chop.induct) (auto simp: T_chop_reduce T_take_eq T_drop_eq) +time_fun mom_select -text \ - The option \domintros\ here allows us to explicitly reason about where the function does and - does not terminate. With this, we can skip the termination proof this time because we can - reuse the one for \<^const>\mom_select\. -\ -function (domintros) T_mom_select :: "nat \ 'a :: linorder list \ nat" where - "T_mom_select k xs = T_length xs + ( - if length xs \ 20 then - T_slow_select k xs - else +lemmas [simp del] = T_mom_select.simps + +lemma T_mom_select_simps: + "length xs \ 20 \ T_mom_select k xs = T_slow_select k xs + T_length xs + 1" + "length xs > 20 \ T_mom_select k xs = ( let xss = chop 5 xs; ms = map slow_median xss; idx = (((length xs + 4) div 5 - 1) div 2); @@ -719,27 +719,14 @@ (if k < nl then T_mom_select k ls else T_length es + (if k < nl + ne then 0 else T_mom_select (k - nl - ne) gs)) + T_mom_select idx ms + T_chop 5 xs + T_map T_slow_median xss + - T_partition3 x xs + T_length ls + 1 + T_partition3 x xs + T_length ls + T_length xs + 1 )" - by auto - -termination T_mom_select -proof (rule allI, safe) - fix k :: nat and xs :: "'a :: linorder list" - have "mom_select_dom (k, xs)" - using mom_select_termination by blast - thus "T_mom_select_dom (k, xs)" - by (induction k xs rule: mom_select.pinduct) - (rule T_mom_select.domintros, simp_all) -qed - -lemmas [simp del] = T_mom_select.simps - + by (subst T_mom_select.simps; simp add: Let_def case_prod_unfold)+ function T'_mom_select :: "nat \ nat" where "T'_mom_select n = (if n \ 20 then - 482 + 483 else T'_mom_select (nat \0.2*n\) + T'_mom_select (nat \0.7*n+3\) + 19 * n + 54)" by force+ @@ -748,7 +735,7 @@ lemmas [simp del] = T'_mom_select.simps -lemma T'_mom_select_ge: "T'_mom_select n \ 482" +lemma T'_mom_select_ge: "T'_mom_select n \ 483" by (induction n rule: T'_mom_select.induct; subst T'_mom_select.simps) auto lemma T'_mom_select_mono: @@ -758,7 +745,7 @@ show ?case proof (cases "m \ 20") case True - hence "T'_mom_select m = 482" + hence "T'_mom_select m = 483" by (subst T'_mom_select.simps) auto also have "\ \ T'_mom_select n" by (rule T'_mom_select_ge) @@ -784,24 +771,26 @@ case (1 k xs) define n where [simp]: "n = length xs" define x where - "x = mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))" + "x = mom_select (((n + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))" define ls es gs where "ls = filter (\y. y < x) xs" and "es = filter (\y. y = x) xs" and "gs = filter (\y. y > x) xs" define nl ne where "nl = length ls" and "ne = length es" note defs = nl_def ne_def x_def ls_def es_def gs_def have tw: "(ls, es, gs) = partition3 x xs" unfolding partition3_def defs One_nat_def .. - note IH = "1.IH"(1,2,3)[OF _ refl refl refl x_def tw refl refl refl refl] + note IH = "1.IH"(1)[OF n_def] + "1.IH"(2)[OF n_def _ x_def tw refl refl nl_def] + "1.IH"(3)[OF n_def _ x_def tw refl refl nl_def _ ne_def] show ?case proof (cases "length xs \ 20") case True \ \base case\ - hence "T_mom_select k xs \ (length xs)\<^sup>2 + 4 * length xs + 2" + hence "T_mom_select k xs \ (length xs)\<^sup>2 + 4 * length xs + 3" using T_slow_select_le[of k xs] \k < length xs\ - by (subst T_mom_select.simps) (auto simp: T_length_eq) - also have "\ \ 20\<^sup>2 + 4 * 20 + 2" + by (subst T_mom_select_simps(1)) (auto simp: T_length_eq) + also have "\ \ 20\<^sup>2 + 4 * 20 + 3" using True by (intro add_mono power_mono) auto - also have "\ = 482" + also have "\ = 483" by simp also have "\ = T'_mom_select (length xs)" using True by (simp add: T'_mom_select.simps) @@ -845,11 +834,11 @@ text \The cost of the first recursive call (to compute the median of medians):\ define T_rec1 where - "T_rec1 = T_mom_select (((length xs + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))" + "T_rec1 = T_mom_select (((n + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs))" from False have "((length xs + 4) div 5 - Suc 0) div 2 < nat \real (length xs) / 5\" by linarith hence "T_rec1 \ T'_mom_select (length (map slow_median (chop 5 xs)))" - using False unfolding T_rec1_def by (intro IH(3)) (auto simp: length_chop) + using False unfolding T_rec1_def by (intro IH(1)) (auto simp: length_chop) hence "T_rec1 \ T'_mom_select (nat \0.2 * n\)" by (simp add: length_chop) @@ -865,7 +854,7 @@ hence "T_rec2 = T_mom_select k ls" by (simp add: T_rec2_def) also have "\ \ T'_mom_select (length ls)" - by (rule IH(1)) (use \k < nl\ False in \auto simp: defs\) + by (rule IH(2)) (use \k < nl\ False in \auto simp: defs\) also have "length ls \ nat \0.7 * n + 3\" unfolding ls_def using size_less_than_median_of_medians[of xs] by (auto simp: length_filter_conv_size_filter_mset slow_median_correct[abs_def] x_eq) @@ -883,15 +872,14 @@ hence "T_rec2 = T_mom_select (k - nl - ne) gs" by (simp add: T_rec2_def) also have "\ \ T'_mom_select (length gs)" - unfolding nl_def ne_def - proof (rule IH(2)) - show "\ length xs \ 20" + proof (rule IH(3)) + show "\n \ 20" using False by auto - show "\ k < length ls" "\k < length ls + length es" + show "\ k < nl" "\k < nl + ne" using \k \ nl + ne\ by (auto simp: nl_def ne_def) have "length xs = nl + ne + length gs" unfolding defs by (rule length_partition3) (simp_all add: partition3_def) - thus "k - length ls - length es < length gs" + thus "k - nl - ne < length gs" using \k \ nl + ne\ \k < length xs\ by (auto simp: nl_def ne_def) qed also have "length gs \ nat \0.7 * n + 3\" @@ -903,10 +891,19 @@ qed text \Now for the final inequality chain:\ - have "T_mom_select k xs \ T_rec2 + T_rec1 + T_ms + 2 * n + nl + ne + T_chop 5 xs + 5" using False - by (subst T_mom_select.simps, unfold Let_def tw [symmetric] defs [symmetric]) - (simp_all add: nl_def ne_def T_rec1_def T_rec2_def T_partition3_eq - T_length_eq T_ms_def) + have "T_mom_select k xs = + (if k < nl then T_mom_select k ls + else T_length es + + (if k < nl + ne then 0 else T_mom_select (k - nl - ne) gs)) + + T_mom_select (((n + 4) div 5 - 1) div 2) (map slow_median (chop 5 xs)) + + T_chop 5 xs + T_map T_slow_median (chop 5 xs) + T_partition3 x xs + + T_length ls + T_length xs + 1" using False + by (subst T_mom_select_simps; + unfold Let_def n_def [symmetric] x_def [symmetric] nl_def [symmetric] + ne_def [symmetric] prod.case tw [symmetric]) simp_all + also have "\ \ T_rec2 + T_rec1 + T_ms + 2 * n + nl + ne + T_chop 5 xs + 5" using False + by (auto simp add: T_rec1_def T_rec2_def T_partition3_eq + T_length_eq T_ms_def nl_def ne_def) also have "nl \ n" by (simp add: nl_def ls_def) also have "ne \ n" by (simp add: ne_def es_def) also note \T_ms \ 10 * n + 48\