26170

1 
(* Title: HOL/Library/Heap_Monad.thy


2 
ID: $Id$


3 
Author: John Matthews, Galois Connections; Alexander Krauss, Lukas Bulwahn & Florian Haftmann, TU Muenchen


4 
*)


5 


6 
header {* A monad with a polymorphic heap *}


7 


8 
theory Heap_Monad


9 
imports Heap


10 
begin


11 


12 
subsection {* The monad *}


13 


14 
subsubsection {* Monad combinators *}


15 


16 
datatype exception = Exn


17 


18 
text {* Monadic heap actions either produce values


19 
and transform the heap, or fail *}


20 
datatype 'a Heap = Heap "heap \<Rightarrow> ('a + exception) \<times> heap"


21 


22 
primrec


23 
execute :: "'a Heap \<Rightarrow> heap \<Rightarrow> ('a + exception) \<times> heap" where


24 
"execute (Heap f) = f"


25 
lemmas [code del] = execute.simps


26 


27 
lemma Heap_execute [simp]:


28 
"Heap (execute f) = f" by (cases f) simp_all


29 


30 
lemma Heap_eqI:


31 
"(\<And>h. execute f h = execute g h) \<Longrightarrow> f = g"


32 
by (cases f, cases g) (auto simp: expand_fun_eq)


33 


34 
lemma Heap_eqI':


35 
"(\<And>h. (\<lambda>x. execute (f x) h) = (\<lambda>y. execute (g y) h)) \<Longrightarrow> f = g"


36 
by (auto simp: expand_fun_eq intro: Heap_eqI)


37 


38 
lemma Heap_strip: "(\<And>f. PROP P f) \<equiv> (\<And>g. PROP P (Heap g))"


39 
proof


40 
fix g :: "heap \<Rightarrow> ('a + exception) \<times> heap"


41 
assume "\<And>f. PROP P f"


42 
then show "PROP P (Heap g)" .


43 
next


44 
fix f :: "'a Heap"


45 
assume assm: "\<And>g. PROP P (Heap g)"


46 
then have "PROP P (Heap (execute f))" .


47 
then show "PROP P f" by simp


48 
qed


49 


50 
definition


51 
heap :: "(heap \<Rightarrow> 'a \<times> heap) \<Rightarrow> 'a Heap" where


52 
[code del]: "heap f = Heap (\<lambda>h. apfst Inl (f h))"


53 


54 
lemma execute_heap [simp]:


55 
"execute (heap f) h = apfst Inl (f h)"


56 
by (simp add: heap_def)


57 


58 
definition


59 
bindM :: "'a Heap \<Rightarrow> ('a \<Rightarrow> 'b Heap) \<Rightarrow> 'b Heap" (infixl ">>=" 54) where


60 
[code del]: "f >>= g = Heap (\<lambda>h. case execute f h of


61 
(Inl x, h') \<Rightarrow> execute (g x) h'


62 
 r \<Rightarrow> r)"


63 


64 
notation


65 
bindM (infixl "\<guillemotright>=" 54)


66 


67 
abbreviation


68 
chainM :: "'a Heap \<Rightarrow> 'b Heap \<Rightarrow> 'b Heap" (infixl ">>" 54) where


69 
"f >> g \<equiv> f >>= (\<lambda>_. g)"


70 


71 
notation


72 
chainM (infixl "\<guillemotright>" 54)


73 


74 
definition


75 
return :: "'a \<Rightarrow> 'a Heap" where


76 
[code del]: "return x = heap (Pair x)"


77 


78 
lemma execute_return [simp]:


79 
"execute (return x) h = apfst Inl (x, h)"


80 
by (simp add: return_def)


81 


82 
definition


83 
raise :: "string \<Rightarrow> 'a Heap" where  {* the string is just decoration *}


84 
[code del]: "raise s = Heap (Pair (Inr Exn))"


85 


86 
notation (latex output)


87 
"raise" ("\<^raw:{\textsf{raise}}>")


