state monad
authorLars Hupel <lars.hupel@mytum.de>
Tue, 11 Jul 2017 20:47:19 +0200
changeset 66271 d157195a468a
parent 66270 403d84138c5c
child 66272 c6714a9562ae
state monad
src/HOL/Library/Library.thy
src/HOL/Library/State_Monad.thy
--- a/src/HOL/Library/Library.thy	Tue Jul 11 17:22:33 2017 +0200
+++ b/src/HOL/Library/Library.thy	Tue Jul 11 20:47:19 2017 +0200
@@ -71,6 +71,7 @@
   Rewrite
   Saturated
   Set_Algebras
+  State_Monad
   Stirling
   Stream
   Sublist
@@ -81,4 +82,4 @@
   While_Combinator
 begin
 end
-(*>*)
+(*>*)
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Library/State_Monad.thy	Tue Jul 11 20:47:19 2017 +0200
@@ -0,0 +1,213 @@
+(*  Title:      HOL/Library/State_Monad.thy
+    Author:     Lars Hupel, TU München
+*)
+
+section \<open>State monad\<close>
+
+theory State_Monad
+imports Monad_Syntax
+begin
+
+datatype ('s, 'a) state = State (run_state: "'s \<Rightarrow> ('a \<times> 's)")
+
+lemma set_state_iff: "x \<in> set_state m \<longleftrightarrow> (\<exists>s s'. run_state m s = (x, s'))"
+by (cases m) (simp add: prod_set_defs eq_fst_iff)
+
+lemma pred_stateI[intro]:
+  assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P a"
+  shows "pred_state P m"
+proof (subst state.pred_set, rule)
+  fix x
+  assume "x \<in> set_state m"
+  then obtain s s' where "run_state m s = (x, s')"
+    by (auto simp: set_state_iff)
+  with assms show "P x" .
+qed
+
+lemma pred_stateD[dest]:
+  assumes "pred_state P m" "run_state m s = (a, s')"
+  shows "P a"
+proof (rule state.exhaust[of m])
+  fix f
+  assume "m = State f"
+  with assms have "pred_fun (\<lambda>_. True) (pred_prod P top) f"
+    by (metis state.pred_inject)
+  moreover have "f s = (a, s')"
+    using assms unfolding \<open>m = _\<close> by auto
+  ultimately show "P a"
+    unfolding pred_prod_beta pred_fun_def
+    by (metis fst_conv)
+qed
+
+lemma pred_state_run_state: "pred_state P m \<Longrightarrow> P (fst (run_state m s))"
+by (meson pred_stateD prod.exhaust_sel)
+
+definition state_io_rel :: "('s \<Rightarrow> 's \<Rightarrow> bool) \<Rightarrow> ('s, 'a) state \<Rightarrow> bool" where
+"state_io_rel P m = (\<forall>s. P s (snd (run_state m s)))"
+
+lemma state_io_relI[intro]:
+  assumes "\<And>a s s'. run_state m s = (a, s') \<Longrightarrow> P s s'"
+  shows "state_io_rel P m"
+using assms unfolding state_io_rel_def
+by (metis prod.collapse)
+
+lemma state_io_relD[dest]:
+  assumes "state_io_rel P m" "run_state m s = (a, s')"
+  shows "P s s'"
+using assms unfolding state_io_rel_def
+by (metis snd_conv)
+
+lemma state_io_rel_mono[mono]: "P \<le> Q \<Longrightarrow> state_io_rel P \<le> state_io_rel Q"
+by blast
+
+lemma state_ext:
+  assumes "\<And>s. run_state m s = run_state n s"
+  shows "m = n"
+using assms
+by (cases m; cases n) auto
+
+context begin
+
+qualified definition return :: "'a \<Rightarrow> ('s, 'a) state" where
+"return a = State (Pair a)"
+
+qualified definition ap :: "('s, 'a \<Rightarrow> 'b) state \<Rightarrow> ('s, 'a) state \<Rightarrow> ('s, 'b) state" where
+"ap f x = State (\<lambda>s. case run_state f s of (g, s') \<Rightarrow> case run_state x s' of (y, s'') \<Rightarrow> (g y, s''))"
+
+qualified definition bind :: "('s, 'a) state \<Rightarrow> ('a \<Rightarrow> ('s, 'b) state) \<Rightarrow> ('s, 'b) state" where
+"bind x f = State (\<lambda>s. case run_state x s of (a, s') \<Rightarrow> run_state (f a) s')"
+
+adhoc_overloading Monad_Syntax.bind bind
+
+lemma bind_left_identity[simp]: "bind (return a) f = f a"
+unfolding return_def bind_def by simp
+
+lemma bind_right_identity[simp]: "bind m return = m"
+unfolding return_def bind_def by simp
+
+lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)"
+unfolding bind_def by (auto split: prod.splits)
+
+lemma bind_predI[intro]:
+  assumes "pred_state (\<lambda>x. pred_state P (f x)) m"
+  shows "pred_state P (bind m f)"
+apply (rule pred_stateI)
+unfolding bind_def
+using assms by (auto split: prod.splits)
+
+qualified definition get :: "('s, 's) state" where
+"get = State (\<lambda>s. (s, s))"
+
+qualified definition set :: "'s \<Rightarrow> ('s, unit) state" where
+"set s' = State (\<lambda>_. ((), s'))"
+
+lemma get_set[simp]: "bind get set = return ()"
+unfolding bind_def get_def set_def return_def
+by simp
+
+lemma set_set[simp]: "bind (set s) (\<lambda>_. set s') = set s'"
+unfolding bind_def set_def
+by simp
+
+fun traverse_list :: "('a \<Rightarrow> ('b, 'c) state) \<Rightarrow> 'a list \<Rightarrow> ('b, 'c list) state" where
+"traverse_list _ [] = return []" |
+"traverse_list f (x # xs) = do {
+  x \<leftarrow> f x;
+  xs \<leftarrow> traverse_list f xs;
+  return (x # xs)
+}"
+
+lemma traverse_list_app[simp]: "traverse_list f (xs @ ys) = do {
+  xs \<leftarrow> traverse_list f xs;
+  ys \<leftarrow> traverse_list f ys;
+  return (xs @ ys)
+}"
+by (induction xs) auto
+
+lemma traverse_comp[simp]: "traverse_list (g \<circ> f) xs = traverse_list g (map f xs)"
+by (induction xs) auto
+
+abbreviation mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
+"mono_state \<equiv> state_io_rel (op \<le>)"
+
+abbreviation strict_mono_state :: "('s::preorder, 'a) state \<Rightarrow> bool" where
+"strict_mono_state \<equiv> state_io_rel (op <)"
+
+corollary strict_mono_implies_mono: "strict_mono_state m \<Longrightarrow> mono_state m"
+unfolding state_io_rel_def
+by (simp add: less_imp_le)
+
+lemma return_mono[simp, intro]: "mono_state (return x)"
+unfolding return_def by auto
+
+lemma get_mono[simp, intro]: "mono_state get"
+unfolding get_def by auto
+
+lemma put_mono:
+  assumes "\<And>x. s' \<ge> x"
+  shows "mono_state (set s')"
+using assms unfolding set_def
+by auto
+
+lemma map_mono[intro]: "mono_state m \<Longrightarrow> mono_state (map_state f m)"
+by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
+
+lemma map_strict_mono[intro]: "strict_mono_state m \<Longrightarrow> strict_mono_state (map_state f m)"
+by (auto intro!: state_io_relI split: prod.splits simp: map_prod_def state.map_sel)
+
+lemma bind_mono_strong:
+  assumes "mono_state m"
+  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
+  shows "mono_state (bind m f)"
+unfolding bind_def
+apply (rule state_io_relI)
+using assms by (auto split: prod.splits dest!: state_io_relD intro: order_trans)
+
+lemma bind_strict_mono_strong1:
+  assumes "mono_state m"
+  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
+  shows "strict_mono_state (bind m f)"
+unfolding bind_def
+apply (rule state_io_relI)
+using assms by (auto split: prod.splits dest!: state_io_relD intro: le_less_trans)
+
+lemma bind_strict_mono_strong2:
+  assumes "strict_mono_state m"
+  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> mono_state (f x)"
+  shows "strict_mono_state (bind m f)"
+unfolding bind_def
+apply (rule state_io_relI)
+using assms by (auto split: prod.splits dest!: state_io_relD intro: less_le_trans)
+
+corollary bind_strict_mono_strong:
+  assumes "strict_mono_state m"
+  assumes "\<And>x s s'. run_state m s = (x, s') \<Longrightarrow> strict_mono_state (f x)"
+  shows "strict_mono_state (bind m f)"
+using assms by (auto intro: bind_strict_mono_strong1 strict_mono_implies_mono)
+
+qualified definition update :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) state" where
+"update f = bind get (set \<circ> f)"
+
+lemma update_id[simp]: "update (\<lambda>x. x) = return ()"
+unfolding update_def return_def get_def set_def bind_def
+by auto
+
+lemma update_comp[simp]: "bind (update f) (\<lambda>_. update g) = update (g \<circ> f)"
+unfolding update_def return_def get_def set_def bind_def
+by auto
+
+lemma update_mono:
+  assumes "\<And>x. x \<le> f x"
+  shows "mono_state (update f)"
+using assms unfolding update_def get_def set_def bind_def
+by (auto intro!: state_io_relI)
+
+lemma update_strict_mono:
+  assumes "\<And>x. x < f x"
+  shows "strict_mono_state (update f)"
+using assms unfolding update_def get_def set_def bind_def
+by (auto intro!: state_io_relI)
+
+end
+
+end
\ No newline at end of file