src/Tools/code/code_target.ML
changeset 26753 094d70c81243
parent 26752 6b276119139b
child 26998 2c4032d59586
     1.1 --- a/src/Tools/code/code_target.ML	Sun Apr 27 17:13:01 2008 +0200
     1.2 +++ b/src/Tools/code/code_target.ML	Mon Apr 28 13:41:04 2008 +0200
     1.3 @@ -1855,7 +1855,7 @@
     1.4          | NONE => error "Illegal message expression";
     1.5    in (1, pretty) end;
     1.6  
     1.7 -fun pretty_imperative_monad_bind bind' unit' =
     1.8 +fun pretty_imperative_monad_bind bind' return' unit' =
     1.9    let
    1.10      val dummy_name = "";
    1.11      val dummy_type = ITyVar dummy_name;
    1.12 @@ -1869,16 +1869,18 @@
    1.13              val v = Name.variant vs "x";
    1.14              val ty' = (hd o fst o CodeThingol.unfold_fun) ty;
    1.15            in ((v, ty'), t `$ IVar v) end;
    1.16 +    fun force (t as IConst (c, _) `$ t') = if c = return'
    1.17 +          then t' else t `$ unitt
    1.18 +      | force t = t `$ unitt;
    1.19      fun tr_bind' [(t1, _), (t2, ty2)] =
    1.20        let
    1.21          val ((v, ty), t) = dest_abs (t2, ty2);
    1.22 -      in ICase (((t1 `$ unitt, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
    1.23 -    and tr_bind'' (t as _ `$ _) = (case CodeThingol.unfold_app t
    1.24 +      in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
    1.25 +    and tr_bind'' t = case CodeThingol.unfold_app t
    1.26           of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if c = bind'
    1.27                then tr_bind' [(x1, ty1), (x2, ty2)]
    1.28 -              else t `$ unitt
    1.29 -          | _ => t `$ unitt)
    1.30 -      | tr_bind'' t = t `$ unitt;
    1.31 +              else force t
    1.32 +          | _ => force t;
    1.33      fun tr_bind ts = (dummy_name, dummy_type)
    1.34        `|-> ICase (((IVar dummy_name, dummy_type), [(unitt, tr_bind' ts)]), dummy_case_term);
    1.35      fun pretty pr vars fxy ts = pr vars fxy (tr_bind ts);
    1.36 @@ -2009,17 +2011,18 @@
    1.37  fun add_modl_alias target =
    1.38    map_module_alias target o Symtab.update o apsnd CodeName.check_modulename;
    1.39  
    1.40 -fun add_monad target c_run c_bind c_unit thy =
    1.41 +fun add_monad target c_run c_bind c_return_unit thy =
    1.42    let
    1.43      val c_run' = CodeUnit.read_const thy c_run;
    1.44      val c_bind' = CodeUnit.read_const thy c_bind;
    1.45      val c_bind'' = CodeName.const thy c_bind';
    1.46 -    val c_unit'' = Option.map (CodeName.const thy o CodeUnit.read_const thy) c_unit;
    1.47 +    val c_return_unit'' = (Option.map o pairself)
    1.48 +      (CodeName.const thy o CodeUnit.read_const thy) c_return_unit;
    1.49      val is_haskell = target = target_Haskell;
    1.50 -    val _ = if is_haskell andalso is_some c_unit''
    1.51 +    val _ = if is_haskell andalso is_some c_return_unit''
    1.52        then error ("No unit entry may be given for Haskell monad")
    1.53        else ();
    1.54 -    val _ = if not is_haskell andalso is_none c_unit''
    1.55 +    val _ = if not is_haskell andalso is_none c_return_unit''
    1.56        then error ("Unit entry must be given for SML/OCaml monad")
    1.57        else ();
    1.58    in if target = target_Haskell then
    1.59 @@ -2031,7 +2034,8 @@
    1.60    else
    1.61      thy
    1.62      |> gen_add_syntax_const (K I) target c_bind'
    1.63 -          (SOME (pretty_imperative_monad_bind c_bind'' (the c_unit'')))
    1.64 +          (SOME (pretty_imperative_monad_bind c_bind''
    1.65 +            ((fst o the) c_return_unit'') ((snd o the) c_return_unit'')))
    1.66    end;
    1.67  
    1.68  fun gen_allow_exception prep_cs raw_c thy =
    1.69 @@ -2185,10 +2189,10 @@
    1.70  
    1.71  val _ =
    1.72    OuterSyntax.command "code_monad" "define code syntax for monads" K.thy_decl (
    1.73 -    P.term -- P.term -- ((P.term >> SOME) -- Scan.repeat1 P.name
    1.74 +    P.term -- P.term -- ((P.term -- P.term >> SOME) -- Scan.repeat1 P.name
    1.75        || Scan.succeed NONE -- Scan.repeat1 P.name)
    1.76 -    >> (fn ((raw_run, raw_bind), (raw_unit, targets)) => Toplevel.theory 
    1.77 -          (fold (fn target => add_monad target raw_run raw_bind raw_unit) targets))
    1.78 +    >> (fn ((raw_run, raw_bind), (raw_unit_return, targets)) => Toplevel.theory 
    1.79 +          (fold (fn target => add_monad target raw_run raw_bind raw_unit_return) targets))
    1.80    );
    1.81  
    1.82  val _ =