1 (* Title: Tools/Compute_Oracle/am_ghc.ML |
|
2 Author: Steven Obua |
|
3 *) |
|
4 |
|
5 structure AM_GHC : ABSTRACT_MACHINE = struct |
|
6 |
|
7 open AbstractMachine; |
|
8 |
|
9 type program = string * string * (int Inttab.table) |
|
10 |
|
11 fun count_patternvars PVar = 1 |
|
12 | count_patternvars (PConst (_, ps)) = |
|
13 List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps |
|
14 |
|
15 fun update_arity arity code a = |
|
16 (case Inttab.lookup arity code of |
|
17 NONE => Inttab.update_new (code, a) arity |
|
18 | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) |
|
19 |
|
20 (* We have to find out the maximal arity of each constant *) |
|
21 fun collect_pattern_arity PVar arity = arity |
|
22 | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) |
|
23 |
|
24 local |
|
25 fun collect applevel (Var _) arity = arity |
|
26 | collect applevel (Const c) arity = update_arity arity c applevel |
|
27 | collect applevel (Abs m) arity = collect 0 m arity |
|
28 | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) |
|
29 in |
|
30 fun collect_term_arity t arity = collect 0 t arity |
|
31 end |
|
32 |
|
33 fun nlift level n (Var m) = if m < level then Var m else Var (m+n) |
|
34 | nlift level n (Const c) = Const c |
|
35 | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) |
|
36 | nlift level n (Abs b) = Abs (nlift (level+1) n b) |
|
37 |
|
38 fun rep n x = if n = 0 then [] else x::(rep (n-1) x) |
|
39 |
|
40 fun adjust_rules rules = |
|
41 let |
|
42 val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty |
|
43 fun arity_of c = the (Inttab.lookup arity c) |
|
44 fun adjust_pattern PVar = PVar |
|
45 | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C |
|
46 fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable") |
|
47 | adjust_rule (rule as (p as PConst (c, args),t)) = |
|
48 let |
|
49 val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () |
|
50 val args = map adjust_pattern args |
|
51 val len = length args |
|
52 val arity = arity_of c |
|
53 fun lift level n (Var m) = if m < level then Var m else Var (m+n) |
|
54 | lift level n (Const c) = Const c |
|
55 | lift level n (App (a,b)) = App (lift level n a, lift level n b) |
|
56 | lift level n (Abs b) = Abs (lift (level+1) n b) |
|
57 val lift = lift 0 |
|
58 fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) |
|
59 in |
|
60 if len = arity then |
|
61 rule |
|
62 else if arity >= len then |
|
63 (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t)) |
|
64 else (raise Compile "internal error in adjust_rule") |
|
65 end |
|
66 in |
|
67 (arity, map adjust_rule rules) |
|
68 end |
|
69 |
|
70 fun print_term arity_of n = |
|
71 let |
|
72 fun str x = string_of_int x |
|
73 fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s |
|
74 |
|
75 fun print_apps d f [] = f |
|
76 | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args |
|
77 and print_call d (App (a, b)) args = print_call d a (b::args) |
|
78 | print_call d (Const c) args = |
|
79 (case arity_of c of |
|
80 NONE => print_apps d ("Const "^(str c)) args |
|
81 | SOME a => |
|
82 let |
|
83 val len = length args |
|
84 in |
|
85 if a <= len then |
|
86 let |
|
87 val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a)))) |
|
88 in |
|
89 print_apps d s (List.drop (args, a)) |
|
90 end |
|
91 else |
|
92 let |
|
93 fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1))) |
|
94 fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) |
|
95 fun append_args [] t = t |
|
96 | append_args (c::cs) t = append_args cs (App (t, c)) |
|
97 in |
|
98 print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) |
|
99 end |
|
100 end) |
|
101 | print_call d t args = print_apps d (print_term d t) args |
|
102 and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1)) |
|
103 | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")" |
|
104 | print_term d t = print_call d t [] |
|
105 in |
|
106 print_term 0 |
|
107 end |
|
108 |
|
109 fun print_rule arity_of (p, t) = |
|
110 let |
|
111 fun str x = Int.toString x |
|
112 fun print_pattern top n PVar = (n+1, "x"^(str n)) |
|
113 | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)) |
|
114 | print_pattern top n (PConst (c, args)) = |
|
115 let |
|
116 val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args |
|
117 in |
|
118 (n, if top then s else "("^s^")") |
|
119 end |
|
120 and print_pattern_list r [] = r |
|
121 | print_pattern_list (n, p) (t::ts) = |
|
122 let |
|
123 val (n, t) = print_pattern false n t |
|
124 in |
|
125 print_pattern_list (n, p^" "^t) ts |
|
126 end |
|
127 val (n, pattern) = print_pattern true 0 p |
|
128 in |
|
129 pattern^" = "^(print_term arity_of n t) |
|
130 end |
|
131 |
|
132 fun group_rules rules = |
|
133 let |
|
134 fun add_rule (r as (PConst (c,_), _)) groups = |
|
135 let |
|
136 val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) |
|
137 in |
|
138 Inttab.update (c, r::rs) groups |
|
139 end |
|
140 | add_rule _ _ = raise Compile "internal error group_rules" |
|
141 in |
|
142 fold_rev add_rule rules Inttab.empty |
|
143 end |
|
144 |
|
145 fun haskell_prog name rules = |
|
146 let |
|
147 val buffer = Unsynchronized.ref "" |
|
148 fun write s = (buffer := (!buffer)^s) |
|
149 fun writeln s = (write s; write "\n") |
|
150 fun writelist [] = () |
|
151 | writelist (s::ss) = (writeln s; writelist ss) |
|
152 fun str i = Int.toString i |
|
153 val (arity, rules) = adjust_rules rules |
|
154 val rules = group_rules rules |
|
155 val constants = Inttab.keys arity |
|
156 fun arity_of c = Inttab.lookup arity c |
|
157 fun rep_str s n = implode (rep n s) |
|
158 fun indexed s n = s^(str n) |
|
159 fun section n = if n = 0 then [] else (section (n-1))@[n-1] |
|
160 fun make_show c = |
|
161 let |
|
162 val args = section (the (arity_of c)) |
|
163 in |
|
164 " show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = " |
|
165 ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args)) |
|
166 end |
|
167 fun default_case c = |
|
168 let |
|
169 val args = implode (map (indexed " x") (section (the (arity_of c)))) |
|
170 in |
|
171 (indexed "c" c)^args^" = "^(indexed "C" c)^args |
|
172 end |
|
173 val _ = writelist [ |
|
174 "module "^name^" where", |
|
175 "", |
|
176 "data Term = Const Integer | App Term Term | Abs (Term -> Term)", |
|
177 " "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)), |
|
178 "", |
|
179 "instance Show Term where"] |
|
180 val _ = writelist (map make_show constants) |
|
181 val _ = writelist [ |
|
182 " show (Const c) = \"c\"++(show c)", |
|
183 " show (App a b) = \"A\"++(show a)++(show b)", |
|
184 " show (Abs _) = \"L\"", |
|
185 ""] |
|
186 val _ = writelist [ |
|
187 "app (Abs a) b = a b", |
|
188 "app a b = App a b", |
|
189 "", |
|
190 "calc s c = writeFile s (show c)", |
|
191 ""] |
|
192 fun list_group c = (writelist (case Inttab.lookup rules c of |
|
193 NONE => [default_case c, ""] |
|
194 | SOME (rs as ((PConst (_, []), _)::rs')) => |
|
195 if not (null rs') then raise Compile "multiple declaration of constant" |
|
196 else (map (print_rule arity_of) rs) @ [""] |
|
197 | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""])) |
|
198 val _ = map list_group constants |
|
199 in |
|
200 (arity, !buffer) |
|
201 end |
|
202 |
|
203 val guid_counter = Unsynchronized.ref 0 |
|
204 fun get_guid () = |
|
205 let |
|
206 val c = !guid_counter |
|
207 val _ = guid_counter := !guid_counter + 1 |
|
208 in |
|
209 (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) |
|
210 end |
|
211 |
|
212 fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s]))); |
|
213 fun wrap s = "\""^s^"\"" |
|
214 |
|
215 fun writeTextFile name s = File.write (Path.explode name) s |
|
216 |
|
217 val ghc = Unsynchronized.ref (case getenv "GHC_PATH" of "" => "ghc" | s => s) |
|
218 |
|
219 fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) |
|
220 |
|
221 fun compile cache_patterns const_arity eqs = |
|
222 let |
|
223 val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () |
|
224 val eqs = map (fn (a,b,c) => (b,c)) eqs |
|
225 val guid = get_guid () |
|
226 val module = "AMGHC_Prog_"^guid |
|
227 val (arity, source) = haskell_prog module eqs |
|
228 val module_file = tmp_file (module^".hs") |
|
229 val object_file = tmp_file (module^".o") |
|
230 val _ = writeTextFile module_file source |
|
231 val _ = bash ((!ghc)^" -c "^module_file) |
|
232 val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else () |
|
233 in |
|
234 (guid, module_file, arity) |
|
235 end |
|
236 |
|
237 fun readResultFile name = File.read (Path.explode name) |
|
238 |
|
239 fun parse_result arity_of result = |
|
240 let |
|
241 val result = String.explode result |
|
242 fun shift NONE x = SOME x |
|
243 | shift (SOME y) x = SOME (y*10 + x) |
|
244 fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest |
|
245 | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest |
|
246 | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest |
|
247 | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest |
|
248 | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest |
|
249 | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest |
|
250 | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest |
|
251 | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest |
|
252 | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest |
|
253 | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest |
|
254 | parse_int' x rest = (x, rest) |
|
255 fun parse_int rest = parse_int' NONE rest |
|
256 |
|
257 fun parse (#"C"::rest) = |
|
258 (case parse_int rest of |
|
259 (SOME c, rest) => |
|
260 let |
|
261 val (args, rest) = parse_list (the (arity_of c)) rest |
|
262 fun app_args [] t = t |
|
263 | app_args (x::xs) t = app_args xs (App (t, x)) |
|
264 in |
|
265 (app_args args (Const c), rest) |
|
266 end |
|
267 | (NONE, rest) => raise Run "parse C") |
|
268 | parse (#"c"::rest) = |
|
269 (case parse_int rest of |
|
270 (SOME c, rest) => (Const c, rest) |
|
271 | _ => raise Run "parse c") |
|
272 | parse (#"A"::rest) = |
|
273 let |
|
274 val (a, rest) = parse rest |
|
275 val (b, rest) = parse rest |
|
276 in |
|
277 (App (a,b), rest) |
|
278 end |
|
279 | parse (#"L"::rest) = raise Run "there may be no abstraction in the result" |
|
280 | parse _ = raise Run "invalid result" |
|
281 and parse_list n rest = |
|
282 if n = 0 then |
|
283 ([], rest) |
|
284 else |
|
285 let |
|
286 val (x, rest) = parse rest |
|
287 val (xs, rest) = parse_list (n-1) rest |
|
288 in |
|
289 (x::xs, rest) |
|
290 end |
|
291 val (parsed, rest) = parse result |
|
292 fun is_blank (#" "::rest) = is_blank rest |
|
293 | is_blank (#"\n"::rest) = is_blank rest |
|
294 | is_blank [] = true |
|
295 | is_blank _ = false |
|
296 in |
|
297 if is_blank rest then parsed else raise Run "non-blank suffix in result file" |
|
298 end |
|
299 |
|
300 fun run (guid, module_file, arity) t = |
|
301 let |
|
302 val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") |
|
303 fun arity_of c = Inttab.lookup arity c |
|
304 val callguid = get_guid() |
|
305 val module = "AMGHC_Prog_"^guid |
|
306 val call = module^"_Call_"^callguid |
|
307 val result_file = tmp_file (module^"_Result_"^callguid^".txt") |
|
308 val call_file = tmp_file (call^".hs") |
|
309 val term = print_term arity_of 0 t |
|
310 val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")" |
|
311 val _ = writeTextFile call_file call_source |
|
312 val _ = bash ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file) |
|
313 val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')") |
|
314 val t' = parse_result arity_of result |
|
315 val _ = OS.FileSys.remove call_file |
|
316 val _ = OS.FileSys.remove result_file |
|
317 in |
|
318 t' |
|
319 end |
|
320 |
|
321 |
|
322 fun discard _ = () |
|
323 |
|
324 end |
|
325 |
|