|
1 (* Title: Pure/HOL/inductive_codegen.ML |
|
2 ID: $Id$ |
|
3 Author: Stefan Berghofer |
|
4 Copyright 2000 TU Muenchen |
|
5 |
|
6 Code generator for inductive predicates |
|
7 *) |
|
8 |
|
9 signature INDUCTIVE_CODEGEN = |
|
10 sig |
|
11 val setup : (theory -> theory) list |
|
12 end; |
|
13 |
|
14 structure InductiveCodegen : INDUCTIVE_CODEGEN = |
|
15 struct |
|
16 |
|
17 open Codegen; |
|
18 |
|
19 exception Modes of (string * int list list) list * (string * int list list) list; |
|
20 |
|
21 datatype indprem = Prem of string * term list * term list |
|
22 | Sidecond of term; |
|
23 |
|
24 fun prod_factors p (Const ("Pair", _) $ t $ u) = |
|
25 p :: prod_factors (1::p) t @ prod_factors (2::p) u |
|
26 | prod_factors p _ = []; |
|
27 |
|
28 fun split_prod p ps t = if p mem ps then (case t of |
|
29 Const ("Pair", _) $ t $ u => |
|
30 split_prod (1::p) ps t @ split_prod (2::p) ps u |
|
31 | _ => error "Inconsistent use of products") else [t]; |
|
32 |
|
33 fun string_of_factors p ps = if p mem ps then |
|
34 "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")" |
|
35 else "_"; |
|
36 |
|
37 (**** check if a term contains only constructor functions ****) |
|
38 |
|
39 fun is_constrt thy = |
|
40 let |
|
41 val cnstrs = flat (flat (map |
|
42 (map (fn (_, (_, _, cs)) => map (apsnd length) cs) o #descr o snd) |
|
43 (Symtab.dest (DatatypePackage.get_datatypes thy)))); |
|
44 fun check t = (case strip_comb t of |
|
45 (Var _, []) => true |
|
46 | (Const (s, _), ts) => (case assoc (cnstrs, s) of |
|
47 None => false |
|
48 | Some i => length ts = i andalso forall check ts) |
|
49 | _ => false) |
|
50 in check end; |
|
51 |
|
52 (**** check if a type is an equality type (i.e. doesn't contain fun) ****) |
|
53 |
|
54 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts |
|
55 | is_eqT _ = true; |
|
56 |
|
57 (**** mode inference ****) |
|
58 |
|
59 val term_vs = map (fst o fst o dest_Var) o term_vars; |
|
60 val terms_vs = distinct o flat o (map term_vs); |
|
61 |
|
62 (** collect all Vars in a term (with duplicates!) **) |
|
63 fun term_vTs t = map (apfst fst o dest_Var) |
|
64 (filter is_Var (foldl_aterms (op :: o Library.swap) ([], t))); |
|
65 |
|
66 fun known_args _ _ [] = [] |
|
67 | known_args vs i (t::ts) = if term_vs t subset vs then i::known_args vs (i+1) ts |
|
68 else known_args vs (i+1) ts; |
|
69 |
|
70 fun get_args _ _ [] = ([], []) |
|
71 | get_args is i (x::xs) = (if i mem is then apfst else apsnd) (cons x) |
|
72 (get_args is (i+1) xs); |
|
73 |
|
74 fun merge xs [] = xs |
|
75 | merge [] ys = ys |
|
76 | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys) |
|
77 else y::merge (x::xs) ys; |
|
78 |
|
79 fun subsets i j = if i <= j then |
|
80 let val is = subsets (i+1) j |
|
81 in merge (map (fn ks => i::ks) is) is end |
|
82 else [[]]; |
|
83 |
|
84 fun select_mode_prem thy modes vs ps = |
|
85 find_first (is_some o snd) (ps ~~ map |
|
86 (fn Prem (s, us, args) => find_first (fn is => |
|
87 let |
|
88 val (_, out_ts) = get_args is 1 us; |
|
89 val vTs = flat (map term_vTs out_ts); |
|
90 val dupTs = map snd (duplicates vTs) @ |
|
91 mapfilter (curry assoc vTs) vs; |
|
92 in |
|
93 is subset known_args vs 1 us andalso |
|
94 forall (is_constrt thy) (snd (get_args is 1 us)) andalso |
|
95 terms_vs args subset vs andalso |
|
96 forall is_eqT dupTs |
|
97 end) |
|
98 (the (assoc (modes, s))) |
|
99 | Sidecond t => if term_vs t subset vs then Some [] else None) ps); |
|
100 |
|
101 fun check_mode_clause thy arg_vs modes mode (ts, ps) = |
|
102 let |
|
103 fun check_mode_prems vs [] = Some vs |
|
104 | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of |
|
105 None => None |
|
106 | Some (x, _) => check_mode_prems |
|
107 (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs) |
|
108 (filter_out (equal x) ps)); |
|
109 val (in_ts', _) = get_args mode 1 ts; |
|
110 val in_ts = filter (is_constrt thy) in_ts'; |
|
111 val in_vs = terms_vs in_ts; |
|
112 val concl_vs = terms_vs ts |
|
113 in |
|
114 forall is_eqT (map snd (duplicates (flat (map term_vTs in_ts')))) andalso |
|
115 (case check_mode_prems (arg_vs union in_vs) ps of |
|
116 None => false |
|
117 | Some vs => concl_vs subset vs) |
|
118 end; |
|
119 |
|
120 fun check_modes_pred thy arg_vs preds modes (p, ms) = |
|
121 let val Some rs = assoc (preds, p) |
|
122 in (p, filter (fn m => forall (check_mode_clause thy arg_vs modes m) rs) ms) end |
|
123 |
|
124 fun fixp f x = |
|
125 let val y = f x |
|
126 in if x = y then x else fixp f y end; |
|
127 |
|
128 fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes => |
|
129 map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes) |
|
130 (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds); |
|
131 |
|
132 (**** code generation ****) |
|
133 |
|
134 fun mk_eq (x::xs) = |
|
135 let fun mk_eqs _ [] = [] |
|
136 | mk_eqs a (b::cs) = Pretty.str (a ^ " = " ^ b) :: mk_eqs b cs |
|
137 in mk_eqs x xs end; |
|
138 |
|
139 fun mk_tuple xs = Pretty.block (Pretty.str "(" :: |
|
140 flat (separate [Pretty.str ",", Pretty.brk 1] (map single xs)) @ |
|
141 [Pretty.str ")"]); |
|
142 |
|
143 fun mk_v ((names, vs), s) = (case assoc (vs, s) of |
|
144 None => ((names, (s, [s])::vs), s) |
|
145 | Some xs => let val s' = variant names s in |
|
146 ((s'::names, overwrite (vs, (s, s'::xs))), s') end); |
|
147 |
|
148 fun distinct_v (nvs, Var ((s, 0), T)) = |
|
149 apsnd (Var o rpair T o rpair 0) (mk_v (nvs, s)) |
|
150 | distinct_v (nvs, t $ u) = |
|
151 let |
|
152 val (nvs', t') = distinct_v (nvs, t); |
|
153 val (nvs'', u') = distinct_v (nvs', u); |
|
154 in (nvs'', t' $ u') end |
|
155 | distinct_v x = x; |
|
156 |
|
157 fun compile_match nvs eq_ps out_ps success_p fail_p = |
|
158 let val eqs = flat (separate [Pretty.str " andalso", Pretty.brk 1] |
|
159 (map single (flat (map (mk_eq o snd) nvs) @ eq_ps))); |
|
160 in |
|
161 Pretty.block |
|
162 ([Pretty.str "(fn ", mk_tuple out_ps, Pretty.str " =>", Pretty.brk 1] @ |
|
163 (Pretty.block ((if eqs=[] then [] else Pretty.str "if " :: |
|
164 [Pretty.block eqs, Pretty.brk 1, Pretty.str "then "]) @ |
|
165 (success_p :: |
|
166 (if eqs=[] then [] else [Pretty.brk 1, Pretty.str "else ", fail_p]))) :: |
|
167 [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"])) |
|
168 end; |
|
169 |
|
170 fun modename thy s mode = space_implode "_" |
|
171 (mk_const_id (sign_of thy) s :: map string_of_int mode); |
|
172 |
|
173 fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) = |
|
174 let |
|
175 fun check_constrt ((names, eqs), t) = |
|
176 if is_constrt thy t then ((names, eqs), t) else |
|
177 let val s = variant names "x"; |
|
178 in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end; |
|
179 |
|
180 val (in_ts, out_ts) = get_args mode 1 ts; |
|
181 val ((all_vs', eqs), in_ts') = |
|
182 foldl_map check_constrt ((all_vs, []), in_ts); |
|
183 |
|
184 fun compile_prems out_ts' vs names gr [] = |
|
185 let |
|
186 val (gr2, out_ps) = foldl_map (fn (gr, t) => |
|
187 invoke_codegen thy gr dep false t) (gr, out_ts); |
|
188 val (gr3, eq_ps) = foldl_map (fn (gr, (s, t)) => |
|
189 apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single) |
|
190 (invoke_codegen thy gr dep false t)) (gr2, eqs); |
|
191 val (nvs, out_ts'') = foldl_map distinct_v |
|
192 ((names, map (fn x => (x, [x])) vs), out_ts'); |
|
193 val (gr4, out_ps') = foldl_map (fn (gr, t) => |
|
194 invoke_codegen thy gr dep false t) (gr3, out_ts''); |
|
195 in |
|
196 (gr4, compile_match (snd nvs) eq_ps out_ps' |
|
197 (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps]) |
|
198 (Pretty.str "Seq.empty")) |
|
199 end |
|
200 | compile_prems out_ts vs names gr ps = |
|
201 let |
|
202 val vs' = distinct (flat (vs :: map term_vs out_ts)); |
|
203 val Some (p, Some mode') = |
|
204 select_mode_prem thy modes (arg_vs union vs') ps; |
|
205 val ps' = filter_out (equal p) ps; |
|
206 in |
|
207 (case p of |
|
208 Prem (s, us, args) => |
|
209 let |
|
210 val (in_ts, out_ts') = get_args mode' 1 us; |
|
211 val (gr1, in_ps) = foldl_map (fn (gr, t) => |
|
212 invoke_codegen thy gr dep false t) (gr, in_ts); |
|
213 val (gr2, arg_ps) = foldl_map (fn (gr, t) => |
|
214 invoke_codegen thy gr dep true t) (gr1, args); |
|
215 val (nvs, out_ts'') = foldl_map distinct_v |
|
216 ((names, map (fn x => (x, [x])) vs), out_ts); |
|
217 val (gr3, out_ps) = foldl_map (fn (gr, t) => |
|
218 invoke_codegen thy gr dep false t) (gr2, out_ts'') |
|
219 val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps'; |
|
220 in |
|
221 (gr4, compile_match (snd nvs) [] out_ps |
|
222 (Pretty.block (separate (Pretty.brk 1) |
|
223 (Pretty.str (modename thy s mode') :: arg_ps) @ |
|
224 [Pretty.brk 1, mk_tuple in_ps, |
|
225 Pretty.str " :->", Pretty.brk 1, rest])) |
|
226 (Pretty.str "Seq.empty")) |
|
227 end |
|
228 | Sidecond t => |
|
229 let |
|
230 val (gr1, side_p) = invoke_codegen thy gr dep true t; |
|
231 val (nvs, out_ts') = foldl_map distinct_v |
|
232 ((names, map (fn x => (x, [x])) vs), out_ts); |
|
233 val (gr2, out_ps) = foldl_map (fn (gr, t) => |
|
234 invoke_codegen thy gr dep false t) (gr1, out_ts') |
|
235 val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps'; |
|
236 in |
|
237 (gr3, compile_match (snd nvs) [] out_ps |
|
238 (Pretty.block [Pretty.str "?? ", side_p, |
|
239 Pretty.str " :->", Pretty.brk 1, rest]) |
|
240 (Pretty.str "Seq.empty")) |
|
241 end) |
|
242 end; |
|
243 |
|
244 val (gr', prem_p) = compile_prems in_ts' [] all_vs' gr ps; |
|
245 in |
|
246 (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p]) |
|
247 end; |
|
248 |
|
249 fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode = |
|
250 let val (gr', cl_ps) = foldl_map (fn (gr, cl) => |
|
251 compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls) |
|
252 in |
|
253 ((gr', "and "), Pretty.block |
|
254 ([Pretty.block (separate (Pretty.brk 1) |
|
255 (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @ |
|
256 [Pretty.str " inp ="]), |
|
257 Pretty.brk 1] @ |
|
258 flat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps)))) |
|
259 end; |
|
260 |
|
261 fun compile_preds thy gr dep all_vs arg_vs modes preds = |
|
262 let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) => |
|
263 foldl_map (fn ((gr', prfx'), mode) => |
|
264 compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode) |
|
265 ((gr, prfx), the (assoc (modes, s)))) ((gr, "fun "), preds) |
|
266 in |
|
267 (gr', space_implode "\n\n" (map Pretty.string_of (flat prs)) ^ ";\n\n") |
|
268 end; |
|
269 |
|
270 (**** processing of introduction rules ****) |
|
271 |
|
272 val string_of_mode = enclose "[" "]" o commas o map string_of_int; |
|
273 |
|
274 fun print_modes modes = message ("Inferred modes:\n" ^ |
|
275 space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map |
|
276 string_of_mode ms)) modes)); |
|
277 |
|
278 fun print_factors factors = message ("Factors:\n" ^ |
|
279 space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors)); |
|
280 |
|
281 fun get_modes (Some (Modes x), _) = x |
|
282 | get_modes _ = ([], []); |
|
283 |
|
284 fun mk_ind_def thy gr dep names intrs = |
|
285 let val ids = map (mk_const_id (sign_of thy)) names |
|
286 in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ => |
|
287 let |
|
288 fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) = |
|
289 (case strip_comb u of |
|
290 (Const (name, _), args) => |
|
291 (case InductivePackage.get_inductive thy name of |
|
292 None => (gr, Sidecond t') |
|
293 | Some ({names=names', ...}, {intrs=intrs', ...}) => |
|
294 (if names = names' then gr |
|
295 else mk_ind_def thy gr (hd ids) names' intrs', |
|
296 Prem (name, split_prod [] |
|
297 (the (assoc (factors, name))) t, args))) |
|
298 | _ => (gr, Sidecond t')) |
|
299 | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) = |
|
300 (gr, Prem ("eq", [t, u], [])) |
|
301 | process_prem factors (gr, _ $ t) = (gr, Sidecond t); |
|
302 |
|
303 fun add_clause factors ((clauses, gr), intr) = |
|
304 let |
|
305 val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr; |
|
306 val (Const (name, _), args) = strip_comb u; |
|
307 val (gr', prems) = foldl_map (process_prem factors) |
|
308 (gr, Logic.strip_imp_prems intr); |
|
309 in |
|
310 (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @ |
|
311 [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr') |
|
312 end; |
|
313 |
|
314 fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) = |
|
315 (case strip_comb u of |
|
316 (Const (name, _), _) => |
|
317 let val f = prod_factors [] t |
|
318 in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end |
|
319 | _ => fs) |
|
320 | add_prod_factors (fs, _) = fs; |
|
321 |
|
322 val intrs' = map (rename_term o #prop o rep_thm o standard) intrs; |
|
323 val factors = foldl add_prod_factors ([], flat (map (fn t => |
|
324 Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs')); |
|
325 val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep) |
|
326 (Graph.new_node (hd ids, (None, "")) gr)), intrs'); |
|
327 val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs'); |
|
328 val (_, args) = strip_comb u; |
|
329 val arg_vs = flat (map term_vs args); |
|
330 val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map |
|
331 (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids]))); |
|
332 val modes = infer_modes thy extra_modes arg_vs clauses; |
|
333 val _ = print_modes modes; |
|
334 val _ = print_factors factors; |
|
335 val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs |
|
336 (modes @ extra_modes) clauses; |
|
337 in |
|
338 (Graph.map_node (hd ids) (K (Some (Modes (modes, factors)), s)) gr'') |
|
339 end |
|
340 end; |
|
341 |
|
342 fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of |
|
343 (Const (s, _), args) => (case InductivePackage.get_inductive thy s of |
|
344 None => None |
|
345 | Some ({names, ...}, {intrs, ...}) => |
|
346 let |
|
347 fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1) |
|
348 | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1) |
|
349 | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1); |
|
350 |
|
351 val gr1 = mk_ind_def thy gr dep names intrs; |
|
352 val (modes, factors) = pairself flat (ListPair.unzip |
|
353 (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep]))); |
|
354 val ts = split_prod [] (the (assoc (factors, s))) t; |
|
355 val (ts', mode) = if is_query then |
|
356 fst (foldl mk_mode ((([], []), 1), ts)) |
|
357 else (ts, 1 upto length ts); |
|
358 val _ = if mode mem the (assoc (modes, s)) then () else |
|
359 error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode); |
|
360 val (gr2, in_ps) = foldl_map (fn (gr, t) => |
|
361 invoke_codegen thy gr dep false t) (gr1, ts'); |
|
362 val (gr3, arg_ps) = foldl_map (fn (gr, t) => |
|
363 invoke_codegen thy gr dep true t) (gr2, args); |
|
364 in |
|
365 Some (gr3, Pretty.block (separate (Pretty.brk 1) |
|
366 (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps]))) |
|
367 end) |
|
368 | _ => None); |
|
369 |
|
370 fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) = |
|
371 (case mk_ind_call thy gr dep t u false of |
|
372 None => None |
|
373 | Some (gr', call_p) => Some (gr', (if brack then parens else I) |
|
374 (Pretty.block [Pretty.str "nonempty (", call_p, Pretty.str ")"]))) |
|
375 | inductive_codegen thy gr dep brack (Free ("query", _) $ (Const ("op :", _) $ t $ u)) = |
|
376 mk_ind_call thy gr dep t u true |
|
377 | inductive_codegen thy gr dep brack _ = None; |
|
378 |
|
379 val setup = [add_codegen "inductive" inductive_codegen]; |
|
380 |
|
381 end; |