1 (* Author: Lukas Bulwahn, TU Muenchen |
|
2 |
|
3 Preprocessing functions to predicates |
|
4 *) |
|
5 |
|
6 signature PREDICATE_COMPILE_FUN = |
|
7 sig |
|
8 val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory |
|
9 val rewrite_intro : theory -> thm -> thm list |
|
10 val setup_oracle : theory -> theory |
|
11 val pred_of_function : theory -> string -> string option |
|
12 end; |
|
13 |
|
14 structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN = |
|
15 struct |
|
16 |
|
17 |
|
18 (* Oracle for preprocessing *) |
|
19 |
|
20 val (oracle : (string * (cterm -> thm)) option Unsynchronized.ref) = Unsynchronized.ref NONE; |
|
21 |
|
22 fun the_oracle () = |
|
23 case !oracle of |
|
24 NONE => error "Oracle is not setup" |
|
25 | SOME (_, oracle) => oracle |
|
26 |
|
27 val setup_oracle = Thm.add_oracle (Binding.name "pred_compile_preprocessing", I) #-> |
|
28 (fn ora => fn thy => let val _ = (oracle := SOME ora) in thy end) |
|
29 |
|
30 |
|
31 fun is_funtype (Type ("fun", [_, _])) = true |
|
32 | is_funtype _ = false; |
|
33 |
|
34 fun is_Type (Type _) = true |
|
35 | is_Type _ = false |
|
36 |
|
37 (* returns true if t is an application of an datatype constructor *) |
|
38 (* which then consequently would be splitted *) |
|
39 (* else false *) |
|
40 (* |
|
41 fun is_constructor thy t = |
|
42 if (is_Type (fastype_of t)) then |
|
43 (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of |
|
44 NONE => false |
|
45 | SOME info => (let |
|
46 val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info) |
|
47 val (c, _) = strip_comb t |
|
48 in (case c of |
|
49 Const (name, _) => name mem_string constr_consts |
|
50 | _ => false) end)) |
|
51 else false |
|
52 *) |
|
53 |
|
54 (* must be exported in code.ML *) |
|
55 fun is_constr thy = is_some o Code.get_datatype_of_constr thy; |
|
56 |
|
57 (* Table from constant name (string) to term of inductive predicate *) |
|
58 structure Pred_Compile_Preproc = TheoryDataFun |
|
59 ( |
|
60 type T = string Symtab.table; |
|
61 val empty = Symtab.empty; |
|
62 val copy = I; |
|
63 val extend = I; |
|
64 fun merge _ = Symtab.merge (op =); |
|
65 ) |
|
66 |
|
67 fun pred_of_function thy name = Symtab.lookup (Pred_Compile_Preproc.get thy) name |
|
68 |
|
69 fun defined thy = Symtab.defined (Pred_Compile_Preproc.get thy) |
|
70 |
|
71 |
|
72 fun transform_ho_typ (T as Type ("fun", _)) = |
|
73 let |
|
74 val (Ts, T') = strip_type T |
|
75 in if T' = @{typ "bool"} then T else (Ts @ [T']) ---> HOLogic.boolT end |
|
76 | transform_ho_typ t = t |
|
77 |
|
78 fun transform_ho_arg arg = |
|
79 case (fastype_of arg) of |
|
80 (T as Type ("fun", _)) => |
|
81 (case arg of |
|
82 Free (name, _) => Free (name, transform_ho_typ T) |
|
83 | _ => error "I am surprised") |
|
84 | _ => arg |
|
85 |
|
86 fun pred_type T = |
|
87 let |
|
88 val (Ts, T') = strip_type T |
|
89 val Ts' = map transform_ho_typ Ts |
|
90 in |
|
91 (Ts' @ [T']) ---> HOLogic.boolT |
|
92 end; |
|
93 |
|
94 (* FIXME: create new predicate name -- does not avoid nameclashing *) |
|
95 fun pred_of f = |
|
96 let |
|
97 val (name, T) = dest_Const f |
|
98 in |
|
99 if (body_type T = @{typ bool}) then |
|
100 (Free (Long_Name.base_name name ^ "P", T)) |
|
101 else |
|
102 (Free (Long_Name.base_name name ^ "P", pred_type T)) |
|
103 end |
|
104 |
|
105 fun mk_param thy lookup_pred (t as Free (v, _)) = lookup_pred t |
|
106 | mk_param thy lookup_pred t = |
|
107 let |
|
108 val _ = tracing ("called param with " ^ (Syntax.string_of_term_global thy t)) |
|
109 in if Predicate_Compile_Aux.is_predT (fastype_of t) then |
|
110 t |
|
111 else |
|
112 let |
|
113 val (vs, body) = strip_abs t |
|
114 val names = Term.add_free_names body [] |
|
115 val vs_names = Name.variant_list names (map fst vs) |
|
116 val vs' = map2 (curry Free) vs_names (map snd vs) |
|
117 val body' = subst_bounds (rev vs', body) |
|
118 val (f, args) = strip_comb body' |
|
119 val resname = Name.variant (vs_names @ names) "res" |
|
120 val resvar = Free (resname, body_type (fastype_of body')) |
|
121 (*val P = case try lookup_pred f of SOME P => P | NONE => error "mk_param" |
|
122 val pred_body = list_comb (P, args @ [resvar]) |
|
123 *) |
|
124 val pred_body = HOLogic.mk_eq (body', resvar) |
|
125 val param = fold_rev lambda (vs' @ [resvar]) pred_body |
|
126 in param end |
|
127 end |
|
128 (* creates the list of premises for every intro rule *) |
|
129 (* theory -> term -> (string list, term list list) *) |
|
130 |
|
131 fun dest_code_eqn eqn = let |
|
132 val (lhs, rhs) = Logic.dest_equals (Logic.unvarify (Thm.prop_of eqn)) |
|
133 val (func, args) = strip_comb lhs |
|
134 in ((func, args), rhs) end; |
|
135 |
|
136 fun string_of_typ T = Syntax.string_of_typ_global @{theory} T |
|
137 |
|
138 fun string_of_term t = |
|
139 case t of |
|
140 Const (c, T) => "Const (" ^ c ^ ", " ^ string_of_typ T ^ ")" |
|
141 | Free (c, T) => "Free (" ^ c ^ ", " ^ string_of_typ T ^ ")" |
|
142 | Var ((c, i), T) => "Var ((" ^ c ^ ", " ^ string_of_int i ^ "), " ^ string_of_typ T ^ ")" |
|
143 | Bound i => "Bound " ^ string_of_int i |
|
144 | Abs (x, T, t) => "Abs (" ^ x ^ ", " ^ string_of_typ T ^ ", " ^ string_of_term t ^ ")" |
|
145 | t1 $ t2 => "(" ^ string_of_term t1 ^ ") $ (" ^ string_of_term t2 ^ ")" |
|
146 |
|
147 fun ind_package_get_nparams thy name = |
|
148 case try (Inductive.the_inductive (ProofContext.init thy)) name of |
|
149 SOME (_, result) => length (Inductive.params_of (#raw_induct result)) |
|
150 | NONE => error ("No such predicate: " ^ quote name) |
|
151 |
|
152 (* TODO: does not work with higher order functions yet *) |
|
153 fun mk_rewr_eq (func, pred) = |
|
154 let |
|
155 val (argTs, resT) = (strip_type (fastype_of func)) |
|
156 val nctxt = |
|
157 Name.make_context (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) (func $ pred) []) |
|
158 val (argnames, nctxt') = Name.variants (replicate (length argTs) "a") nctxt |
|
159 val ([resname], nctxt'') = Name.variants ["r"] nctxt' |
|
160 val args = map Free (argnames ~~ argTs) |
|
161 val res = Free (resname, resT) |
|
162 in Logic.mk_equals |
|
163 (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res])) |
|
164 end; |
|
165 |
|
166 fun has_split_rule_cname @{const_name "nat_case"} = true |
|
167 | has_split_rule_cname @{const_name "list_case"} = true |
|
168 | has_split_rule_cname _ = false |
|
169 |
|
170 fun has_split_rule_term thy (Const (@{const_name "nat_case"}, _)) = true |
|
171 | has_split_rule_term thy (Const (@{const_name "list_case"}, _)) = true |
|
172 | has_split_rule_term thy _ = false |
|
173 |
|
174 fun has_split_rule_term' thy (Const (@{const_name "If"}, _)) = true |
|
175 | has_split_rule_term' thy (Const (@{const_name "Let"}, _)) = true |
|
176 | has_split_rule_term' thy c = has_split_rule_term thy c |
|
177 |
|
178 fun prepare_split_thm ctxt split_thm = |
|
179 (split_thm RS @{thm iffD2}) |
|
180 |> LocalDefs.unfold ctxt [@{thm atomize_conjL[symmetric]}, |
|
181 @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}] |
|
182 |
|
183 fun find_split_thm thy (Const (name, typ)) = |
|
184 let |
|
185 fun split_name str = |
|
186 case first_field "." str |
|
187 of (SOME (field, rest)) => field :: split_name rest |
|
188 | NONE => [str] |
|
189 val splitted_name = split_name name |
|
190 in |
|
191 if length splitted_name > 0 andalso |
|
192 String.isSuffix "_case" (List.last splitted_name) |
|
193 then |
|
194 (List.take (splitted_name, length splitted_name - 1)) @ ["split"] |
|
195 |> space_implode "." |
|
196 |> PureThy.get_thm thy |
|
197 |> SOME |
|
198 handle ERROR msg => NONE |
|
199 else NONE |
|
200 end |
|
201 | find_split_thm _ _ = NONE |
|
202 |
|
203 fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if} |
|
204 | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *) |
|
205 | find_split_thm' thy c = find_split_thm thy c |
|
206 |
|
207 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t) |
|
208 |
|
209 fun folds_map f xs y = |
|
210 let |
|
211 fun folds_map' acc [] y = [(rev acc, y)] |
|
212 | folds_map' acc (x :: xs) y = |
|
213 maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y) |
|
214 in |
|
215 folds_map' [] xs y |
|
216 end; |
|
217 |
|
218 fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) = |
|
219 let |
|
220 fun mk_prems' (t as Const (name, T)) (names, prems) = |
|
221 if is_constr thy name orelse (is_none (try lookup_pred t)) then |
|
222 [(t, (names, prems))] |
|
223 else [(lookup_pred t, (names, prems))] |
|
224 | mk_prems' (t as Free (f, T)) (names, prems) = |
|
225 [(lookup_pred t, (names, prems))] |
|
226 | mk_prems' (t as Abs _) (names, prems) = |
|
227 if Predicate_Compile_Aux.is_predT (fastype_of t) then |
|
228 [(t, (names, prems))] else error "mk_prems': Abs " |
|
229 (* mk_param *) |
|
230 | mk_prems' t (names, prems) = |
|
231 if Predicate_Compile_Aux.is_constrt thy t then |
|
232 [(t, (names, prems))] |
|
233 else |
|
234 if has_split_rule_term' thy (fst (strip_comb t)) then |
|
235 let |
|
236 val (f, args) = strip_comb t |
|
237 val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f)) |
|
238 (* TODO: contextify things - this line is to unvarify the split_thm *) |
|
239 (*val ((_, [isplit_thm]), _) = Variable.import true [split_thm] (ProofContext.init thy)*) |
|
240 val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm) |
|
241 val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) |
|
242 val subst = Pattern.match thy (split_t, t) (Vartab.empty, Vartab.empty) |
|
243 val (_, split_args) = strip_comb split_t |
|
244 val match = split_args ~~ args |
|
245 fun mk_prems_of_assm assm = |
|
246 let |
|
247 val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm)) |
|
248 val var_names = Name.variant_list names (map fst vTs) |
|
249 val vars = map Free (var_names ~~ (map snd vTs)) |
|
250 val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm')) |
|
251 val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res) |
|
252 in |
|
253 mk_prems' inner_t (var_names @ names, prems' @ prems) |
|
254 end |
|
255 in |
|
256 maps mk_prems_of_assm assms |
|
257 end |
|
258 else |
|
259 let |
|
260 val (f, args) = strip_comb t |
|
261 (* TODO: special procedure for higher-order functions: split arguments in |
|
262 simple types and function types *) |
|
263 val resname = Name.variant names "res" |
|
264 val resvar = Free (resname, body_type (fastype_of t)) |
|
265 val names' = resname :: names |
|
266 fun mk_prems'' (t as Const (c, _)) = |
|
267 if is_constr thy c orelse (is_none (try lookup_pred t)) then |
|
268 folds_map mk_prems' args (names', prems) |> |
|
269 map |
|
270 (fn (argvs, (names'', prems')) => |
|
271 let |
|
272 val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs))) |
|
273 in (names'', prem :: prems') end) |
|
274 else |
|
275 let |
|
276 val pred = lookup_pred t |
|
277 val nparams = get_nparams pred |
|
278 val (params, args) = chop nparams args |
|
279 val params' = map (mk_param thy lookup_pred) params |
|
280 in |
|
281 folds_map mk_prems' args (names', prems) |
|
282 |> map (fn (argvs, (names'', prems')) => |
|
283 let |
|
284 val prem = HOLogic.mk_Trueprop (list_comb (pred, params' @ argvs @ [resvar])) |
|
285 in (names'', prem :: prems') end) |
|
286 end |
|
287 | mk_prems'' (t as Free (_, _)) = |
|
288 let |
|
289 (* higher order argument call *) |
|
290 val pred = lookup_pred t |
|
291 in |
|
292 folds_map mk_prems' args (resname :: names, prems) |
|
293 |> map (fn (argvs, (names', prems')) => |
|
294 let |
|
295 val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs @ [resvar])) |
|
296 in (names', prem :: prems') end) |
|
297 end |
|
298 | mk_prems'' t = |
|
299 error ("Invalid term: " ^ Syntax.string_of_term_global thy t) |
|
300 in |
|
301 map (pair resvar) (mk_prems'' f) |
|
302 end |
|
303 in |
|
304 mk_prems' t (names, prems) |
|
305 end; |
|
306 |
|
307 (* assumption: mutual recursive predicates all have the same parameters. *) |
|
308 fun define_predicates specs thy = |
|
309 if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then |
|
310 ([], thy) |
|
311 else |
|
312 let |
|
313 val consts = map fst specs |
|
314 val eqns = maps snd specs |
|
315 (*val eqns = maps (Predicate_Compile_Preproc_Data.get_specification thy) consts*) |
|
316 (* create prednames *) |
|
317 val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list |
|
318 val argss' = map (map transform_ho_arg) argss |
|
319 val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss')) |
|
320 val preds = map pred_of funs |
|
321 val prednames = map (fst o dest_Free) preds |
|
322 val funnames = map (fst o dest_Const) funs |
|
323 val fun_pred_names = (funnames ~~ prednames) |
|
324 (* mapping from term (Free or Const) to term *) |
|
325 fun lookup_pred (Const (name, T)) = |
|
326 (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of |
|
327 SOME c => Const (c, pred_type T) |
|
328 | NONE => |
|
329 (case AList.lookup op = fun_pred_names name of |
|
330 SOME f => Free (f, pred_type T) |
|
331 | NONE => Const (name, T))) |
|
332 | lookup_pred (Free (name, T)) = |
|
333 if member op = (map fst pnames) name then |
|
334 Free (name, transform_ho_typ T) |
|
335 else |
|
336 Free (name, T) |
|
337 | lookup_pred t = |
|
338 error ("lookup function is not defined for " ^ Syntax.string_of_term_global thy t) |
|
339 |
|
340 (* mapping from term (predicate term, not function term!) to int *) |
|
341 fun get_nparams (Const (name, _)) = |
|
342 the_default 0 (try (ind_package_get_nparams thy) name) |
|
343 | get_nparams (Free (name, _)) = |
|
344 (if member op = prednames name then |
|
345 length pnames |
|
346 else 0) |
|
347 | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t)) |
|
348 |
|
349 (* create intro rules *) |
|
350 |
|
351 fun mk_intros ((func, pred), (args, rhs)) = |
|
352 if (body_type (fastype_of func) = @{typ bool}) then |
|
353 (*TODO: preprocess predicate definition of rhs *) |
|
354 [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))] |
|
355 else |
|
356 let |
|
357 val names = Term.add_free_names rhs [] |
|
358 in mk_prems thy (lookup_pred, get_nparams) rhs (names, []) |
|
359 |> map (fn (resultt, (names', prems)) => |
|
360 Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt])))) |
|
361 end |
|
362 fun mk_rewr_thm (func, pred) = @{thm refl} |
|
363 in |
|
364 case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of |
|
365 NONE => ([], thy) |
|
366 | SOME intr_ts => |
|
367 if is_some (try (map (cterm_of thy)) intr_ts) then |
|
368 let |
|
369 val (ind_result, thy') = |
|
370 Inductive.add_inductive_global (serial ()) |
|
371 {quiet_mode = false, verbose = false, kind = Thm.internalK, |
|
372 alt_name = Binding.empty, coind = false, no_elim = false, |
|
373 no_ind = false, skip_mono = false, fork_mono = false} |
|
374 (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds))) |
|
375 pnames |
|
376 (map (fn x => (Attrib.empty_binding, x)) intr_ts) |
|
377 [] thy |
|
378 val prednames = map (fst o dest_Const) (#preds ind_result) |
|
379 (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *) |
|
380 (* add constants to my table *) |
|
381 val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname) (#intrs ind_result))) prednames |
|
382 val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy' |
|
383 in |
|
384 (specs, thy'') |
|
385 end |
|
386 else |
|
387 let |
|
388 val _ = tracing "Introduction rules of function_predicate are not welltyped" |
|
389 in ([], thy) end |
|
390 end |
|
391 |
|
392 (* preprocessing intro rules - uses oracle *) |
|
393 |
|
394 (* theory -> thm -> thm *) |
|
395 fun rewrite_intro thy intro = |
|
396 let |
|
397 fun lookup_pred (Const (name, T)) = |
|
398 (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of |
|
399 SOME c => Const (c, pred_type T) |
|
400 | NONE => error ("Function " ^ name ^ " is not inductified")) |
|
401 | lookup_pred (Free (name, T)) = Free (name, T) |
|
402 | lookup_pred _ = error "lookup function is not defined!" |
|
403 |
|
404 fun get_nparams (Const (name, _)) = |
|
405 the_default 0 (try (ind_package_get_nparams thy) name) |
|
406 | get_nparams (Free _) = 0 |
|
407 | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t)) |
|
408 |
|
409 val intro_t = (Logic.unvarify o prop_of) intro |
|
410 val (prems, concl) = Logic.strip_horn intro_t |
|
411 val frees = map fst (Term.add_frees intro_t []) |
|
412 fun rewrite prem names = |
|
413 let |
|
414 val t = (HOLogic.dest_Trueprop prem) |
|
415 val (lit, mk_lit) = case try HOLogic.dest_not t of |
|
416 SOME t => (t, HOLogic.mk_not) |
|
417 | NONE => (t, I) |
|
418 val (P, args) = (strip_comb lit) |
|
419 in |
|
420 folds_map ( |
|
421 fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)]) |
|
422 else mk_prems thy (lookup_pred, get_nparams) t) args (names, []) |
|
423 |> map (fn (resargs, (names', prems')) => |
|
424 let |
|
425 val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs))) |
|
426 in (prem'::prems', names') end) |
|
427 end |
|
428 val intro_ts' = folds_map rewrite prems frees |
|
429 |> maps (fn (prems', frees') => |
|
430 rewrite concl frees' |
|
431 |> map (fn (concl'::conclprems, _) => |
|
432 Logic.list_implies ((flat prems') @ conclprems, concl'))) |
|
433 in |
|
434 map (Drule.standard o the_oracle () o cterm_of thy) intro_ts' |
|
435 end; |
|
436 |
|
437 end; |
|