1 (* Title: HOL/Tools/exhaustive_generators.ML |
|
2 Author: Lukas Bulwahn, TU Muenchen |
|
3 |
|
4 Exhaustive generators for various types. |
|
5 *) |
|
6 |
|
7 signature EXHAUSTIVE_GENERATORS = |
|
8 sig |
|
9 val compile_generator_expr: |
|
10 Proof.context -> term -> int -> term list option * Quickcheck.report option |
|
11 val compile_generator_exprs: |
|
12 Proof.context -> term list -> (int -> term list option) list |
|
13 val put_counterexample: (unit -> int -> term list option) |
|
14 -> Proof.context -> Proof.context |
|
15 val put_counterexample_batch: (unit -> (int -> term list option) list) |
|
16 -> Proof.context -> Proof.context |
|
17 val smart_quantifier : bool Config.T; |
|
18 val quickcheck_pretty : bool Config.T; |
|
19 val setup: theory -> theory |
|
20 end; |
|
21 |
|
22 structure Exhaustive_Generators : EXHAUSTIVE_GENERATORS = |
|
23 struct |
|
24 |
|
25 (* static options *) |
|
26 |
|
27 val define_foundationally = false |
|
28 |
|
29 (* dynamic options *) |
|
30 |
|
31 val (smart_quantifier, setup_smart_quantifier) = |
|
32 Attrib.config_bool "quickcheck_smart_quantifier" (K true) |
|
33 |
|
34 val (quickcheck_pretty, setup_quickcheck_pretty) = |
|
35 Attrib.config_bool "quickcheck_pretty" (K true) |
|
36 |
|
37 (** general term functions **) |
|
38 |
|
39 fun mk_measure f = |
|
40 let |
|
41 val Type ("fun", [T, @{typ nat}]) = fastype_of f |
|
42 in |
|
43 Const (@{const_name Wellfounded.measure}, |
|
44 (T --> @{typ nat}) --> HOLogic.mk_prodT (T, T) --> @{typ bool}) |
|
45 $ f |
|
46 end |
|
47 |
|
48 fun mk_sumcases rT f (Type (@{type_name Sum_Type.sum}, [TL, TR])) = |
|
49 let |
|
50 val lt = mk_sumcases rT f TL |
|
51 val rt = mk_sumcases rT f TR |
|
52 in |
|
53 SumTree.mk_sumcase TL TR rT lt rt |
|
54 end |
|
55 | mk_sumcases _ f T = f T |
|
56 |
|
57 fun mk_undefined T = Const(@{const_name undefined}, T) |
|
58 |
|
59 |
|
60 (** abstract syntax **) |
|
61 |
|
62 fun termifyT T = HOLogic.mk_prodT (T, @{typ "unit => Code_Evaluation.term"}); |
|
63 |
|
64 val size = @{term "i :: code_numeral"} |
|
65 val size_pred = @{term "(i :: code_numeral) - 1"} |
|
66 val size_ge_zero = @{term "(i :: code_numeral) > 0"} |
|
67 fun test_function T = Free ("f", termifyT T --> @{typ "term list option"}) |
|
68 |
|
69 fun mk_none_continuation (x, y) = |
|
70 let |
|
71 val (T as Type(@{type_name "option"}, [T'])) = fastype_of x |
|
72 in |
|
73 Const (@{const_name "Quickcheck_Exhaustive.orelse"}, T --> T --> T) |
|
74 $ x $ y |
|
75 end |
|
76 |
|
77 (** datatypes **) |
|
78 |
|
79 (* constructing exhaustive generator instances on datatypes *) |
|
80 |
|
81 exception FUNCTION_TYPE; |
|
82 val exhaustiveN = "exhaustive"; |
|
83 |
|
84 fun exhaustiveT T = (termifyT T --> @{typ "Code_Evaluation.term list option"}) |
|
85 --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"} |
|
86 |
|
87 fun check_allT T = (termifyT T --> @{typ "Code_Evaluation.term list option"}) |
|
88 --> @{typ "Code_Evaluation.term list option"} |
|
89 |
|
90 fun mk_equations thy descr vs tycos exhaustives (Ts, Us) = |
|
91 let |
|
92 fun mk_call T = |
|
93 let |
|
94 val exhaustive = Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T) |
|
95 in |
|
96 (T, (fn t => exhaustive $ |
|
97 (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"}) |
|
98 $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred)) |
|
99 end |
|
100 fun mk_aux_call fTs (k, _) (tyco, Ts) = |
|
101 let |
|
102 val T = Type (tyco, Ts) |
|
103 val _ = if not (null fTs) then raise FUNCTION_TYPE else () |
|
104 in |
|
105 (T, (fn t => nth exhaustives k $ |
|
106 (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"}) |
|
107 $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred)) |
|
108 end |
|
109 fun mk_consexpr simpleT (c, xs) = |
|
110 let |
|
111 val (Ts, fns) = split_list xs |
|
112 val constr = Const (c, Ts ---> simpleT) |
|
113 val bounds = map (fn x => Bound (2 * x + 1)) (((length xs) - 1) downto 0) |
|
114 val term_bounds = map (fn x => Bound (2 * x)) (((length xs) - 1) downto 0) |
|
115 val Eval_App = Const ("Code_Evaluation.App", HOLogic.termT --> HOLogic.termT --> HOLogic.termT) |
|
116 val Eval_Const = Const ("Code_Evaluation.Const", HOLogic.literalT --> @{typ typerep} --> HOLogic.termT) |
|
117 val term = fold (fn u => fn t => Eval_App $ t $ (u $ @{term "()"})) |
|
118 bounds (Eval_Const $ HOLogic.mk_literal c $ HOLogic.mk_typerep (Ts ---> simpleT)) |
|
119 val start_term = test_function simpleT $ |
|
120 (HOLogic.pair_const simpleT @{typ "unit => Code_Evaluation.term"} |
|
121 $ (list_comb (constr, bounds)) $ absdummy (@{typ unit}, term)) |
|
122 in fold_rev (fn f => fn t => f t) fns start_term end |
|
123 fun mk_rhs exprs = |
|
124 @{term "If :: bool => term list option => term list option => term list option"} |
|
125 $ size_ge_zero $ (foldr1 mk_none_continuation exprs) $ @{term "None :: term list option"} |
|
126 val rhss = |
|
127 Datatype_Aux.interpret_construction descr vs |
|
128 { atyp = mk_call, dtyp = mk_aux_call } |
|
129 |> (map o apfst) Type |
|
130 |> map (fn (T, cs) => map (mk_consexpr T) cs) |
|
131 |> map mk_rhs |
|
132 val lhss = map2 (fn t => fn T => t $ test_function T $ size) exhaustives (Ts @ Us); |
|
133 val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss) |
|
134 in |
|
135 eqs |
|
136 end |
|
137 |
|
138 (* foundational definition with the function package *) |
|
139 |
|
140 val less_int_pred = @{lemma "i > 0 ==> Code_Numeral.nat_of ((i :: code_numeral) - 1) < Code_Numeral.nat_of i" by auto} |
|
141 |
|
142 fun mk_single_measure T = HOLogic.mk_comp (@{term "Code_Numeral.nat_of"}, |
|
143 Const (@{const_name "Product_Type.snd"}, T --> @{typ "code_numeral"})) |
|
144 |
|
145 fun mk_termination_measure T = |
|
146 let |
|
147 val T' = fst (HOLogic.dest_prodT (HOLogic.dest_setT T)) |
|
148 in |
|
149 mk_measure (mk_sumcases @{typ nat} mk_single_measure T') |
|
150 end |
|
151 |
|
152 fun termination_tac ctxt = |
|
153 Function_Relation.relation_tac ctxt mk_termination_measure 1 |
|
154 THEN rtac @{thm wf_measure} 1 |
|
155 THEN (REPEAT_DETERM (Simplifier.asm_full_simp_tac |
|
156 (HOL_basic_ss addsimps [@{thm in_measure}, @{thm o_def}, @{thm snd_conv}, |
|
157 @{thm nat_mono_iff}, less_int_pred] @ @{thms sum.cases}) 1)) |
|
158 |
|
159 fun pat_completeness_auto ctxt = |
|
160 Pat_Completeness.pat_completeness_tac ctxt 1 |
|
161 THEN auto_tac (clasimpset_of ctxt) |
|
162 |
|
163 |
|
164 (* creating the instances *) |
|
165 |
|
166 fun instantiate_exhaustive_datatype config descr vs tycos prfx (names, auxnames) (Ts, Us) thy = |
|
167 let |
|
168 val _ = Datatype_Aux.message config "Creating exhaustive generators ..."; |
|
169 val exhaustivesN = map (prefix (exhaustiveN ^ "_")) (names @ auxnames); |
|
170 in |
|
171 thy |
|
172 |> Class.instantiation (tycos, vs, @{sort exhaustive}) |
|
173 |> (if define_foundationally then |
|
174 let |
|
175 val exhaustives = map2 (fn name => fn T => Free (name, exhaustiveT T)) exhaustivesN (Ts @ Us) |
|
176 val eqs = mk_equations thy descr vs tycos exhaustives (Ts, Us) |
|
177 in |
|
178 Function.add_function |
|
179 (map (fn (name, T) => |
|
180 Syntax.no_syn (Binding.conceal (Binding.name name), SOME (exhaustiveT T))) |
|
181 (exhaustivesN ~~ (Ts @ Us))) |
|
182 (map (pair (apfst Binding.conceal Attrib.empty_binding)) eqs) |
|
183 Function_Common.default_config pat_completeness_auto |
|
184 #> snd |
|
185 #> Local_Theory.restore |
|
186 #> (fn lthy => Function.prove_termination NONE (termination_tac lthy) lthy) |
|
187 #> snd |
|
188 end |
|
189 else |
|
190 fold_map (fn (name, T) => Local_Theory.define |
|
191 ((Binding.conceal (Binding.name name), NoSyn), |
|
192 (apfst Binding.conceal Attrib.empty_binding, mk_undefined (exhaustiveT T))) |
|
193 #> apfst fst) (exhaustivesN ~~ (Ts @ Us)) |
|
194 #> (fn (exhaustives, lthy) => |
|
195 let |
|
196 val eqs_t = mk_equations thy descr vs tycos exhaustives (Ts, Us) |
|
197 val eqs = map (fn eq => Goal.prove lthy ["f", "i"] [] eq |
|
198 (fn _ => Skip_Proof.cheat_tac (ProofContext.theory_of lthy))) eqs_t |
|
199 in |
|
200 fold (fn (name, eq) => Local_Theory.note |
|
201 ((Binding.conceal (Binding.qualify true prfx |
|
202 (Binding.qualify true name (Binding.name "simps"))), |
|
203 Code.add_default_eqn_attrib :: map (Attrib.internal o K) |
|
204 [Simplifier.simp_add, Nitpick_Simps.add]), [eq]) #> snd) (exhaustivesN ~~ eqs) lthy |
|
205 end)) |
|
206 |> Class.prove_instantiation_exit (K (Class.intro_classes_tac [])) |
|
207 end handle FUNCTION_TYPE => |
|
208 (Datatype_Aux.message config |
|
209 "Creation of exhaustivevalue generators failed because the datatype contains a function type"; |
|
210 thy) |
|
211 |
|
212 (** building and compiling generator expressions **) |
|
213 |
|
214 structure Counterexample = Proof_Data |
|
215 ( |
|
216 type T = unit -> int -> term list option |
|
217 (* FIXME avoid user error with non-user text *) |
|
218 fun init _ () = error "Counterexample" |
|
219 ); |
|
220 val put_counterexample = Counterexample.put; |
|
221 |
|
222 structure Counterexample_Batch = Proof_Data |
|
223 ( |
|
224 type T = unit -> (int -> term list option) list |
|
225 (* FIXME avoid user error with non-user text *) |
|
226 fun init _ () = error "Counterexample" |
|
227 ); |
|
228 val put_counterexample_batch = Counterexample_Batch.put; |
|
229 |
|
230 val target = "Quickcheck"; |
|
231 |
|
232 fun mk_smart_generator_expr ctxt t = |
|
233 let |
|
234 val thy = ProofContext.theory_of ctxt |
|
235 val ((vnames, Ts), t') = apfst split_list (strip_abs t) |
|
236 val ([depth_name], ctxt') = Variable.variant_fixes ["depth"] ctxt |
|
237 val (names, ctxt'') = Variable.variant_fixes vnames ctxt' |
|
238 val (term_names, ctxt''') = Variable.variant_fixes (map (prefix "t_") vnames) ctxt'' |
|
239 val depth = Free (depth_name, @{typ code_numeral}) |
|
240 val frees = map2 (curry Free) names Ts |
|
241 val term_vars = map (fn n => Free (n, @{typ "unit => term"})) term_names |
|
242 fun strip_imp (Const(@{const_name HOL.implies},_) $ A $ B) = apfst (cons A) (strip_imp B) |
|
243 | strip_imp A = ([], A) |
|
244 val (assms, concl) = strip_imp (subst_bounds (rev frees, t')) |
|
245 val terms = HOLogic.mk_list @{typ term} (map (fn v => v $ @{term "()"}) term_vars) |
|
246 fun mk_exhaustive_closure (free as Free (_, T), term_var) t = |
|
247 if Sign.of_sort thy (T, @{sort enum}) then |
|
248 Const (@{const_name "Quickcheck_Exhaustive.check_all_class.check_all"}, check_allT T) |
|
249 $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) |
|
250 $ lambda free (lambda term_var t)) |
|
251 else |
|
252 Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T) |
|
253 $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) |
|
254 $ lambda free (lambda term_var t)) $ depth |
|
255 fun lookup v = the (AList.lookup (op =) (names ~~ (frees ~~ term_vars)) v) |
|
256 val none_t = @{term "None :: term list option"} |
|
257 fun mk_safe_if (cond, then_t, else_t) = |
|
258 @{term "Quickcheck_Exhaustive.catch_match :: term list option => term list option => term list option"} $ |
|
259 (@{term "If :: bool => term list option => term list option => term list option"} |
|
260 $ cond $ then_t $ else_t) $ none_t; |
|
261 fun mk_test_term bound_vars assms = |
|
262 let |
|
263 fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t []) |
|
264 val (vars, check) = |
|
265 case assms of [] => |
|
266 (vars_of concl, (concl, none_t, @{term "Some :: term list => term list option"} $ terms)) |
|
267 | assm :: assms => |
|
268 (vars_of assm, (assm, mk_test_term (union (op =) (vars_of assm) bound_vars) assms, none_t)) |
|
269 in |
|
270 fold_rev mk_exhaustive_closure (map lookup vars) (mk_safe_if check) |
|
271 end |
|
272 in lambda depth (mk_test_term [] assms) end |
|
273 |
|
274 fun mk_generator_expr ctxt t = |
|
275 let |
|
276 val Ts = (map snd o fst o strip_abs) t; |
|
277 val thy = ProofContext.theory_of ctxt |
|
278 val bound_max = length Ts - 1; |
|
279 val bounds = map_index (fn (i, ty) => |
|
280 (2 * (bound_max - i) + 1, 2 * (bound_max - i), 2 * i, ty)) Ts; |
|
281 val result = list_comb (t, map (fn (i, _, _, _) => Bound i) bounds); |
|
282 val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds); |
|
283 val check = |
|
284 @{term "Quickcheck_Exhaustive.catch_match :: term list option => term list option => term list option"} $ |
|
285 (@{term "If :: bool => term list option => term list option => term list option"} |
|
286 $ result $ @{term "None :: term list option"} $ (@{term "Some :: term list => term list option"} $ terms)) |
|
287 $ @{term "None :: term list option"}; |
|
288 fun mk_exhaustive_closure (_, _, i, T) t = |
|
289 Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T) |
|
290 $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) |
|
291 $ absdummy (T, absdummy (@{typ "unit => term"}, t))) $ Bound i |
|
292 in Abs ("d", @{typ code_numeral}, fold_rev mk_exhaustive_closure bounds check) end |
|
293 |
|
294 (** post-processing of function terms **) |
|
295 |
|
296 fun dest_fun_upd (Const (@{const_name fun_upd}, _) $ t0 $ t1 $ t2) = (t0, (t1, t2)) |
|
297 | dest_fun_upd t = raise TERM ("dest_fun_upd", [t]) |
|
298 |
|
299 fun mk_fun_upd T1 T2 (t1, t2) t = |
|
300 Const (@{const_name fun_upd}, (T1 --> T2) --> T1 --> T2 --> T1 --> T2) $ t $ t1 $ t2 |
|
301 |
|
302 fun dest_fun_upds t = |
|
303 case try dest_fun_upd t of |
|
304 NONE => |
|
305 (case t of |
|
306 Abs (_, _, _) => ([], t) |
|
307 | _ => raise TERM ("dest_fun_upds", [t])) |
|
308 | SOME (t0, (t1, t2)) => apfst (cons (t1, t2)) (dest_fun_upds t0) |
|
309 |
|
310 fun make_fun_upds T1 T2 (tps, t) = fold_rev (mk_fun_upd T1 T2) tps t |
|
311 |
|
312 fun make_set T1 [] = Const (@{const_abbrev Set.empty}, T1 --> @{typ bool}) |
|
313 | make_set T1 ((_, @{const False}) :: tps) = make_set T1 tps |
|
314 | make_set T1 ((t1, @{const True}) :: tps) = |
|
315 Const (@{const_name insert}, T1 --> (T1 --> @{typ bool}) --> T1 --> @{typ bool}) |
|
316 $ t1 $ (make_set T1 tps) |
|
317 | make_set T1 ((_, t) :: tps) = raise TERM ("make_set", [t]) |
|
318 |
|
319 fun make_coset T [] = Const (@{const_abbrev UNIV}, T --> @{typ bool}) |
|
320 | make_coset T tps = |
|
321 let |
|
322 val U = T --> @{typ bool} |
|
323 fun invert @{const False} = @{const True} |
|
324 | invert @{const True} = @{const False} |
|
325 in |
|
326 Const (@{const_name "Groups.minus_class.minus"}, U --> U --> U) |
|
327 $ Const (@{const_abbrev UNIV}, U) $ make_set T (map (apsnd invert) tps) |
|
328 end |
|
329 |
|
330 fun make_map T1 T2 [] = Const (@{const_abbrev Map.empty}, T1 --> T2) |
|
331 | make_map T1 T2 ((_, Const (@{const_name None}, _)) :: tps) = make_map T1 T2 tps |
|
332 | make_map T1 T2 ((t1, t2) :: tps) = mk_fun_upd T1 T2 (t1, t2) (make_map T1 T2 tps) |
|
333 |
|
334 fun post_process_term t = |
|
335 let |
|
336 fun map_Abs f t = |
|
337 case t of Abs (x, T, t') => Abs (x, T, f t') | _ => raise TERM ("map_Abs", [t]) |
|
338 fun process_args t = case strip_comb t of |
|
339 (c as Const (_, _), ts) => list_comb (c, map post_process_term ts) |
|
340 in |
|
341 case fastype_of t of |
|
342 Type (@{type_name fun}, [T1, T2]) => |
|
343 (case try dest_fun_upds t of |
|
344 SOME (tps, t) => |
|
345 (map (pairself post_process_term) tps, map_Abs post_process_term t) |
|
346 |> (case T2 of |
|
347 @{typ bool} => |
|
348 (case t of |
|
349 Abs(_, _, @{const True}) => fst #> rev #> make_set T1 |
|
350 | Abs(_, _, @{const False}) => fst #> rev #> make_coset T1 |
|
351 | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> rev #> make_set T1 |
|
352 | _ => raise TERM ("post_process_term", [t])) |
|
353 | Type (@{type_name option}, _) => |
|
354 (case t of |
|
355 Abs(_, _, Const(@{const_name None}, _)) => fst #> make_map T1 T2 |
|
356 | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> make_map T1 T2 |
|
357 | _ => make_fun_upds T1 T2) |
|
358 | _ => make_fun_upds T1 T2) |
|
359 | NONE => process_args t) |
|
360 | _ => process_args t |
|
361 end |
|
362 |
|
363 (** generator compiliation **) |
|
364 |
|
365 fun compile_generator_expr ctxt t = |
|
366 let |
|
367 val thy = ProofContext.theory_of ctxt |
|
368 val t' = |
|
369 (if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr) |
|
370 ctxt t; |
|
371 val compile = Code_Runtime.dynamic_value_strict |
|
372 (Counterexample.get, put_counterexample, "Exhaustive_Generators.put_counterexample") |
|
373 thy (SOME target) (fn proc => fn g => g #> (Option.map o map) proc) t' []; |
|
374 in |
|
375 fn size => rpair NONE (compile size |> |
|
376 (if Config.get ctxt quickcheck_pretty then Option.map (map post_process_term) else I)) |
|
377 end; |
|
378 |
|
379 fun compile_generator_exprs ctxt ts = |
|
380 let |
|
381 val thy = ProofContext.theory_of ctxt |
|
382 val mk_generator_expr = |
|
383 if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr |
|
384 val ts' = map (mk_generator_expr ctxt) ts; |
|
385 val compiles = Code_Runtime.dynamic_value_strict |
|
386 (Counterexample_Batch.get, put_counterexample_batch, |
|
387 "Exhaustive_Generators.put_counterexample_batch") |
|
388 thy (SOME target) (fn proc => map (fn g => g #> (Option.map o map) proc)) |
|
389 (HOLogic.mk_list @{typ "code_numeral => term list option"} ts') []; |
|
390 in |
|
391 map (fn compile => fn size => compile size |> Option.map (map post_process_term)) compiles |
|
392 end; |
|
393 |
|
394 |
|
395 (** setup **) |
|
396 |
|
397 val setup = |
|
398 Datatype.interpretation |
|
399 (Quickcheck_Generators.ensure_sort_datatype (@{sort exhaustive}, instantiate_exhaustive_datatype)) |
|
400 #> setup_smart_quantifier |
|
401 #> setup_quickcheck_pretty |
|
402 #> Context.theory_map (Quickcheck.add_generator ("exhaustive", compile_generator_expr)) |
|
403 #> Context.theory_map (Quickcheck.add_batch_generator ("exhaustive", compile_generator_exprs)); |
|
404 |
|
405 end; |
|