src/HOL/Library/State_Monad.thy
author haftmann
Wed Jul 18 20:51:21 2018 +0200 (11 months ago)
changeset 68658 16cc1161ad7f
parent 67399 eab6ce8368fa
child 68756 7066e83dfe46
permissions -rw-r--r--
tuned equation
lars@66271
     1
(*  Title:      HOL/Library/State_Monad.thy
lars@66271
     2
    Author:     Lars Hupel, TU M√ľnchen
lars@66271
     3
*)
lars@66271
     4
lars@66271
     5
section \<open>State monad\<close>
lars@66271
     6
lars@66271
     7
theory State_Monad
lars@66271
     8
imports Monad_Syntax
lars@66271
     9
begin
lars@66271
    10
lars@66271
    11
datatype ('s, 'a) state = State (run_state: "'s \<Rightarrow> ('a \<times> 's)")
lars@66271
    12
lars@66271
    13
lemma set_state_iff: "x \<in> set_state m \<longleftrightarrow> (\<exists>s s'. run_state m s = (x, s'))"
lars@66271
    14
by (cases m) (simp add: prod_set_defs eq_fst_iff)
lars@66271
    15
lars@66271
    16
lemma pred_stateI[intro]:
lars@66271
    17
  assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P a"
lars@66271
    18
  shows "pred_state P m"
lars@66271
    19
proof (subst state.pred_set, rule)
lars@66271
    20
  fix x
lars@66271
    21
  assume "x \<in> set_state m"
lars@66271
    22
  then obtain s s' where "run_state m s = (x, s')"
lars@66271
    23
    by (auto simp: set_state_iff)
lars@66271
    24
  with assms show "P x" .
lars@66271
    25
qed
lars@66271
    26
lars@66271
    27
lemma pred_stateD[dest]:
lars@66271
    28
  assumes "pred_state P m" "run_state m s = (a, s')"
lars@66271
    29
  shows "P a"
lars@66271
    30
proof (rule state.exhaust[of m])
lars@66271
    31
  fix f
lars@66271
    32
  assume "m = State f"
lars@66271
    33
  with assms have "pred_fun (\<lambda>_. True) (pred_prod P top) f"
lars@66271
    34
    by (metis state.pred_inject)
lars@66271
    35
  moreover have "f s = (a, s')"
lars@66271
    36
    using assms unfolding \<open>m = _\<close> by auto
lars@66271
    37
  ultimately show "P a"
lars@66271
    38
    unfolding pred_prod_beta pred_fun_def
lars@66271
    39
    by (metis fst_conv)
lars@66271
    40
qed
lars@66271
    41
lars@66271
    42
lemma pred_state_run_state: "pred_state P m \<Longrightarrow> P (fst (run_state m s))"
lars@66271
    43
by (meson pred_stateD prod.exhaust_sel)
lars@66271
    44
lars@66271
    45
definition state_io_rel :: "('s \<Rightarrow> 's \<Rightarrow> bool) \<Rightarrow> ('s, 'a) state \<Rightarrow> bool" where
lars@66271
    46
"state_io_rel P m = (\<forall>s. P s (snd (run_state m s)))"
lars@66271
    47
lars@66271
    48
lemma state_io_relI[intro]:
lars@66271
    49
  assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P s s'"
lars@66271
    50
  shows "state_io_rel P m"
lars@66271
    51
using assms unfolding state_io_rel_def
lars@66271
    52
by (metis prod.collapse)
lars@66271
    53
lars@66271
    54
lemma state_io_relD[dest]:
lars@66271
    55
  assumes "state_io_rel P m" "run_state m s = (a, s')"
lars@66271
    56
  shows "P s s'"
lars@66271
    57
using assms unfolding state_io_rel_def
lars@66271
    58
by (metis snd_conv)
lars@66271
    59
lars@66271
    60
lemma state_io_rel_mono[mono]: "P \<le> Q \<Longrightarrow> state_io_rel P \<le> state_io_rel Q"
lars@66271
    61
by blast
lars@66271
    62
lars@66271
    63
lemma state_ext:
lars@66271
    64
  assumes "\<And>s. run_state m s = run_state n s"
lars@66271
    65
  shows "m = n"
lars@66271
    66
using assms
lars@66271
    67
by (cases m; cases n) auto
lars@66271
    68
lars@66271
    69
context begin
lars@66271
    70
lars@66271
    71
qualified definition return :: "'a \<Rightarrow> ('s, 'a) state" where
lars@66271
    72
"return a = State (Pair a)"
lars@66271
    73
