src/HOL/Multivariate_Analysis/Cartesian_Euclidean_Space.thy
changeset 57418 6ab1c7cb0b8d
parent 56273 def3bbe6f2a5
child 57512 cc97b347b301
--- a/src/HOL/Multivariate_Analysis/Cartesian_Euclidean_Space.thy	Fri Jun 27 22:08:55 2014 +0200
+++ b/src/HOL/Multivariate_Analysis/Cartesian_Euclidean_Space.thy	Sat Jun 28 09:16:42 2014 +0200
@@ -8,23 +8,18 @@
   "(if k=a then 1 else (0::'a::semiring_1)) * (if k=a then 1 else 0) = (if k=a then 1 else 0)"
   by (cases "k=a") auto
 
-lemma setsum_Plus:
-  "\<lbrakk>finite A; finite B\<rbrakk> \<Longrightarrow>
-    (\<Sum>x\<in>A <+> B. g x) = (\<Sum>x\<in>A. g (Inl x)) + (\<Sum>x\<in>B. g (Inr x))"
-  unfolding Plus_def
-  by (subst setsum_Un_disjoint, auto simp add: setsum_reindex)
-
 lemma setsum_UNIV_sum:
   fixes g :: "'a::finite + 'b::finite \<Rightarrow> _"
   shows "(\<Sum>x\<in>UNIV. g x) = (\<Sum>x\<in>UNIV. g (Inl x)) + (\<Sum>x\<in>UNIV. g (Inr x))"
   apply (subst UNIV_Plus_UNIV [symmetric])
-  apply (rule setsum_Plus [OF finite finite])
+  apply (subst setsum.Plus)
+  apply simp_all
   done
 
 lemma setsum_mult_product:
   "setsum h {..<A * B :: nat} = (\<Sum>i\<in>{..<A}. \<Sum>j\<in>{..<B}. h (j + i * B))"
   unfolding setsum_nat_group[of h B A, unfolded atLeast0LessThan, symmetric]
-proof (rule setsum_cong, simp, rule setsum_reindex_cong)
+proof (rule setsum.cong, simp, rule setsum.reindex_cong)
   fix i
   show "inj_on (\<lambda>j. j + i * B) {..<B}" by (auto intro!: inj_onI)
   show "{i * B..<i * B + B} = (\<lambda>j. j + i * B) ` {..<B}"
@@ -110,11 +105,17 @@
 
 subsection {* A naive proof procedure to lift really trivial arithmetic stuff from the basis of the vector space. *}
 
+lemma setsum_cong_aux:
+  "(\<And>x. x \<in> A \<Longrightarrow> f x = g x) \<Longrightarrow> setsum f A = setsum g A"
+  by (auto intro: setsum.cong)
+
+hide_fact (open) setsum_cong_aux
+
 method_setup vector = {*
 let
   val ss1 =
     simpset_of (put_simpset HOL_basic_ss @{context}
-      addsimps [@{thm setsum_addf} RS sym,
+      addsimps [@{thm setsum.distrib} RS sym,
       @{thm setsum_subtractf} RS sym, @{thm setsum_right_distrib},
       @{thm setsum_left_distrib}, @{thm setsum_negf} RS sym])
   val ss2 =
@@ -126,8 +127,8 @@
               @{thm vec_lambda_beta}, @{thm vector_scalar_mult_def}])
   fun vector_arith_tac ctxt ths =
     simp_tac (put_simpset ss1 ctxt)
-    THEN' (fn i => rtac @{thm setsum_cong2} i
-         ORELSE rtac @{thm setsum_0'} i
+    THEN' (fn i => rtac @{thm Cartesian_Euclidean_Space.setsum_cong_aux} i
+         ORELSE rtac @{thm setsum.neutral} i
          ORELSE simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm vec_eq_iff}]) i)
     (* THEN' TRY o clarify_tac HOL_cs  THEN' (TRY o rtac @{thm iffI}) *)
     THEN' asm_full_simp_tac (put_simpset ss2 ctxt addsimps ths)
@@ -357,14 +358,14 @@
 
 lemma mat_0[simp]: "mat 0 = 0" by (vector mat_def)
 lemma matrix_add_ldistrib: "(A ** (B + C)) = (A ** B) + (A ** C)"
-  by (vector matrix_matrix_mult_def setsum_addf[symmetric] field_simps)
+  by (vector matrix_matrix_mult_def setsum.distrib[symmetric] field_simps)
 
 lemma matrix_mul_lid:
   fixes A :: "'a::semiring_1 ^ 'm ^ 'n"
   shows "mat 1 ** A = A"
   apply (simp add: matrix_matrix_mult_def mat_def)
   apply vector
