26170

1 
(* Title: HOL/Library/Heap_Monad.thy


2 
ID: $Id$


3 
Author: John Matthews, Galois Connections; Alexander Krauss, Lukas Bulwahn & Florian Haftmann, TU Muenchen


4 
*)


5 


6 
header {* A monad with a polymorphic heap *}


7 


8 
theory Heap_Monad


9 
imports Heap


10 
begin


11 


12 
subsection {* The monad *}


13 


14 
subsubsection {* Monad combinators *}


15 


16 
datatype exception = Exn


17 


18 
text {* Monadic heap actions either produce values


19 
and transform the heap, or fail *}


20 
datatype 'a Heap = Heap "heap \<Rightarrow> ('a + exception) \<times> heap"


21 


22 
primrec


23 
execute :: "'a Heap \<Rightarrow> heap \<Rightarrow> ('a + exception) \<times> heap" where


24 
"execute (Heap f) = f"


25 
lemmas [code del] = execute.simps


26 


27 
lemma Heap_execute [simp]:


28 
"Heap (execute f) = f" by (cases f) simp_all


29 


30 
lemma Heap_eqI:


31 
"(\<And>h. execute f h = execute g h) \<Longrightarrow> f = g"


32 
by (cases f, cases g) (auto simp: expand_fun_eq)


33 


34 
lemma Heap_eqI':


35 
"(\<And>h. (\<lambda>x. execute (f x) h) = (\<lambda>y. execute (g y) h)) \<Longrightarrow> f = g"


36 
by (auto simp: expand_fun_eq intro: Heap_eqI)


37 


38 
lemma Heap_strip: "(\<And>f. PROP P f) \<equiv> (\<And>g. PROP P (Heap g))"


39 
proof


40 
fix g :: "heap \<Rightarrow> ('a + exception) \<times> heap"


41 
assume "\<And>f. PROP P f"


42 
then show "PROP P (Heap g)" .


43 
next


44 
fix f :: "'a Heap"


45 
assume assm: "\<And>g. PROP P (Heap g)"


46 
then have "PROP P (Heap (execute f))" .


47 
then show "PROP P f" by simp


48 
qed


49 


50 
definition


51 
heap :: "(heap \<Rightarrow> 'a \<times> heap) \<Rightarrow> 'a Heap" where


52 
[code del]: "heap f = Heap (\<lambda>h. apfst Inl (f h))"


53 


54 
lemma execute_heap [simp]:


55 
"execute (heap f) h = apfst Inl (f h)"


56 
by (simp add: heap_def)


57 


58 
definition


59 
run :: "'a Heap \<Rightarrow> 'a Heap" where


60 
run_drop [code del]: "run f = f"


61 


62 
definition


63 
bindM :: "'a Heap \<Rightarrow> ('a \<Rightarrow> 'b Heap) \<Rightarrow> 'b Heap" (infixl ">>=" 54) where


64 
[code del]: "f >>= g = Heap (\<lambda>h. case execute f h of


65 
(Inl x, h') \<Rightarrow> execute (g x) h'


66 
 r \<Rightarrow> r)"


67 


68 
notation


69 
bindM (infixl "\<guillemotright>=" 54)


70 


71 
abbreviation


72 
chainM :: "'a Heap \<Rightarrow> 'b Heap \<Rightarrow> 'b Heap" (infixl ">>" 54) where


73 
"f >> g \<equiv> f >>= (\<lambda>_. g)"


74 


75 
notation


76 
chainM (infixl "\<guillemotright>" 54)


77 


78 
definition


79 
return :: "'a \<Rightarrow> 'a Heap" where


80 
[code del]: "return x = heap (Pair x)"


81 


82 
lemma execute_return [simp]:


83 
"execute (return x) h = apfst Inl (x, h)"


84 
by (simp add: return_def)


85 


86 
definition


87 
raise :: "string \<Rightarrow> 'a Heap" where  {* the string is just decoration *}


