| author | wenzelm | 
| Wed, 21 Jan 2009 22:26:48 +0100 | |
| changeset 29603 | b660ee46f2f6 | 
| parent 25520 | e123c81257a5 | 
| child 30161 | c26e515f1c29 | 
| permissions | -rw-r--r-- | 
| 24584 | 1 | (* Title: Tools/Compute_Oracle/am_ghc.ML | 
| 23663 | 2 | ID: $Id$ | 
| 3 | Author: Steven Obua | |
| 4 | *) | |
| 5 | ||
| 6 | structure AM_GHC : ABSTRACT_MACHINE = struct | |
| 7 | ||
| 8 | open AbstractMachine; | |
| 9 | ||
| 10 | type program = string * string * (int Inttab.table) | |
| 11 | ||
| 12 | fun count_patternvars PVar = 1 | |
| 13 | | count_patternvars (PConst (_, ps)) = | |
| 14 | List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps | |
| 15 | ||
| 16 | fun update_arity arity code a = | |
| 17 | (case Inttab.lookup arity code of | |
| 18 | NONE => Inttab.update_new (code, a) arity | |
| 24134 
6e69e0031f34
added int type constraints to accomodate hacked SML/NJ;
 wenzelm parents: 
23663diff
changeset | 19 | | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) | 
| 23663 | 20 | |
| 21 | (* We have to find out the maximal arity of each constant *) | |
| 22 | fun collect_pattern_arity PVar arity = arity | |
| 23 | | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) | |
| 24 | ||
| 25 | local | |
| 26 | fun collect applevel (Var _) arity = arity | |
| 27 | | collect applevel (Const c) arity = update_arity arity c applevel | |
| 28 | | collect applevel (Abs m) arity = collect 0 m arity | |
| 29 | | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) | |
| 30 | in | |
| 31 | fun collect_term_arity t arity = collect 0 t arity | |
| 32 | end | |
| 33 | ||
| 34 | fun nlift level n (Var m) = if m < level then Var m else Var (m+n) | |
| 35 | | nlift level n (Const c) = Const c | |
| 36 | | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) | |
| 37 | | nlift level n (Abs b) = Abs (nlift (level+1) n b) | |
| 38 | ||
| 39 | fun rep n x = if n = 0 then [] else x::(rep (n-1) x) | |
| 40 | ||
| 41 | fun adjust_rules rules = | |
| 42 | let | |
| 43 | val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty | |
| 44 | fun arity_of c = the (Inttab.lookup arity c) | |
| 45 | fun adjust_pattern PVar = PVar | |
| 46 | 	  | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
 | |
| 47 | 	fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
 | |
| 48 | | adjust_rule (rule as (p as PConst (c, args),t)) = | |
| 49 | let | |
| 50 | 		val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
 | |
| 51 | val args = map adjust_pattern args | |
| 52 | val len = length args | |
| 53 | val arity = arity_of c | |
| 54 | fun lift level n (Var m) = if m < level then Var m else Var (m+n) | |
| 55 | | lift level n (Const c) = Const c | |
| 56 | | lift level n (App (a,b)) = App (lift level n a, lift level n b) | |
| 57 | | lift level n (Abs b) = Abs (lift (level+1) n b) | |
| 58 | val lift = lift 0 | |
| 59 | fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) | |
| 60 | in | |
| 61 | if len = arity then | |
| 62 | rule | |
| 63 | else if arity >= len then | |
| 64 | (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t)) | |
| 65 | else (raise Compile "internal error in adjust_rule") | |
| 66 | end | |
| 67 | in | |
| 68 | (arity, map adjust_rule rules) | |
| 69 | end | |
| 70 | ||
| 71 | fun print_term arity_of n = | |
| 72 | let | |
| 73 | fun str x = string_of_int x | |
| 74 |     fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
 | |
| 75 | ||
| 76 | fun print_apps d f [] = f | |
| 77 |       | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
 | |
| 78 | and print_call d (App (a, b)) args = print_call d a (b::args) | |
| 79 | | print_call d (Const c) args = | |
| 80 | (case arity_of c of | |
| 81 | 	     NONE => print_apps d ("Const "^(str c)) args 
 | |
| 82 | | SOME a => | |
| 83 | let | |
| 84 | val len = length args | |
| 85 | in | |
| 86 | if a <= len then | |
| 87 | let | |
| 88 | val s = "c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a)))) | |
| 89 | in | |
| 90 | print_apps d s (List.drop (args, a)) | |
| 91 | end | |
| 92 | else | |
| 93 | let | |
| 94 | fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1))) | |
| 95 | fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) | |
| 96 | fun append_args [] t = t | |
| 97 | | append_args (c::cs) t = append_args cs (App (t, c)) | |
| 98 | in | |
| 99 | print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) | |
| 100 | end | |
| 101 | end) | |
| 102 | | print_call d t args = print_apps d (print_term d t) args | |
| 103 | and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1)) | |
| 104 | | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")" | |
| 105 | | print_term d t = print_call d t [] | |
| 106 | in | |
| 107 | print_term 0 | |
| 108 | end | |
| 109 | ||
| 110 | fun print_rule arity_of (p, t) = | |
| 111 | let | |
| 112 | fun str x = Int.toString x | |
| 113 | fun print_pattern top n PVar = (n+1, "x"^(str n)) | |
| 114 | | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)) | |
| 115 | | print_pattern top n (PConst (c, args)) = | |
| 116 | let | |
| 117 | val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args | |
| 118 | in | |
| 119 | 		(n, if top then s else "("^s^")")
 | |
