src/HOL/Library/State_Monad.thy
 author haftmann Fri Mar 22 19:18:08 2019 +0000 (3 months ago) changeset 69946 494934c30f38 parent 68756 7066e83dfe46 permissions -rw-r--r--
improved code equations taken over from AFP
```     1 (*  Title:      HOL/Library/State_Monad.thy
```
```     2     Author:     Lars Hupel, TU München
```
```     3 *)
```
```     4
```
```     5 section \<open>State monad\<close>
```
```     6
```
```     7 theory State_Monad
```
```     8 imports Monad_Syntax
```
```     9 begin
```
```    10
```
```    11 datatype ('s, 'a) state = State (run_state: "'s \<Rightarrow> ('a \<times> 's)")
```
```    12
```
```    13 lemma set_state_iff: "x \<in> set_state m \<longleftrightarrow> (\<exists>s s'. run_state m s = (x, s'))"
```
```    14 by (cases m) (simp add: prod_set_defs eq_fst_iff)
```
```    15
```
```    16 lemma pred_stateI[intro]:
```
```    17   assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P a"
```
```    18   shows "pred_state P m"
```
```    19 proof (subst state.pred_set, rule)
```
```    20   fix x
```
```    21   assume "x \<in> set_state m"
```
```    22   then obtain s s' where "run_state m s = (x, s')"
```
```    23     by (auto simp: set_state_iff)
```
```    24   with assms show "P x" .
```
```    25 qed
```
```    26
```
```    27 lemma pred_stateD[dest]:
```
```    28   assumes "pred_state P m" "run_state m s = (a, s')"
```
```    29   shows "P a"
```
```    30 proof (rule state.exhaust[of m])
```
```    31   fix f
```
```    32   assume "m = State f"
```
```    33   with assms have "pred_fun (\<lambda>_. True) (pred_prod P top) f"
```
```    34     by (metis state.pred_inject)
```
```    35   moreover have "f s = (a, s')"
```
```    36     using assms unfolding \<open>m = _\<close> by auto
```
```    37   ultimately show "P a"
```
```    38     unfolding pred_prod_beta pred_fun_def
```
```    39     by (metis fst_conv)
```
```    40 qed
```
```    41
```
```    42 lemma pred_state_run_state: "pred_state P m \<Longrightarrow> P (fst (run_state m s))"
```
```    43 by (meson pred_stateD prod.exhaust_sel)
```
```    44
```
```    45 definition state_io_rel :: "('s \<Rightarrow> 's \<Rightarrow> bool) \<Rightarrow> ('s, 'a) state \<Rightarrow> bool" where
```
```    46 "state_io_rel P m = (\<forall>s. P s (snd (run_state m s)))"
```
```    47
```
```    48 lemma state_io_relI[intro]:
```
```    49   assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P s s'"
```
```    50   shows "state_io_rel P m"
```
```    51 using assms unfolding state_io_rel_def
```
```    52 by (metis prod.collapse)
```
```    53
```
```    54 lemma state_io_relD[dest]:
```
```    55   assumes "state_io_rel P m" "run_state m s = (a, s')"
```
```    56   shows "P s s'"
```
```    57 using assms unfolding state_io_rel_def
```
```    58 by (metis snd_conv)
```
```    59
```
```    60 lemma state_io_rel_mono[mono]: "P \<le> Q \<Longrightarrow> state_io_rel P \<le> state_io_rel Q"
```
```    61 by blast
```
```    62
```
```    63 lemma state_ext:
```
```    64   assumes "\<And>s. run_state m s = run_state n s"
```
```    65   shows "m = n"
```
```    66 using assms
```
```    67 by (cases m; cases n) auto
```
```    68
```
```    69 context begin
```
```    70
```
```    71 qualified definition return :: "'a \<Rightarrow> ('s, 'a) state" where
```
```    72 "return a = State (Pair a)"
```
```    73
```
```    74 lemma run_state_return[simp]: "run_state (return x) s = (x, s)"
```
```    75 unfolding return_def
```
```    76 by simp
```
```    77
```
```    78 qualified definition ap :: "('s, 'a \<Rightarrow> 'b) state \<Rightarrow> ('s, 'a) state \<Rightarrow> ('s, 'b) state" where
```
```    79 "ap f x = State (\<lambda>s. case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
```
```    80
```
```    81 lemma run_state_ap[simp]:
```
```    82   "run_state (ap f x) s = (case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
```
```    83 unfolding ap_def by auto
```
```    84
```
```    85 qualified definition bind :: "('s, 'a) state \<Rightarrow> ('a \<Rightarrow> ('s, 'b) state) \<Rightarrow> ('s, 'b) state" where
```
```    86 "bind x f = State (\<lambda>s. case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
```
```    87
```
```    88 lemma run_state_bind[simp]:
```
```    89   "run_state (bind x f) s = (case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
```
```    90 unfolding bind_def by auto
```
```    91
```
```    92 adhoc_overloading Monad_Syntax.bind bind
```
```    93
```
```    94 lemma bind_left_identity[simp]: "bind (return a) f = f a"
```
```    95 unfolding return_def bind_def by simp
```
```    96
```
```    97 lemma bind_right_identity[simp]: "bind m return = m"
```
```    98 unfolding return_def bind_def by simp
```
```    99
```
```   100 lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)"
```
```   101 unfolding bind_def by (auto split: prod.splits)
```
```   102
```
```   103 lemma bind_predI[intro]:
```
```   104   assumes "pred_state (\<lambda>x. pred_state P (f x)) m"
```
```   105   shows "pred_state P (bind m f)"
```
```   106 apply (rule pred_stateI)
```
```   107 unfolding bind_def
```
```   108 using assms by (auto split: prod.splits)
```
```   109
```
```   110 qualified definition get :: "('s, 's) state" where
```
```   111 "get = State (\<lambda>s. (s, s))"
```
```   112
```
```   113 lemma run_state_get[simp]: "run_state get s = (s, s)"
```
```   114 unfolding get_def by simp
```
```   115
```
```   116 qualified definition set :: "'s \<Rightarrow> ('s, unit) state" where
```
```   117 "set s' = State (\<lambda>_. ((), s'))"
```
```   118
```
```   119 lemma run_state_set[simp]: "run_state (set s') s = ((), s')"
```
```   120 unfolding set_def by simp
```
```   121
```
```   122 lemma get_set[simp]: "bind get set = return ()"
```
```   123 unfolding bind_def get_def set_def return_def
```
```   124 by simp
```
```   125
```
```   126 lemma set_set[simp]: "bind (set s) (\<lambda>_. set s') = set s'"
```
```   127 unfolding bind_def set_def
```
```   128 by simp
```
```   129
```
```   130 lemma get_bind_set[simp]: "bind get (\<lambda>s. bind (set s) (f s)) = bind get (\<lambda>s. f s ())"
```
```   131 unfolding bind_def get_def set_def
```
```   132 by simp
```
```   133
```
```   134 lemma get_const[simp]: "bind get (\<lambda>_. m) = m"
```
```   135 unfolding get_def bind_def
```
```   136 by simp
```
```   137
```
```   138 fun traverse_list :: "('a \<Rightarrow> ('b, 'c) state) \<Rightarrow> 'a list \<Rightarrow> ('b, 'c list) state" where
```
```   139 "traverse_list _ [] = return []" |
```
```   140 "traverse_list f (x # xs) = do {
```
```   141   x \<leftarrow> f x;
```
```   142   xs \<leftarrow> traverse_list f xs;
```
```   143   return (x # xs)
```
```   144 }"
```
```   145
```
```   146 lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do {
```
```   147   xs \<leftarrow> traverse_list f xs;
```
```   148   ys \<leftarrow> traverse_list f ys;
```
```   149   return (xs @ ys)
```
```   150 }"
```
```   151 by (induction xs) auto
```
```   152
```
```   153 lemma traverse_comp[simp]: "traverse_list (g \<circ> f) xs = traverse_list g (map f xs)"
```
```   154 by (induction xs) auto
```
```   155
```
```   156 abbreviation mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
```
```   157 "mono_state \<equiv> state_io_rel (\<le>)"
```
```   158
```
```   159 abbreviation strict_mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
```
```   160 "strict_mono_state \<equiv> state_io_rel (<)"
```
```   161
```
```   162 corollary strict_mono_implies_mono: "strict_mono_state m \<Longrightarrow> mono_state m"
```
```   163 unfolding state_io_rel_def
```
```   164 by (simp add: less_imp_le)
```
```   165
```
```   166 lemma return_mono[simp, intro]: "mono_state (return x)"
```
```   167 unfolding return_def by auto
```
```   168
```
```   169 lemma get_mono[simp, intro]: "mono_state get"
```
```   170 unfolding get_def by auto
```
```   171
```
```   172 lemma put_mono:
```
```   173   assumes "\<And>x. s' \<ge> x"
```
```   174   shows "mono_state (set s')"
```
```   175 using assms unfolding set_def
```
```   176 by auto
```
```   177
```
```   178 lemma map_mono[intro]: "mono_state m \<Longrightarrow> mono_state (map_state f m)"
```
```   179 by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
```
```   180
```
```   181 lemma map_strict_mono[intro]: "strict_mono_state m \<Longrightarrow> strict_mono_state (map_state f m)"
```
```   182 by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
```
```   183
```
```   184 lemma bind_mono_strong:
```
```   185   assumes "mono_state m"
```
```   186   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
```
```   187   shows "mono_state (bind m f)"
```
```   188 unfolding bind_def
```
```   189 apply (rule state_io_relI)
```
```   190 using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans)
```
```   191
```
```   192 lemma bind_strict_mono_strong1:
```
```   193   assumes "mono_state m"
```
```   194   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
```
```   195   shows "strict_mono_state (bind m f)"
```
```   196 unfolding bind_def
```
```   197 apply (rule state_io_relI)
```
```   198 using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans)
```
```   199
```
```   200 lemma bind_strict_mono_strong2:
```
```   201   assumes "strict_mono_state m"
```
```   202   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
```
```   203   shows "strict_mono_state (bind m f)"
```
```   204 unfolding bind_def
```
```   205 apply (rule state_io_relI)
```
```   206 using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans)
```
```   207
```
```   208 corollary bind_strict_mono_strong:
```
```   209   assumes "strict_mono_state m"
```
```   210   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
```
```   211   shows "strict_mono_state (bind m f)"
```
```   212 using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono)
```
```   213
```
```   214 qualified definition update :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) state" where
```
```   215 "update f = bind get (set \<circ> f)"
```
```   216
```
```   217 lemma update_id[simp]: "update (\<lambda>x. x) = return ()"
```
```   218 unfolding update_def return_def get_def set_def bind_def
```
```   219 by auto
```
```   220
```
```   221 lemma update_comp[simp]: "bind (update f) (\<lambda>_. update g) = update (g \<circ> f)"
```
```   222 unfolding update_def return_def get_def set_def bind_def
```
```   223 by auto
```
```   224
```
```   225 lemma set_update[simp]: "bind (set s) (\<lambda>_. update f) = set (f s)"
```
```   226 unfolding set_def update_def bind_def get_def set_def
```
```   227 by simp
```
```   228
```
```   229 lemma set_bind_update[simp]: "bind (set s) (\<lambda>_. bind (update f) g) = bind (set (f s)) g"
```
```   230 unfolding set_def update_def bind_def get_def set_def
```
```   231 by simp
```
```   232
```
```   233 lemma update_mono:
```
```   234   assumes "\<And>x. x \<le> f x"
```
```   235   shows "mono_state (update f)"
```
```   236 using assms unfolding update_def get_def set_def bind_def
```
```   237 by (auto intro!: state_io_relI)
```
```   238
```
```   239 lemma update_strict_mono:
```
```   240   assumes "\<And>x. x < f x"
```
```   241   shows "strict_mono_state (update f)"
```
```   242 using assms unfolding update_def get_def set_def bind_def
```
```   243 by (auto intro!: state_io_relI)
```
```   244
```
```   245 end
```
```   246
```
`   247 end`