author | wenzelm |
Fri, 20 Feb 2009 21:01:52 +0100 | |
changeset 29773 | cbaee647ea29 |
parent 28268 | ac8431ecd57e |
child 30161 | c26e515f1c29 |
permissions | -rw-r--r-- |
24584 | 1 |
(* Title: Tools/Compute_Oracle/am_sml.ML |
23663 | 2 |
ID: $Id$ |
3 |
Author: Steven Obua |
|
4 |
||
5 |
ToDO: "parameterless rewrite cannot be used in pattern": In a lot of cases it CAN be used, and these cases should be handled properly; |
|
6 |
right now, all cases throw an exception. |
|
7 |
||
8 |
*) |
|
9 |
||
10 |
signature AM_SML = |
|
11 |
sig |
|
12 |
include ABSTRACT_MACHINE |
|
13 |
val save_result : (string * term) -> unit |
|
14 |
val set_compiled_rewriter : (term -> term) -> unit |
|
15 |
val list_nth : 'a list * int -> 'a |
|
25520 | 16 |
val dump_output : (string option) ref |
23663 | 17 |
end |
18 |
||
19 |
structure AM_SML : AM_SML = struct |
|
20 |
||
21 |
open AbstractMachine; |
|
22 |
||
25548 | 23 |
val dump_output = ref (NONE: string option) |
25520 | 24 |
|
23663 | 25 |
type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) |
26 |
||
27 |
val saved_result = ref (NONE:(string*term)option) |
|
28 |
||
29 |
fun save_result r = (saved_result := SOME r) |
|
30 |
fun clear_result () = (saved_result := NONE) |
|
31 |
||
32 |
val list_nth = List.nth |
|
33 |
||
34 |
(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*) |
|
35 |
||
36 |
val compiled_rewriter = ref (NONE:(term -> term)Option.option) |
|
37 |
||
38 |
fun set_compiled_rewriter r = (compiled_rewriter := SOME r) |
|
39 |
||
40 |
fun count_patternvars PVar = 1 |
|
41 |
| count_patternvars (PConst (_, ps)) = |
|
42 |
List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps |
|
43 |
||
44 |
fun update_arity arity code a = |
|
45 |
(case Inttab.lookup arity code of |
|
46 |
NONE => Inttab.update_new (code, a) arity |
|
24134
6e69e0031f34
added int type constraints to accomodate hacked SML/NJ;
wenzelm
parents:
23663
diff
changeset
|
47 |
| SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) |
23663 | 48 |
|
49 |
(* We have to find out the maximal arity of each constant *) |
|
50 |
fun collect_pattern_arity PVar arity = arity |
|
51 |
| collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) |
|
52 |
||
53 |
(* We also need to find out the maximal toplevel arity of each function constant *) |
|
54 |
fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" |
|
55 |
| collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) |
|
56 |
||
57 |
local |
|
58 |
fun collect applevel (Var _) arity = arity |
|
59 |
| collect applevel (Const c) arity = update_arity arity c applevel |
|
60 |
| collect applevel (Abs m) arity = collect 0 m arity |
|
61 |
| collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) |
|
62 |
in |
|
63 |
fun collect_term_arity t arity = collect 0 t arity |
|
64 |
end |
|
65 |
||
66 |
fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) |
|
67 |
||
68 |
||
69 |
fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) |
|
70 |
||
71 |
fun beta (Const c) = Const c |
|
72 |
| beta (Var i) = Var i |
|
73 |
| beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) |
|
74 |
| beta (App (a, b)) = |
|
75 |
(case beta a of |
|
76 |
Abs m => beta (App (Abs m, b)) |
|
77 |
| a => App (a, beta b)) |
|
78 |
| beta (Abs m) = Abs (beta m) |
|
25217 | 79 |
| beta (Computed t) = Computed t |
23663 | 80 |
and subst x (Const c) t = Const c |
81 |
| subst x (Var i) t = if i = x then t else Var i |
|
82 |
| subst x (App (a,b)) t = App (subst x a t, subst x b t) |
|
83 |
| subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) |
|
84 |
and lift level (Const c) = Const c |
|
85 |
| lift level (App (a,b)) = App (lift level a, lift level b) |
|
86 |
| lift level (Var i) = if i < level then Var i else Var (i+1) |
|
87 |
| lift level (Abs m) = Abs (lift (level + 1) m) |
|
88 |
and unlift level (Const c) = Const c |
|
89 |
| unlift level (App (a, b)) = App (unlift level a, unlift level b) |
|
90 |
| unlift level (Abs m) = Abs (unlift (level+1) m) |
|
91 |
| unlift level (Var i) = if i < level then Var i else Var (i-1) |
|
92 |
||
93 |
fun nlift level n (Var m) = if m < level then Var m else Var (m+n) |
|
94 |
| nlift level n (Const c) = Const c |
|
95 |
| nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) |
|
96 |
| nlift level n (Abs b) = Abs (nlift (level+1) n b) |
|
97 |
||
98 |
fun subst_const (c, t) (Const c') = if c = c' then t else Const c' |
|
99 |
| subst_const _ (Var i) = Var i |
|
100 |
| subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) |
|
101 |
| subst_const ct (Abs m) = Abs (subst_const ct m) |
|
102 |
||
103 |
(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) |
|
104 |
fun inline_rules rules = |
|
105 |
let |
|
106 |
fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b |
|
107 |
| term_contains_const c (Abs m) = term_contains_const c m |
|
108 |
| term_contains_const c (Var i) = false |
|
109 |
| term_contains_const c (Const c') = (c = c') |
|
110 |
fun find_rewrite [] = NONE |
|
111 |
| find_rewrite ((prems, PConst (c, []), r) :: _) = |
|
112 |
if check_freevars 0 r then |
|
113 |
if term_contains_const c r then |
|
114 |
raise Compile "parameterless rewrite is caught in cycle" |
|
115 |
else if not (null prems) then |
|
116 |
raise Compile "parameterless rewrite may not be guarded" |
|
117 |
else |
|
118 |
SOME (c, r) |
|
119 |
else raise Compile "unbound variable on right hand side or guards of rule" |
|
120 |
| find_rewrite (_ :: rules) = find_rewrite rules |
|
121 |
fun remove_rewrite (c,r) [] = [] |
|
122 |
| remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = |
|
123 |
(if c = c' then |
|
124 |
if null args andalso r = r' andalso null (prems') then |
|
125 |
remove_rewrite cr rules |
|
126 |
else raise Compile "incompatible parameterless rewrites found" |
|
127 |
else |
|
128 |
rule :: (remove_rewrite cr rules)) |
|
129 |
| remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs) |
|
130 |
fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args) |
|
131 |
| pattern_contains_const c (PVar) = false |
|
132 |
fun inline_rewrite (ct as (c, _)) (prems, p, r) = |
|
133 |
if pattern_contains_const c p then |
|
134 |
raise Compile "parameterless rewrite cannot be used in pattern" |
|
135 |
else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) |
|
136 |
fun inline inlined rules = |
|
137 |
(case find_rewrite rules of |
|
138 |
NONE => (Inttab.make inlined, rules) |
|
139 |
| SOME ct => |
|
140 |
let |
|
141 |
val rules = map (inline_rewrite ct) (remove_rewrite ct rules) |
|
142 |
val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined) |
|
143 |
in |
|
144 |
inline inlined rules |
|
145 |
end) |
|
146 |
in |
|
147 |
inline [] rules |
|
148 |
end |
|
149 |
||
150 |
||
151 |
(* |
|
152 |
Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. |
|
153 |
Also beta reduce the adjusted right hand side of a rule. |
|
154 |
*) |
|
155 |
fun adjust_rules rules = |
|
156 |
let |
|
157 |
val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty |
|
158 |
val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty |
|
159 |
fun arity_of c = the (Inttab.lookup arity c) |
|
160 |
fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) |
|
25520 | 161 |
fun test_pattern PVar = () |
162 |
| test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) |
|
23663 | 163 |
fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") |
164 |
| adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") |
|
165 |
| adjust_rule (rule as (prems, p as PConst (c, args),t)) = |
|
166 |
let |
|
167 |
val patternvars_counted = count_patternvars p |
|
168 |
fun check_fv t = check_freevars patternvars_counted t |
|
169 |
val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () |
|
170 |
val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () |
|
25520 | 171 |
val _ = map test_pattern args |
23663 | 172 |
val len = length args |
173 |
val arity = arity_of c |
|
174 |
val lift = nlift 0 |
|
25520 | 175 |
fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) |
176 |
fun adjust_term n t = addapps_tm n (lift n t) |
|
177 |
fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) |
|
23663 | 178 |
in |
179 |
if len = arity then |
|
180 |
rule |
|
181 |
else if arity >= len then |
|
182 |
(map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) |
|
183 |
else (raise Compile "internal error in adjust_rule") |
|
184 |
end |
|
25520 | 185 |
fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") |
23663 | 186 |
in |
187 |
(arity, toplevel_arity, map (beta_rule o adjust_rule) rules) |
|
188 |
end |
|
189 |
||
190 |
fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = |
|
191 |
let |
|
192 |
fun str x = string_of_int x |
|
193 |
fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s |
|
194 |
val module_prefix = (case module of NONE => "" | SOME s => s^".") |
|
195 |
fun print_apps d f [] = f |
|
196 |
| print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args |
|
197 |
and print_call d (App (a, b)) args = print_call d a (b::args) |
|
198 |
| print_call d (Const c) args = |
|
199 |
(case arity_of c of |
|
200 |
NONE => print_apps d (module_prefix^"Const "^(str c)) args |
|
201 |
| SOME 0 => module_prefix^"C"^(str c) |
|
202 |
| SOME a => |
|
203 |
let |
|
204 |
val len = length args |
|
205 |
in |
|
206 |
if a <= len then |
|
207 |
let |
|
208 |
val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) |
|
209 |
val _ = if strict_a > a then raise Compile "strict" else () |
|
210 |
val s = module_prefix^"c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) |
|
211 |
val s = s^(concat (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) |
|
212 |
in |
|
213 |
print_apps d s (List.drop (args, a)) |
|
214 |
end |
|
215 |
else |
|
216 |
let |
|
217 |
fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) |
|
218 |
fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) |
|
219 |
fun append_args [] t = t |
|
220 |
| append_args (c::cs) t = append_args cs (App (t, c)) |
|
221 |
in |
|
222 |
print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) |
|
223 |
end |
|
224 |
end) |
|
225 |
| print_call d t args = print_apps d (print_term d t) args |
|
226 |
and print_term d (Var x) = |
|
227 |
if x < d then |
|
228 |
"b"^(str (d-x-1)) |
|
229 |
else |
|
230 |
let |
|
231 |
val n = pattern_var_count - (x-d) - 1 |
|
232 |
val x = "x"^(str n) |
|
233 |
in |
|
234 |
if n < pattern_var_count - pattern_lazy_var_count then |
|
235 |
x |
|
236 |
else |
|
237 |
"("^x^" ())" |
|
238 |
end |
|
239 |
| print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" |
|
240 |
| print_term d t = print_call d t [] |
|
241 |
in |
|
242 |
print_term 0 |
|
243 |
end |
|
244 |
||
245 |
fun section n = if n = 0 then [] else (section (n-1))@[n-1] |
|
246 |
||
247 |
fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = |
|
248 |
let |
|
249 |
fun str x = Int.toString x |
|
250 |
fun print_pattern top n PVar = (n+1, "x"^(str n)) |
|
251 |
| print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) |
|
252 |
| print_pattern top n (PConst (c, args)) = |
|
253 |
let |
|
254 |
val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") |
|
255 |
val (n, s) = print_pattern_list 0 top (n, f) args |
|
256 |
in |
|
257 |
(n, s) |
|
258 |
end |
|
259 |
and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") |
|
260 |
| print_pattern_list' counter top (n, p) (t::ts) = |
|
261 |
let |
|
262 |
val (n, t) = print_pattern false n t |
|
263 |
in |
|
264 |
print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts |
|
265 |
end |
|
266 |
and print_pattern_list counter top (n, p) (t::ts) = |
|
267 |
let |
|
268 |
val (n, t) = print_pattern false n t |
|
269 |
in |
|
270 |
print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts |
|
271 |
end |
|
272 |
val c = (case p of PConst (c, _) => c | _ => raise Match) |
|
273 |
val (n, pattern) = print_pattern true 0 p |
|
274 |
val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) |
|
275 |
fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm |
|
276 |
fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" |
|
277 |
val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(concat (map (fn i => " a"^(str i)) (section (the (arity_of c))))) |
|
278 |
fun print_guards t [] = print_tm t |
|
279 |
| print_guards t (g::gs) = "if ("^(print_guard g)^")"^(concat (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch |
|
280 |
in |
|
281 |
(if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) |
|
282 |
end |
|
283 |
||
284 |
fun group_rules rules = |
|
285 |
let |
|
286 |
fun add_rule (r as (_, PConst (c,_), _)) groups = |
|
287 |
let |
|
288 |
val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) |
|
289 |
in |
|
290 |
Inttab.update (c, r::rs) groups |
|
291 |
end |
|
292 |
| add_rule _ _ = raise Compile "internal error group_rules" |
|
293 |
in |
|
294 |
fold_rev add_rule rules Inttab.empty |
|
295 |
end |
|
296 |
||
297 |
fun sml_prog name code rules = |
|
298 |
let |
|
299 |
val buffer = ref "" |
|
300 |
fun write s = (buffer := (!buffer)^s) |
|
301 |
fun writeln s = (write s; write "\n") |
|
302 |
fun writelist [] = () |
|
303 |
| writelist (s::ss) = (writeln s; writelist ss) |
|
304 |
fun str i = Int.toString i |
|
305 |
val (inlinetab, rules) = inline_rules rules |
|
306 |
val (arity, toplevel_arity, rules) = adjust_rules rules |
|
307 |
val rules = group_rules rules |
|
308 |
val constants = Inttab.keys arity |
|
309 |
fun arity_of c = Inttab.lookup arity c |
|
310 |
fun toplevel_arity_of c = Inttab.lookup toplevel_arity c |
|
311 |
fun rep_str s n = concat (rep n s) |
|
312 |
fun indexed s n = s^(str n) |
|
313 |
fun string_of_tuple [] = "" |
|
314 |
| string_of_tuple (x::xs) = "("^x^(concat (map (fn s => ", "^s) xs))^")" |
|
315 |
fun string_of_args [] = "" |
|
316 |
| string_of_args (x::xs) = x^(concat (map (fn s => " "^s) xs)) |
|
317 |
fun default_case gnum c = |
|
318 |
let |
|
319 |
val leftargs = concat (map (indexed " x") (section (the (arity_of c)))) |
|
320 |
val rightargs = section (the (arity_of c)) |
|
321 |
val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) |
|
322 |
val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs |
|
323 |
val right = (indexed "C" c)^" "^(string_of_tuple xs) |
|
24654 | 324 |
val message = "(\"unresolved lazy call: "^(string_of_int c)^", \"^(makestring x"^(string_of_int (strict_args - 1))^"))" |
325 |
val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right |
|
23663 | 326 |
in |
327 |
(indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right |
|
328 |
end |
|
329 |
||
330 |
fun eval_rules c = |
|
331 |
let |
|
332 |
val arity = the (arity_of c) |
|
333 |
val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) |
|
334 |
fun eval_rule n = |
|
335 |
let |
|
336 |
val sc = string_of_int c |
|
337 |
val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) |
|
338 |
fun arg i = |
|
339 |
let |
|
340 |
val x = indexed "x" i |
|
341 |
val x = if i < n then "(eval bounds "^x^")" else x |
|
342 |
val x = if i < strict_arity then x else "(fn () => "^x^")" |
|
343 |
in |
|
344 |
x |
|
345 |
end |
|
346 |
val right = "c"^sc^" "^(string_of_args (map arg (section arity))) |
|
347 |
val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right |
|
348 |
val right = if arity > 0 then right else "C"^sc |
|
349 |
in |
|
350 |
" | eval bounds ("^left^") = "^right |
|
351 |
end |
|
352 |
in |
|
353 |
map eval_rule (rev (section (arity + 1))) |
|
354 |
end |
|
25217 | 355 |
|
25220 | 356 |
fun convert_computed_rules (c: int) : string list = |
25217 | 357 |
let |
358 |
val arity = the (arity_of c) |
|
359 |
fun eval_rule () = |
|
360 |
let |
|
361 |
val sc = string_of_int c |
|
362 |
val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc) |
|
363 |
fun arg i = "(convert_computed "^(indexed "x" i)^")" |
|
364 |
val right = "C"^sc^" "^(string_of_tuple (map arg (section arity))) |
|
365 |
val right = if arity > 0 then right else "C"^sc |
|
366 |
in |
|
367 |
" | convert_computed ("^left^") = "^right |
|
368 |
end |
|
369 |
in |
|
370 |
[eval_rule ()] |
|
371 |
end |
|
23663 | 372 |
|
373 |
fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" |
|
24654 | 374 |
val _ = writelist [ |
23663 | 375 |
"structure "^name^" = struct", |
376 |
"", |
|
377 |
"datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", |
|
378 |
" "^(concat (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), |
|
379 |
""] |
|
380 |
fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" |
|
381 |
fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ |
|
382 |
(case the (arity_of c) of |
|
383 |
0 => "true" |
|
384 |
| n => |
|
385 |
let |
|
386 |
val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) |
|
387 |
val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) |
|
388 |
in |
|
389 |
eq^(concat eqs) |
|
390 |
end) |
|
391 |
val _ = writelist [ |
|
392 |
"fun term_eq (Const c1) (Const c2) = (c1 = c2)", |
|
393 |
" | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] |
|
394 |
val _ = writelist (map make_term_eq constants) |
|
395 |
val _ = writelist [ |
|
396 |
" | term_eq _ _ = false", |
|
397 |
"" |
|
398 |
] |
|
399 |
val _ = writelist [ |
|
400 |
"fun app (Abs a) b = a b", |
|
401 |
" | app a b = App (a, b)", |
|
402 |
""] |
|
403 |
fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) |
|
404 |
fun writefundecl [] = () |
|
405 |
| writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) |
|
406 |
fun list_group c = (case Inttab.lookup rules c of |
|
407 |
NONE => [defcase 0 c] |
|
408 |
| SOME rs => |
|
409 |
let |
|
410 |
val rs = |
|
411 |
fold |
|
412 |
(fn r => |
|
413 |
fn rs => |
|
414 |
let |
|
415 |
val (gnum, l, rs) = |
|
416 |
(case rs of |
|
417 |
[] => (0, [], []) |
|
418 |
| (gnum, l)::rs => (gnum, l, rs)) |
|
419 |
val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r |
|
420 |
in |
|
421 |
if gnum' = gnum then |
|
422 |
(gnum, r::l)::rs |
|
423 |
else |
|
424 |
let |
|
425 |
val args = concat (map (fn i => " a"^(str i)) (section (the (arity_of c)))) |
|
426 |
fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args |
|
427 |
val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') |
|
428 |
in |
|
429 |
(gnum', [])::(gnum, s::r::l)::rs |
|
430 |
end |
|
431 |
end) |
|
432 |
rs [] |
|
433 |
val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) |
|
434 |
in |
|
435 |
rev (map (fn z => rev (snd z)) rs) |
|
436 |
end) |
|
437 |
val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) |
|
438 |
val _ = writelist [ |
|
439 |
"fun convert (Const i) = AM_SML.Const i", |
|
440 |
" | convert (App (a, b)) = AM_SML.App (convert a, convert b)", |
|
441 |
" | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] |
|
442 |
fun make_convert c = |
|
443 |
let |
|
444 |
val args = map (indexed "a") (section (the (arity_of c))) |
|
445 |
val leftargs = |
|
446 |
case args of |
|
447 |
[] => "" |
|
448 |
| (x::xs) => "("^x^(concat (map (fn s => ", "^s) xs))^")" |
|
449 |
val args = map (indexed "convert a") (section (the (arity_of c))) |
|
450 |
val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) |
|
451 |
in |
|
452 |
" | convert (C"^(str c)^" "^leftargs^") = "^right |
|
453 |
end |
|
454 |
val _ = writelist (map make_convert constants) |
|
25217 | 455 |
val _ = writelist [ |
456 |
"", |
|
457 |
"fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"", |
|
458 |
" | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""] |
|
459 |
val _ = map (writelist o convert_computed_rules) constants |
|
460 |
val _ = writelist [ |
|
461 |
" | convert_computed (AbstractMachine.Const c) = Const c", |
|
462 |
" | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)", |
|
463 |
" | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] |
|
23663 | 464 |
val _ = writelist [ |
465 |
"", |
|
466 |
"fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", |
|
467 |
" | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] |
|
468 |
val _ = map (writelist o eval_rules) constants |
|
469 |
val _ = writelist [ |
|
470 |
" | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", |
|
25217 | 471 |
" | eval bounds (AbstractMachine.Const c) = Const c", |
472 |
" | eval bounds (AbstractMachine.Computed t) = convert_computed t"] |
|
23663 | 473 |
val _ = writelist [ |
474 |
"", |
|
475 |
"fun export term = AM_SML.save_result (\""^code^"\", convert term)", |
|
476 |
"", |
|
24654 | 477 |
"val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))", |
23663 | 478 |
"", |
479 |
"end"] |
|
480 |
in |
|
481 |
(arity, toplevel_arity, inlinetab, !buffer) |
|
482 |
end |
|
483 |
||
484 |
val guid_counter = ref 0 |
|
485 |
fun get_guid () = |
|
486 |
let |
|
487 |
val c = !guid_counter |
|
488 |
val _ = guid_counter := !guid_counter + 1 |
|
489 |
in |
|
490 |
(LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) |
|
491 |
end |
|
492 |
||
493 |
||
494 |
fun writeTextFile name s = File.write (Path.explode name) s |
|
495 |
||
28268
ac8431ecd57e
use_text/use_file now depend on explicit ML name space;
wenzelm
parents:
26385
diff
changeset
|
496 |
fun use_source src = use_text ML_Context.name_space (1, "") Output.ml_output false src |
23663 | 497 |
|
25520 | 498 |
fun compile cache_patterns const_arity eqs = |
23663 | 499 |
let |
500 |
val guid = get_guid () |
|
501 |
val code = Real.toString (random ()) |
|
502 |
val module = "AMSML_"^guid |
|
503 |
val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs |
|
25520 | 504 |
val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source |
23663 | 505 |
val _ = compiled_rewriter := NONE |
506 |
val _ = use_source source |
|
507 |
in |
|
508 |
case !compiled_rewriter of |
|
509 |
NONE => raise Compile "broken link to compiled function" |
|
510 |
| SOME f => (module, code, arity, toplevel_arity, inlinetab, f) |
|
511 |
end |
|
512 |
||
513 |
||
514 |
fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = |
|
515 |
let |
|
516 |
val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") |
|
517 |
fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) |
|
518 |
| inline (Var i) = Var i |
|
519 |
| inline (App (a, b)) = App (inline a, inline b) |
|
520 |
| inline (Abs m) = Abs (inline m) |
|
521 |
val t = beta (inline t) |
|
522 |
fun arity_of c = Inttab.lookup arity c |
|
523 |
fun toplevel_arity_of c = Inttab.lookup toplevel_arity c |
|
524 |
val term = print_term NONE arity_of toplevel_arity_of 0 0 t |
|
525 |
val source = "local open "^module^" in val _ = export ("^term^") end" |
|
526 |
val _ = writeTextFile "Gencode_call.ML" source |
|
527 |
val _ = clear_result () |
|
528 |
val _ = use_source source |
|
529 |
in |
|
530 |
case !saved_result of |
|
531 |
NONE => raise Run "broken link to compiled code" |
|
532 |
| SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked") |
|
533 |
end |
|
534 |
||
535 |
fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = |
|
536 |
let |
|
537 |
val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") |
|
538 |
fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) |
|
539 |
| inline (Var i) = Var i |
|
540 |
| inline (App (a, b)) = App (inline a, inline b) |
|
541 |
| inline (Abs m) = Abs (inline m) |
|
25217 | 542 |
| inline (Computed t) = Computed t |
23663 | 543 |
in |
544 |
compiled_fun (beta (inline t)) |
|
545 |
end |
|
546 |
||
547 |
fun discard p = () |
|
548 |
||
549 |
end |