|
1 fun inst_cterm inst ct = fst (Drule.dest_equals |
|
2 (Thm.cprop_of (Thm.instantiate inst (reflexive ct)))); |
|
3 fun tyinst_cterm tyinst = inst_cterm (tyinst, []); |
|
4 |
|
5 val bla = ref ([] : term list); |
|
6 |
|
7 (******************************************************) |
|
8 (* Code generator for equational proofs *) |
|
9 (******************************************************) |
|
10 fun my_mk_meta_eq thm = |
|
11 let |
|
12 val (_, eq) = Thm.dest_comb (cprop_of thm); |
|
13 val (ct, rhs) = Thm.dest_comb eq; |
|
14 val (_, lhs) = Thm.dest_comb ct |
|
15 in Thm.implies_elim (Drule.instantiate' [Some (ctyp_of_term lhs)] |
|
16 [Some lhs, Some rhs] eq_reflection) thm |
|
17 end; |
|
18 |
|
19 structure SimprocsCodegen = |
|
20 struct |
|
21 |
|
22 val simp_thms = ref ([] : thm list); |
|
23 |
|
24 fun parens b = if b then Pretty.enclose "(" ")" else Pretty.block; |
|
25 |
|
26 fun gen_mk_val f xs ps = Pretty.block ([Pretty.str "val ", |
|
27 f (length xs > 1) (flat |
|
28 (separate [Pretty.str ",", Pretty.brk 1] (map (single o Pretty.str) xs))), |
|
29 Pretty.str " =", Pretty.brk 1] @ ps @ [Pretty.str ";"]); |
|
30 |
|
31 val mk_val = gen_mk_val parens; |
|
32 val mk_vall = gen_mk_val (K (Pretty.enclose "[" "]")); |
|
33 |
|
34 fun rename s = if s mem ThmDatabase.ml_reserved then s ^ "'" else s; |
|
35 |
|
36 fun mk_decomp_name (Var ((s, i), _)) = rename (if i=0 then s else s ^ string_of_int i) |
|
37 | mk_decomp_name (Const (s, _)) = rename (Codegen.mk_id (Sign.base_name s)) |
|
38 | mk_decomp_name _ = "ct"; |
|
39 |
|
40 fun decomp_term_code cn ((vs, bs, ps), (v, t)) = |
|
41 if exists (equal t o fst) bs then (vs, bs, ps) |
|
42 else (case t of |
|
43 Var _ => (vs, bs @ [(t, v)], ps) |
|
44 | Const _ => (vs, if cn then bs @ [(t, v)] else bs, ps) |
|
45 | Bound _ => (vs, bs, ps) |
|
46 | Abs (s, T, t) => |
|
47 let |
|
48 val v1 = variant vs s; |
|
49 val v2 = variant (v1 :: vs) (mk_decomp_name t) |
|
50 in |
|
51 decomp_term_code cn ((v1 :: v2 :: vs, |
|
52 bs @ [(Free (s, T), v1)], |
|
53 ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1, |
|
54 Pretty.str "None", Pretty.brk 1, Pretty.str v]]), (v2, t)) |
|
55 end |
|
56 | t $ u => |
|
57 let |
|
58 val v1 = variant vs (mk_decomp_name t); |
|
59 val v2 = variant (v1 :: vs) (mk_decomp_name u); |
|
60 val (vs', bs', ps') = decomp_term_code cn ((v1 :: v2 :: vs, bs, |
|
61 ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_comb", Pretty.brk 1, |
|
62 Pretty.str v]]), (v1, t)); |
|
63 val (vs'', bs'', ps'') = decomp_term_code cn ((vs', bs', ps'), (v2, u)) |
|
64 in |
|
65 if bs'' = bs then (vs, bs, ps) else (vs'', bs'', ps'') |
|
66 end); |
|
67 |
|
68 val strip_tv = implode o tl o explode; |
|
69 |
|
70 fun mk_decomp_tname (TVar ((s, i), _)) = |
|
71 strip_tv ((if i=0 then s else s ^ string_of_int i) ^ "T") |
|
72 | mk_decomp_tname (Type (s, _)) = Codegen.mk_id (Sign.base_name s) ^ "T" |
|
73 | mk_decomp_tname _ = "cT"; |
|
74 |
|
75 fun decomp_type_code ((vs, bs, ps), (v, TVar (ixn, _))) = |
|
76 if exists (equal ixn o fst) bs then (vs, bs, ps) |
|
77 else (vs, bs @ [(ixn, v)], ps) |
|
78 | decomp_type_code ((vs, bs, ps), (v, Type (_, Ts))) = |
|
79 let |
|
80 val vs' = variantlist (map mk_decomp_tname Ts, vs); |
|
81 val (vs'', bs', ps') = |
|
82 foldl decomp_type_code ((vs @ vs', bs, ps @ |
|
83 [mk_vall vs' [Pretty.str "Thm.dest_ctyp", Pretty.brk 1, |
|
84 Pretty.str v]]), vs' ~~ Ts) |
|
85 in |
|
86 if bs' = bs then (vs, bs, ps) else (vs'', bs', ps') |
|
87 end; |
|
88 |
|
89 fun gen_mk_bindings s dest decomp ((vs, bs, ps), (v, x)) = |
|
90 let |
|
91 val s' = variant vs s; |
|
92 val (vs', bs', ps') = decomp ((s' :: vs, bs, ps @ |
|
93 [mk_val [s'] (dest v)]), (s', x)) |
|
94 in |
|
95 if bs' = bs then (vs, bs, ps) else (vs', bs', ps') |
|
96 end; |
|
97 |
|
98 val mk_term_bindings = gen_mk_bindings "ct" |
|
99 (fn s => [Pretty.str "cprop_of", Pretty.brk 1, Pretty.str s]) |
|
100 (decomp_term_code true); |
|
101 |
|
102 val mk_type_bindings = gen_mk_bindings "cT" |
|
103 (fn s => [Pretty.str "Thm.ctyp_of_term", Pretty.brk 1, Pretty.str s]) |
|
104 decomp_type_code; |
|
105 |
|
106 fun pretty_pattern b (Const (s, _)) = Pretty.block [Pretty.str "Const", |
|
107 Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\", _)")] |
|
108 | pretty_pattern b (t as _ $ _) = parens b |
|
109 (flat (separate [Pretty.str " $", Pretty.brk 1] |
|
110 (map (single o pretty_pattern true) (op :: (strip_comb t))))) |
|
111 | pretty_pattern b _ = Pretty.str "_"; |
|
112 |
|
113 fun term_consts' t = foldl_aterms |
|
114 (fn (cs, c as Const _) => c ins cs | (cs, _) => cs) ([], t); |
|
115 |
|
116 fun mk_apps s b p [] = p |
|
117 | mk_apps s b p (q :: qs) = |
|
118 mk_apps s b (parens (b orelse not (null qs)) |
|
119 [Pretty.str s, Pretty.brk 1, p, Pretty.brk 1, q]) qs; |
|
120 |
|
121 fun mk_refleq eq ct = mk_val [eq] [Pretty.str ("Thm.reflexive " ^ ct)]; |
|
122 |
|
123 fun mk_tyinst ((s, i), s') = |
|
124 Pretty.block [Pretty.str ("((" ^ quote s ^ ","), Pretty.brk 1, |
|
125 Pretty.str (string_of_int i ^ "),"), Pretty.brk 1, |
|
126 Pretty.str (s' ^ ")")]; |
|
127 |
|
128 fun inst_ty b ty_bs t s = (case term_tvars t of |
|
129 [] => Pretty.str s |
|
130 | Ts => parens b [Pretty.str "tyinst_cterm", Pretty.brk 1, |
|
131 Pretty.list "[" "]" (map (fn (ixn, _) => mk_tyinst |
|
132 (ixn, the (assoc (ty_bs, ixn)))) Ts), |
|
133 Pretty.brk 1, Pretty.str s]); |
|
134 |
|
135 fun mk_cterm_code b ty_bs ts xs (vals, t $ u) = |
|
136 let |
|
137 val (vals', p1) = mk_cterm_code true ty_bs ts xs (vals, t); |
|
138 val (vals'', p2) = mk_cterm_code true ty_bs ts xs (vals', u) |
|
139 in |
|
140 (vals'', parens b [Pretty.str "Thm.capply", Pretty.brk 1, |
|
141 p1, Pretty.brk 1, p2]) |
|
142 end |
|
143 | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) = |
|
144 let |
|
145 val u = Free (s, T); |
|
146 val Some s' = assoc (ts, u); |
|
147 val p = Pretty.str s'; |
|
148 val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs) |
|
149 (if null (typ_tvars T) then vals |
|
150 else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t) |
|
151 in (vals', |
|
152 parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p']) |
|
153 end |
|
154 | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs)) |
|
155 | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of |
|
156 None => |
|
157 let val Some s = assoc (ts, t) |
|
158 in (if is_Const t andalso not (null (term_tvars t)) then |
|
159 vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))] |
|
160 else vals, Pretty.str s) |
|
161 end |
|
162 | Some ((_, s), _) => (vals, Pretty.str s)); |
|
163 |
|
164 fun get_cases sg = |
|
165 Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new |
|
166 ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop |
|
167 (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab)) |
|
168 (Symtab.empty, DatatypePackage.get_datatypes_sg sg); |
|
169 |
|
170 fun decomp_case th = |
|
171 let |
|
172 val (lhs, _) = Logic.dest_equals (prop_of th); |
|
173 val (f, ts) = strip_comb lhs; |
|
174 val (us, u) = split_last ts; |
|
175 val (Const (s, _), vs) = strip_comb u |
|
176 in (us, s, vs, u) end; |
|
177 |
|
178 fun rename vs t = |
|
179 let |
|
180 fun mk_subst ((vs, subs), Var ((s, i), T)) = |
|
181 let val s' = variant vs s |
|
182 in if s = s' then (vs, subs) |
|
183 else (s' :: vs, ((s, i), Var ((s', i), T)) :: subs) |
|
184 end; |
|
185 val (vs', subs) = foldl mk_subst ((vs, []), term_vars t) |
|
186 in (vs', subst_Vars subs t) end; |
|
187 |
|
188 fun is_instance sg t u = t = subst_TVars_Vartab |
|
189 (Type.typ_match (Sign.tsig_of sg) (Vartab.empty, |
|
190 (fastype_of u, fastype_of t))) u handle Type.TYPE_MATCH => false; |
|
191 |
|
192 (* |
|
193 fun lookup sg fs t = apsome snd (Library.find_first |
|
194 (is_instance sg t o fst) fs); |
|
195 *) |
|
196 |
|
197 fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of |
|
198 None => (bla := (t ins !bla); None) |
|
199 | Some (_, x) => Some x); |
|
200 |
|
201 fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t); |
|
202 |
|
203 fun mk_let s i xs ys = |
|
204 Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)), |
|
205 Pretty.fbrk, |
|
206 Pretty.blk (i, ([Pretty.str "in", Pretty.fbrk] @ ys)), |
|
207 Pretty.fbrk, Pretty.str "end"]); |
|
208 |
|
209 (*****************************************************************************) |
|
210 (* Generate bindings for simplifying term t *) |
|
211 (* mkeq: whether to generate reflexivity theorem for uninterpreted terms *) |
|
212 (* fs: interpreted functions *) |
|
213 (* ts: atomic terms *) |
|
214 (* vs: used identifiers *) |
|
215 (* vals: list of bindings of the form ((eq, ct), ps) where *) |
|
216 (* eq: name of equational theorem *) |
|
217 (* ct: name of simplified cterm *) |
|
218 (* ps: ML code for creating the above two items *) |
|
219 (*****************************************************************************) |
|
220 |
|
221 fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) = |
|
222 (case assoc (vals, t) of |
|
223 Some ((eq, ct), ps) => (* binding already generated *) |
|
224 if mkeq andalso eq="" then |
|
225 let val eq' = variant vs "eq" |
|
226 in ((eq' :: vs, overwrite (vals, |
|
227 (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct)) |
|
228 end |
|
229 else ((vs, vals), (eq, ct)) |
|
230 | None => (case assoc (ts, t) of |
|
231 Some v => (* atomic term *) |
|
232 let val xs = if not (null (term_tvars t)) andalso is_Const t then |
|
233 [mk_val [v] [inst_ty false ty_bs t v]] else [] |
|
234 in |
|
235 if mkeq then |
|
236 let val eq = variant vs "eq" |
|
237 in ((eq :: vs, vals @ |
|
238 [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v)) |
|
239 end |
|
240 else ((vs, if null xs then vals else vals @ |
|
241 [(t, (("", v), xs))]), ("", v)) |
|
242 end |
|
243 | None => (* complex term *) |
|
244 let val (f as Const (cname, _), us) = strip_comb t |
|
245 in case Symtab.lookup (case_tab, cname) of |
|
246 Some cases => (* case expression *) |
|
247 let |
|
248 val (us', u) = split_last us; |
|
249 val b = unint sg fs u; |
|
250 val ((vs1, vals1), (eq, ct)) = |
|
251 mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u); |
|
252 val xs = variantlist (replicate (length us') "f", vs1); |
|
253 val (vals2, ps) = foldl_map |
|
254 (mk_cterm_code false ty_bs ts []) (vals1, us'); |
|
255 val fvals = map (fn (x, p) => mk_val [x] [p]) (xs ~~ ps); |
|
256 val uT = fastype_of u; |
|
257 val (us'', _, _, u') = decomp_case (hd cases); |
|
258 val (vs2, ty_bs', ty_vals) = mk_type_bindings |
|
259 (mk_type_bindings ((vs1 @ xs, [], []), |
|
260 (hd xs, fastype_of (hd us''))), (ct, fastype_of u')); |
|
261 val insts1 = map mk_tyinst ty_bs'; |
|
262 val i = length vals2; |
|
263 |
|
264 fun mk_case_code ((vs, vals), (f, (name, eqn))) = |
|
265 let |
|
266 val (fvs, cname, cvs, _) = decomp_case eqn; |
|
267 val Ts = binder_types (fastype_of f); |
|
268 val ys = variantlist (map (fst o fst o dest_Var) cvs, vs); |
|
269 val cvs' = map Var (map (rpair 0) ys ~~ Ts); |
|
270 val rs = cvs' ~~ cvs; |
|
271 val lhs = list_comb (Const (cname, Ts ---> uT), cvs'); |
|
272 val rhs = foldl betapply (f, cvs'); |
|
273 val (vs', tm_bs, tm_vals) = decomp_term_code false |
|
274 ((vs @ ys, [], []), (ct, lhs)); |
|
275 val ((vs'', all_vals), (eq', ct')) = mk_simpl_code sg case_tab |
|
276 false fs (tm_bs @ ts) ty_bs thm_bs ((vs', vals), rhs); |
|
277 val (old_vals, eq_vals) = splitAt (i, all_vals); |
|
278 val vs''' = vs @ filter (fn v => exists |
|
279 (fn (_, ((v', _), _)) => v = v') old_vals) (vs'' \\ vs'); |
|
280 val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(", |
|
281 inst_ty false ty_bs' t (the (assoc (thm_bs, t))), Pretty.str ",", |
|
282 Pretty.brk 1, Pretty.str (s ^ ")")]) ((fvs ~~ xs) @ |
|
283 (map (fn (v, s) => (the (assoc (rs, v)), s)) tm_bs)); |
|
284 val eq'' = if null insts1 andalso null insts2 then Pretty.str name |
|
285 else parens (eq' <> "") [Pretty.str |
|
286 (if null cvs then "Thm.instantiate" else "Drule.instantiate"), |
|
287 Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1, |
|
288 Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2, |
|
289 Pretty.str ")", Pretty.brk 1, Pretty.str name]; |
|
290 val eq''' = if eq' = "" then eq'' else |
|
291 Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1, |
|
292 eq'', Pretty.brk 1, Pretty.str eq'] |
|
293 in |
|
294 ((vs''', old_vals), Pretty.block [pretty_pattern false lhs, |
|
295 Pretty.str " =>", |
|
296 Pretty.brk 1, mk_let "let" 2 (tm_vals @ flat (map (snd o snd) eq_vals)) |
|
297 [Pretty.str ("(" ^ ct' ^ ","), Pretty.brk 1, eq''', Pretty.str ")"]]) |
|
298 end; |
|
299 |
|
300 val case_names = map (fn i => Sign.base_name cname ^ "_" ^ |
|
301 string_of_int i) (1 upto length cases); |
|
302 val ((vs3, vals3), case_ps) = foldl_map mk_case_code |
|
303 ((vs2, vals2), us' ~~ (case_names ~~ cases)); |
|
304 val eq' = variant vs3 "eq"; |
|
305 val ct' = variant (eq' :: vs3) "ct"; |
|
306 val eq'' = variant (eq' :: ct' :: vs3) "eq"; |
|
307 val case_vals = |
|
308 fvals @ ty_vals @ |
|
309 [mk_val [ct', eq'] ([Pretty.str "(case", Pretty.brk 1, |
|
310 Pretty.str ("term_of " ^ ct ^ " of"), Pretty.brk 1] @ |
|
311 flat (separate [Pretty.brk 1, Pretty.str "| "] |
|
312 (map single case_ps)) @ [Pretty.str ")"])] |
|
313 in |
|
314 if b then |
|
315 ((eq' :: ct' :: vs3, vals3 @ |
|
316 [(t, ((eq', ct'), case_vals))]), (eq', ct')) |
|
317 else |
|
318 let val ((vs4, vals4), (_, ctcase)) = mk_simpl_code sg case_tab false |
|
319 fs ts ty_bs thm_bs ((eq' :: eq'' :: ct' :: vs3, vals3), f) |
|
320 in |
|
321 ((vs4, vals4 @ [(t, ((eq'', ct'), case_vals @ |
|
322 [mk_val [eq''] [Pretty.str "Thm.transitive", Pretty.brk 1, |
|
323 Pretty.str "(Thm.combination", Pretty.brk 1, |
|
324 Pretty.str "(Thm.reflexive", Pretty.brk 1, |
|
325 mk_apps "Thm.capply" true (Pretty.str ctcase) |
|
326 (map Pretty.str xs), |
|
327 Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"), |
|
328 Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct')) |
|
329 end |
|
330 end |
|
331 |
|
332 | None => |
|
333 let |
|
334 val b = forall (unint sg fs) us; |
|
335 val (q, eqs) = foldl_map |
|
336 (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us); |
|
337 val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b |
|
338 then (q, ("", "")) |
|
339 else mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs (q, f); |
|
340 val ct = variant vs' "ct"; |
|
341 val eq = variant (ct :: vs') "eq"; |
|
342 val ctv = mk_val [ct] [mk_apps "Thm.capply" false |
|
343 (Pretty.str ctf) (map (Pretty.str o snd) eqs)]; |
|
344 fun combp b = mk_apps "Thm.combination" b |
|
345 (Pretty.str eqf) (map (Pretty.str o fst) eqs) |
|
346 in |
|
347 case (lookup sg fs f, b) of |
|
348 (None, true) => (* completely uninterpreted *) |
|
349 if mkeq then ((ct :: eq :: vs', vals' @ |
|
350 [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct)) |
|
351 else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct)) |
|
352 | (None, false) => (* function uninterpreted *) |
|
353 ((eq :: ct :: vs', vals' @ |
|
354 [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct)) |
|
355 | (Some (s, _, _), true) => (* arguments uninterpreted *) |
|
356 ((eq :: ct :: vs', vals' @ |
|
357 [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1) |
|
358 (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct)) |
|
359 | (Some (s, _, _), false) => (* function and arguments interpreted *) |
|
360 let val eq' = variant (eq :: ct :: vs') "eq" |
|
361 in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct), |
|
362 [mk_val [ct, eq] (separate (Pretty.brk 1) |
|
363 (Pretty.str s :: map (Pretty.str o snd) eqs)), |
|
364 mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1, |
|
365 combp true, Pretty.brk 1, Pretty.str eq]]))]), (eq', ct)) |
|
366 end |
|
367 end |
|
368 end)); |
|
369 |
|
370 fun lhs_of thm = fst (Logic.dest_equals (prop_of thm)); |
|
371 fun rhs_of thm = snd (Logic.dest_equals (prop_of thm)); |
|
372 |
|
373 fun mk_funs_code sg case_tab fs fs' = |
|
374 let |
|
375 val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of |
|
376 None => None |
|
377 | Some thms => Some (unsuffix "_case" (Sign.base_name s) ^ ".cases", |
|
378 map (fn i => Sign.base_name s ^ "_" ^ string_of_int i) |
|
379 (1 upto length thms) ~~ thms))) |
|
380 (foldr add_term_consts (map (prop_of o snd) |
|
381 (flat (map (#3 o snd) fs')), [])); |
|
382 val case_vals = map (fn (s, cs) => mk_vall (map fst cs) |
|
383 [Pretty.str "map my_mk_meta_eq", Pretty.brk 1, |
|
384 Pretty.str ("(thms \"" ^ s ^ "\")")]) case_thms; |
|
385 val (vs, thm_bs, thm_vals) = foldl mk_term_bindings (([], [], []), |
|
386 flat (map (map (apsnd prop_of) o #3 o snd) fs') @ |
|
387 map (apsnd prop_of) (flat (map snd case_thms))); |
|
388 |
|
389 fun mk_fun_code (prfx, (fname, d, eqns)) = |
|
390 let |
|
391 val (f, ts) = strip_comb (lhs_of (snd (hd eqns))); |
|
392 val args = variantlist (replicate (length ts) "ct", vs); |
|
393 val (vs', ty_bs, ty_vals) = foldl mk_type_bindings |
|
394 ((vs @ args, [], []), args ~~ map fastype_of ts); |
|
395 val insts1 = map mk_tyinst ty_bs; |
|
396 |
|
397 fun mk_eqn_code (name, eqn) = |
|
398 let |
|
399 val (_, argts) = strip_comb (lhs_of eqn); |
|
400 val (vs'', tm_bs, tm_vals) = foldl (decomp_term_code false) |
|
401 ((vs', [], []), args ~~ argts); |
|
402 val ((vs''', eq_vals), (eq, ct)) = mk_simpl_code sg case_tab false fs |
|
403 (tm_bs @ filter_out (is_Var o fst) thm_bs) ty_bs thm_bs |
|
404 ((vs'', []), rhs_of eqn); |
|
405 val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(", |
|
406 inst_ty false ty_bs t (the (assoc (thm_bs, t))), Pretty.str ",", Pretty.brk 1, |
|
407 Pretty.str (s ^ ")")]) tm_bs |
|
408 val eq' = if null insts1 andalso null insts2 then Pretty.str name |
|
409 else parens (eq <> "") [Pretty.str "Thm.instantiate", |
|
410 Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1, |
|
411 Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2, |
|
412 Pretty.str ")", Pretty.brk 1, Pretty.str name]; |
|
413 val eq'' = if eq = "" then eq' else |
|
414 Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1, |
|
415 eq', Pretty.brk 1, Pretty.str eq] |
|
416 in |
|
417 Pretty.block [parens (length argts > 1) |
|
418 (Pretty.commas (map (pretty_pattern false) argts)), |
|
419 Pretty.str " =>", |
|
420 Pretty.brk 1, mk_let "let" 2 (ty_vals @ tm_vals @ flat (map (snd o snd) eq_vals)) |
|
421 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]] |
|
422 end; |
|
423 |
|
424 val default = if d then |
|
425 let |
|
426 val Some s = assoc (thm_bs, f); |
|
427 val ct = variant vs' "ct" |
|
428 in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1, |
|
429 Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2 |
|
430 (ty_vals @ (if null (term_tvars f) then [] else |
|
431 [mk_val [s] [inst_ty false ty_bs f s]]) @ |
|
432 [mk_val [ct] [mk_apps "Thm.capply" false (Pretty.str s) |
|
433 (map Pretty.str args)]]) |
|
434 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, |
|
435 Pretty.str "Thm.reflexive", Pretty.brk 1, Pretty.str (ct ^ ")")]] |
|
436 end |
|
437 else [] |
|
438 in |
|
439 ("and ", Pretty.block (separate (Pretty.brk 1) |
|
440 (Pretty.str (prfx ^ fname) :: map Pretty.str args) @ |
|
441 [Pretty.str " =", Pretty.brk 1, Pretty.str "(case", Pretty.brk 1, |
|
442 Pretty.list "(" ")" (map (fn s => Pretty.str ("term_of " ^ s)) args), |
|
443 Pretty.str " of", Pretty.brk 1] @ |
|
444 flat (separate [Pretty.brk 1, Pretty.str "| "] |
|
445 (map (single o mk_eqn_code) eqns)) @ [Pretty.str ")"] @ default)) |
|
446 end; |
|
447 |
|
448 val (_, decls) = foldl_map mk_fun_code ("fun ", map snd fs') |
|
449 in |
|
450 mk_let "local" 2 (case_vals @ thm_vals) (separate Pretty.fbrk decls) |
|
451 end; |
|
452 |
|
453 fun mk_simprocs_code sg eqns = |
|
454 let |
|
455 val case_tab = get_cases sg; |
|
456 fun get_head th = head_of (fst (Logic.dest_equals (prop_of th))); |
|
457 fun attach_term (x as (_, _, (_, th) :: _)) = (get_head th, x); |
|
458 val eqns' = map attach_term eqns; |
|
459 fun mk_node (s, _, (_, th) :: _) = (s, get_head th); |
|
460 fun mk_edges (s, _, ths) = map (pair s) (distinct |
|
461 (mapfilter (fn t => apsome #1 (lookup sg eqns' t)) |
|
462 (flat (map (term_consts' o prop_of o snd) ths)))); |
|
463 val gr = foldr (uncurry Graph.add_edge) |
|
464 (map (pair "" o #1) eqns @ flat (map mk_edges eqns), |
|
465 foldr (uncurry Graph.new_node) |
|
466 (("", Bound 0) :: map mk_node eqns, Graph.empty)); |
|
467 val keys = rev (Graph.all_succs gr [""] \ ""); |
|
468 fun gr_ord (x :: _, y :: _) = |
|
469 int_ord (find_index (equal x) keys, find_index (equal y) keys); |
|
470 val scc = map (fn xs => filter (fn (_, (s, _, _)) => s mem xs) eqns') |
|
471 (sort gr_ord (Graph.strong_conn gr \ [""])); |
|
472 in |
|
473 flat (separate [Pretty.str ";", Pretty.fbrk, Pretty.str " ", Pretty.fbrk] |
|
474 (map (fn eqns'' => [mk_funs_code sg case_tab eqns' eqns'']) scc)) @ |
|
475 [Pretty.str ";", Pretty.fbrk] |
|
476 end; |
|
477 |
|
478 fun use_simprocs_code sg eqns = |
|
479 let |
|
480 fun attach_name (i, x) = (i+1, ("simp_thm_" ^ string_of_int i, x)); |
|
481 fun attach_names (i, (s, b, eqs)) = |
|
482 let val (i', eqs') = foldl_map attach_name (i, eqs) |
|
483 in (i', (s, b, eqs')) end; |
|
484 val (_, eqns') = foldl_map attach_names (1, eqns); |
|
485 val (names, thms) = split_list (flat (map #3 eqns')); |
|
486 val s = setmp print_mode [] Pretty.string_of |
|
487 (mk_let "local" 2 [mk_vall names [Pretty.str "!SimprocsCodegen.simp_thms"]] |
|
488 (mk_simprocs_code sg eqns')) |
|
489 in |
|
490 (simp_thms := thms; use_text Context.ml_output false s) |
|
491 end; |
|
492 |
|
493 end; |