guard combinator
authorhaftmann
Fri, 09 Jul 2010 09:48:53 +0200
changeset 37754 683d1e1bc234
parent 37753 3ac6867279f0
child 37755 7086b7feaaa5
guard combinator
src/HOL/Imperative_HOL/Heap_Monad.thy
--- a/src/HOL/Imperative_HOL/Heap_Monad.thy	Fri Jul 09 09:48:53 2010 +0200
+++ b/src/HOL/Imperative_HOL/Heap_Monad.thy	Fri Jul 09 09:48:53 2010 +0200
@@ -37,6 +37,14 @@
   "execute (heap f) = Some \<circ> f"
   by (simp add: heap_def)
 
+definition guard :: "(heap \<Rightarrow> bool) \<Rightarrow> (heap \<Rightarrow> 'a \<times> heap) \<Rightarrow> 'a Heap" where
+  [code del]: "guard P f = Heap (\<lambda>h. if P h then Some (f h) else None)"
+
+lemma execute_guard [simp]:
+  "\<not> P h \<Longrightarrow> execute (guard P f) h = None"
+  "P h \<Longrightarrow> execute (guard P f) h = Some (f h)"
+  by (simp_all add: guard_def)
+
 lemma heap_cases [case_names succeed fail]:
   fixes f and h
   assumes succeed: "\<And>x h'. execute f h = Some (x, h') \<Longrightarrow> P"
@@ -58,7 +66,7 @@
   "execute (raise s) = (\<lambda>_. None)"
   by (simp add: raise_def)
 
-definition bindM :: "'a Heap \<Rightarrow> ('a \<Rightarrow> 'b Heap) \<Rightarrow> 'b Heap" (infixl ">>=" 54) where
+definition bindM :: "'a Heap \<Rightarrow> ('a \<Rightarrow> 'b Heap) \<Rightarrow> 'b Heap" (infixl ">>=" 54) where (*FIXME just bind*)
   [code del]: "f >>= g = Heap (\<lambda>h. case execute f h of
                   Some (x, h') \<Rightarrow> execute (g x) h'
                 | None \<Rightarrow> None)"
@@ -74,6 +82,12 @@
   "execute (heap f \<guillemotright>= g) h = execute (g (fst (f h))) (snd (f h))"
   by (simp add: bindM_def split_def)
   
+lemma execute_eq_SomeI:
+  assumes "Heap_Monad.execute f h = Some (x, h')"
+    and "Heap_Monad.execute (g x) h' = Some (y, h'')"
+  shows "Heap_Monad.execute (f \<guillemotright>= g) h = Some (y, h'')"
+  using assms by (simp add: bindM_def)
+
 lemma return_bind [simp]: "return x \<guillemotright>= f = f x"
   by (rule Heap_eqI) simp
 
@@ -86,10 +100,10 @@
 lemma raise_bind [simp]: "raise e \<guillemotright>= f = raise e"
   by (rule Heap_eqI) simp
 
-abbreviation chainM :: "'a Heap \<Rightarrow> 'b Heap \<Rightarrow> 'b Heap"  (infixl ">>" 54) where
+abbreviation chain :: "'a Heap \<Rightarrow> 'b Heap \<Rightarrow> 'b Heap"  (infixl ">>" 54) where
   "f >> g \<equiv> f >>= (\<lambda>_. g)"
 
-notation chainM (infixl "\<guillemotright>" 54)
+notation chain (infixl "\<guillemotright>" 54)
 
 
 subsubsection {* do-syntax *}
@@ -105,9 +119,9 @@
 syntax
   "_do" :: "do_expr \<Rightarrow> 'a"
     ("(do (_)//done)" [12] 100)
-  "_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
+  "_bind" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
     ("_ <- _;//_" [1000, 13, 12] 12)
-  "_chainM" :: "'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
+  "_chain" :: "'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
     ("_;//_" [13, 12] 12)
   "_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
     ("let _ = _;//_" [1000, 13, 12] 12)
@@ -115,13 +129,13 @@
     ("_" [12] 12)
 
 syntax (xsymbols)
