src/HOL/Library/State_Monad.thy
author haftmann
Wed Jul 18 20:51:21 2018 +0200 (11 months ago)
changeset 68658 16cc1161ad7f
parent 67399 eab6ce8368fa
child 68756 7066e83dfe46
permissions -rw-r--r--
tuned equation
     1 (*  Title:      HOL/Library/State_Monad.thy
     2     Author:     Lars Hupel, TU M√ľnchen
     3 *)
     4 
     5 section \<open>State monad\<close>
     6 
     7 theory State_Monad
     8 imports Monad_Syntax
     9 begin
    10 
    11 datatype ('s, 'a) state = State (run_state: "'s \<Rightarrow> ('a \<times> 's)")
    12 
    13 lemma set_state_iff: "x \<in> set_state m \<longleftrightarrow> (\<exists>s s'. run_state m s = (x, s'))"
    14 by (cases m) (simp add: prod_set_defs eq_fst_iff)
    15 
    16 lemma pred_stateI[intro]:
    17   assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P a"
    18   shows "pred_state P m"
    19 proof (subst state.pred_set, rule)
    20   fix x
    21   assume "x \<in> set_state m"
    22   then obtain s s' where "run_state m s = (x, s')"
    23     by (auto simp: set_state_iff)
    24   with assms show "P x" .
    25 qed
    26 
    27 lemma pred_stateD[dest]:
    28   assumes "pred_state P m" "run_state m s = (a, s')"
    29   shows "P a"
    30 proof (rule state.exhaust[of m])
    31   fix f
    32   assume "m = State f"
    33   with assms have "pred_fun (\<lambda>_. True) (pred_prod P top) f"
    34     by (metis state.pred_inject)
    35   moreover have "f s = (a, s')"
    36     using assms unfolding \<open>m = _\<close> by auto
    37   ultimately show "P a"
    38     unfolding pred_prod_beta pred_fun_def
    39     by (metis fst_conv)
    40 qed
    41 
    42 lemma pred_state_run_state: "pred_state P m \<Longrightarrow> P (fst (run_state m s))"
    43 by (meson pred_stateD prod.exhaust_sel)
    44 
    45 definition state_io_rel :: "('s \<Rightarrow> 's \<Rightarrow> bool) \<Rightarrow> ('s, 'a) state \<Rightarrow> bool" where
    46 "state_io_rel P m = (\<forall>s. P s (snd (run_state m s)))"
    47 
    48 lemma state_io_relI[intro]:
    49   assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P s s'"
    50   shows "state_io_rel P m"
    51 using assms unfolding state_io_rel_def
    52 by (metis prod.collapse)
    53 
    54 lemma state_io_relD[dest]:
    55   assumes "state_io_rel P m" "run_state m s = (a, s')"
    56   shows "P s s'"
    57 using assms unfolding state_io_rel_def
    58 by (metis snd_conv)
    59 
    60 lemma state_io_rel_mono[mono]: "P \<le> Q \<Longrightarrow> state_io_rel P \<le> state_io_rel Q"
    61 by blast
    62 
    63 lemma state_ext:
    64   assumes "\<And>s. run_state m s = run_state n s"
    65   shows "m = n"
    66 using assms
    67 by (cases m; cases n) auto
    68 
    69 context begin
    70 
    71 qualified definition return :: "'a \<Rightarrow> ('s, 'a) state" where
    72 "return a = State (Pair a)"
    73 
    74 lemma run_state_return[simp]: "run_state (return x) s = (x, s)"
    75 unfolding return_def
    76 by simp
    77 
    78 qualified definition ap :: "('s, 'a \<Rightarrow> 'b) state \<Rightarrow> ('s, 'a) state \<Rightarrow> ('s, 'b) state" where
    79 "ap f x = State (\<lambda>s. case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
    80 
    81 qualified definition bind :: "('s, 'a) state \<Rightarrow> ('a \<Rightarrow> ('s, 'b) state) \<Rightarrow> ('s, 'b) state" where
    82 "bind x f = State (\<lambda>s. case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
    83 
    84 adhoc_overloading Monad_Syntax.bind bind
    85 
    86 lemma bind_left_identity[simp]: "bind (return a) f = f a"
    87 unfolding return_def bind_def by simp
    88 
    89 lemma bind_right_identity[simp]: "bind m return = m"
    90 unfolding return_def bind_def by simp
    91 
    92 lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)"
    93 unfolding bind_def by (auto split: prod.splits)
    94 
    95 lemma bind_predI[intro]:
    96   assumes "pred_state (\<lambda>x. pred_state P (f x)) m"
    97   shows "pred_state P (bind m f)"
    98 apply (rule pred_stateI)
    99 unfolding bind_def
   100 using assms by (auto split: prod.splits)
   101 
   102 qualified definition get :: "('s, 's) state" where
   103 "get = State (\<lambda>s. (s, s))"
   104 
   105 qualified definition set :: "'s \<Rightarrow> ('s, unit) state" where
   106 "set s' = State (\<lambda>_. ((), s'))"
   107 
   108 lemma get_set[simp]: "bind get set = return ()"
   109 unfolding bind_def get_def set_def return_def
   110 by simp
   111 
   112 lemma set_set[simp]: "bind (set s) (\<lambda>_. set s') = set s'"
   113 unfolding bind_def set_def
   114 by simp
   115 
   116 lemma get_bind_set[simp]: "bind get (\<lambda>s. bind (set s) (f s)) = bind get (\<lambda>s. f s ())"
   117 unfolding bind_def get_def set_def
   118 by simp
   119 
   120 lemma get_const[simp]: "bind get (\<lambda>_. m) = m"
   121 unfolding get_def bind_def
   122 by simp
   123 
   124 fun traverse_list :: "('a \<Rightarrow> ('b, 'c) state) \<Rightarrow> 'a list \<Rightarrow> ('b, 'c list) state" where
   125 "traverse_list _ [] = return []" |
   126 "traverse_list f (x # xs) = do {
   127   x \<leftarrow> f x;
   128   xs \<leftarrow> traverse_list f xs;
   129   return (x # xs)
   130 }"
   131 
   132 lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do {
   133   xs \<leftarrow> traverse_list f xs;
   134   ys \<leftarrow> traverse_list f ys;
   135   return (xs @ ys)
   136 }"
   137 by (induction xs) auto
   138 
   139 lemma traverse_comp[simp]: "traverse_list (g \<circ> f) xs = traverse_list g (map f xs)"
   140 by (induction xs) auto
   141 
   142 abbreviation mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
   143 "mono_state \<equiv> state_io_rel (\<le>)"
   144 
   145 abbreviation strict_mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
   146 "strict_mono_state \<equiv> state_io_rel (<)"
   147 
   148 corollary strict_mono_implies_mono: "strict_mono_state m \<Longrightarrow> mono_state m"
   149 unfolding state_io_rel_def
   150 by (simp add: less_imp_le)
   151 
   152 lemma return_mono[simp, intro]: "mono_state (return x)"
   153 unfolding return_def by auto
   154 
   155 lemma get_mono[simp, intro]: "mono_state get"
   156 unfolding get_def by auto
   157 
   158 lemma put_mono:
   159   assumes "\<And>x. s' \<ge> x"
   160   shows "mono_state (set s')"
   161 using assms unfolding set_def
   162 by auto
   163 
   164 lemma map_mono[intro]: "mono_state m \<Longrightarrow> mono_state (map_state f m)"
   165 by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
   166 
   167 lemma map_strict_mono[intro]: "strict_mono_state m \<Longrightarrow> strict_mono_state (map_state f m)"
   168 by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
   169 
   170 lemma bind_mono_strong:
   171   assumes "mono_state m"
   172   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
   173   shows "mono_state (bind m f)"
   174 unfolding bind_def
   175 apply (rule state_io_relI)
   176 using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans)
   177 
   178 lemma bind_strict_mono_strong1:
   179   assumes "mono_state m"
   180   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
   181   shows "strict_mono_state (bind m f)"
   182 unfolding bind_def
   183 apply (rule state_io_relI)
   184 using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans)
   185 
   186 lemma bind_strict_mono_strong2:
   187   assumes "strict_mono_state m"
   188   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
   189   shows "strict_mono_state (bind m f)"
   190 unfolding bind_def
   191 apply (rule state_io_relI)
   192 using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans)
   193 
   194 corollary bind_strict_mono_strong:
   195   assumes "strict_mono_state m"
   196   assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
   197   shows "strict_mono_state (bind m f)"
   198 using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono)
   199 
   200 qualified definition update :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) state" where
   201 "update f = bind get (set \<circ> f)"
   202 
   203 lemma update_id[simp]: "update (\<lambda>x. x) = return ()"
   204 unfolding update_def return_def get_def set_def bind_def
   205 by auto
   206 
   207 lemma update_comp[simp]: "bind (update f) (\<lambda>_. update g) = update (g \<circ> f)"
   208 unfolding update_def return_def get_def set_def bind_def
   209 by auto
   210 
   211 lemma set_update[simp]: "bind (set s) (\<lambda>_. update f) = set (f s)"
   212 unfolding set_def update_def bind_def get_def set_def
   213 by simp
   214 
   215 lemma set_bind_update[simp]: "bind (set s) (\<lambda>_. bind (update f) g) = bind (set (f s)) g"
   216 unfolding set_def update_def bind_def get_def set_def
   217 by simp
   218 
   219 lemma update_mono:
   220   assumes "\<And>x. x \<le> f x"
   221   shows "mono_state (update f)"
   222 using assms unfolding update_def get_def set_def bind_def
   223 by (auto intro!: state_io_relI)
   224 
   225 lemma update_strict_mono:
   226   assumes "\<And>x. x < f x"
   227   shows "strict_mono_state (update f)"
   228 using assms unfolding update_def get_def set_def bind_def
   229 by (auto intro!: state_io_relI)
   230 
   231 end
   232 
   233 end