|
1 (* Title: HOL/Tools/function_package/scnp_reconstruct.ML |
|
2 Author: Armin Heller, TU Muenchen |
|
3 Author: Alexander Krauss, TU Muenchen |
|
4 |
|
5 Proof reconstruction for SCNP |
|
6 *) |
|
7 |
|
8 signature SCNP_RECONSTRUCT = |
|
9 sig |
|
10 |
|
11 val decomp_scnp : ScnpSolve.label list -> Proof.context -> method |
|
12 |
|
13 val setup : theory -> theory |
|
14 |
|
15 datatype multiset_setup = |
|
16 Multiset of |
|
17 { |
|
18 msetT : typ -> typ, |
|
19 mk_mset : typ -> term list -> term, |
|
20 mset_regroup_conv : int list -> conv, |
|
21 mset_member_tac : int -> int -> tactic, |
|
22 mset_nonempty_tac : int -> tactic, |
|
23 mset_pwleq_tac : int -> tactic, |
|
24 set_of_simps : thm list, |
|
25 smsI' : thm, |
|
26 wmsI2'' : thm, |
|
27 wmsI1 : thm, |
|
28 reduction_pair : thm |
|
29 } |
|
30 |
|
31 |
|
32 val multiset_setup : multiset_setup -> theory -> theory |
|
33 |
|
34 end |
|
35 |
|
36 structure ScnpReconstruct : SCNP_RECONSTRUCT = |
|
37 struct |
|
38 |
|
39 val PROFILE = FundefCommon.PROFILE |
|
40 fun TRACE x = if ! FundefCommon.profile then Output.tracing x else () |
|
41 |
|
42 open ScnpSolve |
|
43 |
|
44 val natT = HOLogic.natT |
|
45 val nat_pairT = HOLogic.mk_prodT (natT, natT) |
|
46 |
|
47 (* Theory dependencies *) |
|
48 |
|
49 datatype multiset_setup = |
|
50 Multiset of |
|
51 { |
|
52 msetT : typ -> typ, |
|
53 mk_mset : typ -> term list -> term, |
|
54 mset_regroup_conv : int list -> conv, |
|
55 mset_member_tac : int -> int -> tactic, |
|
56 mset_nonempty_tac : int -> tactic, |
|
57 mset_pwleq_tac : int -> tactic, |
|
58 set_of_simps : thm list, |
|
59 smsI' : thm, |
|
60 wmsI2'' : thm, |
|
61 wmsI1 : thm, |
|
62 reduction_pair : thm |
|
63 } |
|
64 |
|
65 structure MultisetSetup = TheoryDataFun |
|
66 ( |
|
67 type T = multiset_setup option |
|
68 val empty = NONE |
|
69 val copy = I; |
|
70 val extend = I; |
|
71 fun merge _ (v1, v2) = if is_some v2 then v2 else v1 |
|
72 ) |
|
73 |
|
74 val multiset_setup = MultisetSetup.put o SOME |
|
75 |
|
76 fun undef x = error "undef" |
|
77 fun get_multiset_setup thy = MultisetSetup.get thy |
|
78 |> the_default (Multiset |
|
79 { msetT = undef, mk_mset=undef, |
|
80 mset_regroup_conv=undef, mset_member_tac = undef, |
|
81 mset_nonempty_tac = undef, mset_pwleq_tac = undef, |
|
82 set_of_simps = [],reduction_pair = refl, |
|
83 smsI'=refl, wmsI2''=refl, wmsI1=refl }) |
|
84 |
|
85 fun order_rpair _ MAX = @{thm max_rpair_set} |
|
86 | order_rpair msrp MS = msrp |
|
87 | order_rpair _ MIN = @{thm min_rpair_set} |
|
88 |
|
89 fun ord_intros_max true = |
|
90 (@{thm smax_emptyI}, @{thm smax_insertI}) |
|
91 | ord_intros_max false = |
|
92 (@{thm wmax_emptyI}, @{thm wmax_insertI}) |
|
93 fun ord_intros_min true = |
|
94 (@{thm smin_emptyI}, @{thm smin_insertI}) |
|
95 | ord_intros_min false = |
|
96 (@{thm wmin_emptyI}, @{thm wmin_insertI}) |
|
97 |
|
98 fun gen_probl D cs = |
|
99 let |
|
100 val n = Termination.get_num_points D |
|
101 val arity = length o Termination.get_measures D |
|
102 fun measure p i = nth (Termination.get_measures D p) i |
|
103 |
|
104 fun mk_graph c = |
|
105 let |
|
106 val (_, p, _, q, _, _) = Termination.dest_call D c |
|
107 |
|
108 fun add_edge i j = |
|
109 case Termination.get_descent D c (measure p i) (measure q j) |
|
110 of SOME (Termination.Less _) => cons (i, GTR, j) |
|
111 | SOME (Termination.LessEq _) => cons (i, GEQ, j) |
|
112 | _ => I |
|
113 |
|
114 val edges = |
|
115 fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) [] |
|
116 in |
|
117 G (p, q, edges) |
|
118 end |
|
119 in |
|
120 GP (map arity (0 upto n - 1), map mk_graph cs) |
|
121 end |
|
122 |
|
123 (* General reduction pair application *) |
|
124 fun rem_inv_img ctxt = |
|
125 let |
|
126 val unfold_tac = LocalDefs.unfold_tac ctxt |
|
127 in |
|
128 rtac @{thm subsetI} 1 |
|
129 THEN etac @{thm CollectE} 1 |
|
130 THEN REPEAT (etac @{thm exE} 1) |
|
131 THEN unfold_tac @{thms inv_image_def} |
|
132 THEN rtac @{thm CollectI} 1 |
|
133 THEN etac @{thm conjE} 1 |
|
134 THEN etac @{thm ssubst} 1 |
|
135 THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality} |
|
136 @ @{thms Sum_Type.sum_cases}) |
|
137 end |
|
138 |
|
139 (* Sets *) |
|
140 |
|
141 val setT = HOLogic.mk_setT |
|
142 |
|
143 fun mk_set T [] = Const (@{const_name "{}"}, setT T) |
|
144 | mk_set T (x :: xs) = |
|
145 Const (@{const_name insert}, T --> setT T --> setT T) $ |
|
146 x $ mk_set T xs |
|
147 |
|
148 fun set_member_tac m i = |
|
149 if m = 0 then rtac @{thm insertI1} i |
|
150 else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i |
|
151 |
|
152 val set_nonempty_tac = rtac @{thm insert_not_empty} |
|
153 |
|
154 fun set_finite_tac i = |
|
155 rtac @{thm finite.emptyI} i |
|
156 ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st)) |
|
157 |
|
158 |
|
159 (* Reconstruction *) |
|
160 |
|
161 fun reconstruct_tac ctxt D cs (gp as GP (_, gs)) certificate = |
|
162 let |
|
163 val thy = ProofContext.theory_of ctxt |
|
164 val Multiset |
|
165 { msetT, mk_mset, |
|
166 mset_regroup_conv, mset_member_tac, |
|
167 mset_nonempty_tac, mset_pwleq_tac, set_of_simps, |
|
168 smsI', wmsI2'', wmsI1, reduction_pair=ms_rp } |
|
169 = get_multiset_setup thy |
|
170 |
|
171 fun measure_fn p = nth (Termination.get_measures D p) |
|
172 |
|
173 fun get_desc_thm cidx m1 m2 bStrict = |
|
174 case Termination.get_descent D (nth cs cidx) m1 m2 |
|
175 of SOME (Termination.Less thm) => |
|
176 if bStrict then thm |
|
177 else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le})) |
|
178 | SOME (Termination.LessEq (thm, _)) => |
|
179 if not bStrict then thm |
|
180 else sys_error "get_desc_thm" |
|
181 | _ => sys_error "get_desc_thm" |
|
182 |
|
183 val (label, lev, sl, covering) = certificate |
|
184 |
|
185 fun prove_lev strict g = |
|
186 let |
|
187 val G (p, q, el) = nth gs g |
|
188 |
|
189 fun less_proof strict (j, b) (i, a) = |
|
190 let |
|
191 val tag_flag = b < a orelse (not strict andalso b <= a) |
|
192 |
|
193 val stored_thm = |
|
194 get_desc_thm g (measure_fn p i) (measure_fn q j) |
|
195 (not tag_flag) |
|
196 |> Conv.fconv_rule (Thm.beta_conversion true) |
|
197 |
|
198 val rule = if strict |
|
199 then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1} |
|
200 else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1} |
|
201 in |
|
202 rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm) |
|
203 THEN (if tag_flag then arith_tac ctxt 1 else all_tac) |
|
204 end |
|
205 |
|
206 fun steps_tac MAX strict lq lp = |
|
207 let |
|
208 val (empty, step) = ord_intros_max strict |
|
209 in |
|
210 if length lq = 0 |
|
211 then rtac empty 1 THEN set_finite_tac 1 |
|
212 THEN (if strict then set_nonempty_tac 1 else all_tac) |
|
213 else |
|
214 let |
|
215 val (j, b) :: rest = lq |
|
216 val (i, a) = the (covering g strict j) |
|
217 fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1 |
|
218 val solve_tac = choose lp THEN less_proof strict (j, b) (i, a) |
|
219 in |
|
220 rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp |
|
221 end |
|
222 end |
|
223 | steps_tac MIN strict lq lp = |
|
224 let |
|
225 val (empty, step) = ord_intros_min strict |
|
226 in |
|
227 if length lp = 0 |
|
228 then rtac empty 1 |
|
229 THEN (if strict then set_nonempty_tac 1 else all_tac) |
|
230 else |
|
231 let |
|
232 val (i, a) :: rest = lp |
|
233 val (j, b) = the (covering g strict i) |
|
234 fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1 |
|
235 val solve_tac = choose lq THEN less_proof strict (j, b) (i, a) |
|
236 in |
|
237 rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest |
|
238 end |
|
239 end |
|
240 | steps_tac MS strict lq lp = |
|
241 let |
|
242 fun get_str_cover (j, b) = |
|
243 if is_some (covering g true j) then SOME (j, b) else NONE |
|
244 fun get_wk_cover (j, b) = the (covering g false j) |
|
245 |
|
246 val qs = lq \\ map_filter get_str_cover lq |
|
247 val ps = map get_wk_cover qs |
|
248 |
|
249 fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys |
|
250 val iqs = indices lq qs |
|
251 val ips = indices lp ps |
|
252 |
|
253 local open Conv in |
|
254 fun t_conv a C = |
|
255 params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt |
|
256 val goal_rewrite = |
|
257 t_conv arg1_conv (mset_regroup_conv iqs) |
|
258 then_conv t_conv arg_conv (mset_regroup_conv ips) |
|
259 end |
|
260 in |
|
261 CONVERSION goal_rewrite 1 |
|
262 THEN (if strict then rtac smsI' 1 |
|
263 else if qs = lq then rtac wmsI2'' 1 |
|
264 else rtac wmsI1 1) |
|
265 THEN mset_pwleq_tac 1 |
|
266 THEN EVERY (map2 (less_proof false) qs ps) |
|
267 THEN (if strict orelse qs <> lq |
|
268 then LocalDefs.unfold_tac ctxt set_of_simps |
|
269 THEN steps_tac MAX true (lq \\ qs) (lp \\ ps) |
|
270 else all_tac) |
|
271 end |
|
272 in |
|
273 rem_inv_img ctxt |
|
274 THEN steps_tac label strict (nth lev q) (nth lev p) |
|
275 end |
|
276 |
|
277 val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (mk_set, setT) |
|
278 |
|
279 fun tag_pair p (i, tag) = |
|
280 HOLogic.pair_const natT natT $ |
|
281 (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag |
|
282 |
|
283 fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p, |
|
284 mk_set nat_pairT (map (tag_pair p) lm)) |
|
285 |
|
286 val level_mapping = |
|
287 map_index pt_lev lev |
|
288 |> Termination.mk_sumcases D (setT nat_pairT) |
|
289 |> cterm_of thy |
|
290 in |
|
291 PROFILE "Proof Reconstruction" |
|
292 (CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv sl))) 1 |
|
293 THEN (rtac @{thm reduction_pair_lemma} 1) |
|
294 THEN (rtac @{thm rp_inv_image_rp} 1) |
|
295 THEN (rtac (order_rpair ms_rp label) 1) |
|
296 THEN PRIMITIVE (instantiate' [] [SOME level_mapping]) |
|
297 THEN unfold_tac @{thms rp_inv_image_def} (simpset_of thy) |
|
298 THEN LocalDefs.unfold_tac ctxt |
|
299 (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv}) |
|
300 THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}])) |
|
301 THEN EVERY (map (prove_lev true) sl) |
|
302 THEN EVERY (map (prove_lev false) ((0 upto length cs - 1) \\ sl))) |
|
303 end |
|
304 |
|
305 |
|
306 |
|
307 local open Termination in |
|
308 fun print_cell (SOME (Less _)) = "<" |
|
309 | print_cell (SOME (LessEq _)) = "\<le>" |
|
310 | print_cell (SOME (None _)) = "-" |
|
311 | print_cell (SOME (False _)) = "-" |
|
312 | print_cell (NONE) = "?" |
|
313 |
|
314 fun print_error ctxt D = CALLS (fn (cs, i) => |
|
315 let |
|
316 val np = get_num_points D |
|
317 val ms = map (get_measures D) (0 upto np - 1) |
|
318 val tys = map (get_types D) (0 upto np - 1) |
|
319 fun index xs = (1 upto length xs) ~~ xs |
|
320 fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs |
|
321 val ims = index (map index ms) |
|
322 val _ = Output.tracing (concat (outp "fn #" ":\n" (concat o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims)) |
|
323 fun print_call (k, c) = |
|
324 let |
|
325 val (_, p, _, q, _, _) = dest_call D c |
|
326 val _ = Output.tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ |
|
327 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1)) |
|
328 val caller_ms = nth ms p |
|
329 val callee_ms = nth ms q |
|
330 val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms) |
|
331 fun print_ln (i : int, l) = concat (Int.toString i :: " " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l) |
|
332 val _ = Output.tracing (concat (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ |
|
333 " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" |
|
334 ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries))) |
|
335 in |
|
336 true |
|
337 end |
|
338 fun list_call (k, c) = |
|
339 let |
|
340 val (_, p, _, q, _, _) = dest_call D c |
|
341 val _ = Output.tracing ("call #" ^ (Int.toString k) ^ ": fn " ^ |
|
342 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ |
|
343 (Syntax.string_of_term ctxt c)) |
|
344 in true end |
|
345 val _ = forall list_call ((1 upto length cs) ~~ cs) |
|
346 val _ = forall print_call ((1 upto length cs) ~~ cs) |
|
347 in |
|
348 all_tac |
|
349 end) |
|
350 end |
|
351 |
|
352 |
|
353 fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) => |
|
354 let |
|
355 val gp = gen_probl D cs |
|
356 (* val _ = TRACE ("SCNP instance: " ^ makestring gp)*) |
|
357 val certificate = generate_certificate use_tags orders gp |
|
358 (* val _ = TRACE ("Certificate: " ^ makestring certificate)*) |
|
359 |
|
360 val ms_configured = is_some (MultisetSetup.get (ProofContext.theory_of ctxt)) |
|
361 in |
|
362 case certificate |
|
363 of NONE => err_cont D i |
|
364 | SOME cert => |
|
365 if not ms_configured andalso #1 cert = MS |
|
366 then err_cont D i |
|
367 else SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i |
|
368 THEN (rtac @{thm wf_empty} i ORELSE cont D i) |
|
369 end) |
|
370 |
|
371 fun decomp_scnp_tac orders autom_tac ctxt err_cont = |
|
372 let |
|
373 open Termination |
|
374 val derive_diag = Descent.derive_diag ctxt autom_tac |
|
375 val derive_all = Descent.derive_all ctxt autom_tac |
|
376 val decompose = Decompose.decompose_tac ctxt autom_tac |
|
377 val scnp_no_tags = single_scnp_tac false orders ctxt |
|
378 val scnp_full = single_scnp_tac true orders ctxt |
|
379 |
|
380 fun first_round c e = |
|
381 derive_diag (REPEAT scnp_no_tags c e) |
|
382 |
|
383 val second_round = |
|
384 REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e) |
|
385 |
|
386 val third_round = |
|
387 derive_all oo |
|
388 REPEAT (fn c => fn e => |
|
389 scnp_full (decompose c c) e) |
|
390 |
|
391 fun Then s1 s2 c e = s1 (s2 c c) (s2 c e) |
|
392 |
|
393 val strategy = Then (Then first_round second_round) third_round |
|
394 |
|
395 in |
|
396 TERMINATION ctxt (strategy err_cont err_cont) |
|
397 end |
|
398 |
|
399 fun decomp_scnp orders ctxt = |
|
400 let |
|
401 val extra_simps = FundefCommon.TerminationSimps.get ctxt |
|
402 val autom_tac = auto_tac (local_clasimpset_of ctxt addsimps2 extra_simps) |
|
403 in |
|
404 Method.SIMPLE_METHOD |
|
405 (TRY (FundefCommon.apply_termination_rule ctxt 1) |
|
406 THEN TRY Termination.wf_union_tac |
|
407 THEN |
|
408 (rtac @{thm wf_empty} 1 |
|
409 ORELSE decomp_scnp_tac orders autom_tac ctxt (print_error ctxt) 1)) |
|
410 end |
|
411 |
|
412 |
|
413 (* Method setup *) |
|
414 |
|
415 val orders = |
|
416 (Scan.repeat1 |
|
417 ((Args.$$$ "max" >> K MAX) || |
|
418 (Args.$$$ "min" >> K MIN) || |
|
419 (Args.$$$ "ms" >> K MS)) |
|
420 || Scan.succeed [MAX, MS, MIN]) |
|
421 |
|
422 val setup = Method.add_method |
|
423 ("sizechange", Method.sectioned_args (Scan.lift orders) clasimp_modifiers decomp_scnp, |
|
424 "termination prover with graph decomposition and the NP subset of size change termination") |
|
425 |
|
426 end |