| 120 | end | |
| 121 | and print_pattern_list r [] = r | |
| 122 | | print_pattern_list (n, p) (t::ts) = | |
| 123 | let | |
| 124 | val (n, t) = print_pattern false n t | |
| 125 | in | |
| 126 | print_pattern_list (n, p^" "^t) ts | |
| 127 | end | |
| 128 | val (n, pattern) = print_pattern true 0 p | |
| 129 | in | |
| 130 | pattern^" = "^(print_term arity_of n t) | |
| 131 | end | |
| 132 | ||
| 133 | fun group_rules rules = | |
| 134 | let | |
| 135 | fun add_rule (r as (PConst (c,_), _)) groups = | |
| 136 | let | |
| 137 | val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) | |
| 138 | in | |
| 139 | Inttab.update (c, r::rs) groups | |
| 140 | end | |
| 141 | | add_rule _ _ = raise Compile "internal error group_rules" | |
| 142 | in | |
| 143 | fold_rev add_rule rules Inttab.empty | |
| 144 | end | |
| 145 | ||
| 146 | fun haskell_prog name rules = | |
| 147 | let | |
| 148 | val buffer = ref "" | |
| 149 | fun write s = (buffer := (!buffer)^s) | |
| 150 | fun writeln s = (write s; write "\n") | |
| 151 | fun writelist [] = () | |
| 152 | | writelist (s::ss) = (writeln s; writelist ss) | |
| 153 | fun str i = Int.toString i | |
| 154 | val (arity, rules) = adjust_rules rules | |
| 155 | val rules = group_rules rules | |
| 156 | val constants = Inttab.keys arity | |
| 157 | fun arity_of c = Inttab.lookup arity c | |
| 158 | fun rep_str s n = concat (rep n s) | |
| 159 | fun indexed s n = s^(str n) | |
| 160 | fun section n = if n = 0 then [] else (section (n-1))@[n-1] | |
| 161 | fun make_show c = | |
| 162 | let | |
| 163 | val args = section (the (arity_of c)) | |
| 164 | in | |
| 165 | 		"  show ("^(indexed "C" c)^(concat (map (indexed " a") args))^") = "
 | |
| 166 | ^"\""^(indexed "C" c)^"\""^(concat (map (fn a => "++(show "^(indexed "a" a)^")") args)) | |
| 167 | end | |
| 168 | fun default_case c = | |
| 169 | let | |
| 170 | val args = concat (map (indexed " x") (section (the (arity_of c)))) | |
| 171 | in | |
| 172 | (indexed "c" c)^args^" = "^(indexed "C" c)^args | |
| 173 | end | |
| 174 | val _ = writelist [ | |
| 175 | "module "^name^" where", | |
| 176 | "", | |
| 177 | "data Term = Const Integer | App Term Term | Abs (Term -> Term)", | |
| 178 | " "^(concat (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)), | |
| 179 | "", | |
| 180 | "instance Show Term where"] | |
| 181 | val _ = writelist (map make_show constants) | |
| 182 | val _ = writelist [ | |
| 183 | " show (Const c) = \"c\"++(show c)", | |
| 184 | " show (App a b) = \"A\"++(show a)++(show b)", | |
| 185 | " show (Abs _) = \"L\"", | |
| 186 | ""] | |
| 187 | val _ = writelist [ | |
| 188 | "app (Abs a) b = a b", | |
| 189 | "app a b = App a b", | |
| 190 | "", | |
| 191 | "calc s c = writeFile s (show c)", | |
| 192 | ""] | |
| 193 | fun list_group c = (writelist (case Inttab.lookup rules c of | |
| 194 | NONE => [default_case c, ""] | |
| 195 | | SOME (rs as ((PConst (_, []), _)::rs')) => | |
| 196 | if not (null rs') then raise Compile "multiple declaration of constant" | |
| 197 | else (map (print_rule arity_of) rs) @ [""] | |
| 198 | | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""])) | |
| 199 | val _ = map list_group constants | |
| 200 | in | |
| 201 | (arity, !buffer) | |
| 202 | end | |
| 203 | ||
| 204 | val guid_counter = ref 0 | |
| 205 | fun get_guid () = | |
| 206 | let | |
| 207 | val c = !guid_counter | |
| 208 | val _ = guid_counter := !guid_counter + 1 | |
| 209 | in | |
| 210 | (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) | |
| 211 | end | |
| 212 | ||
| 213 | fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s]))); | |
| 214 | fun wrap s = "\""^s^"\"" | |
| 215 | ||
| 216 | fun writeTextFile name s = File.write (Path.explode name) s | |
| 217 | ||
| 218 | val ghc = ref (case getenv "GHC_PATH" of "" => "ghc" | s => s) | |
| 219 | ||
| 220 | fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) | |
| 221 | ||
| 25520 | 222 | fun compile cache_patterns const_arity eqs = | 
| 23663 | 223 | let | 
| 224 | 	val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
 | |