88 


89 
lemma execute_raise [simp]:


90 
"execute (raise s) h = (Inr Exn, h)"


91 
by (simp add: raise_def)


92 


93 


94 
subsubsection {* dosyntax *}


95 


96 
text {*


97 
We provide a convenient donotation for monadic expressions


98 
wellknown from Haskell. @{const Let} is printed


99 
specially in doexpressions.


100 
*}


101 


102 
nonterminals do_expr


103 


104 
syntax


105 
"_do" :: "do_expr \<Rightarrow> 'a"


106 
("(do (_)//done)" [12] 100)


107 
"_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


108 
("_ < _;//_" [1000, 13, 12] 12)


109 
"_chainM" :: "'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


110 
("_;//_" [13, 12] 12)


111 
"_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


112 
("let _ = _;//_" [1000, 13, 12] 12)


113 
"_nil" :: "'a \<Rightarrow> do_expr"


114 
("_" [12] 12)


115 


116 
syntax (xsymbols)


117 
"_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


118 
("_ \<leftarrow> _;//_" [1000, 13, 12] 12)


119 
syntax (latex output)


120 
"_do" :: "do_expr \<Rightarrow> 'a"


121 
("(\<^raw:{\textsf{do}}> (_))" [12] 100)


122 
"_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


123 
("\<^raw:\textsf{let}> _ = _;//_" [1000, 13, 12] 12)


124 
notation (latex output)


125 
"return" ("\<^raw:{\textsf{return}}>")


126 


127 
translations

28145

128 
"_do f" => "f"

26170

129 
"_bindM x f g" => "f \<guillemotright>= (\<lambda>x. g)"


130 
"_chainM f g" => "f \<guillemotright> g"


131 
"_let x t f" => "CONST Let t (\<lambda>x. f)"


132 
"_nil f" => "f"


133 


134 
print_translation {*


135 
let


136 
fun dest_abs_eta (Abs (abs as (_, ty, _))) =


137 
let


138 
val (v, t) = Syntax.variant_abs abs;

28145

139 
in (Free (v, ty), t) end

26170

140 
 dest_abs_eta t =


141 
let


142 
val (v, t) = Syntax.variant_abs ("", dummyT, t $ Bound 0);

28145

143 
in (Free (v, dummyT), t) end;

26170

144 
fun unfold_monad (Const (@{const_syntax bindM}, _) $ f $ g) =


145 
let

28145

146 
val (v, g') = dest_abs_eta g;


147 
val vs = fold_aterms (fn Free (v, _) => insert (op =) v  _ => I) v [];

26170

148 
val v_used = fold_aterms

28145

149 
(fn Free (w, _) => (fn s => s orelse member (op =) vs w)  _ => I) g' false;

26170

150 
in if v_used then

28145

151 
Const ("_bindM", dummyT) $ v $ f $ unfold_monad g'

26170

152 
else


153 
Const ("_chainM", dummyT) $ f $ unfold_monad g'


154 
end


155 
 unfold_monad (Const (@{const_syntax chainM}, _) $ f $ g) =


156 
Const ("_chainM", dummyT) $ f $ unfold_monad g


157 
 unfold_monad (Const (@{const_syntax Let}, _) $ f $ g) =


158 
let

28145

159 
val (v, g') = dest_abs_eta g;


160 
in Const ("_let", dummyT) $ v $ f $ unfold_monad g' end

26170

161 
 unfold_monad (Const (@{const_syntax Pair}, _) $ f) =

28145

162 
Const (@{const_syntax return}, dummyT) $ f

26170

163 
 unfold_monad f = f;

28145

164 
fun contains_bindM (Const (@{const_syntax bindM}, _) $ _ $ _) = true


165 
 contains_bindM (Const (@{const_syntax Let}, _) $ _ $ Abs (_, _, t)) =


166 
contains_bindM t;


167 
fun bindM_monad_tr' (f::g::ts) = list_comb


168 
(Const ("_do", dummyT) $ unfold_monad (Const (@{const_syntax bindM}, dummyT) $ f $ g), ts);


169 
fun Let_monad_tr' (f :: (g as Abs (_, _, g')) :: ts) = if contains_bindM g' then list_comb


