23663
|
1 |
(* Title: Pure/Tools/am_sml.ML
|
|
2 |
ID: $Id$
|
|
3 |
Author: Steven Obua
|
|
4 |
|
|
5 |
ToDO: "parameterless rewrite cannot be used in pattern": In a lot of cases it CAN be used, and these cases should be handled properly;
|
|
6 |
right now, all cases throw an exception.
|
|
7 |
|
|
8 |
*)
|
|
9 |
|
|
10 |
signature AM_SML =
|
|
11 |
sig
|
|
12 |
include ABSTRACT_MACHINE
|
|
13 |
val save_result : (string * term) -> unit
|
|
14 |
val set_compiled_rewriter : (term -> term) -> unit
|
|
15 |
val list_nth : 'a list * int -> 'a
|
|
16 |
end
|
|
17 |
|
|
18 |
structure AM_SML : AM_SML = struct
|
|
19 |
|
|
20 |
open AbstractMachine;
|
|
21 |
|
|
22 |
type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term)
|
|
23 |
|
|
24 |
val saved_result = ref (NONE:(string*term)option)
|
|
25 |
|
|
26 |
fun save_result r = (saved_result := SOME r)
|
|
27 |
fun clear_result () = (saved_result := NONE)
|
|
28 |
|
|
29 |
val list_nth = List.nth
|
|
30 |
|
|
31 |
(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*)
|
|
32 |
|
|
33 |
val compiled_rewriter = ref (NONE:(term -> term)Option.option)
|
|
34 |
|
|
35 |
fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
|
|
36 |
|
|
37 |
fun importable (Var _) = false
|
|
38 |
| importable (Const _) = true
|
|
39 |
| importable (App (a, b)) = importable a andalso importable b
|
|
40 |
| importable (Abs _) = false
|
|
41 |
|
|
42 |
(*Returns true iff at most 0 .. (free-1) occur unbound. therefore
|
|
43 |
check_freevars 0 t iff t is closed*)
|
|
44 |
fun check_freevars free (Var x) = x < free
|
|
45 |
| check_freevars free (Const c) = true
|
|
46 |
| check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v
|
|
47 |
| check_freevars free (Abs m) = check_freevars (free+1) m
|
|
48 |
|
|
49 |
fun count_patternvars PVar = 1
|
|
50 |
| count_patternvars (PConst (_, ps)) =
|
|
51 |
List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
|
|
52 |
|
|
53 |
fun update_arity arity code a =
|
|
54 |
(case Inttab.lookup arity code of
|
|
55 |
NONE => Inttab.update_new (code, a) arity
|
|
56 |
| SOME a' => if a > a' then Inttab.update (code, a) arity else arity)
|
|
57 |
|
|
58 |
(* We have to find out the maximal arity of each constant *)
|
|
59 |
fun collect_pattern_arity PVar arity = arity
|
|
60 |
| collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
|
|
61 |
|
|
62 |
(* We also need to find out the maximal toplevel arity of each function constant *)
|
|
63 |
fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity"
|
|
64 |
| collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args)
|
|
65 |
|
|
66 |
local
|
|
67 |
fun collect applevel (Var _) arity = arity
|
|
68 |
| collect applevel (Const c) arity = update_arity arity c applevel
|
|
69 |
| collect applevel (Abs m) arity = collect 0 m arity
|
|
70 |
| collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
|
|
71 |
in
|
|
72 |
fun collect_term_arity t arity = collect 0 t arity
|
|
73 |
end
|
|
74 |
|
|
75 |
fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity)
|
|
76 |
|
|
77 |
|
|
78 |
fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x)
|
|
79 |
|
|
80 |
fun beta (Const c) = Const c
|
|
81 |
| beta (Var i) = Var i
|
|
82 |
| beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b)))
|
|
83 |
| beta (App (a, b)) =
|
|
84 |
(case beta a of
|
|
85 |
Abs m => beta (App (Abs m, b))
|
|
86 |
| a => App (a, beta b))
|
|
87 |
| beta (Abs m) = Abs (beta m)
|
|
88 |
and subst x (Const c) t = Const c
|
|
89 |
| subst x (Var i) t = if i = x then t else Var i
|
|
90 |
| subst x (App (a,b)) t = App (subst x a t, subst x b t)
|
|
91 |
| subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t))
|
|
92 |
and lift level (Const c) = Const c
|
|
93 |
| lift level (App (a,b)) = App (lift level a, lift level b)
|
|
94 |
| lift level (Var i) = if i < level then Var i else Var (i+1)
|
|
95 |
| lift level (Abs m) = Abs (lift (level + 1) m)
|
|
96 |
and unlift level (Const c) = Const c
|
|
97 |
| unlift level (App (a, b)) = App (unlift level a, unlift level b)
|
|
98 |
| unlift level (Abs m) = Abs (unlift (level+1) m)
|
|
99 |
| unlift level (Var i) = if i < level then Var i else Var (i-1)
|
|
100 |
|
|
101 |
fun nlift level n (Var m) = if m < level then Var m else Var (m+n)
|
|
102 |
| nlift level n (Const c) = Const c
|
|
103 |
| nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
|
|
104 |
| nlift level n (Abs b) = Abs (nlift (level+1) n b)
|
|
105 |
|
|
106 |
fun subst_const (c, t) (Const c') = if c = c' then t else Const c'
|
|
107 |
| subst_const _ (Var i) = Var i
|
|
108 |
| subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b)
|
|
109 |
| subst_const ct (Abs m) = Abs (subst_const ct m)
|
|
110 |
|
|
111 |
(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *)
|
|
112 |
fun inline_rules rules =
|
|
113 |
let
|
|
114 |
fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b
|
|
115 |
| term_contains_const c (Abs m) = term_contains_const c m
|
|
116 |
| term_contains_const c (Var i) = false
|
|
117 |
| term_contains_const c (Const c') = (c = c')
|
|
118 |
fun find_rewrite [] = NONE
|
|
119 |
| find_rewrite ((prems, PConst (c, []), r) :: _) =
|
|
120 |
if check_freevars 0 r then
|
|
121 |
if term_contains_const c r then
|
|
122 |
raise Compile "parameterless rewrite is caught in cycle"
|
|
123 |
else if not (null prems) then
|
|
124 |
raise Compile "parameterless rewrite may not be guarded"
|
|
125 |
else
|
|
126 |
SOME (c, r)
|
|
127 |
else raise Compile "unbound variable on right hand side or guards of rule"
|
|
128 |
| find_rewrite (_ :: rules) = find_rewrite rules
|
|
129 |
fun remove_rewrite (c,r) [] = []
|
|
130 |
| remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) =
|
|
131 |
(if c = c' then
|
|
132 |
if null args andalso r = r' andalso null (prems') then
|
|
133 |
remove_rewrite cr rules
|
|
134 |
else raise Compile "incompatible parameterless rewrites found"
|
|
135 |
else
|
|
136 |
rule :: (remove_rewrite cr rules))
|
|
137 |
| remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs)
|
|
138 |
fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args)
|
|
139 |
| pattern_contains_const c (PVar) = false
|
|
140 |
fun inline_rewrite (ct as (c, _)) (prems, p, r) =
|
|
141 |
if pattern_contains_const c p then
|
|
142 |
raise Compile "parameterless rewrite cannot be used in pattern"
|
|
143 |
else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r)
|
|
144 |
fun inline inlined rules =
|
|
145 |
(case find_rewrite rules of
|
|
146 |
NONE => (Inttab.make inlined, rules)
|
|
147 |
| SOME ct =>
|
|
148 |
let
|
|
149 |
val rules = map (inline_rewrite ct) (remove_rewrite ct rules)
|
|
150 |
val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined)
|
|
151 |
in
|
|
152 |
inline inlined rules
|
|
153 |
end)
|
|
154 |
in
|
|
155 |
inline [] rules
|
|
156 |
end
|
|
157 |
|
|
158 |
|
|
159 |
(*
|
|
160 |
Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity.
|
|
161 |
Also beta reduce the adjusted right hand side of a rule.
|
|
162 |
*)
|
|
163 |
fun adjust_rules rules =
|
|
164 |
let
|
|
165 |
val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty
|
|
166 |
val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
|
|
167 |
fun arity_of c = the (Inttab.lookup arity c)
|
|
168 |
fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c)
|
|
169 |
fun adjust_pattern PVar = PVar
|
|
170 |
| adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
|
|
171 |
fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
|
|
172 |
| adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
|
|
173 |
| adjust_rule (rule as (prems, p as PConst (c, args),t)) =
|
|
174 |
let
|
|
175 |
val patternvars_counted = count_patternvars p
|
|
176 |
fun check_fv t = check_freevars patternvars_counted t
|
|
177 |
val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else ()
|
|
178 |
val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else ()
|
|
179 |
val args = map adjust_pattern args
|
|
180 |
val len = length args
|
|
181 |
val arity = arity_of c
|
|
182 |
val lift = nlift 0
|
|
183 |
fun adjust_tm n t = if n=0 then t else adjust_tm (n-1) (App (t, Var (n-1)))
|
|
184 |
fun adjust_term n t = adjust_tm n (lift n t)
|
|
185 |
fun adjust_guard n (Guard (a,b)) = Guard (adjust_term n a, adjust_term n b)
|
|
186 |
in
|
|
187 |
if len = arity then
|
|
188 |
rule
|
|
189 |
else if arity >= len then
|
|
190 |
(map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
|
|
191 |
else (raise Compile "internal error in adjust_rule")
|
|
192 |
end
|
|
193 |
fun beta_guard (Guard (a,b)) = Guard (beta a, beta b)
|
|
194 |
fun beta_rule (prems, p, t) = ((map beta_guard prems, p, beta t) handle Match => raise Compile "beta_rule")
|
|
195 |
in
|
|
196 |
(arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
|
|
197 |
end
|
|
198 |
|
|
199 |
fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count =
|
|
200 |
let
|
|
201 |
fun str x = string_of_int x
|
|
202 |
fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
|
|
203 |
val module_prefix = (case module of NONE => "" | SOME s => s^".")
|
|
204 |
fun print_apps d f [] = f
|
|
205 |
| print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
|
|
206 |
and print_call d (App (a, b)) args = print_call d a (b::args)
|
|
207 |
| print_call d (Const c) args =
|
|
208 |
(case arity_of c of
|
|
209 |
NONE => print_apps d (module_prefix^"Const "^(str c)) args
|
|
210 |
| SOME 0 => module_prefix^"C"^(str c)
|
|
211 |
| SOME a =>
|
|
212 |
let
|
|
213 |
val len = length args
|
|
214 |
in
|
|
215 |
if a <= len then
|
|
216 |
let
|
|
217 |
val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a)
|
|
218 |
val _ = if strict_a > a then raise Compile "strict" else ()
|
|
219 |
val s = module_prefix^"c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a))))
|
|
220 |
val s = s^(concat (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a))))
|
|
221 |
in
|
|
222 |
print_apps d s (List.drop (args, a))
|
|
223 |
end
|
|
224 |
else
|
|
225 |
let
|
|
226 |
fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1)))
|
|
227 |
fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
|
|
228 |
fun append_args [] t = t
|
|
229 |
| append_args (c::cs) t = append_args cs (App (t, c))
|
|
230 |
in
|
|
231 |
print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
|
|
232 |
end
|
|
233 |
end)
|
|
234 |
| print_call d t args = print_apps d (print_term d t) args
|
|
235 |
and print_term d (Var x) =
|
|
236 |
if x < d then
|
|
237 |
"b"^(str (d-x-1))
|
|
238 |
else
|
|
239 |
let
|
|
240 |
val n = pattern_var_count - (x-d) - 1
|
|
241 |
val x = "x"^(str n)
|
|
242 |
in
|
|
243 |
if n < pattern_var_count - pattern_lazy_var_count then
|
|
244 |
x
|
|
245 |
else
|
|
246 |
"("^x^" ())"
|
|
247 |
end
|
|
248 |
| print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")"
|
|
249 |
| print_term d t = print_call d t []
|
|
250 |
in
|
|
251 |
print_term 0
|
|
252 |
end
|
|
253 |
|
|
254 |
fun section n = if n = 0 then [] else (section (n-1))@[n-1]
|
|
255 |
|
|
256 |
fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) =
|
|
257 |
let
|
|
258 |
fun str x = Int.toString x
|
|
259 |
fun print_pattern top n PVar = (n+1, "x"^(str n))
|
|
260 |
| print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else ""))
|
|
261 |
| print_pattern top n (PConst (c, args)) =
|
|
262 |
let
|
|
263 |
val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")
|
|
264 |
val (n, s) = print_pattern_list 0 top (n, f) args
|
|
265 |
in
|
|
266 |
(n, s)
|
|
267 |
end
|
|
268 |
and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")")
|
|
269 |
| print_pattern_list' counter top (n, p) (t::ts) =
|
|
270 |
let
|
|
271 |
val (n, t) = print_pattern false n t
|
|
272 |
in
|
|
273 |
print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts
|
|
274 |
end
|
|
275 |
and print_pattern_list counter top (n, p) (t::ts) =
|
|
276 |
let
|
|
277 |
val (n, t) = print_pattern false n t
|
|
278 |
in
|
|
279 |
print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts
|
|
280 |
end
|
|
281 |
val c = (case p of PConst (c, _) => c | _ => raise Match)
|
|
282 |
val (n, pattern) = print_pattern true 0 p
|
|
283 |
val lazy_vars = the (arity_of c) - the (toplevel_arity_of c)
|
|
284 |
fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm
|
|
285 |
fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")"
|
|
286 |
val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(concat (map (fn i => " a"^(str i)) (section (the (arity_of c)))))
|
|
287 |
fun print_guards t [] = print_tm t
|
|
288 |
| print_guards t (g::gs) = "if ("^(print_guard g)^")"^(concat (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch
|
|
289 |
in
|
|
290 |
(if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards))
|
|
291 |
end
|
|
292 |
|
|
293 |
fun group_rules rules =
|
|
294 |
let
|
|
295 |
fun add_rule (r as (_, PConst (c,_), _)) groups =
|
|
296 |
let
|
|
297 |
val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
|
|
298 |
in
|
|
299 |
Inttab.update (c, r::rs) groups
|
|
300 |
end
|
|
301 |
| add_rule _ _ = raise Compile "internal error group_rules"
|
|
302 |
in
|
|
303 |
fold_rev add_rule rules Inttab.empty
|
|
304 |
end
|
|
305 |
|
|
306 |
fun sml_prog name code rules =
|
|
307 |
let
|
|
308 |
val buffer = ref ""
|
|
309 |
fun write s = (buffer := (!buffer)^s)
|
|
310 |
fun writeln s = (write s; write "\n")
|
|
311 |
fun writelist [] = ()
|
|
312 |
| writelist (s::ss) = (writeln s; writelist ss)
|
|
313 |
fun str i = Int.toString i
|
|
314 |
val (inlinetab, rules) = inline_rules rules
|
|
315 |
val (arity, toplevel_arity, rules) = adjust_rules rules
|
|
316 |
val rules = group_rules rules
|
|
317 |
val constants = Inttab.keys arity
|
|
318 |
fun arity_of c = Inttab.lookup arity c
|
|
319 |
fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
|
|
320 |
fun rep_str s n = concat (rep n s)
|
|
321 |
fun indexed s n = s^(str n)
|
|
322 |
fun string_of_tuple [] = ""
|
|
323 |
| string_of_tuple (x::xs) = "("^x^(concat (map (fn s => ", "^s) xs))^")"
|
|
324 |
fun string_of_args [] = ""
|
|
325 |
| string_of_args (x::xs) = x^(concat (map (fn s => " "^s) xs))
|
|
326 |
fun default_case gnum c =
|
|
327 |
let
|
|
328 |
val leftargs = concat (map (indexed " x") (section (the (arity_of c))))
|
|
329 |
val rightargs = section (the (arity_of c))
|
|
330 |
val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa)
|
|
331 |
val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs
|
|
332 |
val right = (indexed "C" c)^" "^(string_of_tuple xs)
|
|
333 |
val debug_lazy = "(print x"^(string_of_int (strict_args - 1))^";"
|
|
334 |
val right = if strict_args < the (arity_of c) then debug_lazy^"raise AM_SML.Run \"unresolved lazy call: "^(string_of_int c)^"\")" else right
|
|
335 |
in
|
|
336 |
(indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right
|
|
337 |
end
|
|
338 |
|
|
339 |
fun eval_rules c =
|
|
340 |
let
|
|
341 |
val arity = the (arity_of c)
|
|
342 |
val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa)
|
|
343 |
fun eval_rule n =
|
|
344 |
let
|
|
345 |
val sc = string_of_int c
|
|
346 |
val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc)
|
|
347 |
fun arg i =
|
|
348 |
let
|
|
349 |
val x = indexed "x" i
|
|
350 |
val x = if i < n then "(eval bounds "^x^")" else x
|
|
351 |
val x = if i < strict_arity then x else "(fn () => "^x^")"
|
|
352 |
in
|
|
353 |
x
|
|
354 |
end
|
|
355 |
val right = "c"^sc^" "^(string_of_args (map arg (section arity)))
|
|
356 |
val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right
|
|
357 |
val right = if arity > 0 then right else "C"^sc
|
|
358 |
in
|
|
359 |
" | eval bounds ("^left^") = "^right
|
|
360 |
end
|
|
361 |
in
|
|
362 |
map eval_rule (rev (section (arity + 1)))
|
|
363 |
end
|
|
364 |
|
|
365 |
fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else ""
|
|
366 |
val _ = writelist [
|
|
367 |
"structure "^name^" = struct",
|
|
368 |
"",
|
|
369 |
"datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)",
|
|
370 |
" "^(concat (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)),
|
|
371 |
""]
|
|
372 |
fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")"
|
|
373 |
fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^
|
|
374 |
(case the (arity_of c) of
|
|
375 |
0 => "true"
|
|
376 |
| n =>
|
|
377 |
let
|
|
378 |
val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n)
|
|
379 |
val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs))
|
|
380 |
in
|
|
381 |
eq^(concat eqs)
|
|
382 |
end)
|
|
383 |
val _ = writelist [
|
|
384 |
"fun term_eq (Const c1) (Const c2) = (c1 = c2)",
|
|
385 |
" | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"]
|
|
386 |
val _ = writelist (map make_term_eq constants)
|
|
387 |
val _ = writelist [
|
|
388 |
" | term_eq _ _ = false",
|
|
389 |
""
|
|
390 |
]
|
|
391 |
val _ = writelist [
|
|
392 |
"fun app (Abs a) b = a b",
|
|
393 |
" | app a b = App (a, b)",
|
|
394 |
""]
|
|
395 |
fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else [])
|
|
396 |
fun writefundecl [] = ()
|
|
397 |
| writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs)))
|
|
398 |
fun list_group c = (case Inttab.lookup rules c of
|
|
399 |
NONE => [defcase 0 c]
|
|
400 |
| SOME rs =>
|
|
401 |
let
|
|
402 |
val rs =
|
|
403 |
fold
|
|
404 |
(fn r =>
|
|
405 |
fn rs =>
|
|
406 |
let
|
|
407 |
val (gnum, l, rs) =
|
|
408 |
(case rs of
|
|
409 |
[] => (0, [], [])
|
|
410 |
| (gnum, l)::rs => (gnum, l, rs))
|
|
411 |
val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r
|
|
412 |
in
|
|
413 |
if gnum' = gnum then
|
|
414 |
(gnum, r::l)::rs
|
|
415 |
else
|
|
416 |
let
|
|
417 |
val args = concat (map (fn i => " a"^(str i)) (section (the (arity_of c))))
|
|
418 |
fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args
|
|
419 |
val s = gnumc (gnum) ^ " = " ^ gnumc (gnum')
|
|
420 |
in
|
|
421 |
(gnum', [])::(gnum, s::r::l)::rs
|
|
422 |
end
|
|
423 |
end)
|
|
424 |
rs []
|
|
425 |
val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs)
|
|
426 |
in
|
|
427 |
rev (map (fn z => rev (snd z)) rs)
|
|
428 |
end)
|
|
429 |
val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants)
|
|
430 |
val _ = writelist [
|
|
431 |
"fun convert (Const i) = AM_SML.Const i",
|
|
432 |
" | convert (App (a, b)) = AM_SML.App (convert a, convert b)",
|
|
433 |
" | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""]
|
|
434 |
fun make_convert c =
|
|
435 |
let
|
|
436 |
val args = map (indexed "a") (section (the (arity_of c)))
|
|
437 |
val leftargs =
|
|
438 |
case args of
|
|
439 |
[] => ""
|
|
440 |
| (x::xs) => "("^x^(concat (map (fn s => ", "^s) xs))^")"
|
|
441 |
val args = map (indexed "convert a") (section (the (arity_of c)))
|
|
442 |
val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c))
|
|
443 |
in
|
|
444 |
" | convert (C"^(str c)^" "^leftargs^") = "^right
|
|
445 |
end
|
|
446 |
val _ = writelist (map make_convert constants)
|
|
447 |
val _ = writelist [
|
|
448 |
"",
|
|
449 |
"fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)",
|
|
450 |
" | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"]
|
|
451 |
val _ = map (writelist o eval_rules) constants
|
|
452 |
val _ = writelist [
|
|
453 |
" | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)",
|
|
454 |
" | eval bounds (AbstractMachine.Const c) = Const c"]
|
|
455 |
val _ = writelist [
|
|
456 |
"",
|
|
457 |
"fun export term = AM_SML.save_result (\""^code^"\", convert term)",
|
|
458 |
"",
|
|
459 |
"val _ = AM_SML.set_compiled_rewriter (fn t => convert (eval [] t))",
|
|
460 |
"",
|
|
461 |
"end"]
|
|
462 |
in
|
|
463 |
(arity, toplevel_arity, inlinetab, !buffer)
|
|
464 |
end
|
|
465 |
|
|
466 |
val guid_counter = ref 0
|
|
467 |
fun get_guid () =
|
|
468 |
let
|
|
469 |
val c = !guid_counter
|
|
470 |
val _ = guid_counter := !guid_counter + 1
|
|
471 |
in
|
|
472 |
(LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
|
|
473 |
end
|
|
474 |
|
|
475 |
|
|
476 |
fun writeTextFile name s = File.write (Path.explode name) s
|
|
477 |
|
|
478 |
fun use_source src = use_text "" Output.ml_output false src
|
|
479 |
|
|
480 |
fun compile eqs =
|
|
481 |
let
|
|
482 |
val guid = get_guid ()
|
|
483 |
val code = Real.toString (random ())
|
|
484 |
val module = "AMSML_"^guid
|
|
485 |
val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs
|
|
486 |
(* val _ = writeTextFile "Gencode.ML" source*)
|
|
487 |
val _ = compiled_rewriter := NONE
|
|
488 |
val _ = use_source source
|
|
489 |
in
|
|
490 |
case !compiled_rewriter of
|
|
491 |
NONE => raise Compile "broken link to compiled function"
|
|
492 |
| SOME f => (module, code, arity, toplevel_arity, inlinetab, f)
|
|
493 |
end
|
|
494 |
|
|
495 |
|
|
496 |
fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t =
|
|
497 |
let
|
|
498 |
val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
|
|
499 |
fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
|
|
500 |
| inline (Var i) = Var i
|
|
501 |
| inline (App (a, b)) = App (inline a, inline b)
|
|
502 |
| inline (Abs m) = Abs (inline m)
|
|
503 |
val t = beta (inline t)
|
|
504 |
fun arity_of c = Inttab.lookup arity c
|
|
505 |
fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
|
|
506 |
val term = print_term NONE arity_of toplevel_arity_of 0 0 t
|
|
507 |
val source = "local open "^module^" in val _ = export ("^term^") end"
|
|
508 |
val _ = writeTextFile "Gencode_call.ML" source
|
|
509 |
val _ = clear_result ()
|
|
510 |
val _ = use_source source
|
|
511 |
in
|
|
512 |
case !saved_result of
|
|
513 |
NONE => raise Run "broken link to compiled code"
|
|
514 |
| SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked")
|
|
515 |
end
|
|
516 |
|
|
517 |
fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t =
|
|
518 |
let
|
|
519 |
val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
|
|
520 |
fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
|
|
521 |
| inline (Var i) = Var i
|
|
522 |
| inline (App (a, b)) = App (inline a, inline b)
|
|
523 |
| inline (Abs m) = Abs (inline m)
|
|
524 |
in
|
|
525 |
compiled_fun (beta (inline t))
|
|
526 |
end
|
|
527 |
|
|
528 |
fun discard p = ()
|
|
529 |
|
|
530 |
end
|