88 
[code del]: "raise s = Heap (Pair (Inr Exn))"


89 


90 
notation (latex output)


91 
"raise" ("\<^raw:{\textsf{raise}}>")


92 


93 
lemma execute_raise [simp]:


94 
"execute (raise s) h = (Inr Exn, h)"


95 
by (simp add: raise_def)


96 


97 


98 
subsubsection {* dosyntax *}


99 


100 
text {*


101 
We provide a convenient donotation for monadic expressions


102 
wellknown from Haskell. @{const Let} is printed


103 
specially in doexpressions.


104 
*}


105 


106 
nonterminals do_expr


107 


108 
syntax


109 
"_do" :: "do_expr \<Rightarrow> 'a"


110 
("(do (_)//done)" [12] 100)


111 
"_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


112 
("_ < _;//_" [1000, 13, 12] 12)


113 
"_chainM" :: "'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


114 
("_;//_" [13, 12] 12)


115 
"_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


116 
("let _ = _;//_" [1000, 13, 12] 12)


117 
"_nil" :: "'a \<Rightarrow> do_expr"


118 
("_" [12] 12)


119 


120 
syntax (xsymbols)


121 
"_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


122 
("_ \<leftarrow> _;//_" [1000, 13, 12] 12)


123 
syntax (latex output)


124 
"_do" :: "do_expr \<Rightarrow> 'a"


125 
("(\<^raw:{\textsf{do}}> (_))" [12] 100)


126 
"_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"


127 
("\<^raw:\textsf{let}> _ = _;//_" [1000, 13, 12] 12)


128 
notation (latex output)


129 
"return" ("\<^raw:{\textsf{return}}>")


130 


131 
translations


132 
"_do f" => "CONST run f"


133 
"_bindM x f g" => "f \<guillemotright>= (\<lambda>x. g)"


134 
"_chainM f g" => "f \<guillemotright> g"


135 
"_let x t f" => "CONST Let t (\<lambda>x. f)"


136 
"_nil f" => "f"


137 