170 
(Const ("_do", dummyT) $ unfold_monad (Const (@{const_syntax Let}, dummyT) $ f $ g), ts)


171 
else raise Match;


172 
in [


173 
(@{const_syntax bindM}, bindM_monad_tr'),


174 
(@{const_syntax Let}, Let_monad_tr')


175 
] end;

26170

176 
*}


177 


178 


179 
subsection {* Monad properties *}


180 


181 
subsubsection {* Monad laws *}


182 


183 
lemma return_bind: "return x \<guillemotright>= f = f x"


184 
by (simp add: bindM_def return_def)


185 


186 
lemma bind_return: "f \<guillemotright>= return = f"


187 
proof (rule Heap_eqI)


188 
fix h


189 
show "execute (f \<guillemotright>= return) h = execute f h"


190 
by (auto simp add: bindM_def return_def split: sum.splits prod.splits)


191 
qed


192 


193 
lemma bind_bind: "(f \<guillemotright>= g) \<guillemotright>= h = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h)"


194 
by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)


195 


196 
lemma bind_bind': "f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h x) = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= (\<lambda>y. return (x, y))) \<guillemotright>= (\<lambda>(x, y). h x y)"


197 
by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)


198 


199 
lemma raise_bind: "raise e \<guillemotright>= f = raise e"


200 
by (simp add: raise_def bindM_def)


201 


202 


203 
lemmas monad_simp = return_bind bind_return bind_bind raise_bind


204 


205 


206 
subsection {* Generic combinators *}


207 


208 
definition


209 
liftM :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b Heap"


210 
where


211 
"liftM f = return o f"


212 


213 
definition


214 
compM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> ('b \<Rightarrow> 'c Heap) \<Rightarrow> 'a \<Rightarrow> 'c Heap" (infixl ">>==" 54)


215 
where


216 
"(f >>== g) = (\<lambda>x. f x \<guillemotright>= g)"


217 


218 
notation


219 
compM (infixl "\<guillemotright>==" 54)


220 


221 
lemma liftM_collapse: "liftM f x = return (f x)"


222 
by (simp add: liftM_def)


223 


224 
lemma liftM_compM: "liftM f \<guillemotright>== g = g o f"


