|
23663
|
1 |
(* Title: Pure/Tools/compute.ML
|
|
23174
|
2 |
ID: $Id$
|
|
|
3 |
Author: Steven Obua
|
|
|
4 |
*)
|
|
|
5 |
|
|
|
6 |
signature COMPUTE = sig
|
|
|
7 |
|
|
|
8 |
type computer
|
|
|
9 |
|
|
23663
|
10 |
datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML
|
|
23174
|
11 |
|
|
23663
|
12 |
exception Make of string
|
|
|
13 |
val make : machine -> theory -> thm list -> computer
|
|
23174
|
14 |
|
|
23663
|
15 |
exception Compute of string
|
|
23174
|
16 |
val compute : computer -> (int -> string) -> cterm -> term
|
|
|
17 |
val theory_of : computer -> theory
|
|
23663
|
18 |
val hyps_of : computer -> term list
|
|
|
19 |
val shyps_of : computer -> sort list
|
|
23174
|
20 |
|
|
23663
|
21 |
val rewrite_param : computer -> (int -> string) -> cterm -> thm
|
|
|
22 |
val rewrite : computer -> cterm -> thm
|
|
|
23 |
|
|
|
24 |
val discard : computer -> unit
|
|
|
25 |
|
|
|
26 |
val setup : theory -> theory
|
|
|
27 |
|
|
23174
|
28 |
end
|
|
|
29 |
|
|
23663
|
30 |
structure Compute :> COMPUTE = struct
|
|
|
31 |
|
|
|
32 |
datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML
|
|
|
33 |
|
|
|
34 |
(* Terms are mapped to integer codes *)
|
|
|
35 |
structure Encode :>
|
|
|
36 |
sig
|
|
|
37 |
type encoding
|
|
|
38 |
val empty : encoding
|
|
|
39 |
val insert : term -> encoding -> int * encoding
|
|
|
40 |
val lookup_code : term -> encoding -> int option
|
|
|
41 |
val lookup_term : int -> encoding -> term option
|
|
|
42 |
val remove_code : int -> encoding -> encoding
|
|
|
43 |
val remove_term : term -> encoding -> encoding
|
|
|
44 |
val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a
|
|
|
45 |
end
|
|
|
46 |
=
|
|
|
47 |
struct
|
|
|
48 |
|
|
|
49 |
type encoding = int * (int Termtab.table) * (term Inttab.table)
|
|
|
50 |
|
|
|
51 |
val empty = (0, Termtab.empty, Inttab.empty)
|
|
|
52 |
|
|
|
53 |
fun insert t (e as (count, term2int, int2term)) =
|
|
|
54 |
(case Termtab.lookup term2int t of
|
|
|
55 |
NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term))
|
|
|
56 |
| SOME code => (code, e))
|
|
|
57 |
|
|
|
58 |
fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t
|
|
|
59 |
|
|
|
60 |
fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c
|
|
|
61 |
|
|
|
62 |
fun remove_code c (e as (count, term2int, int2term)) =
|
|
|
63 |
(case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term))
|
|
|
64 |
|
|
|
65 |
fun remove_term t (e as (count, term2int, int2term)) =
|
|
|
66 |
(case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term))
|
|
|
67 |
|
|
|
68 |
fun fold f (_, term2int, _) = Termtab.fold f term2int
|
|
|
69 |
|
|
|
70 |
end
|
|
|
71 |
|
|
23174
|
72 |
|
|
|
73 |
exception Make of string;
|
|
23663
|
74 |
exception Compute of string;
|
|
23174
|
75 |
|
|
23663
|
76 |
local
|
|
|
77 |
fun make_constant t ty encoding =
|
|
|
78 |
let
|
|
|
79 |
val (code, encoding) = Encode.insert t encoding
|
|
|
80 |
in
|
|
|
81 |
(encoding, AbstractMachine.Const code)
|
|
|
82 |
end
|
|
|
83 |
in
|
|
23174
|
84 |
|
|
23663
|
85 |
fun remove_types encoding t =
|
|
|
86 |
case t of
|
|
|
87 |
Var (_, ty) => make_constant t ty encoding
|
|
|
88 |
| Free (_, ty) => make_constant t ty encoding
|
|
|
89 |
| Const (_, ty) => make_constant t ty encoding
|
|
|
90 |
| Abs (_, ty, t') =>
|
|
|
91 |
let val (encoding, t'') = remove_types encoding t' in
|
|
|
92 |
(encoding, AbstractMachine.Abs t'')
|
|
|
93 |
end
|
|
|
94 |
| a $ b =>
|
|
|
95 |
let
|
|
|
96 |
val (encoding, a) = remove_types encoding a
|
|
|
97 |
val (encoding, b) = remove_types encoding b
|
|
|
98 |
in
|
|
|
99 |
(encoding, AbstractMachine.App (a,b))
|
|
|
100 |
end
|
|
|
101 |
| Bound b => (encoding, AbstractMachine.Var b)
|
|
|
102 |
end
|
|
|
103 |
|
|
|
104 |
local
|
|
|
105 |
fun type_of (Free (_, ty)) = ty
|
|
|
106 |
| type_of (Const (_, ty)) = ty
|
|
|
107 |
| type_of (Var (_, ty)) = ty
|
|
|
108 |
| type_of _ = sys_error "infer_types: type_of error"
|
|
|
109 |
in
|
|
|
110 |
fun infer_types naming encoding =
|
|
23174
|
111 |
let
|
|
23663
|
112 |
fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v))
|
|
|
113 |
| infer_types _ bounds _ (AbstractMachine.Const code) =
|
|
|
114 |
let
|
|
|
115 |
val c = the (Encode.lookup_term code encoding)
|
|
|
116 |
in
|
|
|
117 |
(c, type_of c)
|
|
|
118 |
end
|
|
|
119 |
| infer_types level bounds _ (AbstractMachine.App (a, b)) =
|
|
|
120 |
let
|
|
|
121 |
val (a, aty) = infer_types level bounds NONE a
|
|
|
122 |
val (adom, arange) =
|
|
23174
|
123 |
case aty of
|
|
|
124 |
Type ("fun", [dom, range]) => (dom, range)
|
|
|
125 |
| _ => sys_error "infer_types: function type expected"
|
|
23663
|
126 |
val (b, bty) = infer_types level bounds (SOME adom) b
|
|
|
127 |
in
|
|
|
128 |
(a $ b, arange)
|
|
|
129 |
end
|
|
|
130 |
| infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) =
|
|
23174
|
131 |
let
|
|
23663
|
132 |
val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m
|
|
23174
|
133 |
in
|
|
23663
|
134 |
(Abs (naming level, dom, m), ty)
|
|
23174
|
135 |
end
|
|
23663
|
136 |
| infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction"
|
|
23174
|
137 |
|
|
23663
|
138 |
fun infer ty term =
|
|
23174
|
139 |
let
|
|
23663
|
140 |
val (term', _) = infer_types 0 [] (SOME ty) term
|
|
23174
|
141 |
in
|
|
|
142 |
term'
|
|
|
143 |
end
|
|
|
144 |
in
|
|
|
145 |
infer
|
|
|
146 |
end
|
|
23663
|
147 |
end
|
|
23174
|
148 |
|
|
23663
|
149 |
datatype prog =
|
|
|
150 |
ProgBarras of AM_Interpreter.program
|
|
|
151 |
| ProgBarrasC of AM_Compiler.program
|
|
|
152 |
| ProgHaskell of AM_GHC.program
|
|
|
153 |
| ProgSML of AM_SML.program
|
|
23174
|
154 |
|
|
23663
|
155 |
structure Sorttab = TableFun(type key = sort val ord = Term.sort_ord)
|
|
|
156 |
|
|
|
157 |
datatype computer = Computer of theory_ref * Encode.encoding * term list * unit Sorttab.table * prog
|
|
|
158 |
|
|
|
159 |
datatype cthm = ComputeThm of term list * sort list * term
|
|
|
160 |
|
|
|
161 |
fun thm2cthm th =
|
|
23174
|
162 |
let
|
|
23663
|
163 |
val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th
|
|
|
164 |
val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else ()
|
|
|
165 |
in
|
|
|
166 |
ComputeThm (hyps, shyps, prop)
|
|
|
167 |
end
|
|
23174
|
168 |
|
|
23663
|
169 |
fun make machine thy raw_ths =
|
|
|
170 |
let
|
|
|
171 |
fun transfer (x:thm) = Thm.transfer thy x
|
|
|
172 |
val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths
|
|
23174
|
173 |
|
|
23663
|
174 |
fun thm2rule (encoding, hyptable, shyptable) th =
|
|
|
175 |
let
|
|
|
176 |
val (ComputeThm (hyps, shyps, prop)) = th
|
|
|
177 |
val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable
|
|
|
178 |
val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable
|
|
|
179 |
val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop)
|
|
|
180 |
val (a, b) = Logic.dest_equals prop
|
|
|
181 |
handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)")
|
|
|
182 |
val a = Envir.eta_contract a
|
|
|
183 |
val b = Envir.eta_contract b
|
|
|
184 |
val prems = map Envir.eta_contract prems
|
|
23174
|
185 |
|
|
23663
|
186 |
val (encoding, left) = remove_types encoding a
|
|
|
187 |
val (encoding, right) = remove_types encoding b
|
|
|
188 |
fun remove_types_of_guard encoding g =
|
|
|
189 |
(let
|
|
|
190 |
val (t1, t2) = Logic.dest_equals g
|
|
|
191 |
val (encoding, t1) = remove_types encoding t1
|
|
|
192 |
val (encoding, t2) = remove_types encoding t2
|
|
|
193 |
in
|
|
|
194 |
(encoding, AbstractMachine.Guard (t1, t2))
|
|
|
195 |
end handle TERM _ => raise (Make "guards must be meta-level equations"))
|
|
|
196 |
val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, [])
|
|
|
197 |
|
|
|
198 |
fun make_pattern encoding n vars (var as AbstractMachine.Abs _) =
|
|
|
199 |
raise (Make "no lambda abstractions allowed in pattern")
|
|
|
200 |
| make_pattern encoding n vars (var as AbstractMachine.Var _) =
|
|
|
201 |
raise (Make "no bound variables allowed in pattern")
|
|
|
202 |
| make_pattern encoding n vars (AbstractMachine.Const code) =
|
|
|
203 |
(case the (Encode.lookup_term code encoding) of
|
|
|
204 |
Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar)
|
|
|
205 |
handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed"))
|
|
|
206 |
| _ => (n, vars, AbstractMachine.PConst (code, [])))
|
|
|
207 |
| make_pattern encoding n vars (AbstractMachine.App (a, b)) =
|
|
23174
|
208 |
let
|
|
23663
|
209 |
val (n, vars, pa) = make_pattern encoding n vars a
|
|
|
210 |
val (n, vars, pb) = make_pattern encoding n vars b
|
|
23174
|
211 |
in
|
|
|
212 |
case pa of
|
|
|
213 |
AbstractMachine.PVar =>
|
|
|
214 |
raise (Make "patterns may not start with a variable")
|
|
|
215 |
| AbstractMachine.PConst (c, args) =>
|
|
23663
|
216 |
(n, vars, AbstractMachine.PConst (c, args@[pb]))
|
|
23174
|
217 |
end
|
|
|
218 |
|
|
23663
|
219 |
(* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule.
|
|
|
220 |
As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore
|
|
|
221 |
this check can be left out. *)
|
|
|
222 |
|
|
|
223 |
val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left
|
|
23174
|
224 |
val _ = (case pattern of
|
|
23663
|
225 |
AbstractMachine.PVar =>
|
|
23174
|
226 |
raise (Make "patterns may not start with a variable")
|
|
23663
|
227 |
(* | AbstractMachine.PConst (_, []) =>
|
|
|
228 |
(print th; raise (Make "no parameter rewrite found"))*)
|
|
|
229 |
| _ => ())
|
|
23174
|
230 |
|
|
|
231 |
(* finally, provide a function for renaming the
|
|
23663
|
232 |
pattern bound variables on the right hand side *)
|
|
23174
|
233 |
|
|
23663
|
234 |
fun rename level vars (var as AbstractMachine.Var _) = var
|
|
|
235 |
| rename level vars (c as AbstractMachine.Const code) =
|
|
|
236 |
(case Inttab.lookup vars code of
|
|
|
237 |
NONE => c
|
|
|
238 |
| SOME n => AbstractMachine.Var (vcount-n-1+level))
|
|
|
239 |
| rename level vars (AbstractMachine.App (a, b)) =
|
|
|
240 |
AbstractMachine.App (rename level vars a, rename level vars b)
|
|
|
241 |
| rename level vars (AbstractMachine.Abs m) =
|
|
|
242 |
AbstractMachine.Abs (rename (level+1) vars m)
|
|
|
243 |
|
|
|
244 |
fun rename_guard (AbstractMachine.Guard (a,b)) =
|
|
|
245 |
AbstractMachine.Guard (rename 0 vars a, rename 0 vars b)
|
|
23174
|
246 |
in
|
|
23663
|
247 |
((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right))
|
|
23174
|
248 |
end
|
|
|
249 |
|
|
23663
|
250 |
val ((encoding, hyptable, shyptable), rules) =
|
|
|
251 |
fold_rev (fn th => fn (encoding_hyptable, rules) =>
|
|
23174
|
252 |
let
|
|
23663
|
253 |
val (encoding_hyptable, rule) = thm2rule encoding_hyptable th
|
|
|
254 |
in (encoding_hyptable, rule::rules) end)
|
|
|
255 |
ths ((Encode.empty, Termtab.empty, Sorttab.empty), [])
|
|
23174
|
256 |
|
|
23663
|
257 |
val prog =
|
|
|
258 |
case machine of
|
|
|
259 |
BARRAS => ProgBarras (AM_Interpreter.compile rules)
|
|
|
260 |
| BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules)
|
|
|
261 |
| HASKELL => ProgHaskell (AM_GHC.compile rules)
|
|
|
262 |
| SML => ProgSML (AM_SML.compile rules)
|
|
23174
|
263 |
|
|
23663
|
264 |
(* val _ = print (Encode.fold (fn x => fn s => x::s) encoding [])*)
|
|
|
265 |
|
|
|
266 |
fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
|
|
23174
|
267 |
|
|
23663
|
268 |
val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable
|
|
|
269 |
|
|
|
270 |
in Computer (Theory.self_ref thy, encoding, Termtab.keys hyptable, shyptable, prog) end
|
|
|
271 |
|
|
|
272 |
(*fun timeit f =
|
|
23174
|
273 |
let
|
|
23663
|
274 |
val t1 = Time.toMicroseconds (Time.now ())
|
|
|
275 |
val x = f ()
|
|
|
276 |
val t2 = Time.toMicroseconds (Time.now ())
|
|
|
277 |
val _ = writeln ("### time = "^(Real.toString ((Real.fromLargeInt t2 - Real.fromLargeInt t1)/(1000000.0)))^"s")
|
|
23174
|
278 |
in
|
|
23663
|
279 |
x
|
|
|
280 |
end*)
|
|
|
281 |
|
|
|
282 |
fun report s f = f () (*writeln s; timeit f*)
|
|
23174
|
283 |
|
|
23663
|
284 |
fun compute (Computer (rthy, encoding, hyps, shyptable, prog)) naming ct =
|
|
23174
|
285 |
let
|
|
23663
|
286 |
fun run (ProgBarras p) = AM_Interpreter.run p
|
|
|
287 |
| run (ProgBarrasC p) = AM_Compiler.run p
|
|
|
288 |
| run (ProgHaskell p) = AM_GHC.run p
|
|
|
289 |
| run (ProgSML p) = AM_SML.run p
|
|
23174
|
290 |
val {t=t, T=ty, thy=ctthy, ...} = rep_cterm ct
|
|
|
291 |
val thy = Theory.merge (Theory.deref rthy, ctthy)
|
|
23663
|
292 |
val (encoding, t) = report "remove_types" (fn () => remove_types encoding t)
|
|
|
293 |
val t = report "run" (fn () => run prog t)
|
|
|
294 |
val t = report "infer_types" (fn () => infer_types naming encoding ty t)
|
|
23174
|
295 |
in
|
|
|
296 |
t
|
|
|
297 |
end
|
|
|
298 |
|
|
23663
|
299 |
fun discard (Computer (rthy, encoding, hyps, shyptable, prog)) =
|
|
|
300 |
(case prog of
|
|
|
301 |
ProgBarras p => AM_Interpreter.discard p
|
|
|
302 |
| ProgBarrasC p => AM_Compiler.discard p
|
|
|
303 |
| ProgHaskell p => AM_GHC.discard p
|
|
|
304 |
| ProgSML p => AM_SML.discard p)
|
|
|
305 |
|
|
|
306 |
fun theory_of (Computer (rthy, _, _,_,_)) = Theory.deref rthy
|
|
|
307 |
fun hyps_of (Computer (_, _, hyps, _, _)) = hyps
|
|
|
308 |
fun shyps_of (Computer (_, _, _, shyptable, _)) = Sorttab.keys (shyptable)
|
|
|
309 |
fun shyptab_of (Computer (_, _, _, shyptable, _)) = shyptable
|
|
23174
|
310 |
|
|
|
311 |
fun default_naming i = "v_" ^ Int.toString i
|
|
|
312 |
|
|
23663
|
313 |
exception Param of computer * (int -> string) * cterm;
|
|
23174
|
314 |
|
|
23663
|
315 |
fun rewrite_param r n ct =
|
|
|
316 |
let
|
|
|
317 |
val thy = theory_of_cterm ct
|
|
|
318 |
val th = timeit (fn () => invoke_oracle_i thy "Compute_Oracle.compute" (thy, Param (r, n, ct)))
|
|
|
319 |
val hyps = map (fn h => assume (cterm_of thy h)) (hyps_of r)
|
|
|
320 |
in
|
|
|
321 |
fold (fn h => fn p => implies_elim p h) hyps th
|
|
|
322 |
end
|
|
|
323 |
|
|
|
324 |
(*fun rewrite_param r n ct =
|
|
|
325 |
let
|
|
|
326 |
val hyps = hyps_of r
|
|
|
327 |
val shyps = shyps_of r
|
|
|
328 |
val thy = theory_of_cterm ct
|
|
|
329 |
val _ = Theory.assert_super (theory_of r) thy
|
|
|
330 |
val t' = timeit (fn () => compute r n ct)
|
|
|
331 |
val eq = Logic.mk_equals (term_of ct, t')
|
|
|
332 |
in
|
|
|
333 |
Thm.unchecked_oracle thy "Compute.compute" (eq, hyps, shyps)
|
|
|
334 |
end*)
|
|
|
335 |
|
|
|
336 |
fun rewrite r ct = rewrite_param r default_naming ct
|
|
|
337 |
|
|
|
338 |
(* theory setup *)
|
|
|
339 |
|
|
|
340 |
fun compute_oracle (thy, Param (r, naming, ct)) =
|
|
23174
|
341 |
let
|
|
|
342 |
val _ = Theory.assert_super (theory_of r) thy
|
|
|
343 |
val t' = compute r naming ct
|
|
23663
|
344 |
val eq = Logic.mk_equals (term_of ct, t')
|
|
|
345 |
val hyps = hyps_of r
|
|
|
346 |
val shyptab = shyptab_of r
|
|
|
347 |
fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab
|
|
|
348 |
fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab
|
|
|
349 |
val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (eq::hyps) shyptab)
|
|
|
350 |
val _ = if not (null shyps) then raise Compute ("dangling sort hypotheses: "^(makestring shyps)) else ()
|
|
23174
|
351 |
in
|
|
23663
|
352 |
fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps eq
|
|
23174
|
353 |
end
|
|
23663
|
354 |
| compute_oracle _ = raise Match
|
|
|
355 |
|
|
|
356 |
|
|
|
357 |
val setup = (fn thy => (writeln "install oracle"; Theory.add_oracle ("compute", compute_oracle) thy))
|
|
|
358 |
|
|
|
359 |
(*val _ = Context.add_setup (Theory.add_oracle ("compute", compute_oracle))*)
|
|
23174
|
360 |
|
|
|
361 |
end
|
|
23663
|
362 |
|