-  "_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
+  "_bind" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
     ("_ \<leftarrow> _;//_" [1000, 13, 12] 12)
 
 translations
   "_do f" => "f"
-  "_bindM x f g" => "f \<guillemotright>= (\<lambda>x. g)"
-  "_chainM f g" => "f \<guillemotright> g"
+  "_bind x f g" => "f \<guillemotright>= (\<lambda>x. g)"
+  "_chain f g" => "f \<guillemotright> g"
   "_let x t f" => "CONST Let t (\<lambda>x. f)"
   "_nil f" => "f"
 
@@ -142,12 +156,12 @@
           val v_used = fold_aterms
             (fn Free (w, _) => (fn s => s orelse member (op =) vs w) | _ => I) g' false;
         in if v_used then
-          Const (@{syntax_const "_bindM"}, dummyT) $ v $ f $ unfold_monad g'
+          Const (@{syntax_const "_bind"}, dummyT) $ v $ f $ unfold_monad g'
         else
-          Const (@{syntax_const "_chainM"}, dummyT) $ f $ unfold_monad g'
+          Const (@{syntax_const "_chain"}, dummyT) $ f $ unfold_monad g'
         end
-    | unfold_monad (Const (@{const_syntax chainM}, _) $ f $ g) =
-        Const (@{syntax_const "_chainM"}, dummyT) $ f $ unfold_monad g
+    | unfold_monad (Const (@{const_syntax chain}, _) $ f $ g) =
+        Const (@{syntax_const "_chain"}, dummyT) $ f $ unfold_monad g
     | unfold_monad (Const (@{const_syntax Let}, _) $ f $ g) =
         let
           val (v, g') = dest_abs_eta g;
@@ -155,14 +169,14 @@
     | unfold_monad (Const (@{const_syntax Pair}, _) $ f) =
         Const (@{const_syntax return}, dummyT) $ f
     | unfold_monad f = f;
-  fun contains_bindM (Const (@{const_syntax bindM}, _) $ _ $ _) = true
-    | contains_bindM (Const (@{const_syntax Let}, _) $ _ $ Abs (_, _, t)) =
-        contains_bindM t;
+  fun contains_bind (Const (@{const_syntax bindM}, _) $ _ $ _) = true
+    | contains_bind (Const (@{const_syntax Let}, _) $ _ $ Abs (_, _, t)) =
+        contains_bind t;
   fun bindM_monad_tr' (f::g::ts) = list_comb
     (Const (@{syntax_const "_do"}, dummyT) $
       unfold_monad (Const (@{const_syntax bindM}, dummyT) $ f $ g), ts);
   fun Let_monad_tr' (f :: (g as Abs (_, _, g')) :: ts) =
-    if contains_bindM g' then list_comb
+    if contains_bind g' then list_comb
       (Const (@{syntax_const "_do"}, dummyT) $
         unfold_monad (Const (@{const_syntax Let}, dummyT) $ f $ g), ts)
     else raise Match;
@@ -180,31 +194,55 @@
 definition assert :: "('a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a Heap" where
   "assert P x = (if P x then return x else raise ''assert'')"
 
+lemma execute_assert [simp]:
+  "P x \<Longrightarrow> execute (assert P x) h = Some (x, h)"
+  "\<not> P x \<Longrightarrow> execute (assert P x) h = None"
+  by (simp_all add: assert_def)
+
 lemma assert_cong [fundef_cong]:
   assumes "P = P'"
   assumes "\<And>x. P' x \<Longrightarrow> f x = f' x"
   shows "(assert P x >>= f) = (assert P' x >>= f')"
-  using assms by (auto simp add: assert_def return_bind raise_bind)
+  by (rule Heap_eqI) (insert assms, simp add: assert_def)
 
-definition liftM :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b Heap" where
-  "liftM f = return o f"
+definition lift :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b Heap" where
+  "lift f = return o f"
 
-lemma liftM_collapse [simp]:
-  "liftM f x = return (f x)"
-  by (simp add: liftM_def)
+lemma lift_collapse [simp]:
+  "lift f x = return (f x)"
+  by (simp add: lift_def)
 
