|
66271
|
1 |
(* Title: HOL/Library/State_Monad.thy
|
|
|
2 |
Author: Lars Hupel, TU München
|
|
|
3 |
*)
|
|
|
4 |
|
|
|
5 |
section \<open>State monad\<close>
|
|
|
6 |
|
|
|
7 |
theory State_Monad
|
|
|
8 |
imports Monad_Syntax
|
|
|
9 |
begin
|
|
|
10 |
|
|
|
11 |
datatype ('s, 'a) state = State (run_state: "'s \<Rightarrow> ('a \<times> 's)")
|
|
|
12 |
|
|
|
13 |
lemma set_state_iff: "x \<in> set_state m \<longleftrightarrow> (\<exists>s s'. run_state m s = (x, s'))"
|
|
|
14 |
by (cases m) (simp add: prod_set_defs eq_fst_iff)
|
|
|
15 |
|
|
|
16 |
lemma pred_stateI[intro]:
|
|
|
17 |
assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P a"
|
|
|
18 |
shows "pred_state P m"
|
|
|
19 |
proof (subst state.pred_set, rule)
|
|
|
20 |
fix x
|
|
|
21 |
assume "x \<in> set_state m"
|
|
|
22 |
then obtain s s' where "run_state m s = (x, s')"
|
|
|
23 |
by (auto simp: set_state_iff)
|
|
|
24 |
with assms show "P x" .
|
|
|
25 |
qed
|
|
|
26 |
|
|
|
27 |
lemma pred_stateD[dest]:
|
|
|
28 |
assumes "pred_state P m" "run_state m s = (a, s')"
|
|
|
29 |
shows "P a"
|
|
|
30 |
proof (rule state.exhaust[of m])
|
|
|
31 |
fix f
|
|
|
32 |
assume "m = State f"
|
|
|
33 |
with assms have "pred_fun (\<lambda>_. True) (pred_prod P top) f"
|
|
|
34 |
by (metis state.pred_inject)
|
|
|
35 |
moreover have "f s = (a, s')"
|
|
|
36 |
using assms unfolding \<open>m = _\<close> by auto
|
|
|
37 |
ultimately show "P a"
|
|
|
38 |
unfolding pred_prod_beta pred_fun_def
|
|
|
39 |
by (metis fst_conv)
|
|
|
40 |
qed
|
|
|
41 |
|
|
|
42 |
lemma pred_state_run_state: "pred_state P m \<Longrightarrow> P (fst (run_state m s))"
|
|
|
43 |
by (meson pred_stateD prod.exhaust_sel)
|
|
|
44 |
|
|
|
45 |
definition state_io_rel :: "('s \<Rightarrow> 's \<Rightarrow> bool) \<Rightarrow> ('s, 'a) state \<Rightarrow> bool" where
|
|
|
46 |
"state_io_rel P m = (\<forall>s. P s (snd (run_state m s)))"
|
|
|
47 |
|
|
|
48 |
lemma state_io_relI[intro]:
|
|
|
49 |
assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P s s'"
|
|
|
50 |
shows "state_io_rel P m"
|
|
|
51 |
using assms unfolding state_io_rel_def
|
|
|
52 |
by (metis prod.collapse)
|
|
|
53 |
|
|
|
54 |
lemma state_io_relD[dest]:
|
|
|
55 |
assumes "state_io_rel P m" "run_state m s = (a, s')"
|
|
|
56 |
shows "P s s'"
|
|
|
57 |
using assms unfolding state_io_rel_def
|
|
|
58 |
by (metis snd_conv)
|
|
|
59 |
|
|
|
60 |
lemma state_io_rel_mono[mono]: "P \<le> Q \<Longrightarrow> state_io_rel P \<le> state_io_rel Q"
|
|
|
61 |
by blast
|
|
|
62 |
|
|
|
63 |
lemma state_ext:
|
|
|
64 |
assumes "\<And>s. run_state m s = run_state n s"
|
|
|
65 |
shows "m = n"
|
|
|
66 |
using assms
|
|
|
67 |
by (cases m; cases n) auto
|
|
|
68 |
|
|
|
69 |
context begin
|
|
|
70 |
|
|
|
71 |
qualified definition return :: "'a \<Rightarrow> ('s, 'a) state" where
|
|
|
72 |
"return a = State (Pair a)"
|
|
|
73 |
|
|
|
74 |
qualified definition ap :: "('s, 'a \<Rightarrow> 'b) state \<Rightarrow> ('s, 'a) state \<Rightarrow> ('s, 'b) state" where
|
|
|
75 |
"ap f x = State (\<lambda>s. case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
|
|
|
76 |
|
|
|
77 |
qualified definition bind :: "('s, 'a) state \<Rightarrow> ('a \<Rightarrow> ('s, 'b) state) \<Rightarrow> ('s, 'b) state" where
|
|
|
78 |
"bind x f = State (\<lambda>s. case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
|
|
|
79 |
|
|
|
80 |
adhoc_overloading Monad_Syntax.bind bind
|
|
|
81 |
|
|
|
82 |
lemma bind_left_identity[simp]: "bind (return a) f = f a"
|
|
|
83 |
unfolding return_def bind_def by simp
|
|
|
84 |
|
|
|
85 |
lemma bind_right_identity[simp]: "bind m return = m"
|
|
|
86 |
unfolding return_def bind_def by simp
|
|
|
87 |
|
|
|
88 |
lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)"
|
|
|
89 |
unfolding bind_def by (auto split: prod.splits)
|
|
|
90 |
|
|
|
91 |
lemma bind_predI[intro]:
|
|
|
92 |
assumes "pred_state (\<lambda>x. pred_state P (f x)) m"
|
|
|
93 |
shows "pred_state P (bind m f)"
|
|
|
94 |
apply (rule pred_stateI)
|
|
|
95 |
unfolding bind_def
|
|
|
96 |
using assms by (auto split: prod.splits)
|
|
|
97 |
|
|
|
98 |
qualified definition get :: "('s, 's) state" where
|
|
|
99 |
"get = State (\<lambda>s. (s, s))"
|
|
|
100 |
|
|
|
101 |
qualified definition set :: "'s \<Rightarrow> ('s, unit) state" where
|
|
|
102 |
"set s' = State (\<lambda>_. ((), s'))"
|
|
|
103 |
|
|
|
104 |
lemma get_set[simp]: "bind get set = return ()"
|
|
|
105 |
unfolding bind_def get_def set_def return_def
|
|
|
106 |
by simp
|
|
|
107 |
|
|
|
108 |
lemma set_set[simp]: "bind (set s) (\<lambda>_. set s') = set s'"
|
|
|
109 |
unfolding bind_def set_def
|
|
|
110 |
by simp
|
|
|
111 |
|
|
|
112 |
fun traverse_list :: "('a \<Rightarrow> ('b, 'c) state) \<Rightarrow> 'a list \<Rightarrow> ('b, 'c list) state" where
|
|
|
113 |
"traverse_list _ [] = return []" |
|
|
|
114 |
"traverse_list f (x # xs) = do {
|
|
|
115 |
x \<leftarrow> f x;
|
|
|
116 |
xs \<leftarrow> traverse_list f xs;
|
|
|
117 |
return (x # xs)
|
|
|
118 |
}"
|
|
|
119 |
|
|
|
120 |
lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do {
|
|
|
121 |
xs \<leftarrow> traverse_list f xs;
|
|
|
122 |
ys \<leftarrow> traverse_list f ys;
|
|
|
123 |
return (xs @ ys)
|
|
|
124 |
}"
|
|
|
125 |
by (induction xs) auto
|
|
|
126 |
|
|
|
127 |
lemma traverse_comp[simp]: "traverse_list (g \<circ> f) xs = traverse_list g (map f xs)"
|
|
|
128 |
by (induction xs) auto
|
|
|
129 |
|
|
|
130 |
abbreviation mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
|
|
|
131 |
"mono_state \<equiv> state_io_rel (op \<le>)"
|
|
|
132 |
|
|
|
133 |
abbreviation strict_mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
|
|
|
134 |
"strict_mono_state \<equiv> state_io_rel (op <)"
|
|
|
135 |
|
|
|
136 |
corollary strict_mono_implies_mono: "strict_mono_state m \<Longrightarrow> mono_state m"
|
|
|
137 |
unfolding state_io_rel_def
|
|
|
138 |
by (simp add: less_imp_le)
|
|
|
139 |
|
|
|
140 |
lemma return_mono[simp, intro]: "mono_state (return x)"
|
|
|
141 |
unfolding return_def by auto
|
|
|
142 |
|
|
|
143 |
lemma get_mono[simp, intro]: "mono_state get"
|
|
|
144 |
unfolding get_def by auto
|
|
|
145 |
|
|
|
146 |
lemma put_mono:
|
|
|
147 |
assumes "\<And>x. s' \<ge> x"
|
|
|
148 |
shows "mono_state (set s')"
|
|
|
149 |
using assms unfolding set_def
|
|
|
150 |
by auto
|
|
|
151 |
|
|
|
152 |
lemma map_mono[intro]: "mono_state m \<Longrightarrow> mono_state (map_state f m)"
|
|
|
153 |
by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
|
|
|
154 |
|
|
|
155 |
lemma map_strict_mono[intro]: "strict_mono_state m \<Longrightarrow> strict_mono_state (map_state f m)"
|
|
|
156 |
by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
|
|
|
157 |
|
|
|
158 |
lemma bind_mono_strong:
|
|
|
159 |
assumes "mono_state m"
|
|
|
160 |
assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
|
|
|
161 |
shows "mono_state (bind m f)"
|
|
|
162 |
unfolding bind_def
|
|
|
163 |
apply (rule state_io_relI)
|
|
|
164 |
using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans)
|
|
|
165 |
|
|
|
166 |
lemma bind_strict_mono_strong1:
|
|
|
167 |
assumes "mono_state m"
|
|
|
168 |
assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
|
|
|
169 |
shows "strict_mono_state (bind m f)"
|
|
|
170 |
unfolding bind_def
|
|
|
171 |
apply (rule state_io_relI)
|
|
|
172 |
using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans)
|
|
|
173 |
|
|
|
174 |
lemma bind_strict_mono_strong2:
|
|
|
175 |
assumes "strict_mono_state m"
|
|
|
176 |
assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
|
|
|
177 |
shows "strict_mono_state (bind m f)"
|
|
|
178 |
unfolding bind_def
|
|
|
179 |
apply (rule state_io_relI)
|
|
|
180 |
using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans)
|
|
|
181 |
|
|
|
182 |
corollary bind_strict_mono_strong:
|
|
|
183 |
assumes "strict_mono_state m"
|
|
|
184 |
assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
|
|
|
185 |
shows "strict_mono_state (bind m f)"
|
|
|
186 |
using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono)
|
|
|
187 |
|
|
|
188 |
qualified definition update :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) state" where
|
|
|
189 |
"update f = bind get (set \<circ> f)"
|
|
|
190 |
|
|
|
191 |
lemma update_id[simp]: "update (\<lambda>x. x) = return ()"
|
|
|
192 |
unfolding update_def return_def get_def set_def bind_def
|
|
|
193 |
by auto
|
|
|
194 |
|
|
|
195 |
lemma update_comp[simp]: "bind (update f) (\<lambda>_. update g) = update (g \<circ> f)"
|
|
|
196 |
unfolding update_def return_def get_def set_def bind_def
|
|
|
197 |
by auto
|
|
|
198 |
|
|
|
199 |
lemma update_mono:
|
|
|
200 |
assumes "\<And>x. x \<le> f x"
|
|
|
201 |
shows "mono_state (update f)"
|
|
|
202 |
using assms unfolding update_def get_def set_def bind_def
|
|
|
203 |
by (auto intro!: state_io_relI)
|
|
|
204 |
|
|
|
205 |
lemma update_strict_mono:
|
|
|
206 |
assumes "\<And>x. x < f x"
|
|
|
207 |
shows "strict_mono_state (update f)"
|
|
|
208 |
using assms unfolding update_def get_def set_def bind_def
|
|
|
209 |
by (auto intro!: state_io_relI)
|
|
|
210 |
|
|
|
211 |
end
|
|
|
212 |
|
|
|
213 |
end |