lars@66271 ` 1` ```(* Title: HOL/Library/State_Monad.thy ``` lars@66271 ` 2` ``` Author: Lars Hupel, TU München ``` lars@66271 ` 3` ```*) ``` lars@66271 ` 4` lars@66271 ` 5` ```section \State monad\ ``` lars@66271 ` 6` lars@66271 ` 7` ```theory State_Monad ``` lars@66271 ` 8` ```imports Monad_Syntax ``` lars@66271 ` 9` ```begin ``` lars@66271 ` 10` lars@66271 ` 11` ```datatype ('s, 'a) state = State (run_state: "'s \ ('a \ 's)") ``` lars@66271 ` 12` lars@66271 ` 13` ```lemma set_state_iff: "x \ set_state m \ (\s s'. run_state m s = (x, s'))" ``` lars@66271 ` 14` ```by (cases m) (simp add: prod_set_defs eq_fst_iff) ``` lars@66271 ` 15` lars@66271 ` 16` ```lemma pred_stateI[intro]: ``` lars@66271 ` 17` ``` assumes "\a s s'. run_state m s = (a, s') \ P a" ``` lars@66271 ` 18` ``` shows "pred_state P m" ``` lars@66271 ` 19` ```proof (subst state.pred_set, rule) ``` lars@66271 ` 20` ``` fix x ``` lars@66271 ` 21` ``` assume "x \ set_state m" ``` lars@66271 ` 22` ``` then obtain s s' where "run_state m s = (x, s')" ``` lars@66271 ` 23` ``` by (auto simp: set_state_iff) ``` lars@66271 ` 24` ``` with assms show "P x" . ``` lars@66271 ` 25` ```qed ``` lars@66271 ` 26` lars@66271 ` 27` ```lemma pred_stateD[dest]: ``` lars@66271 ` 28` ``` assumes "pred_state P m" "run_state m s = (a, s')" ``` lars@66271 ` 29` ``` shows "P a" ``` lars@66271 ` 30` ```proof (rule state.exhaust[of m]) ``` lars@66271 ` 31` ``` fix f ``` lars@66271 ` 32` ``` assume "m = State f" ``` lars@66271 ` 33` ``` with assms have "pred_fun (\_. True) (pred_prod P top) f" ``` lars@66271 ` 34` ``` by (metis state.pred_inject) ``` lars@66271 ` 35` ``` moreover have "f s = (a, s')" ``` lars@66271 ` 36` ``` using assms unfolding \m = _\ by auto ``` lars@66271 ` 37` ``` ultimately show "P a" ``` lars@66271 ` 38` ``` unfolding pred_prod_beta pred_fun_def ``` lars@66271 ` 39` ``` by (metis fst_conv) ``` lars@66271 ` 40` ```qed ``` lars@66271 ` 41` lars@66271 ` 42` ```lemma pred_state_run_state: "pred_state P m \ P (fst (run_state m s))" ``` lars@66271 ` 43` ```by (meson pred_stateD prod.exhaust_sel) ``` lars@66271 ` 44` lars@66271 ` 45` ```definition state_io_rel :: "('s \ 's \ bool) \ ('s, 'a) state \ bool" where ``` lars@66271 ` 46` ```"state_io_rel P m = (\s. P s (snd (run_state m s)))" ``` lars@66271 ` 47` lars@66271 ` 48` ```lemma state_io_relI[intro]: ``` lars@66271 ` 49` ``` assumes "\a s s'. run_state m s = (a, s') \ P s s'" ``` lars@66271 ` 50` ``` shows "state_io_rel P m" ``` lars@66271 ` 51` ```using assms unfolding state_io_rel_def ``` lars@66271 ` 52` ```by (metis prod.collapse) ``` lars@66271 ` 53` lars@66271 ` 54` ```lemma state_io_relD[dest]: ``` lars@66271 ` 55` ``` assumes "state_io_rel P m" "run_state m s = (a, s')" ``` lars@66271 ` 56` ``` shows "P s s'" ``` lars@66271 ` 57` ```using assms unfolding state_io_rel_def ``` lars@66271 ` 58` ```by (metis snd_conv) ``` lars@66271 ` 59` lars@66271 ` 60` ```lemma state_io_rel_mono[mono]: "P \ Q \ state_io_rel P \ state_io_rel Q" ``` lars@66271 ` 61` ```by blast ``` lars@66271 ` 62` lars@66271 ` 63` ```lemma state_ext: ``` lars@66271 ` 64` ``` assumes "\s. run_state m s = run_state n s" ``` lars@66271 ` 65` ``` shows "m = n" ``` lars@66271 ` 66` ```using assms ``` lars@66271 ` 67` ```by (cases m; cases n) auto ``` lars@66271 ` 68` lars@66271 ` 69` ```context begin ``` lars@66271 ` 70` lars@66271 ` 71` ```qualified definition return :: "'a \ ('s, 'a) state" where ``` lars@66271 ` 72` ```"return a = State (Pair a)" ``` lars@66271 ` 73` lars@66275 ` 74` ```lemma run_state_return[simp]: "run_state (return x) s = (x, s)" ``` lars@66275 ` 75` ```unfolding return_def ``` lars@66275 ` 76` ```by simp ``` lars@66275 ` 77` lars@66271 ` 78` ```qualified definition ap :: "('s, 'a \ 'b) state \ ('s, 'a) state \ ('s, 'b) state" where ``` lars@66271 ` 79` ```"ap f x = State (\s. case run_state f s of (g, s') \ case run_state x s' of (y, s'') \ (g y, s''))" ``` lars@66271 ` 80` lars@66271 ` 81` ```qualified definition bind :: "('s, 'a) state \ ('a \ ('s, 'b) state) \ ('s, 'b) state" where ``` lars@66271 ` 82` ```"bind x f = State (\s. case run_state x s of (a, s') \ run_state (f a) s')" ``` lars@66271 ` 83` lars@66271 ` 84` ```adhoc_overloading Monad_Syntax.bind bind ``` lars@66271 ` 85` lars@66271 ` 86` ```lemma bind_left_identity[simp]: "bind (return a) f = f a" ``` lars@66271 ` 87` ```unfolding return_def bind_def by simp ``` lars@66271 ` 88` lars@66271 ` 89` ```lemma bind_right_identity[simp]: "bind m return = m" ``` lars@66271 ` 90` ```unfolding return_def bind_def by simp ``` lars@66271 ` 91` lars@66271 ` 92` ```lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\x. bind (f x) g)" ``` lars@66271 ` 93` ```unfolding bind_def by (auto split: prod.splits) ``` lars@66271 ` 94` lars@66271 ` 95` ```lemma bind_predI[intro]: ``` lars@66271 ` 96` ``` assumes "pred_state (\x. pred_state P (f x)) m" ``` lars@66271 ` 97` ``` shows "pred_state P (bind m f)" ``` lars@66271 ` 98` ```apply (rule pred_stateI) ``` lars@66271 ` 99` ```unfolding bind_def ``` lars@66271 ` 100` ```using assms by (auto split: prod.splits) ``` lars@66271 ` 101` lars@66271 ` 102` ```qualified definition get :: "('s, 's) state" where ``` lars@66271 ` 103` ```"get = State (\s. (s, s))" ``` lars@66271 ` 104` lars@66271 ` 105` ```qualified definition set :: "'s \ ('s, unit) state" where ``` lars@66271 ` 106` ```"set s' = State (\_. ((), s'))" ``` lars@66271 ` 107` lars@66271 ` 108` ```lemma get_set[simp]: "bind get set = return ()" ``` lars@66271 ` 109` ```unfolding bind_def get_def set_def return_def ``` lars@66271 ` 110` ```by simp ``` lars@66271 ` 111` lars@66271 ` 112` ```lemma set_set[simp]: "bind (set s) (\_. set s') = set s'" ``` lars@66271 ` 113` ```unfolding bind_def set_def ``` lars@66271 ` 114` ```by simp ``` lars@66271 ` 115` lars@66275 ` 116` ```lemma get_bind_set[simp]: "bind get (\s. bind (set s) (f s)) = bind get (\s. f s ())" ``` lars@66275 ` 117` ```unfolding bind_def get_def set_def ``` lars@66275 ` 118` ```by simp ``` lars@66275 ` 119` lars@66275 ` 120` ```lemma get_const[simp]: "bind get (\_. m) = m" ``` lars@66275 ` 121` ```unfolding get_def bind_def ``` lars@66275 ` 122` ```by simp ``` lars@66275 ` 123` lars@66271 ` 124` ```fun traverse_list :: "('a \ ('b, 'c) state) \ 'a list \ ('b, 'c list) state" where ``` lars@66271 ` 125` ```"traverse_list _ [] = return []" | ``` lars@66271 ` 126` ```"traverse_list f (x # xs) = do { ``` lars@66271 ` 127` ``` x \ f x; ``` lars@66271 ` 128` ``` xs \ traverse_list f xs; ``` lars@66271 ` 129` ``` return (x # xs) ``` lars@66271 ` 130` ```}" ``` lars@66271 ` 131` lars@66271 ` 132` ```lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do { ``` lars@66271 ` 133` ``` xs \ traverse_list f xs; ``` lars@66271 ` 134` ``` ys \ traverse_list f ys; ``` lars@66271 ` 135` ``` return (xs @ ys) ``` lars@66271 ` 136` ```}" ``` lars@66271 ` 137` ```by (induction xs) auto ``` lars@66271 ` 138` lars@66271 ` 139` ```lemma traverse_comp[simp]: "traverse_list (g \ f) xs = traverse_list g (map f xs)" ``` lars@66271 ` 140` ```by (induction xs) auto ``` lars@66271 ` 141` lars@66271 ` 142` ```abbreviation mono_state :: "('s::preorder, 'a) state \ bool" where ``` nipkow@67399 ` 143` ```"mono_state \ state_io_rel (\)" ``` lars@66271 ` 144` lars@66271 ` 145` ```abbreviation strict_mono_state :: "('s::preorder, 'a) state \ bool" where ``` nipkow@67399 ` 146` ```"strict_mono_state \ state_io_rel (<)" ``` lars@66271 ` 147` lars@66271 ` 148` ```corollary strict_mono_implies_mono: "strict_mono_state m \ mono_state m" ``` lars@66271 ` 149` ```unfolding state_io_rel_def ``` lars@66271 ` 150` ```by (simp add: less_imp_le) ``` lars@66271 ` 151` lars@66271 ` 152` ```lemma return_mono[simp, intro]: "mono_state (return x)" ``` lars@66271 ` 153` ```unfolding return_def by auto ``` lars@66271 ` 154` lars@66271 ` 155` ```lemma get_mono[simp, intro]: "mono_state get" ``` lars@66271 ` 156` ```unfolding get_def by auto ``` lars@66271 ` 157` lars@66271 ` 158` ```lemma put_mono: ``` lars@66271 ` 159` ``` assumes "\x. s' \ x" ``` lars@66271 ` 160` ``` shows "mono_state (set s')" ``` lars@66271 ` 161` ```using assms unfolding set_def ``` lars@66271 ` 162` ```by auto ``` lars@66271 ` 163` lars@66271 ` 164` ```lemma map_mono[intro]: "mono_state m \ mono_state (map_state f m)" ``` lars@66271 ` 165` ```by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel) ``` lars@66271 ` 166` lars@66271 ` 167` ```lemma map_strict_mono[intro]: "strict_mono_state m \ strict_mono_state (map_state f m)" ``` lars@66271 ` 168` ```by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel) ``` lars@66271 ` 169` lars@66271 ` 170` ```lemma bind_mono_strong: ``` lars@66271 ` 171` ``` assumes "mono_state m" ``` lars@66271 ` 172` ``` assumes "\x s s'. run_state m s = (x, s') \ mono_state (f x)" ``` lars@66271 ` 173` ``` shows "mono_state (bind m f)" ``` lars@66271 ` 174` ```unfolding bind_def ``` lars@66271 ` 175` ```apply (rule state_io_relI) ``` lars@66271 ` 176` ```using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans) ``` lars@66271 ` 177` lars@66271 ` 178` ```lemma bind_strict_mono_strong1: ``` lars@66271 ` 179` ``` assumes "mono_state m" ``` lars@66271 ` 180` ``` assumes "\x s s'. run_state m s = (x, s') \ strict_mono_state (f x)" ``` lars@66271 ` 181` ``` shows "strict_mono_state (bind m f)" ``` lars@66271 ` 182` ```unfolding bind_def ``` lars@66271 ` 183` ```apply (rule state_io_relI) ``` lars@66271 ` 184` ```using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans) ``` lars@66271 ` 185` lars@66271 ` 186` ```lemma bind_strict_mono_strong2: ``` lars@66271 ` 187` ``` assumes "strict_mono_state m" ``` lars@66271 ` 188` ``` assumes "\x s s'. run_state m s = (x, s') \ mono_state (f x)" ``` lars@66271 ` 189` ``` shows "strict_mono_state (bind m f)" ``` lars@66271 ` 190` ```unfolding bind_def ``` lars@66271 ` 191` ```apply (rule state_io_relI) ``` lars@66271 ` 192` ```using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans) ``` lars@66271 ` 193` lars@66271 ` 194` ```corollary bind_strict_mono_strong: ``` lars@66271 ` 195` ``` assumes "strict_mono_state m" ``` lars@66271 ` 196` ``` assumes "\x s s'. run_state m s = (x, s') \ strict_mono_state (f x)" ``` lars@66271 ` 197` ``` shows "strict_mono_state (bind m f)" ``` lars@66271 ` 198` ```using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono) ``` lars@66271 ` 199` lars@66271 ` 200` ```qualified definition update :: "('s \ 's) \ ('s, unit) state" where ``` lars@66271 ` 201` ```"update f = bind get (set \ f)" ``` lars@66271 ` 202` lars@66271 ` 203` ```lemma update_id[simp]: "update (\x. x) = return ()" ``` lars@66271 ` 204` ```unfolding update_def return_def get_def set_def bind_def ``` lars@66271 ` 205` ```by auto ``` lars@66271 ` 206` lars@66271 ` 207` ```lemma update_comp[simp]: "bind (update f) (\_. update g) = update (g \ f)" ``` lars@66271 ` 208` ```unfolding update_def return_def get_def set_def bind_def ``` lars@66271 ` 209` ```by auto ``` lars@66271 ` 210` lars@66275 ` 211` ```lemma set_update[simp]: "bind (set s) (\_. update f) = set (f s)" ``` lars@66275 ` 212` ```unfolding set_def update_def bind_def get_def set_def ``` lars@66275 ` 213` ```by simp ``` lars@66275 ` 214` lars@66275 ` 215` ```lemma set_bind_update[simp]: "bind (set s) (\_. bind (update f) g) = bind (set (f s)) g" ``` lars@66275 ` 216` ```unfolding set_def update_def bind_def get_def set_def ``` lars@66275 ` 217` ```by simp ``` lars@66275 ` 218` lars@66271 ` 219` ```lemma update_mono: ``` lars@66271 ` 220` ``` assumes "\x. x \ f x" ``` lars@66271 ` 221` ``` shows "mono_state (update f)" ``` lars@66271 ` 222` ```using assms unfolding update_def get_def set_def bind_def ``` lars@66271 ` 223` ```by (auto intro!: state_io_relI) ``` lars@66271 ` 224` lars@66271 ` 225` ```lemma update_strict_mono: ``` lars@66271 ` 226` ``` assumes "\x. x < f x" ``` lars@66271 ` 227` ``` shows "strict_mono_state (update f)" ``` lars@66271 ` 228` ```using assms unfolding update_def get_def set_def bind_def ``` lars@66271 ` 229` ```by (auto intro!: state_io_relI) ``` lars@66271 ` 230` lars@66271 ` 231` ```end ``` lars@66271 ` 232` nipkow@67399 ` 233` ```end ```