-lemma bind_liftM:
-  "(f \<guillemotright>= liftM g) = (f \<guillemotright>= (\<lambda>x. return (g x)))"
-  by (simp add: liftM_def comp_def)
+lemma bind_lift:
+  "(f \<guillemotright>= lift g) = (f \<guillemotright>= (\<lambda>x. return (g x)))"
+  by (simp add: lift_def comp_def)
 
-primrec mapM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b list Heap" where
+primrec mapM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b list Heap" where (*FIXME just map?*)
   "mapM f [] = return []"
-| "mapM f (x#xs) = do
+| "mapM f (x # xs) = do
      y \<leftarrow> f x;
      ys \<leftarrow> mapM f xs;
      return (y # ys)
    done"
 
+lemma mapM_append:
+  "mapM f (xs @ ys) = mapM f xs \<guillemotright>= (\<lambda>xs. mapM f ys \<guillemotright>= (\<lambda>ys. return (xs @ ys)))"
+  by (induct xs) simp_all
+
+lemma execute_mapM_unchanged_heap:
+  assumes "\<And>x. x \<in> set xs \<Longrightarrow> \<exists>y. execute (f x) h = Some (y, h)"
+  shows "execute (mapM f xs) h =
+    Some (List.map (\<lambda>x. fst (the (execute (f x) h))) xs, h)"
+using assms proof (induct xs)
+  case Nil show ?case by simp
+next
+  case (Cons x xs)
+  from Cons.prems obtain y
+    where y: "execute (f x) h = Some (y, h)" by auto
+  moreover from Cons.prems Cons.hyps have "execute (mapM f xs) h =
+    Some (map (\<lambda>x. fst (the (execute (f x) h))) xs, h)" by auto
+  ultimately show ?case by (simp, simp only: execute_bind(1), simp)
+qed
+
 
 subsubsection {* A monadic combinator for simple recursive functions *}
 
@@ -371,7 +409,7 @@
 code_const return (SML "!(fn/ ()/ =>/ _)")
 code_const Heap_Monad.raise' (SML "!(raise/ Fail/ _)")
 
-code_type Heap (OCaml "_")
+code_type Heap (OCaml "unit/ ->/ _")
 code_const "op \<guillemotright>=" (OCaml "!(fun/ f'_/ ()/ ->/ f'_/ (_/ ())/ ())")
 code_const return (OCaml "!(fun/ ()/ ->/ _)")
 code_const Heap_Monad.raise' (OCaml "failwith/ _")
@@ -388,7 +426,7 @@
     fun is_const c = case lookup_const naming c
      of SOME c' => (fn c'' => c' = c'')
       | NONE => K false;
-    val is_bindM = is_const @{const_name bindM};
+    val is_bind = is_const @{const_name bindM};
     val is_return = is_const @{const_name return};
     val dummy_name = "";
     val dummy_type = ITyVar dummy_name;
@@ -412,13 +450,13 @@
         val ((v, ty), t) = dest_abs (t2, ty2);
       in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
     and tr_bind'' t = case unfold_app t
-         of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if is_bindM c
+         of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if is_bind c
               then tr_bind' [(x1, ty1), (x2, ty2)]
               else force t
           | _ => force t;
     fun imp_monad_bind'' ts = (SOME dummy_name, dummy_type) `|=> ICase (((IVar (SOME dummy_name), dummy_type),
       [(unitt, tr_bind' ts)]), dummy_case_term)
-    and imp_monad_bind' (const as (c, (_, tys))) ts = if is_bindM c then case (ts, tys)
+    and imp_monad_bind' (const as (c, (_, tys))) ts = if is_bind c then case (ts, tys)
        of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)]
         | ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)] `$ t3
         | (ts, _) => imp_monad_bind (eta_expand 2 (const, ts))
@@ -489,6 +527,6 @@
 code_const return (Haskell "return")
 code_const Heap_Monad.raise' (Haskell "error/ _")
 
-hide_const (open) Heap heap execute raise'
+hide_const (open) Heap heap guard execute raise'
 
 end