225 
by (auto intro: Heap_eqI' simp add: expand_fun_eq liftM_def compM_def bindM_def)


226 


227 
lemma compM_return: "f \<guillemotright>== return = f"


228 
by (simp add: compM_def monad_simp)


229 


230 
lemma compM_compM: "(f \<guillemotright>== g) \<guillemotright>== h = f \<guillemotright>== (g \<guillemotright>== h)"


231 
by (simp add: compM_def monad_simp)


232 


233 
lemma liftM_bind:


234 
"(\<lambda>x. liftM f x \<guillemotright>= liftM g) = liftM (\<lambda>x. g (f x))"


235 
by (rule Heap_eqI') (simp add: monad_simp liftM_def bindM_def)


236 


237 
lemma liftM_comp:


238 
"liftM f o g = liftM (f o g)"


239 
by (rule Heap_eqI') (simp add: liftM_def)


240 


241 
lemmas monad_simp' = monad_simp liftM_compM compM_return


242 
compM_compM liftM_bind liftM_comp


243 


244 
primrec


245 
mapM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b list Heap"


246 
where


247 
"mapM f [] = return []"


248 
 "mapM f (x#xs) = do y \<leftarrow> f x;


249 
ys \<leftarrow> mapM f xs;


250 
return (y # ys)


251 
done"


252 


253 
primrec


254 
foldM :: "('a \<Rightarrow> 'b \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b \<Rightarrow> 'b Heap"


255 
where


256 
"foldM f [] s = return s"


257 
 "foldM f (x#xs) s = f x s \<guillemotright>= foldM f xs"


258 


259 
hide (open) const heap execute


260 

26182

261 


262 
subsection {* Code generator setup *}


263 


264 
subsubsection {* Logical intermediate layer *}


265 


266 
definition


267 
Fail :: "message_string \<Rightarrow> exception"


268 
where


269 
[code func del]: "Fail s = Exn"


270 


271 
definition


272 
raise_exc :: "exception \<Rightarrow> 'a Heap"


273 
where


274 
[code func del]: "raise_exc e = raise []"


275 


276 
lemma raise_raise_exc [code func, code inline]:


277 
"raise s = raise_exc (Fail (STR s))"


278 
unfolding Fail_def raise_exc_def raise_def ..


279 


280 
hide (open) const Fail raise_exc


281 


282 

27707

283 
subsubsection {* SML and OCaml *}

26182

284 

26752

285 
code_type Heap (SML "unit/ >/ _")

26182

286 
code_const Heap (SML "raise/ (Fail/ \"bare Heap\")")

27826

287 
code_const "op \<guillemotright>=" (SML "!(fn/ f'_/ =>/ fn/ ()/ =>/ f'_/ (_/ ())/ ())")

27707

288 
code_const return (SML "!(fn/ ()/ =>/ _)")

26182

289 
code_const "Heap_Monad.Fail" (SML "Fail")

27707

290 
code_const "Heap_Monad.raise_exc" (SML "!(fn/ ()/ =>/ raise/ _)")

26182

291 


292 
code_type Heap (OCaml "_")


293 
code_const Heap (OCaml "failwith/ \"bare Heap\"")

27826

294 
code_const "op \<guillemotright>=" (OCaml "!(fun/ f'_/ ()/ >/ f'_/ (_/ ())/ ())")

27707

295 
code_const return (OCaml "!(fun/ ()/ >/ _)")

26182

296 
code_const "Heap_Monad.Fail" (OCaml "Failure")

27707

297 
code_const "Heap_Monad.raise_exc" (OCaml "!(fun/ ()/ >/ raise/ _)")


298 


299 
ML {*


300 
local


301 

28054

302 
open Code_Thingol;

27707

303 

28054

304 
val bind' = Code_Name.const @{theory} @{const_name bindM};


305 
val return' = Code_Name.const @{theory} @{const_name return};


306 
val unit' = Code_Name.const @{theory} @{const_name Unity};

27707

307 


308 
fun imp_monad_bind'' ts =


309 
let


310 
val dummy_name = "";


311 
val dummy_type = ITyVar dummy_name;


312 
val dummy_case_term = IVar dummy_name;


313 
(*assumption: dummy values are not relevant for serialization*)


314 
val unitt = IConst (unit', ([], []));


315 
fun dest_abs ((v, ty) `> t, _) = ((v, ty), t)


316 
 dest_abs (t, ty) =


317 
let

28054

318 
val vs = Code_Thingol.fold_varnames cons t [];

27707

319 
val v = Name.variant vs "x";

28054

320 
val ty' = (hd o fst o Code_Thingol.unfold_fun) ty;

27707

321 
in ((v, ty'), t `$ IVar v) end;


322 
fun force (t as IConst (c, _) `$ t') = if c = return'


323 
then t' else t `$ unitt


324 
 force t = t `$ unitt;


325 
fun tr_bind' [(t1, _), (t2, ty2)] =


326 
let


327 
val ((v, ty), t) = dest_abs (t2, ty2);


328 
in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end

28054

329 
and tr_bind'' t = case Code_Thingol.unfold_app t

27707

330 
of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if c = bind'


331 
then tr_bind' [(x1, ty1), (x2, ty2)]


332 
else force t


333 
 _ => force t;


334 
in (dummy_name, dummy_type) `> ICase (((IVar dummy_name, dummy_type),


335 
[(unitt, tr_bind' ts)]), dummy_case_term) end


336 
and imp_monad_bind' (const as (c, (_, tys))) ts = if c = bind' then case (ts, tys)


337 
of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)]


338 
 ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)] `$ t3


339 
 (ts, _) => imp_monad_bind (eta_expand 2 (const, ts))


340 
else IConst const `$$ map imp_monad_bind ts


341 
and imp_monad_bind (IConst const) = imp_monad_bind' const []


342 
 imp_monad_bind (t as IVar _) = t


343 
 imp_monad_bind (t as _ `$ _) = (case unfold_app t


344 
of (IConst const, ts) => imp_monad_bind' const ts


345 
 (t, ts) => imp_monad_bind t `$$ map imp_monad_bind ts)


346 
 imp_monad_bind (v_ty `> t) = v_ty `> imp_monad_bind t


347 
 imp_monad_bind (ICase (((t, ty), pats), t0)) = ICase


348 
(((imp_monad_bind t, ty), (map o pairself) imp_monad_bind pats), imp_monad_bind t0);


349 


350 
in


351 


352 
val imp_program = (Graph.map_nodes o map_terms_stmt) imp_monad_bind;


353 


354 
end


355 
*}


356 

28054

357 
setup {* Code_Target.extend_target ("SML_imp", ("SML", imp_program)) *}


358 
setup {* Code_Target.extend_target ("OCaml_imp", ("OCaml", imp_program)) *}

26182

359 


360 
code_reserved OCaml Failure raise


361 


362 


363 
subsubsection {* Haskell *}


364 


365 
text {* Adaption layer *}


366 


367 
code_include Haskell "STMonad"


368 
{*import qualified Control.Monad;


369 
import qualified Control.Monad.ST;


370 
import qualified Data.STRef;


371 
import qualified Data.Array.ST;


372 

27695

373 
type RealWorld = Control.Monad.ST.RealWorld;

26182

374 
type ST s a = Control.Monad.ST.ST s a;


375 
type STRef s a = Data.STRef.STRef s a;

27673

376 
type STArray s a = Data.Array.ST.STArray s Int a;

26182

377 


378 
runST :: (forall s. ST s a) > a;


379 
runST s = Control.Monad.ST.runST s;


380 


381 
newSTRef = Data.STRef.newSTRef;


382 
readSTRef = Data.STRef.readSTRef;


383 
writeSTRef = Data.STRef.writeSTRef;


384 

27673

385 
newArray :: (Int, Int) > a > ST s (STArray s a);

26182

386 
newArray = Data.Array.ST.newArray;


387 

27673

388 
newListArray :: (Int, Int) > [a] > ST s (STArray s a);

26182

389 
newListArray = Data.Array.ST.newListArray;


390 

27673

391 
lengthArray :: STArray s a > ST s Int;


392 
lengthArray a = Control.Monad.liftM snd (Data.Array.ST.getBounds a);

26182

393 

27673

394 
readArray :: STArray s a > Int > ST s a;

26182

395 
readArray = Data.Array.ST.readArray;


396 

27673

397 
writeArray :: STArray s a > Int > a > ST s ();

26182

398 
writeArray = Data.Array.ST.writeArray;*}


399 

27695

400 
code_reserved Haskell RealWorld ST STRef Array

26182

401 
runST


402 
newSTRef reasSTRef writeSTRef

27673

403 
newArray newListArray lengthArray readArray writeArray

26182

404 


405 
text {* Monad *}


406 

27695

407 
code_type Heap (Haskell "ST/ RealWorld/ _")


408 
code_const Heap (Haskell "error/ \"bare Heap\"")

28145

409 
code_monad "op \<guillemotright>=" Haskell

26182

410 
code_const return (Haskell "return")


411 
code_const "Heap_Monad.Fail" (Haskell "_")


412 
code_const "Heap_Monad.raise_exc" (Haskell "error")


413 

26170

414 
end
