1 (* Title: HOL/Tools/Sledgehammer/sledgehammer_shrink.ML |
|
2 Author: Jasmin Blanchette, TU Muenchen |
|
3 Author: Steffen Juilf Smolka, TU Muenchen |
|
4 |
|
5 Shrinking of reconstructed isar proofs. |
|
6 *) |
|
7 |
|
8 signature SLEDGEHAMMER_SHRINK = |
|
9 sig |
|
10 type isar_step = Sledgehammer_Proof.isar_step |
|
11 type preplay_time = Sledgehammer_Preplay.preplay_time |
|
12 val shrink_proof : |
|
13 bool -> Proof.context -> string -> string -> bool -> Time.time option |
|
14 -> real -> isar_step list -> isar_step list * (bool * preplay_time) |
|
15 end |
|
16 |
|
17 structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK = |
|
18 struct |
|
19 |
|
20 open Sledgehammer_Util |
|
21 open Sledgehammer_Proof |
|
22 open Sledgehammer_Preplay |
|
23 |
|
24 (* Parameters *) |
|
25 val merge_timeout_slack = 1.2 |
|
26 |
|
27 (* Data structures, orders *) |
|
28 val label_ord = prod_ord int_ord fast_string_ord o pairself swap |
|
29 structure Label_Table = Table( |
|
30 type key = label |
|
31 val ord = label_ord) |
|
32 |
|
33 (* clean vector interface *) |
|
34 fun get i v = Vector.sub (v, i) |
|
35 fun replace x i v = Vector.update (v, i, x) |
|
36 fun update f i v = replace (get i v |> f) i v |
|
37 fun v_map_index f v = Vector.foldr (op::) nil v |> map_index f |> Vector.fromList |
|
38 fun v_fold_index f v s = |
|
39 Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd |
|
40 |
|
41 (* Queue interface to table *) |
|
42 fun pop tab key = |
|
43 let val v = hd (Inttab.lookup_list tab key) in |
|
44 (v, Inttab.remove_list (op =) (key, v) tab) |
|
45 end |
|
46 fun pop_max tab = pop tab (the (Inttab.max_key tab)) |
|
47 fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab |
|
48 |
|
49 (* Main function for shrinking proofs *) |
|
50 fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout |
|
51 isar_shrink proof = |
|
52 let |
|
53 (* 60 seconds seems like a good interpreation of "no timeout" *) |
|
54 val preplay_timeout = preplay_timeout |> the_default (seconds 60.0) |
|
55 |
|
56 (* handle metis preplay fail *) |
|
57 local |
|
58 open Unsynchronized |
|
59 val metis_fail = ref false |
|
60 in |
|
61 fun handle_metis_fail try_metis () = |
|
62 try_metis () handle exn => |
|
63 (if Exn.is_interrupt exn orelse debug then reraise exn |
|
64 else metis_fail := true; some_preplay_time) |
|
65 fun get_time lazy_time = |
|
66 if !metis_fail andalso not (Lazy.is_finished lazy_time) |
|
67 then some_preplay_time |
|
68 else Lazy.force lazy_time |
|
69 val metis_fail = fn () => !metis_fail |
|
70 end |
|
71 |
|
72 (* Shrink proof on top level - do not shrink case splits *) |
|
73 fun shrink_top_level on_top_level ctxt proof = |
|
74 let |
|
75 (* proof vector *) |
|
76 val proof_vect = proof |> map SOME |> Vector.fromList |
|
77 val n = Vector.length proof_vect |
|
78 val n_metis = metis_steps_top_level proof |
|
79 val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round |
|
80 |
|
81 (* table for mapping from (top-level-)label to proof position *) |
|
82 fun update_table (i, Assume (l, _)) = Label_Table.update_new (l, i) |
|
83 | update_table (i, Obtain (_, _, l, _, _)) = Label_Table.update_new (l, i) |
|
84 | update_table (i, Prove (_, l, _, _)) = Label_Table.update_new (l, i) |
|
85 | update_table _ = I |
|
86 val label_index_table = fold_index update_table proof Label_Table.empty |
|
87 val lookup_indices = map_filter (Label_Table.lookup label_index_table) |
|
88 |
|
89 (* proof references *) |
|
90 fun refs (Obtain (_, _, _, _, By_Metis (lfs, _))) = lookup_indices lfs |
|
91 | refs (Prove (_, _, _, By_Metis (lfs, _))) = lookup_indices lfs |
|
92 | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) = |
|
93 lookup_indices lfs @ maps (maps refs) cases |
|
94 | refs (Prove (_, _, _, Subblock proof)) = maps refs proof |
|
95 | refs _ = [] |
|
96 val refed_by_vect = |
|
97 Vector.tabulate (n, (fn _ => [])) |
|
98 |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof |
|
99 |> Vector.map rev (* after rev, indices are sorted in ascending order *) |
|
100 |
|
101 (* candidates for elimination, use table as priority queue (greedy |
|
102 algorithm) *) |
|
103 fun add_if_cand proof_vect (i, [j]) = |
|
104 (case (the (get i proof_vect), the (get j proof_vect)) of |
|
105 (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) => |
|
106 cons (Term.size_of_term t, i) |
|
107 | (Prove (_, _, t, By_Metis _), Obtain (_, _, _, _, By_Metis _)) => |
|
108 cons (Term.size_of_term t, i) |
|
109 | _ => I) |
|
110 | add_if_cand _ _ = I |
|
111 val cand_tab = |
|
112 v_fold_index (add_if_cand proof_vect) refed_by_vect [] |
|
113 |> Inttab.make_list |
|
114 |
|
115 (* cache metis preplay times in lazy time vector *) |
|
116 val metis_time = |
|
117 v_map_index |
|
118 (if not preplay then K (zero_preplay_time) #> Lazy.value |
|
119 else |
|
120 apsnd the (* step *) |
|
121 #> apfst (fn i => try (get (i-1) #> the) proof_vect) (* succedent *) |
|
122 #> try_metis debug type_enc lam_trans ctxt preplay_timeout |
|
123 #> handle_metis_fail |
|
124 #> Lazy.lazy) |
|
125 proof_vect |
|
126 |
|
127 fun sum_up_time lazy_time_vector = |
|
128 Vector.foldl |
|
129 (apfst get_time #> uncurry add_preplay_time) |
|
130 zero_preplay_time lazy_time_vector |
|
131 |
|
132 (* Merging *) |
|
133 fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1))) step2 = |
|
134 let |
|
135 val (step_constructor, lfs2, gfs2) = |
|
136 (case step2 of |
|
137 (Prove (qs2, label2, t, By_Metis (lfs2, gfs2))) => |
|
138 (fn by => Prove (qs2, label2, t, by), lfs2, gfs2) |
|
139 | (Obtain (qs2, xs, label2, t, By_Metis (lfs2, gfs2))) => |
|
140 (fn by => Obtain (qs2, xs, label2, t, by), lfs2, gfs2) |
|
141 | _ => error "sledgehammer_shrink: unmergeable Isar steps" ) |
|
142 val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1 |
|
143 val gfs = union (op =) gfs1 gfs2 |
|
144 in step_constructor (By_Metis (lfs, gfs)) end |
|
145 | merge _ _ = error "sledgehammer_shrink: unmergeable Isar steps" |
|
146 |
|
147 fun try_merge metis_time (s1, i) (s2, j) = |
|
148 if not preplay then (merge s1 s2 |> SOME, metis_time) |
|
149 else |
|
150 (case get i metis_time |> Lazy.force of |
|
151 (true, _) => (NONE, metis_time) |
|
152 | (_, t1) => |
|
153 (case get j metis_time |> Lazy.force of |
|
154 (true, _) => (NONE, metis_time) |
|
155 | (_, t2) => |
|
156 let |
|
157 val s12 = merge s1 s2 |
|
158 val timeout = time_mult merge_timeout_slack (Time.+(t1, t2)) |
|
159 in |
|
160 case try_metis_quietly debug type_enc lam_trans ctxt timeout |
|
161 (NONE, s12) () of |
|
162 (true, _) => (NONE, metis_time) |
|
163 | exact_time => |
|
164 (SOME s12, metis_time |
|
165 |> replace (zero_preplay_time |> Lazy.value) i |
|
166 |> replace (Lazy.value exact_time) j) |
|
167 |
|
168 end)) |
|
169 |
|
170 fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' = |
|
171 if Inttab.is_empty cand_tab |
|
172 orelse n_metis' <= target_n_metis |
|
173 orelse (on_top_level andalso n'<3) |
|
174 then |
|
175 (Vector.foldr |
|
176 (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof) |
|
177 [] proof_vect, |
|
178 sum_up_time metis_time) |
|
179 else |
|
180 let |
|
181 val (i, cand_tab) = pop_max cand_tab |
|
182 val j = get i refed_by |> the_single |
|
183 val s1 = get i proof_vect |> the |
|
184 val s2 = get j proof_vect |> the |
|
185 in |
|
186 case try_merge metis_time (s1, i) (s2, j) of |
|
187 (NONE, metis_time) => |
|
188 merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' |
|
189 | (s, metis_time) => |
|
190 let |
|
191 val refs = refs s1 |
|
192 val refed_by = refed_by |> fold |
|
193 (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs |
|
194 val new_candidates = |
|
195 fold (add_if_cand proof_vect) |
|
196 (map (fn i => (i, get i refed_by)) refs) [] |
|
197 val cand_tab = add_list cand_tab new_candidates |
|
198 val proof_vect = proof_vect |> replace NONE i |> replace s j |
|
199 in |
|
200 merge_steps metis_time proof_vect refed_by cand_tab (n' - 1) |
|
201 (n_metis' - 1) |
|
202 end |
|
203 end |
|
204 in |
|
205 merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis |
|
206 end |
|
207 |
|
208 fun do_proof on_top_level ctxt proof = |
|
209 let |
|
210 (* Enrich context with top-level facts *) |
|
211 val thy = Proof_Context.theory_of ctxt |
|
212 (* TODO: add Skolem variables to context? *) |
|
213 fun enrich_with_fact l t = |
|
214 Proof_Context.put_thms false |
|
215 (string_for_label l, SOME [Skip_Proof.make_thm thy t]) |
|
216 fun enrich_with_step (Assume (l, t)) = enrich_with_fact l t |
|
217 | enrich_with_step (Obtain (_, _, l, t, _)) = enrich_with_fact l t |
|
218 | enrich_with_step (Prove (_, l, t, _)) = enrich_with_fact l t |
|
219 | enrich_with_step _ = I |
|
220 val rich_ctxt = fold enrich_with_step proof ctxt |
|
221 |
|
222 (* Shrink subproofs (case_splits and subblocks) and top-levl *) |
|
223 val ((proof, top_level_time), lower_level_time) = |
|
224 proof |> do_subproof rich_ctxt |
|
225 |>> shrink_top_level on_top_level rich_ctxt |
|
226 in |
|
227 (proof, add_preplay_time lower_level_time top_level_time) |
|
228 end |
|
229 |
|
230 and do_subproof ctxt proof = |
|
231 let |
|
232 fun shrink_each_and_collect_time shrink candidates = |
|
233 let fun f_m cand time = shrink cand ||> add_preplay_time time |
|
234 in fold_map f_m candidates zero_preplay_time end |
|
235 val shrink_subproof = |
|
236 shrink_each_and_collect_time (do_proof false ctxt) |
|
237 fun shrink (Prove (qs, l, t, Case_Split (cases, facts))) = |
|
238 let val (cases, time) = shrink_subproof cases |
|
239 in (Prove (qs, l, t, Case_Split (cases, facts)), time) end |
|
240 | shrink (Prove (qs, l, t, Subblock proof)) = |
|
241 let val ([proof], time) = shrink_subproof [proof] |
|
242 in (Prove (qs, l, t, Subblock proof), time) end |
|
243 | shrink step = (step, zero_preplay_time) |
|
244 in |
|
245 shrink_each_and_collect_time shrink proof |
|
246 end |
|
247 in |
|
248 do_proof true ctxt proof |
|
249 |> apsnd (pair (metis_fail ())) |
|
250 end |
|
251 |
|
252 end |
|