138 
print_translation {*


139 
let


140 
fun dest_abs_eta (Abs (abs as (_, ty, _))) =


141 
let


142 
val (v, t) = Syntax.variant_abs abs;


143 
in ((v, ty), t) end


144 
 dest_abs_eta t =


145 
let


146 
val (v, t) = Syntax.variant_abs ("", dummyT, t $ Bound 0);


147 
in ((v, dummyT), t) end


148 
fun unfold_monad (Const (@{const_syntax bindM}, _) $ f $ g) =


149 
let


150 
val ((v, ty), g') = dest_abs_eta g;


151 
val v_used = fold_aterms


152 
(fn Free (w, _) => (fn s => s orelse v = w)  _ => I) g' false;


153 
in if v_used then


154 
Const ("_bindM", dummyT) $ Free (v, ty) $ f $ unfold_monad g'


155 
else


156 
Const ("_chainM", dummyT) $ f $ unfold_monad g'


157 
end


158 
 unfold_monad (Const (@{const_syntax chainM}, _) $ f $ g) =


159 
Const ("_chainM", dummyT) $ f $ unfold_monad g


160 
 unfold_monad (Const (@{const_syntax Let}, _) $ f $ g) =


161 
let


162 
val ((v, ty), g') = dest_abs_eta g;


163 
in Const ("_let", dummyT) $ Free (v, ty) $ f $ unfold_monad g' end


164 
 unfold_monad (Const (@{const_syntax Pair}, _) $ f) =


165 
Const ("return", dummyT) $ f


166 
 unfold_monad f = f;


167 
fun tr' (f::ts) =


168 
list_comb (Const ("_do", dummyT) $ unfold_monad f, ts)


169 
in [(@{const_syntax "run"}, tr')] end;


170 
*}


171 


172 


173 
subsection {* Monad properties *}


174 


175 
subsubsection {* Superfluous runs *}


176 


177 
text {* @{term run} is just a doodle *}


178 


179 
lemma run_simp [simp]:


180 
"\<And>f. run (run f) = run f"


181 
"\<And>f g. run f \<guillemotright>= g = f \<guillemotright>= g"


182 
"\<And>f g. run f \<guillemotright> g = f \<guillemotright> g"


183 
"\<And>f g. f \<guillemotright>= (\<lambda>x. run g) = f \<guillemotright>= (\<lambda>x. g)"


184 
"\<And>f g. f \<guillemotright> run g = f \<guillemotright> g"


185 
"\<And>f. f = run g \<longleftrightarrow> f = g"


186 
"\<And>f. run f = g \<longleftrightarrow> f = g"


187 
unfolding run_drop by rule+


188 


189 
subsubsection {* Monad laws *}


190 


191 
lemma return_bind: "return x \<guillemotright>= f = f x"


192 
by (simp add: bindM_def return_def)


193 


194 
lemma bind_return: "f \<guillemotright>= return = f"


195 
proof (rule Heap_eqI)


196 
fix h


197 
show "execute (f \<guillemotright>= return) h = execute f h"


198 
by (auto simp add: bindM_def return_def split: sum.splits prod.splits)


199 
qed


200 


201 
lemma bind_bind: "(f \<guillemotright>= g) \<guillemotright>= h = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h)"


202 
by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)


203 


204 
lemma bind_bind': "f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h x) = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= (\<lambda>y. return (x, y))) \<guillemotright>= (\<lambda>(x, y). h x y)"


205 
by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)


206 


207 
lemma raise_bind: "raise e \<guillemotright>= f = raise e"


208 
by (simp add: raise_def bindM_def)


209 


210 


211 
lemmas monad_simp = return_bind bind_return bind_bind raise_bind


212 


213 


214 
subsection {* Generic combinators *}


215 


216 
definition


217 
liftM :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b Heap"


218 
where


219 
"liftM f = return o f"


220 


221 
definition


222 
compM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> ('b \<Rightarrow> 'c Heap) \<Rightarrow> 'a \<Rightarrow> 'c Heap" (infixl ">>==" 54)


223 
where


224 
"(f >>== g) = (\<lambda>x. f x \<guillemotright>= g)"


225 


226 
notation


227 
compM (infixl "\<guillemotright>==" 54)


228 


229 
lemma liftM_collapse: "liftM f x = return (f x)"


230 
by (simp add: liftM_def)


231 


232 
lemma liftM_compM: "liftM f \<guillemotright>== g = g o f"


233 
by (auto intro: Heap_eqI' simp add: expand_fun_eq liftM_def compM_def bindM_def)


234 


235 
lemma compM_return: "f \<guillemotright>== return = f"


236 
by (simp add: compM_def monad_simp)


237 


238 
lemma compM_compM: "(f \<guillemotright>== g) \<guillemotright>== h = f \<guillemotright>== (g \<guillemotright>== h)"


239 
by (simp add: compM_def monad_simp)


240 


241 
lemma liftM_bind:


242 
"(\<lambda>x. liftM f x \<guillemotright>= liftM g) = liftM (\<lambda>x. g (f x))"


243 
by (rule Heap_eqI') (simp add: monad_simp liftM_def bindM_def)


244 


245 
lemma liftM_comp:


246 
"liftM f o g = liftM (f o g)"


247 
by (rule Heap_eqI') (simp add: liftM_def)


248 


249 
lemmas monad_simp' = monad_simp liftM_compM compM_return


250 
compM_compM liftM_bind liftM_comp


251 


252 
primrec


253 
mapM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b list Heap"


254 
where


255 
"mapM f [] = return []"


256 
 "mapM f (x#xs) = do y \<leftarrow> f x;


257 
ys \<leftarrow> mapM f xs;


258 
return (y # ys)


259 
done"


260 


261 
primrec


262 
foldM :: "('a \<Rightarrow> 'b \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b \<Rightarrow> 'b Heap"


263 
where


264 
"foldM f [] s = return s"


265 
 "foldM f (x#xs) s = f x s \<guillemotright>= foldM f xs"


266 


267 
hide (open) const heap execute


268 

26182

269 


270 
subsection {* Code generator setup *}


271 


272 
subsubsection {* Logical intermediate layer *}


273 


274 
definition


275 
Fail :: "message_string \<Rightarrow> exception"


276 
where


277 
[code func del]: "Fail s = Exn"


278 


279 
definition


280 
raise_exc :: "exception \<Rightarrow> 'a Heap"


281 
where


282 
[code func del]: "raise_exc e = raise []"


283 


284 
lemma raise_raise_exc [code func, code inline]:


285 
"raise s = raise_exc (Fail (STR s))"


286 
unfolding Fail_def raise_exc_def raise_def ..


287 


288 
hide (open) const Fail raise_exc


289 


290 

27707

291 
subsubsection {* SML and OCaml *}

26182

292 

26752

293 
code_type Heap (SML "unit/ >/ _")

26182

294 
code_const Heap (SML "raise/ (Fail/ \"bare Heap\")")

27707

295 
code_const "op \<guillemotright>=" (SML "!(fn/ f/ =>/ fn/ g/ =>/ fn/ ()/ =>/ g/ (f/ ())/ ())")

26182

296 
code_const run (SML "_")

27707

297 
code_const return (SML "!(fn/ ()/ =>/ _)")

26182

298 
code_const "Heap_Monad.Fail" (SML "Fail")

27707

299 
code_const "Heap_Monad.raise_exc" (SML "!(fn/ ()/ =>/ raise/ _)")

26182

300 


301 
code_type Heap (OCaml "_")


302 
code_const Heap (OCaml "failwith/ \"bare Heap\"")

27707

303 
code_const "op \<guillemotright>=" (OCaml "!(fun/ f/ g/ ()/ >/ g/ (f/ ())/ ())")

26182

304 
code_const run (OCaml "_")

27707

305 
code_const return (OCaml "!(fun/ ()/ >/ _)")

26182

306 
code_const "Heap_Monad.Fail" (OCaml "Failure")

27707

307 
code_const "Heap_Monad.raise_exc" (OCaml "!(fun/ ()/ >/ raise/ _)")


308 


309 
ML {*


310 
local


311 


312 
open CodeThingol;


313 


314 
val bind' = CodeName.const @{theory} @{const_name bindM};


315 
val return' = CodeName.const @{theory} @{const_name return};


316 
val unit' = CodeName.const @{theory} @{const_name Unity};


317 


318 
fun imp_monad_bind'' ts =


319 
let


320 
val dummy_name = "";


321 
val dummy_type = ITyVar dummy_name;


322 
val dummy_case_term = IVar dummy_name;


323 
(*assumption: dummy values are not relevant for serialization*)


324 
val unitt = IConst (unit', ([], []));


325 
fun dest_abs ((v, ty) `> t, _) = ((v, ty), t)


326 
 dest_abs (t, ty) =


327 
let


328 
val vs = CodeThingol.fold_varnames cons t [];


329 
val v = Name.variant vs "x";


330 
val ty' = (hd o fst o CodeThingol.unfold_fun) ty;


331 
in ((v, ty'), t `$ IVar v) end;


332 
fun force (t as IConst (c, _) `$ t') = if c = return'


333 
then t' else t `$ unitt


334 
 force t = t `$ unitt;


335 
fun tr_bind' [(t1, _), (t2, ty2)] =


336 
let


337 
val ((v, ty), t) = dest_abs (t2, ty2);


338 
in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end


339 
and tr_bind'' t = case CodeThingol.unfold_app t


340 
of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if c = bind'


341 
then tr_bind' [(x1, ty1), (x2, ty2)]


342 
else force t


343 
 _ => force t;


344 
in (dummy_name, dummy_type) `> ICase (((IVar dummy_name, dummy_type),


345 
[(unitt, tr_bind' ts)]), dummy_case_term) end


346 
and imp_monad_bind' (const as (c, (_, tys))) ts = if c = bind' then case (ts, tys)


347 
of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)]


348 
 ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)] `$ t3


349 
 (ts, _) => imp_monad_bind (eta_expand 2 (const, ts))


350 
else IConst const `$$ map imp_monad_bind ts


351 
and imp_monad_bind (IConst const) = imp_monad_bind' const []


352 
 imp_monad_bind (t as IVar _) = t


353 
 imp_monad_bind (t as _ `$ _) = (case unfold_app t


354 
of (IConst const, ts) => imp_monad_bind' const ts


355 
 (t, ts) => imp_monad_bind t `$$ map imp_monad_bind ts)


356 
 imp_monad_bind (v_ty `> t) = v_ty `> imp_monad_bind t


357 
 imp_monad_bind (ICase (((t, ty), pats), t0)) = ICase


358 
(((imp_monad_bind t, ty), (map o pairself) imp_monad_bind pats), imp_monad_bind t0);


359 


360 
in


361 


362 
val imp_program = (Graph.map_nodes o map_terms_stmt) imp_monad_bind;


363 


364 
end


365 
*}


366 


367 
setup {* CodeTarget.extend_target ("SML_imp", ("SML", imp_program)) *}


368 
setup {* CodeTarget.extend_target ("OCaml_imp", ("OCaml", imp_program)) *}

26182

369 


370 
code_reserved OCaml Failure raise


371 


372 


373 
subsubsection {* Haskell *}


374 


375 
text {* Adaption layer *}


376 


377 
code_include Haskell "STMonad"


378 
{*import qualified Control.Monad;


379 
import qualified Control.Monad.ST;


380 
import qualified Data.STRef;


381 
import qualified Data.Array.ST;


382 

27695

383 
type RealWorld = Control.Monad.ST.RealWorld;

26182

384 
type ST s a = Control.Monad.ST.ST s a;


385 
type STRef s a = Data.STRef.STRef s a;

27673

386 
type STArray s a = Data.Array.ST.STArray s Int a;

26182

387 


388 
runST :: (forall s. ST s a) > a;


389 
runST s = Control.Monad.ST.runST s;


390 


391 
newSTRef = Data.STRef.newSTRef;


392 
readSTRef = Data.STRef.readSTRef;


393 
writeSTRef = Data.STRef.writeSTRef;


394 

27673

395 
newArray :: (Int, Int) > a > ST s (STArray s a);

26182

396 
newArray = Data.Array.ST.newArray;


397 

27673

398 
newListArray :: (Int, Int) > [a] > ST s (STArray s a);

26182

399 
newListArray = Data.Array.ST.newListArray;


400 

27673

401 
lengthArray :: STArray s a > ST s Int;


402 
lengthArray a = Control.Monad.liftM snd (Data.Array.ST.getBounds a);

26182

403 

27673

404 
readArray :: STArray s a > Int > ST s a;

26182

405 
readArray = Data.Array.ST.readArray;


406 

27673

407 
writeArray :: STArray s a > Int > a > ST s ();

26182

408 
writeArray = Data.Array.ST.writeArray;*}


409 

27695

410 
code_reserved Haskell RealWorld ST STRef Array

26182

411 
runST


412 
newSTRef reasSTRef writeSTRef

27673

413 
newArray newListArray lengthArray readArray writeArray

26182

414 


415 
text {* Monad *}


416 

27695

417 
code_type Heap (Haskell "ST/ RealWorld/ _")


418 
code_const Heap (Haskell "error/ \"bare Heap\"")

26752

419 
code_monad run "op \<guillemotright>=" Haskell

26182

420 
code_const return (Haskell "return")


421 
code_const "Heap_Monad.Fail" (Haskell "_")


422 
code_const "Heap_Monad.raise_exc" (Haskell "error")


423 

26170

424 
end
