author | haftmann |
Tue, 18 Sep 2007 07:36:10 +0200 | |
changeset 24619 | c2e6a0f8c30b |
parent 24612 | d1b315bdb8d7 |
child 24634 | 38db11874724 |
permissions | -rw-r--r-- |
24590 | 1 |
(* Title: Tools/nbe.ML |
24155 | 2 |
ID: $Id$ |
3 |
Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen |
|
4 |
||
5 |
Evaluation mechanisms for normalization by evaluation. |
|
6 |
*) |
|
7 |
||
8 |
(* |
|
9 |
FIXME: |
|
10 |
- get rid of BVar (?) - it is only used for terms to be evaluated, not for functions |
|
11 |
- proper purge operation - preliminary for... |
|
12 |
- really incremental code generation |
|
13 |
*) |
|
14 |
||
15 |
signature NBE = |
|
16 |
sig |
|
17 |
datatype Univ = |
|
24381 | 18 |
Const of string * Univ list (*named (uninterpreted) constants*) |
24155 | 19 |
| Free of string * Univ list |
20 |
| BVar of int * Univ list |
|
21 |
| Abs of (int * (Univ list -> Univ)) * Univ list; |
|
22 |
val free: string -> Univ list -> Univ (*free (uninterpreted) variables*) |
|
23 |
val abs: int -> (Univ list -> Univ) -> Univ list -> Univ |
|
24 |
(*abstractions as functions*) |
|
25 |
val app: Univ -> Univ -> Univ (*explicit application*) |
|
26 |
||
24590 | 27 |
val univs_ref: (unit -> Univ list) ref |
24423
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset
|
28 |
val lookup_fun: string -> Univ |
24155 | 29 |
|
30 |
val normalization_conv: cterm -> thm |
|
31 |
||
32 |
val trace: bool ref |
|
33 |
val setup: theory -> theory |
|
34 |
end; |
|
35 |
||
36 |
structure Nbe: NBE = |
|
37 |
struct |
|
38 |
||
39 |
(* generic non-sense *) |
|
40 |
||
41 |
val trace = ref false; |
|
42 |
fun tracing f x = if !trace then (Output.tracing (f x); x) else x; |
|
43 |
||
44 |
||
45 |
(** the semantical universe **) |
|
46 |
||
47 |
(* |
|
48 |
Functions are given by their semantical function value. To avoid |
|
49 |
trouble with the ML-type system, these functions have the most |
|
50 |
generic type, that is "Univ list -> Univ". The calling convention is |
|
51 |
that the arguments come as a list, the last argument first. In |
|
52 |
other words, a function call that usually would look like |
|
53 |
||
54 |
f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n) |
|
55 |
||
56 |
would be in our convention called as |
|
57 |
||
58 |
f [x_n,..,x_2,x_1] |
|
59 |
||
60 |
Moreover, to handle functions that are still waiting for some |
|
61 |
arguments we have additionally a list of arguments collected to far |
|
62 |
and the number of arguments we're still waiting for. |
|
63 |
*) |
|
64 |
||
65 |
datatype Univ = |
|
24381 | 66 |
Const of string * Univ list (*named (uninterpreted) constants*) |
24155 | 67 |
| Free of string * Univ list (*free variables*) |
68 |
| BVar of int * Univ list (*bound named variables*) |
|
69 |
| Abs of (int * (Univ list -> Univ)) * Univ list |
|
24381 | 70 |
(*abstractions as closures*); |
24155 | 71 |
|
72 |
(* constructor functions *) |
|
73 |
||
74 |
val free = curry Free; |
|
75 |
fun abs n f ts = Abs ((n, f), ts); |
|
76 |
fun app (Abs ((1, f), xs)) x = f (x :: xs) |
|
77 |
| app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs) |
|
78 |
| app (Const (name, args)) x = Const (name, x :: args) |
|
79 |
| app (Free (name, args)) x = Free (name, x :: args) |
|
80 |
| app (BVar (name, args)) x = BVar (name, x :: args); |
|
81 |
||
82 |
(* global functions store *) |
|
83 |
||
84 |
structure Nbe_Functions = CodeDataFun |
|
85 |
(struct |
|
86 |
type T = Univ Symtab.table; |
|
87 |
val empty = Symtab.empty; |
|
88 |
fun merge _ = Symtab.merge (K true); |
|
89 |
fun purge _ _ _ = Symtab.empty; |
|
90 |
end); |
|
91 |
||
92 |
(* sandbox communication *) |
|
93 |
||
24590 | 94 |
val univs_ref = ref (fn () => [] : Univ list); |
24155 | 95 |
|
96 |
local |
|
97 |
||
98 |
val tab_ref = ref NONE : Univ Symtab.table option ref; |
|
99 |
||
100 |
in |
|
101 |
||
102 |
fun lookup_fun s = case ! tab_ref |
|
103 |
of NONE => error "compile_univs" |
|
104 |
| SOME tab => (the o Symtab.lookup tab) s; |
|
105 |
||
106 |
fun compile_univs tab ([], _) = [] |
|
107 |
| compile_univs tab (cs, raw_s) = |
|
108 |
let |
|
24590 | 109 |
val _ = univs_ref := (fn () => []); |
24155 | 110 |
val s = "Nbe.univs_ref := " ^ raw_s; |
24381 | 111 |
val _ = tracing (fn () => "\n--- generated code:\n" ^ s) (); |
24155 | 112 |
val _ = tab_ref := SOME tab; |
113 |
val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", |
|
114 |
Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") |
|
115 |
(!trace) s; |
|
116 |
val _ = tab_ref := NONE; |
|
24590 | 117 |
val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs; |
24155 | 118 |
in cs ~~ univs end; |
119 |
||
120 |
end; (*local*) |
|
121 |
||
122 |
||
123 |
(** assembling and compiling ML code from terms **) |
|
124 |
||
125 |
(* abstract ML syntax *) |
|
126 |
||
127 |
infix 9 `$` `$$`; |
|
128 |
fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; |
|
129 |
fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; |
|
24590 | 130 |
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; |
24155 | 131 |
|
132 |
fun ml_Val v s = "val " ^ v ^ " = " ^ s; |
|
133 |
fun ml_cases t cs = |
|
134 |
"(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; |
|
135 |
fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end"; |
|
136 |
||
137 |
fun ml_list es = "[" ^ commas es ^ "]"; |
|
138 |
||
24590 | 139 |
val ml_delay = ml_abs "()" |
140 |
||
24155 | 141 |
fun ml_fundefs ([(name, [([], e)])]) = |
142 |
"val " ^ name ^ " = " ^ e ^ "\n" |
|
143 |
| ml_fundefs (eqs :: eqss) = |
|
144 |
let |
|
145 |
fun fundef (name, eqs) = |
|
146 |
let |
|
147 |
fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e |
|
148 |
in space_implode "\n | " (map eqn eqs) end; |
|
149 |
in |
|
150 |
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss |
|
151 |
|> space_implode "\n" |
|
152 |
|> suffix "\n" |
|
153 |
end; |
|
154 |
||
155 |
(* nbe specific syntax *) |
|
156 |
||
157 |
local |
|
158 |
val prefix = "Nbe."; |
|
159 |
val name_const = prefix ^ "Const"; |
|
160 |
val name_free = prefix ^ "free"; |
|
161 |
val name_abs = prefix ^ "abs"; |
|
162 |
val name_app = prefix ^ "app"; |
|
163 |
val name_lookup_fun = prefix ^ "lookup_fun"; |
|
164 |
in |
|
165 |
||
166 |
fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")"); |
|
167 |
fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c; |
|
168 |
fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []]; |
|
169 |
fun nbe_bound v = "v_" ^ v; |
|
170 |
||
171 |
fun nbe_apps e es = |
|
172 |
Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e); |
|
173 |
||
174 |
fun nbe_abss 0 f = f `$` ml_list [] |
|
175 |
| nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []]; |
|
176 |
||
177 |
fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c); |
|
178 |
||
179 |
val nbe_value = "value"; |
|
180 |
||
181 |
end; |
|
182 |
||
24219 | 183 |
open BasicCodeThingol; |
24155 | 184 |
|
185 |
(* greetings to Tarski *) |
|
186 |
||
187 |
fun assemble_iterm thy is_fun num_args = |
|
188 |
let |
|
189 |
fun of_iterm t = |
|
190 |
let |
|
24219 | 191 |
val (t', ts) = CodeThingol.unfold_app t |
24347 | 192 |
in of_iapp t' (fold (cons o of_iterm) ts []) end |
193 |
and of_iconst c ts = case num_args c |
|
194 |
of SOME n => if n <= length ts |
|
195 |
then let val (args2, args1) = chop (length ts - n) ts |
|
196 |
in nbe_apps (nbe_fun c `$` ml_list args1) args2 |
|
197 |
end else nbe_const c ts |
|
198 |
| NONE => if is_fun c then nbe_apps (nbe_fun c) ts |
|
199 |
else nbe_const c ts |
|
200 |
and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts |
|
201 |
| of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts |
|
202 |
| of_iapp ((v, _) `|-> t) ts = |
|
24155 | 203 |
nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts |
24347 | 204 |
| of_iapp (ICase (((t, _), cs), t0)) ts = |
24155 | 205 |
nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs |
206 |
@ [("_", of_iterm t0)])) ts |
|
207 |
in of_iterm end; |
|
208 |
||
209 |
fun assemble_fun thy is_fun num_args (c, eqns) = |
|
210 |
let |
|
211 |
val assemble_arg = assemble_iterm thy (K false) (K NONE); |
|
212 |
val assemble_rhs = assemble_iterm thy is_fun num_args; |
|
213 |
fun assemble_eqn (args, rhs) = |
|
214 |
([ml_list (map assemble_arg (rev args))], assemble_rhs rhs); |
|
215 |
val default_params = map nbe_bound |
|
216 |
(Name.invent_list [] "a" ((the o num_args) c)); |
|
217 |
val default_eqn = ([ml_list default_params], nbe_const c default_params); |
|
218 |
in map assemble_eqn eqns @ [default_eqn] end; |
|
219 |
||
24347 | 220 |
fun assemble_eqnss thy is_fun ([], deps) = ([], "") |
221 |
| assemble_eqnss thy is_fun (eqnss, deps) = |
|
24155 | 222 |
let |
223 |
val cs = map fst eqnss; |
|
224 |
val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss; |
|
24219 | 225 |
val funs = fold (fold (CodeThingol.fold_constnames |
24155 | 226 |
(insert (op =))) o map snd o snd) eqnss []; |
227 |
val bind_funs = map nbe_lookup (filter is_fun funs); |
|
228 |
val bind_locals = ml_fundefs (map nbe_fun cs ~~ map |
|
229 |
(assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss); |
|
24590 | 230 |
val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args) |
231 |
|> ml_delay; |
|
24155 | 232 |
in (cs, ml_Let (bind_funs @ [bind_locals]) result) end; |
233 |
||
24381 | 234 |
fun assemble_eval thy is_fun (((vs, ty), t), deps) = |
24155 | 235 |
let |
24219 | 236 |
val funs = CodeThingol.fold_constnames (insert (op =)) t []; |
237 |
val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []; |
|
24155 | 238 |
val bind_funs = map nbe_lookup (filter is_fun funs); |
239 |
val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)], |
|
240 |
assemble_iterm thy is_fun (K NONE) t)])]; |
|
24590 | 241 |
val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)] |
242 |
|> ml_delay; |
|
24155 | 243 |
in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end; |
244 |
||
24381 | 245 |
fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) = |
24155 | 246 |
NONE |
24381 | 247 |
| eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) = |
24590 | 248 |
SOME ((name, map fst eqns), deps) |
24347 | 249 |
| eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) = |
24155 | 250 |
NONE |
24347 | 251 |
| eqns_of_stmt ((_, CodeThingol.Datatype _), _) = |
24155 | 252 |
NONE |
24347 | 253 |
| eqns_of_stmt ((_, CodeThingol.Class _), _) = |
24155 | 254 |
NONE |
24347 | 255 |
| eqns_of_stmt ((_, CodeThingol.Classrel _), _) = |
24155 | 256 |
NONE |
24347 | 257 |
| eqns_of_stmt ((_, CodeThingol.Classop _), _) = |
24155 | 258 |
NONE |
24347 | 259 |
| eqns_of_stmt ((_, CodeThingol.Classinst _), _) = |
24155 | 260 |
NONE; |
261 |
||
262 |
fun compile_stmts thy is_fun = |
|
263 |
map_filter eqns_of_stmt |
|
24347 | 264 |
#> split_list |
24155 | 265 |
#> assemble_eqnss thy is_fun |
266 |
#> compile_univs (Nbe_Functions.get thy); |
|
267 |
||
268 |
fun eval_term thy is_fun = |
|
269 |
assemble_eval thy is_fun |
|
270 |
#> compile_univs (Nbe_Functions.get thy) |
|
271 |
#> the_single |
|
272 |
#> snd; |
|
273 |
||
274 |
||
275 |
(** compilation and evaluation **) |
|
276 |
||
277 |
(* ensure global functions *) |
|
278 |
||
279 |
fun ensure_funs thy code = |
|
280 |
let |
|
281 |
fun compile' stmts tab = |
|
282 |
let |
|
283 |
val compiled = compile_stmts thy (Symtab.defined tab) stmts; |
|
284 |
in Nbe_Functions.change thy (fold Symtab.update compiled) end; |
|
285 |
val nbe_tab = Nbe_Functions.get thy; |
|
24347 | 286 |
val stmtss = rev (Graph.strong_conn code) |
287 |
|> (map o map_filter) (fn name => if Symtab.defined nbe_tab name |
|
288 |
then NONE |
|
289 |
else SOME ((name, Graph.get_node code name), Graph.imm_succs code name)) |
|
290 |
|> filter_out null |
|
24155 | 291 |
in fold compile' stmtss nbe_tab end; |
292 |
||
293 |
(* re-conversion *) |
|
294 |
||
295 |
fun term_of_univ thy t = |
|
296 |
let |
|
297 |
fun of_apps bounds (t, ts) = |
|
298 |
fold_map (of_univ bounds) ts |
|
299 |
#>> (fn ts' => list_comb (t, rev ts')) |
|
300 |
and of_univ bounds (Const (name, ts)) typidx = |
|
301 |
let |
|
24423
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset
|
302 |
val SOME c = CodeName.const_rev thy name; |
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset
|
303 |
val T = Code.default_typ thy c; |
24155 | 304 |
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; |
305 |
val typidx' = typidx + maxidx_of_typ T' + 1; |
|
306 |
in of_apps bounds (Term.Const (c, T'), ts) typidx' end |
|
307 |
| of_univ bounds (Free (name, ts)) typidx = |
|
308 |
of_apps bounds (Term.Free (name, dummyT), ts) typidx |
|
309 |
| of_univ bounds (BVar (name, ts)) typidx = |
|
310 |
of_apps bounds (Bound (bounds - name - 1), ts) typidx |
|
311 |
| of_univ bounds (t as Abs _) typidx = |
|
312 |
typidx |
|
313 |
|> of_univ (bounds + 1) (app t (BVar (bounds, []))) |
|
314 |
|-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) |
|
315 |
in of_univ 0 t 0 |> fst end; |
|
316 |
||
317 |
(* evaluation with type reconstruction *) |
|
318 |
||
24381 | 319 |
fun eval thy code t vs_ty_t deps = |
24155 | 320 |
let |
24347 | 321 |
val ty = type_of t; |
24155 | 322 |
fun subst_Frees [] = I |
323 |
| subst_Frees inst = |
|
324 |
Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) |
|
325 |
| t => t); |
|
326 |
val anno_vars = |
|
327 |
subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) |
|
328 |
#> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) |
|
24347 | 329 |
fun constrain t = |
24493
d4380e9b287b
replaced ProofContext.infer_types by general Syntax.check_terms;
wenzelm
parents:
24423
diff
changeset
|
330 |
singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain t ty); |
24155 | 331 |
fun check_tvars t = if null (Term.term_tvars t) then t else |
332 |
error ("Illegal schematic type variables in normalized term: " |
|
333 |
^ setmp show_types true (Sign.string_of_term thy) t); |
|
334 |
in |
|
24381 | 335 |
(vs_ty_t, deps) |
24155 | 336 |
|> eval_term thy (Symtab.defined (ensure_funs thy code)) |
337 |
|> term_of_univ thy |
|
338 |
|> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t) |
|
339 |
|> anno_vars |
|
340 |
|> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t) |
|
341 |
|> constrain |
|
342 |
|> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t) |
|
343 |
|> check_tvars |
|
24381 | 344 |
|> tracing (fn _ => "---\n") |
24155 | 345 |
end; |
346 |
||
347 |
(* evaluation oracle *) |
|
348 |
||
24381 | 349 |
exception Normalization of CodeThingol.code * term |
350 |
* (CodeThingol.typscheme * CodeThingol.iterm) * string list; |
|
24155 | 351 |
|
24381 | 352 |
fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) = |
353 |
Logic.mk_equals (t, eval thy code t vs_ty_t deps); |
|
24155 | 354 |
|
24381 | 355 |
fun normalization_invoke thy code t vs_ty_t deps = |
356 |
Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps)); |
|
24283 | 357 |
(*FIXME get rid of hardwired theory name*) |
24155 | 358 |
|
359 |
fun normalization_conv ct = |
|
360 |
let |
|
361 |
val thy = Thm.theory_of_cterm ct; |
|
24381 | 362 |
fun conv code vs_ty_t deps ct = |
24155 | 363 |
let |
364 |
val t = Thm.term_of ct; |
|
24381 | 365 |
in normalization_invoke thy code t vs_ty_t deps end; |
24219 | 366 |
in CodePackage.eval_conv thy conv ct end; |
24155 | 367 |
|
368 |
(* evaluation command *) |
|
369 |
||
370 |
fun norm_print_term ctxt modes t = |
|
371 |
let |
|
372 |
val thy = ProofContext.theory_of ctxt; |
|
373 |
val ct = Thm.cterm_of thy t; |
|
374 |
val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct; |
|
375 |
val ty = Term.type_of t'; |
|
24612 | 376 |
val p = Library.setmp print_mode (modes @ print_mode_value ()) (fn () => |
24155 | 377 |
Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk, |
378 |
Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) (); |
|
379 |
in Pretty.writeln p end; |
|
380 |
||
381 |
||
382 |
(** Isar setup **) |
|
383 |
||
24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset
|
384 |
fun norm_print_term_cmd (modes, s) state = |
24155 | 385 |
let val ctxt = Toplevel.context_of state |
24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset
|
386 |
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; |
24155 | 387 |
|
388 |
val setup = Theory.add_oracle ("normalization", normalization_oracle) |
|
389 |
||
390 |
local structure P = OuterParse and K = OuterKeyword in |
|
391 |
||
392 |
val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) []; |
|
393 |
||
394 |
val nbeP = |
|
395 |
OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag |
|
396 |
(opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd)); |
|
397 |
||
398 |
val _ = OuterSyntax.add_parsers [nbeP]; |
|
399 |
||
400 |
end; |
|
401 |
||
402 |
end; |