--- a/src/HOL/Library/Heap_Monad.thy Sun Apr 27 17:13:01 2008 +0200
+++ b/src/HOL/Library/Heap_Monad.thy Mon Apr 28 13:41:04 2008 +0200
@@ -299,9 +299,8 @@
subsubsection {* SML *}
code_type Heap (SML "unit/ ->/ _")
-term "op \<guillemotright>="
code_const Heap (SML "raise/ (Fail/ \"bare Heap\")")
-code_monad run "op \<guillemotright>=" "()" SML
+code_monad run "op \<guillemotright>=" return "()" SML
code_const run (SML "_")
code_const return (SML "(fn/ ()/ =>/ _)")
code_const "Heap_Monad.Fail" (SML "Fail")
@@ -312,7 +311,7 @@
code_type Heap (OCaml "_")
code_const Heap (OCaml "failwith/ \"bare Heap\"")
-code_monad run "op \<guillemotright>=" "()" OCaml
+code_monad run "op \<guillemotright>=" return "()" OCaml
code_const run (OCaml "_")
code_const return (OCaml "(fn/ ()/ =>/ _)")
code_const "Heap_Monad.Fail" (OCaml "Failure")
--- a/src/Tools/code/code_target.ML Sun Apr 27 17:13:01 2008 +0200
+++ b/src/Tools/code/code_target.ML Mon Apr 28 13:41:04 2008 +0200
@@ -1855,7 +1855,7 @@
| NONE => error "Illegal message expression";
in (1, pretty) end;
-fun pretty_imperative_monad_bind bind' unit' =
+fun pretty_imperative_monad_bind bind' return' unit' =
let
val dummy_name = "";
val dummy_type = ITyVar dummy_name;
@@ -1869,16 +1869,18 @@
val v = Name.variant vs "x";
val ty' = (hd o fst o CodeThingol.unfold_fun) ty;
in ((v, ty'), t `$ IVar v) end;
+ fun force (t as IConst (c, _) `$ t') = if c = return'
+ then t' else t `$ unitt
+ | force t = t `$ unitt;
fun tr_bind' [(t1, _), (t2, ty2)] =
let
val ((v, ty), t) = dest_abs (t2, ty2);
- in ICase (((t1 `$ unitt, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
- and tr_bind'' (t as _ `$ _) = (case CodeThingol.unfold_app t
+ in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
+ and tr_bind'' t = case CodeThingol.unfold_app t
of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if c = bind'
then tr_bind' [(x1, ty1), (x2, ty2)]
- else t `$ unitt
- | _ => t `$ unitt)
- | tr_bind'' t = t `$ unitt;
+ else force t
+ | _ => force t;
fun tr_bind ts = (dummy_name, dummy_type)
`|-> ICase (((IVar dummy_name, dummy_type), [(unitt, tr_bind' ts)]), dummy_case_term);
fun pretty pr vars fxy ts = pr vars fxy (tr_bind ts);
@@ -2009,17 +2011,18 @@
fun add_modl_alias target =
map_module_alias target o Symtab.update o apsnd CodeName.check_modulename;
-fun add_monad target c_run c_bind c_unit thy =
+fun add_monad target c_run c_bind c_return_unit thy =
let
val c_run' = CodeUnit.read_const thy c_run;
val c_bind' = CodeUnit.read_const thy c_bind;
val c_bind'' = CodeName.const thy c_bind';
- val c_unit'' = Option.map (CodeName.const thy o CodeUnit.read_const thy) c_unit;
+ val c_return_unit'' = (Option.map o pairself)
+ (CodeName.const thy o CodeUnit.read_const thy) c_return_unit;
val is_haskell = target = target_Haskell;
- val _ = if is_haskell andalso is_some c_unit''
+ val _ = if is_haskell andalso is_some c_return_unit''
then error ("No unit entry may be given for Haskell monad")
else ();
- val _ = if not is_haskell andalso is_none c_unit''
+ val _ = if not is_haskell andalso is_none c_return_unit''
then error ("Unit entry must be given for SML/OCaml monad")
else ();
in if target = target_Haskell then
@@ -2031,7 +2034,8 @@
else
thy
|> gen_add_syntax_const (K I) target c_bind'
- (SOME (pretty_imperative_monad_bind c_bind'' (the c_unit'')))
+ (SOME (pretty_imperative_monad_bind c_bind''
+ ((fst o the) c_return_unit'') ((snd o the) c_return_unit'')))
end;
fun gen_allow_exception prep_cs raw_c thy =
@@ -2185,10 +2189,10 @@
val _ =
OuterSyntax.command "code_monad" "define code syntax for monads" K.thy_decl (
- P.term -- P.term -- ((P.term >> SOME) -- Scan.repeat1 P.name
+ P.term -- P.term -- ((P.term -- P.term >> SOME) -- Scan.repeat1 P.name
|| Scan.succeed NONE -- Scan.repeat1 P.name)
- >> (fn ((raw_run, raw_bind), (raw_unit, targets)) => Toplevel.theory
- (fold (fn target => add_monad target raw_run raw_bind raw_unit) targets))
+ >> (fn ((raw_run, raw_bind), (raw_unit_return, targets)) => Toplevel.theory
+ (fold (fn target => add_monad target raw_run raw_bind raw_unit_return) targets))
);
val _ =