lars@66275
    74
lemma run_state_return[simp]: "run_state (return x) s = (x, s)"
lars@66275
    75
unfolding return_def
lars@66275
    76
by simp
lars@66275
    77
lars@66271
    78
qualified definition ap :: "('s, 'a \<Rightarrow> 'b) state \<Rightarrow> ('s, 'a) state \<Rightarrow> ('s, 'b) state" where
lars@66271
    79
"ap f x = State (\<lambda>s. case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
lars@66271
    80
lars@66271
    81
qualified definition bind :: "('s, 'a) state \<Rightarrow> ('a \<Rightarrow> ('s, 'b) state) \<Rightarrow> ('s, 'b) state" where
lars@66271
    82
"bind x f = State (\<lambda>s. case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
lars@66271
    83
lars@66271
    84
adhoc_overloading Monad_Syntax.bind bind
lars@66271
    85
lars@66271
    86
lemma bind_left_identity[simp]: "bind (return a) f = f a"
lars@66271
    87
unfolding return_def bind_def by simp
lars@66271
    88
lars@66271
    89
lemma bind_right_identity[simp]: "bind m return = m"
lars@66271
    90
unfolding return_def bind_def by simp
lars@66271
    91
lars@66271
    92
lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)"
lars@66271
    93
unfolding bind_def by (auto split: prod.splits)
lars@66271
    94
lars@66271
    95
lemma bind_predI[intro]:
lars@66271
    96
  assumes "pred_state (\<lambda>x. pred_state P (f x)) m"
lars@66271
    97
  shows "pred_state P (bind m f)"
lars@66271
    98
apply (rule pred_stateI)
lars@66271
    99
unfolding bind_def
lars@66271
   100
using assms by (auto split: prod.splits)
lars@66271
   101
lars@66271
   102
qualified definition get :: "('s, 's) state" where
lars@66271
   103
"get = State (\<lambda>s. (s, s))"
lars@66271
   104
lars@66271
   105
qualified definition set :: "'s \<Rightarrow> ('s, unit) state" where
lars@66271
   106
"set s' = State (\<lambda>_. ((), s'))"
lars@66271
   107
lars@66271
   108
lemma get_set[simp]: "bind get set = return ()"
lars@66271
   109
unfolding bind_def get_def set_def return_def
lars@66271
   110
by simp
lars@66271
   111
lars@66271
   112
lemma set_set[simp]: "bind (set s) (\<lambda>_. set s') = set s'"
lars@66271
   113
unfolding bind_def set_def
lars@66271
   114
by simp
lars@66271
   115
lars@66275
   116
lemma get_bind_set[simp]: "bind get (\<lambda>s. bind (set s) (f s)) = bind get (\<lambda>s. f s ())"
lars@66275
   117
unfolding bind_def get_def set_def
lars@66275
   118
by simp
lars@66275
   119
lars@66275
   120
lemma get_const[simp]: "bind get (\<lambda>_. m) = m"
lars@66275
   121
unfolding get_def bind_def
lars@66275
   122
by simp
lars@66275
   123
lars@66271
   124
fun traverse_list :: "('a \<Rightarrow> ('b, 'c) state) \<Rightarrow> 'a list \<Rightarrow> ('b, 'c list) state" where
lars@66271
   125
"traverse_list _ [] = return []" |
lars@66271
   126
"traverse_list f (x # xs) = do {
lars@66271
   127
  x \<leftarrow> f x;
lars@66271
   128
  xs \<leftarrow> traverse_list f xs;
lars@66271
   129
  return (x # xs)
lars@66271
   130
}"
lars@66271
   131
lars@66271
   132
lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do {
lars@66271
   133
  xs \<leftarrow> traverse_list f xs;
lars@66271
   134
  ys \<leftarrow> traverse_list f ys;
lars@66271
   135
  return (xs @ ys)
lars@66271
   136
}"
lars@66271
   137
by (induction xs) auto
lars@66271
   138
lars@66271
   139
lemma traverse_comp[simp]: "traverse_list (g \<circ> f) xs = traverse_list g (map f xs)"
lars@66271
   140
by (induction xs) auto
lars@66271
   141
lars@66271
   142
abbreviation mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
nipkow@67399
   143
"mono_state \<equiv> state_io_rel (\<le>)"
lars@66271
   144
lars@66271
   145
abbreviation strict_mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
nipkow@67399
   146
"strict_mono_state \<equiv> state_io_rel (<)"
lars@66271
   147
lars@66271
   148
