src/HOL/Library/DAList_Multiset.thy
author haftmann
Sat Dec 17 15:22:13 2016 +0100 (2016-12-17)
changeset 64587 8355a6e2df79
parent 63830 2ea3725a34bd
child 66148 5e60c2d0a1f1
permissions -rw-r--r--
standardized notation
     1 (*  Title:      HOL/Library/DAList_Multiset.thy
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 *)
     4 
     5 section \<open>Multisets partially implemented by association lists\<close>
     6 
     7 theory DAList_Multiset
     8 imports Multiset DAList
     9 begin
    10 
    11 text \<open>Delete prexisting code equations\<close>
    12 
    13 lemma [code, code del]: "{#} = {#}" ..
    14 
    15 lemma [code, code del]: "Multiset.is_empty = Multiset.is_empty" ..
    16 
    17 lemma [code, code del]: "add_mset = add_mset" ..
    18 
    19 lemma [code, code del]: "plus = (plus :: 'a multiset \<Rightarrow> _)" ..
    20 
    21 lemma [code, code del]: "minus = (minus :: 'a multiset \<Rightarrow> _)" ..
    22 
    23 lemma [code, code del]: "inf_subset_mset = (inf_subset_mset :: 'a multiset \<Rightarrow> _)" ..
    24 
    25 lemma [code, code del]: "sup_subset_mset = (sup_subset_mset :: 'a multiset \<Rightarrow> _)" ..
    26 
    27 lemma [code, code del]: "image_mset = image_mset" ..
    28 
    29 lemma [code, code del]: "filter_mset = filter_mset" ..
    30 
    31 lemma [code, code del]: "count = count" ..
    32 
    33 lemma [code, code del]: "size = (size :: _ multiset \<Rightarrow> nat)" ..
    34 
    35 lemma [code, code del]: "sum_mset = sum_mset" ..
    36 
    37 lemma [code, code del]: "prod_mset = prod_mset" ..
    38 
    39 lemma [code, code del]: "set_mset = set_mset" ..
    40 
    41 lemma [code, code del]: "sorted_list_of_multiset = sorted_list_of_multiset" ..
    42 
    43 lemma [code, code del]: "subset_mset = subset_mset" ..
    44 
    45 lemma [code, code del]: "subseteq_mset = subseteq_mset" ..
    46 
    47 lemma [code, code del]: "equal_multiset_inst.equal_multiset = equal_multiset_inst.equal_multiset" ..
    48 
    49 
    50 text \<open>Raw operations on lists\<close>
    51 
    52 definition join_raw ::
    53     "('key \<Rightarrow> 'val \<times> 'val \<Rightarrow> 'val) \<Rightarrow>
    54       ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list"
    55   where "join_raw f xs ys = foldr (\<lambda>(k, v). map_default k v (\<lambda>v'. f k (v', v))) ys xs"
    56 
    57 lemma join_raw_Nil [simp]: "join_raw f xs [] = xs"
    58   by (simp add: join_raw_def)
    59 
    60 lemma join_raw_Cons [simp]:
    61   "join_raw f xs ((k, v) # ys) = map_default k v (\<lambda>v'. f k (v', v)) (join_raw f xs ys)"
    62   by (simp add: join_raw_def)
    63 
    64 lemma map_of_join_raw:
    65   assumes "distinct (map fst ys)"
    66   shows "map_of (join_raw f xs ys) x =
    67     (case map_of xs x of
    68       None \<Rightarrow> map_of ys x
    69     | Some v \<Rightarrow> (case map_of ys x of None \<Rightarrow> Some v | Some v' \<Rightarrow> Some (f x (v, v'))))"
    70   using assms
    71   apply (induct ys)
    72   apply (auto simp add: map_of_map_default split: option.split)
    73   apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
    74   apply (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
    75   done
    76 
    77 lemma distinct_join_raw:
    78   assumes "distinct (map fst xs)"
    79   shows "distinct (map fst (join_raw f xs ys))"
    80   using assms
    81 proof (induct ys)
    82   case Nil
    83   then show ?case by simp
    84 next
    85   case (Cons y ys)
    86   then show ?case by (cases y) (simp add: distinct_map_default)
    87 qed
    88 
    89 definition "subtract_entries_raw xs ys = foldr (\<lambda>(k, v). AList.map_entry k (\<lambda>v'. v' - v)) ys xs"
    90 
    91 lemma map_of_subtract_entries_raw:
    92   assumes "distinct (map fst ys)"
    93   shows "map_of (subtract_entries_raw xs ys) x =
    94     (case map_of xs x of
    95       None \<Rightarrow> None
    96     | Some v \<Rightarrow> (case map_of ys x of None \<Rightarrow> Some v | Some v' \<Rightarrow> Some (v - v')))"
    97   using assms
    98   unfolding subtract_entries_raw_def
    99   apply (induct ys)
   100   apply auto
   101   apply (simp split: option.split)
   102   apply (simp add: map_of_map_entry)
   103   apply (auto split: option.split)
   104   apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
   105   apply (metis map_of_eq_None_iff option.simps(4) option.simps(5))
   106   done
   107 
   108 lemma distinct_subtract_entries_raw:
   109   assumes "distinct (map fst xs)"
   110   shows "distinct (map fst (subtract_entries_raw xs ys))"
   111   using assms
   112   unfolding subtract_entries_raw_def
   113   by (induct ys) (auto simp add: distinct_map_entry)
   114 
   115 
   116 text \<open>Operations on alists with distinct keys\<close>
   117 
   118 lift_definition join :: "('a \<Rightarrow> 'b \<times> 'b \<Rightarrow> 'b) \<Rightarrow> ('a, 'b) alist \<Rightarrow> ('a, 'b) alist \<Rightarrow> ('a, 'b) alist"
   119   is join_raw
   120   by (simp add: distinct_join_raw)
   121 
   122 lift_definition subtract_entries :: "('a, ('b :: minus)) alist \<Rightarrow> ('a, 'b) alist \<Rightarrow> ('a, 'b) alist"
   123   is subtract_entries_raw
   124   by (simp add: distinct_subtract_entries_raw)
   125 
   126 
   127 text \<open>Implementing multisets by means of association lists\<close>
   128 
   129 definition count_of :: "('a \<times> nat) list \<Rightarrow> 'a \<Rightarrow> nat"
   130   where "count_of xs x = (case map_of xs x of None \<Rightarrow> 0 | Some n \<Rightarrow> n)"
   131 
   132 lemma count_of_multiset: "count_of xs \<in> multiset"
   133 proof -
   134   let ?A = "{x::'a. 0 < (case map_of xs x of None \<Rightarrow> 0::nat | Some n \<Rightarrow> n)}"
   135   have "?A \<subseteq> dom (map_of xs)"
   136   proof
   137     fix x
   138     assume "x \<in> ?A"
   139     then have "0 < (case map_of xs x of None \<Rightarrow> 0::nat | Some n \<Rightarrow> n)"
   140       by simp
   141     then have "map_of xs x \<noteq> None"
   142       by (cases "map_of xs x") auto
   143     then show "x \<in> dom (map_of xs)"
   144       by auto
   145   qed
   146   with finite_dom_map_of [of xs] have "finite ?A"
   147     by (auto intro: finite_subset)
   148   then show ?thesis
   149     by (simp add: count_of_def fun_eq_iff multiset_def)
   150 qed
   151 
   152 lemma count_simps [simp]:
   153   "count_of [] = (\<lambda>_. 0)"
   154   "count_of ((x, n) # xs) = (\<lambda>y. if x = y then n else count_of xs y)"
   155   by (simp_all add: count_of_def fun_eq_iff)
   156 
   157 lemma count_of_empty: "x \<notin> fst ` set xs \<Longrightarrow> count_of xs x = 0"
   158   by (induct xs) (simp_all add: count_of_def)
   159 
   160 lemma count_of_filter: "count_of (List.filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
   161   by (induct xs) auto
   162 
   163 lemma count_of_map_default [simp]:
   164   "count_of (map_default x b (\<lambda>x. x + b) xs) y =
   165     (if x = y then count_of xs x + b else count_of xs y)"
   166   unfolding count_of_def by (simp add: map_of_map_default split: option.split)
   167 
   168 lemma count_of_join_raw:
   169   "distinct (map fst ys) \<Longrightarrow>
   170     count_of xs x + count_of ys x = count_of (join_raw (\<lambda>x (x, y). x + y) xs ys) x"
   171   unfolding count_of_def by (simp add: map_of_join_raw split: option.split)
   172 
   173 lemma count_of_subtract_entries_raw:
   174   "distinct (map fst ys) \<Longrightarrow>
   175     count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
   176   unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)
   177 
   178 
   179 text \<open>Code equations for multiset operations\<close>
   180 
   181 definition Bag :: "('a, nat) alist \<Rightarrow> 'a multiset"
   182   where "Bag xs = Abs_multiset (count_of (DAList.impl_of xs))"
   183 
   184 code_datatype Bag
   185 
   186 lemma count_Bag [simp, code]: "count (Bag xs) = count_of (DAList.impl_of xs)"
   187   by (simp add: Bag_def count_of_multiset)
   188 
   189 lemma Mempty_Bag [code]: "{#} = Bag (DAList.empty)"
   190   by (simp add: multiset_eq_iff alist.Alist_inverse DAList.empty_def)
   191 
   192 lift_definition is_empty_Bag_impl :: "('a, nat) alist \<Rightarrow> bool" is
   193   "\<lambda>xs. list_all (\<lambda>x. snd x = 0) xs" .
   194 
   195 lemma is_empty_Bag [code]: "Multiset.is_empty (Bag xs) \<longleftrightarrow> is_empty_Bag_impl xs"
   196 proof -
   197   have "Multiset.is_empty (Bag xs) \<longleftrightarrow> (\<forall>x. count (Bag xs) x = 0)"
   198     unfolding Multiset.is_empty_def multiset_eq_iff by simp
   199   also have "\<dots> \<longleftrightarrow> (\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0)"
   200   proof (intro iffI allI ballI)
   201     fix x assume A: "\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0"
   202     thus "count (Bag xs) x = 0"
   203     proof (cases "x \<in> fst ` set (alist.impl_of xs)")
   204       case False
   205       thus ?thesis by (force simp: count_of_def split: option.splits)
   206     qed (insert A, auto)
   207   qed simp_all
   208   also have "\<dots> \<longleftrightarrow> list_all (\<lambda>x. snd x = 0) (alist.impl_of xs)" 
   209     by (auto simp: count_of_def list_all_def)
   210   finally show ?thesis by (simp add: is_empty_Bag_impl.rep_eq)
   211 qed
   212 
   213 lemma union_Bag [code]: "Bag xs + Bag ys = Bag (join (\<lambda>x (n1, n2). n1 + n2) xs ys)"
   214   by (rule multiset_eqI)
   215     (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)
   216 
   217 lemma add_mset_Bag [code]: "add_mset x (Bag xs) =
   218     Bag (join (\<lambda>x (n1, n2). n1 + n2) (DAList.update x 1 DAList.empty) xs)"
   219   unfolding add_mset_add_single[of x "Bag xs"] union_Bag[symmetric]
   220   by (simp add: multiset_eq_iff update.rep_eq empty.rep_eq)
   221 
   222 lemma minus_Bag [code]: "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
   223   by (rule multiset_eqI)
   224     (simp add: count_of_subtract_entries_raw alist.Alist_inverse
   225       distinct_subtract_entries_raw subtract_entries_def)
   226 
   227 lemma filter_Bag [code]: "filter_mset P (Bag xs) = Bag (DAList.filter (P \<circ> fst) xs)"
   228   by (rule multiset_eqI) (simp add: count_of_filter DAList.filter.rep_eq)
   229 
   230 
   231 lemma mset_eq [code]: "HOL.equal (m1::'a::equal multiset) m2 \<longleftrightarrow> m1 \<subseteq># m2 \<and> m2 \<subseteq># m1"
   232   by (metis equal_multiset_def subset_mset.eq_iff)
   233 
   234 text \<open>By default the code for \<open><\<close> is @{prop"xs < ys \<longleftrightarrow> xs \<le> ys \<and> \<not> xs = ys"}.
   235 With equality implemented by \<open>\<le>\<close>, this leads to three calls of  \<open>\<le>\<close>.
   236 Here is a more efficient version:\<close>
   237 lemma mset_less[code]: "xs \<subset># (ys :: 'a multiset) \<longleftrightarrow> xs \<subseteq># ys \<and> \<not> ys \<subseteq># xs"
   238   by (rule subset_mset.less_le_not_le)
   239 
   240 lemma mset_less_eq_Bag0:
   241   "Bag xs \<subseteq># A \<longleftrightarrow> (\<forall>(x, n) \<in> set (DAList.impl_of xs). count_of (DAList.impl_of xs) x \<le> count A x)"
   242     (is "?lhs \<longleftrightarrow> ?rhs")
   243 proof
   244   assume ?lhs
   245   then show ?rhs by (auto simp add: subseteq_mset_def)
   246 next
   247   assume ?rhs
   248   show ?lhs
   249   proof (rule mset_subset_eqI)
   250     fix x
   251     from \<open>?rhs\<close> have "count_of (DAList.impl_of xs) x \<le> count A x"
   252       by (cases "x \<in> fst ` set (DAList.impl_of xs)") (auto simp add: count_of_empty)
   253     then show "count (Bag xs) x \<le> count A x" by (simp add: subset_mset_def)
   254   qed
   255 qed
   256 
   257 lemma mset_less_eq_Bag [code]:
   258   "Bag xs \<subseteq># (A :: 'a multiset) \<longleftrightarrow> (\<forall>(x, n) \<in> set (DAList.impl_of xs). n \<le> count A x)"
   259 proof -
   260   {
   261     fix x n
   262     assume "(x,n) \<in> set (DAList.impl_of xs)"
   263     then have "count_of (DAList.impl_of xs) x = n"
   264     proof transfer
   265       fix x n
   266       fix xs :: "('a \<times> nat) list"
   267       show "(distinct \<circ> map fst) xs \<Longrightarrow> (x, n) \<in> set xs \<Longrightarrow> count_of xs x = n"
   268       proof (induct xs)
   269         case Nil
   270         then show ?case by simp
   271       next
   272         case (Cons ym ys)
   273         obtain y m where ym: "ym = (y,m)" by force
   274         note Cons = Cons[unfolded ym]
   275         show ?case
   276         proof (cases "x = y")
   277           case False
   278           with Cons show ?thesis
   279             unfolding ym by auto
   280         next
   281           case True
   282           with Cons(2-3) have "m = n" by force
   283           with True show ?thesis
   284             unfolding ym by auto
   285         qed
   286       qed
   287     qed
   288   }
   289   then show ?thesis
   290     unfolding mset_less_eq_Bag0 by auto
   291 qed
   292 
   293 declare multiset_inter_def [code]
   294 declare sup_subset_mset_def [code]
   295 declare mset.simps [code]
   296 
   297 
   298 fun fold_impl :: "('a \<Rightarrow> nat \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> 'b \<Rightarrow> ('a \<times> nat) list \<Rightarrow> 'b"
   299 where
   300   "fold_impl fn e ((a,n) # ms) = (fold_impl fn ((fn a n) e) ms)"
   301 | "fold_impl fn e [] = e"
   302 
   303 context
   304 begin
   305 
   306 qualified definition fold :: "('a \<Rightarrow> nat \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> 'b \<Rightarrow> ('a, nat) alist \<Rightarrow> 'b"
   307   where "fold f e al = fold_impl f e (DAList.impl_of al)"
   308 
   309 end
   310 
   311 context comp_fun_commute
   312 begin
   313 
   314 lemma DAList_Multiset_fold:
   315   assumes fn: "\<And>a n x. fn a n x = (f a ^^ n) x"
   316   shows "fold_mset f e (Bag al) = DAList_Multiset.fold fn e al"
   317   unfolding DAList_Multiset.fold_def
   318 proof (induct al)
   319   fix ys
   320   let ?inv = "{xs :: ('a \<times> nat) list. (distinct \<circ> map fst) xs}"
   321   note cs[simp del] = count_simps
   322   have count[simp]: "\<And>x. count (Abs_multiset (count_of x)) = count_of x"
   323     by (rule Abs_multiset_inverse[OF count_of_multiset])
   324   assume ys: "ys \<in> ?inv"
   325   then show "fold_mset f e (Bag (Alist ys)) = fold_impl fn e (DAList.impl_of (Alist ys))"
   326     unfolding Bag_def unfolding Alist_inverse[OF ys]
   327   proof (induct ys arbitrary: e rule: list.induct)
   328     case Nil
   329     show ?case
   330       by (rule trans[OF arg_cong[of _ "{#}" "fold_mset f e", OF multiset_eqI]])
   331          (auto, simp add: cs)
   332   next
   333     case (Cons pair ys e)
   334     obtain a n where pair: "pair = (a,n)"
   335       by force
   336     from fn[of a n] have [simp]: "fn a n = (f a ^^ n)"
   337       by auto
   338     have inv: "ys \<in> ?inv"
   339       using Cons(2) by auto
   340     note IH = Cons(1)[OF inv]
   341     define Ys where "Ys = Abs_multiset (count_of ys)"
   342     have id: "Abs_multiset (count_of ((a, n) # ys)) = ((op + {# a #}) ^^ n) Ys"
   343       unfolding Ys_def
   344     proof (rule multiset_eqI, unfold count)
   345       fix c
   346       show "count_of ((a, n) # ys) c =
   347         count ((op + {#a#} ^^ n) (Abs_multiset (count_of ys))) c" (is "?l = ?r")
   348       proof (cases "c = a")
   349         case False
   350         then show ?thesis
   351           unfolding cs by (induct n) auto
   352       next
   353         case True
   354         then have "?l = n" by (simp add: cs)
   355         also have "n = ?r" unfolding True
   356         proof (induct n)
   357           case 0
   358           from Cons(2)[unfolded pair] have "a \<notin> fst ` set ys" by auto
   359           then show ?case by (induct ys) (simp, auto simp: cs)
   360         next
   361           case Suc
   362           then show ?case by simp
   363         qed
   364         finally show ?thesis .
   365       qed
   366     qed
   367     show ?case
   368       unfolding pair
   369       apply (simp add: IH[symmetric])
   370       unfolding id Ys_def[symmetric]
   371       apply (induct n)
   372       apply (auto simp: fold_mset_fun_left_comm[symmetric])
   373       done
   374   qed
   375 qed
   376 
   377 end
   378 
   379 context
   380 begin
   381 
   382 private lift_definition single_alist_entry :: "'a \<Rightarrow> 'b \<Rightarrow> ('a, 'b) alist" is "\<lambda>a b. [(a, b)]"
   383   by auto
   384 
   385 lemma image_mset_Bag [code]:
   386   "image_mset f (Bag ms) =
   387     DAList_Multiset.fold (\<lambda>a n m. Bag (single_alist_entry (f a) n) + m) {#} ms"
   388   unfolding image_mset_def
   389 proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, (auto simp: ac_simps)[1])
   390   fix a n m
   391   show "Bag (single_alist_entry (f a) n) + m = ((add_mset \<circ> f) a ^^ n) m" (is "?l = ?r")
   392   proof (rule multiset_eqI)
   393     fix x
   394     have "count ?r x = (if x = f a then n + count m x else count m x)"
   395       by (induct n) auto
   396     also have "\<dots> = count ?l x"
   397       by (simp add: single_alist_entry.rep_eq)
   398     finally show "count ?l x = count ?r x" ..
   399   qed
   400 qed
   401 
   402 end
   403 
   404 (* we cannot use (\<lambda>a n. op + (a * n)) for folding, since * is not defined
   405    in comm_monoid_add *)
   406 lemma sum_mset_Bag[code]: "sum_mset (Bag ms) = DAList_Multiset.fold (\<lambda>a n. ((op + a) ^^ n)) 0 ms"
   407   unfolding sum_mset.eq_fold
   408   apply (rule comp_fun_commute.DAList_Multiset_fold)
   409   apply unfold_locales
   410   apply (auto simp: ac_simps)
   411   done
   412 
   413 (* we cannot use (\<lambda>a n. op * (a ^ n)) for folding, since ^ is not defined
   414    in comm_monoid_mult *)
   415 lemma prod_mset_Bag[code]: "prod_mset (Bag ms) = DAList_Multiset.fold (\<lambda>a n. ((op * a) ^^ n)) 1 ms"
   416   unfolding prod_mset.eq_fold
   417   apply (rule comp_fun_commute.DAList_Multiset_fold)
   418   apply unfold_locales
   419   apply (auto simp: ac_simps)
   420   done
   421 
   422 lemma size_fold: "size A = fold_mset (\<lambda>_. Suc) 0 A" (is "_ = fold_mset ?f _ _")
   423 proof -
   424   interpret comp_fun_commute ?f by standard auto
   425   show ?thesis by (induct A) auto
   426 qed
   427 
   428 lemma size_Bag[code]: "size (Bag ms) = DAList_Multiset.fold (\<lambda>a n. op + n) 0 ms"
   429   unfolding size_fold
   430 proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, simp)
   431   fix a n x
   432   show "n + x = (Suc ^^ n) x"
   433     by (induct n) auto
   434 qed
   435 
   436 
   437 lemma set_mset_fold: "set_mset A = fold_mset insert {} A" (is "_ = fold_mset ?f _ _")
   438 proof -
   439   interpret comp_fun_commute ?f by standard auto
   440   show ?thesis by (induct A) auto
   441 qed
   442 
   443 lemma set_mset_Bag[code]:
   444   "set_mset (Bag ms) = DAList_Multiset.fold (\<lambda>a n. (if n = 0 then (\<lambda>m. m) else insert a)) {} ms"
   445   unfolding set_mset_fold
   446 proof (rule comp_fun_commute.DAList_Multiset_fold, unfold_locales, (auto simp: ac_simps)[1])
   447   fix a n x
   448   show "(if n = 0 then \<lambda>m. m else insert a) x = (insert a ^^ n) x" (is "?l n = ?r n")
   449   proof (cases n)
   450     case 0
   451     then show ?thesis by simp
   452   next
   453     case (Suc m)
   454     then have "?l n = insert a x" by simp
   455     moreover have "?r n = insert a x" unfolding Suc by (induct m) auto
   456     ultimately show ?thesis by auto
   457   qed
   458 qed
   459 
   460 
   461 instantiation multiset :: (exhaustive) exhaustive
   462 begin
   463 
   464 definition exhaustive_multiset ::
   465   "('a multiset \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
   466   where "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (\<lambda>xs. f (Bag xs)) i"
   467 
   468 instance ..
   469 
   470 end
   471 
   472 end
   473