| 225 | val eqs = map (fn (a,b,c) => (b,c)) eqs | |
| 226 | val guid = get_guid () | |
| 227 | val module = "AMGHC_Prog_"^guid | |
| 228 | val (arity, source) = haskell_prog module eqs | |
| 229 | val module_file = tmp_file (module^".hs") | |
| 230 | val object_file = tmp_file (module^".o") | |
| 231 | val _ = writeTextFile module_file source | |
| 232 | val _ = system ((!ghc)^" -c "^module_file) | |
| 233 | 	val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else ()
 | |
| 234 | in | |
| 235 | (guid, module_file, arity) | |
| 236 | end | |
| 237 | ||
| 238 | fun readResultFile name = File.read (Path.explode name) | |
| 239 | ||
| 240 | fun parse_result arity_of result = | |
| 241 | let | |
| 242 | val result = String.explode result | |
| 243 | fun shift NONE x = SOME x | |
| 244 | | shift (SOME y) x = SOME (y*10 + x) | |
| 245 | fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest | |
| 246 | | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest | |
| 247 | | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest | |
| 248 | | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest | |
| 249 | | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest | |
| 250 | | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest | |
| 251 | | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest | |
| 252 | | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest | |
| 253 | | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest | |
| 254 | | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest | |
| 255 | | parse_int' x rest = (x, rest) | |
| 256 | fun parse_int rest = parse_int' NONE rest | |
| 257 | ||
| 258 | fun parse (#"C"::rest) = | |
| 259 | (case parse_int rest of | |
| 260 | (SOME c, rest) => | |
| 261 | let | |
| 262 | val (args, rest) = parse_list (the (arity_of c)) rest | |
| 263 | fun app_args [] t = t | |
| 264 | | app_args (x::xs) t = app_args xs (App (t, x)) | |
| 265 | in | |
| 266 | (app_args args (Const c), rest) | |
| 267 | end | |
| 268 | | (NONE, rest) => raise Run "parse C") | |
| 269 | | parse (#"c"::rest) = | |
| 270 | (case parse_int rest of | |
| 271 | (SOME c, rest) => (Const c, rest) | |
| 272 | | _ => raise Run "parse c") | |
| 273 | | parse (#"A"::rest) = | |
| 274 | let | |
| 275 | val (a, rest) = parse rest | |
| 276 | val (b, rest) = parse rest | |
| 277 | in | |
| 278 | (App (a,b), rest) | |
| 279 | end | |
| 280 | | parse (#"L"::rest) = raise Run "there may be no abstraction in the result" | |
| 281 | | parse _ = raise Run "invalid result" | |
| 282 | and parse_list n rest = | |
| 283 | if n = 0 then | |
| 284 | ([], rest) | |
| 285 | else | |
| 286 | let | |
| 287 | val (x, rest) = parse rest | |
| 288 | val (xs, rest) = parse_list (n-1) rest | |
| 289 | in | |
| 290 | (x::xs, rest) | |
| 291 | end | |
| 292 | val (parsed, rest) = parse result | |
| 293 | fun is_blank (#" "::rest) = is_blank rest | |
| 294 | | is_blank (#"\n"::rest) = is_blank rest | |
| 295 | | is_blank [] = true | |
| 296 | | is_blank _ = false | |
| 297 | in | |
| 298 | if is_blank rest then parsed else raise Run "non-blank suffix in result file" | |
| 299 | end | |
| 300 | ||
| 301 | fun run (guid, module_file, arity) t = | |
| 302 | let | |
| 303 | 	val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
 | |
| 304 | fun arity_of c = Inttab.lookup arity c | |
| 305 | val callguid = get_guid() | |
| 306 | val module = "AMGHC_Prog_"^guid | |
| 307 | val call = module^"_Call_"^callguid | |
| 308 | val result_file = tmp_file (module^"_Result_"^callguid^".txt") | |
| 309 | val call_file = tmp_file (call^".hs") | |
| 310 | val term = print_term arity_of 0 t | |
| 311 | 	val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
 | |
| 312 | val _ = writeTextFile call_file call_source | |
| 313 | val _ = system ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file) | |
| 314 | 	val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')")
 | |
| 315 | val t' = parse_result arity_of result | |
| 316 | val _ = OS.FileSys.remove call_file | |
| 317 | val _ = OS.FileSys.remove result_file | |
| 318 | in | |
| 319 | t' | |
| 320 | end | |
| 321 | ||
| 322 | ||
| 323 | fun discard _ = () | |
| 324 | ||
| 325 | end | |
| 326 |