corollary strict_mono_implies_mono: "strict_mono_state m \<Longrightarrow> mono_state m"
lars@66271
   149
unfolding state_io_rel_def
lars@66271
   150
by (simp add: less_imp_le)
lars@66271
   151
lars@66271
   152
lemma return_mono[simp, intro]: "mono_state (return x)"
lars@66271
   153
unfolding return_def by auto
lars@66271
   154
lars@66271
   155
lemma get_mono[simp, intro]: "mono_state get"
lars@66271
   156
unfolding get_def by auto
lars@66271
   157
lars@66271
   158
lemma put_mono:
lars@66271
   159
  assumes "\<And>x. s' \<ge> x"
lars@66271
   160
  shows "mono_state (set s')"
lars@66271
   161
using assms unfolding set_def
lars@66271
   162
by auto
lars@66271
   163
lars@66271
   164
lemma map_mono[intro]: "mono_state m \<Longrightarrow> mono_state (map_state f m)"
lars@66271
   165
by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
lars@66271
   166
lars@66271
   167
lemma map_strict_mono[intro]: "strict_mono_state m \<Longrightarrow> strict_mono_state (map_state f m)"
lars@66271
   168
by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
lars@66271
   169
lars@66271
   170
lemma bind_mono_strong:
lars@66271
   171
  assumes "mono_state m"
lars@66271
   172
  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
lars@66271
   173
  shows "mono_state (bind m f)"
lars@66271
   174
unfolding bind_def
lars@66271
   175
apply (rule state_io_relI)
lars@66271
   176
using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans)
lars@66271
   177
lars@66271
   178
lemma bind_strict_mono_strong1:
lars@66271
   179
  assumes "mono_state m"
lars@66271
   180
  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
lars@66271
   181
  shows "strict_mono_state (bind m f)"
lars@66271
   182
unfolding bind_def
lars@66271
   183
apply (rule state_io_relI)
lars@66271
   184
using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans)
lars@66271
   185
lars@66271
   186
lemma bind_strict_mono_strong2:
lars@66271
   187
  assumes "strict_mono_state m"
lars@66271
   188
  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
lars@66271
   189
  shows "strict_mono_state (bind m f)"
lars@66271
   190
unfolding bind_def
lars@66271
   191
apply (rule state_io_relI)
lars@66271
   192
using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans)
lars@66271
   193
lars@66271
   194
corollary bind_strict_mono_strong:
lars@66271
   195
  assumes "strict_mono_state m"
lars@66271
   196
  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
lars@66271
   197
  shows "strict_mono_state (bind m f)"
lars@66271
   198
using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono)
lars@66271
   199
lars@66271
   200
qualified definition update :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) state" where
lars@66271
   201
"update f = bind get (set \<circ> f)"
lars@66271
   202
lars@66271
   203
lemma update_id[simp]: "update (\<lambda>x. x) = return ()"
lars@66271
   204
unfolding update_def return_def get_def set_def bind_def
lars@66271
   205
by auto
lars@66271
   206
lars@66271
   207
lemma update_comp[simp]: "bind (update f) (\<lambda>_. update g) = update (g \<circ> f)"
lars@66271
   208
unfolding update_def return_def get_def set_def bind_def
lars@66271
   209
by auto
lars@66271
   210
lars@66275
   211
lemma set_update[simp]: "bind (set s) (\<lambda>_. update f) = set (f s)"
lars@66275
   212
unfolding set_def update_def bind_def get_def set_def
lars@66275
   213
by simp
lars@66275
   214
lars@66275
   215
lemma set_bind_update[simp]: "bind (set s) (\<lambda>_. bind (update f) g) = bind (set (f s)) g"
lars@66275
   216
unfolding set_def update_def bind_def get_def set_def
lars@66275
   217
by simp
lars@66275
   218
lars@66271
   219
lemma update_mono:
lars@66271
   220
  assumes "\<And>x. x \<le> f x"
lars@66271
   221
  shows "mono_state (update f)"
lars@66271
   222
using assms unfolding update_def get_def set_def bind_def
lars@66271
   223
by (auto intro!: state_io_relI)
lars@66271
   224
lars@66271
   225
lemma update_strict_mono:
lars@66271
   226
  assumes "\<And>x. x < f x"
lars@66271
   227
  shows "strict_mono_state (update f)"
lars@66271
   228
using assms unfolding update_def get_def set_def bind_def
lars@66271
   229
by (auto intro!: state_io_relI)
lars@66271
   230
lars@66271
   231
end
lars@66271
   232
nipkow@67399
   233
end