-  apply (auto simp only: if_distrib cond_application_beta setsum_delta'[OF finite]
+  apply (auto simp only: if_distrib cond_application_beta setsum.delta'[OF finite]
     mult_1_left mult_zero_left if_True UNIV_I)
   done
 
@@ -374,26 +375,26 @@
   shows "A ** mat 1 = A"
   apply (simp add: matrix_matrix_mult_def mat_def)
   apply vector
-  apply (auto simp only: if_distrib cond_application_beta setsum_delta[OF finite]
+  apply (auto simp only: if_distrib cond_application_beta setsum.delta[OF finite]
     mult_1_right mult_zero_right if_True UNIV_I cong: if_cong)
   done
 
 lemma matrix_mul_assoc: "A ** (B ** C) = (A ** B) ** C"
   apply (vector matrix_matrix_mult_def setsum_right_distrib setsum_left_distrib mult_assoc)
-  apply (subst setsum_commute)
+  apply (subst setsum.commute)
   apply simp
   done
 
 lemma matrix_vector_mul_assoc: "A *v (B *v x) = (A ** B) *v x"
   apply (vector matrix_matrix_mult_def matrix_vector_mult_def
     setsum_right_distrib setsum_left_distrib mult_assoc)
-  apply (subst setsum_commute)
+  apply (subst setsum.commute)
   apply simp
   done
 
 lemma matrix_vector_mul_lid: "mat 1 *v x = (x::'a::semiring_1 ^ 'n)"
   apply (vector matrix_vector_mult_def mat_def)
-  apply (simp add: if_distrib cond_application_beta setsum_delta' cong del: if_weak_cong)
+  apply (simp add: if_distrib cond_application_beta setsum.delta' cong del: if_weak_cong)
   done
 
 lemma matrix_transpose_mul:
@@ -410,7 +411,7 @@
   apply (erule_tac x="axis ia 1" in allE)
   apply (erule_tac x="i" in allE)
   apply (auto simp add: if_distrib cond_application_beta axis_def
-    setsum_delta[OF finite] cong del: if_weak_cong)
+    setsum.delta[OF finite] cong del: if_weak_cong)
   done
 
 lemma matrix_vector_mul_component: "((A::real^_^_) *v x)$k = (A$k) \<bullet> x"
@@ -418,7 +419,7 @@
 
 lemma dot_lmul_matrix: "((x::real ^_) v* A) \<bullet> y = x \<bullet> (A *v y)"
   apply (simp add: inner_vec_def matrix_vector_mult_def vector_matrix_mult_def setsum_left_distrib setsum_right_distrib mult_ac)
-  apply (subst setsum_commute)
+  apply (subst setsum.commute)
   apply simp
   done
 
@@ -455,10 +456,10 @@
 
 lemma vector_componentwise:
   "(x::'a::ring_1^'n) = (\<chi> j. \<Sum>i\<in>UNIV. (x$i) * (axis i 1 :: 'a^'n) $ j)"
-  by (simp add: axis_def if_distrib setsum_cases vec_eq_iff)
+  by (simp add: axis_def if_distrib setsum.If_cases vec_eq_iff)
 
 lemma basis_expansion: "setsum (\<lambda>i. (x$i) *s axis i 1) UNIV = (x::('a::ring_1) ^'n)"
-  by (auto simp add: axis_def vec_eq_iff if_distrib setsum_cases cong del: if_weak_cong)
+  by (auto simp add: axis_def vec_eq_iff if_distrib setsum.If_cases cong del: if_weak_cong)
 
 lemma linear_componentwise:
   fixes f:: "real ^'m \<Rightarrow> real ^ _"
@@ -492,7 +493,7 @@
 
 lemma matrix_vector_mul_linear: "linear(\<lambda>x. A *v (x::real ^ _))"
   by (simp add: linear_iff matrix_vector_mult_def vec_eq_iff
-      field_simps setsum_right_distrib setsum_addf)
+      field_simps setsum_right_distrib setsum.distrib)
 
 lemma matrix_works:
   assumes lf: "linear f"
@@ -523,7 +524,7 @@
   apply (rule adjoint_unique)
   apply (simp add: transpose_def inner_vec_def matrix_vector_mult_def
     setsum_left_distrib setsum_right_distrib)
-  apply (subst setsum_commute)
+  apply (subst setsum.commute)
   apply (auto simp add: mult_ac)
   done
 
@@ -707,13 +708,13 @@
             using i(1) by (simp add: field_simps)
           have "setsum (\<lambda>xa. if xa = i then (c + (x$i)) * ((column xa A)$j)
               else (x$xa) * ((column xa A$j))) ?U = setsum (\<lambda>xa. (if xa = i then c * ((column i A)$j) else 0) + ((x$xa) * ((column xa A)$j))) ?U"
-            apply (rule setsum_cong[OF refl])
+            apply (rule setsum.cong[OF refl])
             using th apply blast
             done
           also have "\<dots> = setsum (\<lambda>xa. if xa = i then c * ((column i A)$j) else 0) ?U + setsum (\<lambda>xa. ((x$xa) * ((column xa A)$j))) ?U"
-            by (simp add: setsum_addf)
+            by (simp add: setsum.distrib)
           also have "\<dots> = c * ((column i A)$j) + setsum (\<lambda>xa. ((x$xa) * ((column xa A)$j))) ?U"
-            unfolding setsum_delta[OF fU]
+            unfolding setsum.delta[OF fU]
             using i(1) by simp
           finally show "setsum (\<lambda>xa. if xa = i then (c + (x$i)) * ((column xa A)$j)
             else (x$xa) * ((column xa A$j))) ?U = c * ((column i A)$j) + setsum (\<lambda>xa. ((x$xa) * ((column xa A)$j))) ?U" .