50259
|
1 |
signature SLEDGEHAMMER_SHRINK =
|
|
2 |
sig
|
|
3 |
type isar_step = Sledgehammer_Isar_Reconstruct.isar_step
|
|
4 |
val shrink_proof :
|
|
5 |
bool -> Proof.context -> string -> string -> bool -> Time.time -> real
|
|
6 |
-> isar_step list -> isar_step list * (bool * Time.time)
|
|
7 |
end
|
|
8 |
|
|
9 |
structure Sledgehammer_Shrink (* : SLEDGEHAMMER_SHRINK *) =
|
|
10 |
struct
|
|
11 |
|
|
12 |
open Sledgehammer_Isar_Reconstruct
|
|
13 |
|
|
14 |
(* Parameters *)
|
|
15 |
val merge_timeout_slack = 1.2
|
|
16 |
|
|
17 |
(* Data structures, orders *)
|
|
18 |
val label_ord = prod_ord int_ord fast_string_ord o pairself swap
|
|
19 |
structure Label_Table = Table(
|
|
20 |
type key = label
|
|
21 |
val ord = label_ord)
|
|
22 |
|
|
23 |
(* Timing *)
|
|
24 |
type ext_time = bool * Time.time
|
|
25 |
fun ext_time_add (b1, t1) (b2, t2) : ext_time = (b1 orelse b2, t1+t2)
|
|
26 |
val no_time = (false, seconds 0.0)
|
|
27 |
fun take_time timeout tac arg =
|
|
28 |
let val timing = Timing.start () in
|
|
29 |
(TimeLimit.timeLimit timeout tac arg;
|
|
30 |
Timing.result timing |> #cpu |> SOME)
|
|
31 |
handle _ => NONE
|
|
32 |
end
|
|
33 |
fun sum_up_time timeout =
|
|
34 |
Vector.foldl
|
|
35 |
((fn (SOME t, (b, ts)) => (b, t+ts)
|
|
36 |
| (NONE, (_, ts)) => (true, ts+timeout)) o apfst Lazy.force)
|
|
37 |
no_time
|
|
38 |
|
|
39 |
(* clean vector interface *)
|
|
40 |
fun get i v = Vector.sub (v, i)
|
|
41 |
fun replace x i v = Vector.update (v, i, x)
|
|
42 |
fun update f i v = replace (get i v |> f) i v
|
|
43 |
fun v_fold_index f v s =
|
|
44 |
Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd
|
|
45 |
|
|
46 |
(* Queue interface to table *)
|
|
47 |
fun pop tab key =
|
|
48 |
let val v = hd (Inttab.lookup_list tab key) in
|
|
49 |
(v, Inttab.remove_list (op =) (key, v) tab)
|
|
50 |
end
|
|
51 |
fun pop_max tab = pop tab (the (Inttab.max_key tab))
|
|
52 |
fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
|
|
53 |
|
|
54 |
(* Main function for shrinking proofs *)
|
|
55 |
fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
|
|
56 |
isar_shrink proof =
|
|
57 |
let
|
|
58 |
fun shrink_top_level top_level ctxt proof =
|
|
59 |
let
|
|
60 |
(* proof vector *)
|
|
61 |
val proof_vect = proof |> map SOME |> Vector.fromList
|
|
62 |
val n = metis_steps_top_level proof
|
|
63 |
val n_target = Real.fromInt n / isar_shrink |> Real.round
|
|
64 |
|
|
65 |
(* table for mapping from (top-level-)label to proof position *)
|
|
66 |
fun update_table (i, Assume (label, _)) =
|
|
67 |
Label_Table.update_new (label, i)
|
|
68 |
| update_table (i, Prove (_, label, _, _)) =
|
|
69 |
Label_Table.update_new (label, i)
|
|
70 |
| update_table _ = I
|
|
71 |
val label_index_table = fold_index update_table proof Label_Table.empty
|
|
72 |
|
|
73 |
(* proof references *)
|
|
74 |
fun refs (Prove (_, _, _, By_Metis (lfs, _))) =
|
|
75 |
maps (the_list o Label_Table.lookup label_index_table) lfs
|
|
76 |
| refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
|
|
77 |
maps (the_list o Label_Table.lookup label_index_table) lfs
|
|
78 |
@ maps (maps refs) cases
|
|
79 |
| refs _ = []
|
|
80 |
val refed_by_vect =
|
|
81 |
Vector.tabulate (Vector.length proof_vect, (fn _ => []))
|
|
82 |
|> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
|
|
83 |
|> Vector.map rev (* after rev, indices are sorted in ascending order *)
|
|
84 |
|
|
85 |
(* candidates for elimination, use table as priority queue (greedy
|
|
86 |
algorithm) *)
|
|
87 |
fun add_if_cand proof_vect (i, [j]) =
|
|
88 |
(case (the (get i proof_vect), the (get j proof_vect)) of
|
|
89 |
(Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
|
|
90 |
cons (Term.size_of_term t, i)
|
|
91 |
| _ => I)
|
|
92 |
| add_if_cand _ _ = I
|
|
93 |
val cand_tab =
|
|
94 |
v_fold_index (add_if_cand proof_vect) refed_by_vect []
|
|
95 |
|> Inttab.make_list
|
|
96 |
|
|
97 |
(* Metis Preplaying *)
|
|
98 |
fun try_metis timeout (Prove (_, _, t, By_Metis fact_names)) =
|
|
99 |
if not preplay then (fn () => SOME (seconds 0.0)) else
|
|
100 |
let
|
|
101 |
val facts =
|
|
102 |
fact_names
|
|
103 |
|>> map string_for_label
|
|
104 |
|> op @
|
|
105 |
|> map (the_single o thms_of_name ctxt) (* FIXME: maps (the o thms_of_name ctxt) *)
|
|
106 |
val goal =
|
|
107 |
Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
|
|
108 |
fun tac {context = ctxt, prems = _} =
|
|
109 |
Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
|
|
110 |
in
|
|
111 |
take_time timeout (fn () => goal tac)
|
|
112 |
end
|
|
113 |
| try_metis _ _ = (fn () => SOME (seconds 0.0) )
|
|
114 |
|
|
115 |
(* Lazy metis time vector, cache *)
|
|
116 |
val metis_time =
|
|
117 |
Vector.map (Lazy.lazy o try_metis preplay_timeout o the) proof_vect
|
|
118 |
|
|
119 |
(* Merging *)
|
|
120 |
fun merge (Prove (qs1, label1, _, By_Metis (lfs1, gfs1)))
|
|
121 |
(Prove (qs2, label2 , t, By_Metis (lfs2, gfs2))) =
|
|
122 |
let
|
|
123 |
val qs = inter (op =) qs1 qs2 (* FIXME: Is this correct? *)
|
|
124 |
|> member (op =) (union (op =) qs1 qs2) Ultimately ? cons Ultimately
|
|
125 |
|> member (op =) qs2 Show ? cons Show
|
|
126 |
val ls = remove (op =) label1 lfs2 |> union (op =) lfs1
|
|
127 |
val ss = union (op =) gfs1 gfs2
|
|
128 |
in Prove (qs, label2, t, By_Metis (ls, ss)) end
|
|
129 |
fun try_merge metis_time (s1, i) (s2, j) =
|
|
130 |
(case get i metis_time |> Lazy.force of
|
|
131 |
NONE => (NONE, metis_time)
|
|
132 |
| SOME t1 =>
|
|
133 |
(case get j metis_time |> Lazy.force of
|
|
134 |
NONE => (NONE, metis_time)
|
|
135 |
| SOME t2 =>
|
|
136 |
let
|
|
137 |
val s12 = merge s1 s2
|
|
138 |
val timeout =
|
|
139 |
Time.+ (t1, t2) |> Time.toReal |> curry Real.* merge_timeout_slack
|
|
140 |
|> Time.fromReal
|
|
141 |
in
|
|
142 |
case try_metis timeout s12 () of
|
|
143 |
NONE => (NONE, metis_time)
|
|
144 |
| some_t12 =>
|
|
145 |
(SOME s12, metis_time
|
|
146 |
|> replace (seconds 0.0 |> SOME |> Lazy.value) i
|
|
147 |
|> replace (Lazy.value some_t12) j)
|
|
148 |
|
|
149 |
end))
|
|
150 |
|
|
151 |
fun merge_steps metis_time proof_vect refed_by cand_tab n' =
|
|
152 |
if Inttab.is_empty cand_tab
|
|
153 |
orelse n' <= n_target
|
|
154 |
orelse (top_level andalso Vector.length proof_vect<3)
|
|
155 |
then
|
|
156 |
(Vector.foldr
|
|
157 |
(fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
|
|
158 |
[] proof_vect,
|
|
159 |
sum_up_time preplay_timeout metis_time)
|
|
160 |
else
|
|
161 |
let
|
|
162 |
val (i, cand_tab) = pop_max cand_tab
|
|
163 |
val j = get i refed_by |> the_single
|
|
164 |
val s1 = get i proof_vect |> the
|
|
165 |
val s2 = get j proof_vect |> the
|
|
166 |
in
|
|
167 |
case try_merge metis_time (s1, i) (s2, j) of
|
|
168 |
(NONE, metis_time) =>
|
|
169 |
merge_steps metis_time proof_vect refed_by cand_tab n'
|
|
170 |
| (s, metis_time) =>
|
|
171 |
let
|
|
172 |
val refs = refs s1
|
|
173 |
val refed_by = refed_by |> fold
|
|
174 |
(update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
|
|
175 |
val new_candidates =
|
|
176 |
fold (add_if_cand proof_vect)
|
|
177 |
(map (fn i => (i, get i refed_by)) refs) []
|
|
178 |
val cand_tab = add_list cand_tab new_candidates
|
|
179 |
val proof_vect = proof_vect |> replace NONE i |> replace s j
|
|
180 |
in
|
|
181 |
merge_steps metis_time proof_vect refed_by cand_tab (n' - 1)
|
|
182 |
end
|
|
183 |
end
|
|
184 |
in
|
|
185 |
merge_steps metis_time proof_vect refed_by_vect cand_tab n
|
|
186 |
end
|
|
187 |
|
|
188 |
fun shrink_proof' top_level ctxt proof =
|
|
189 |
let
|
|
190 |
(* Enrich context with top-level facts *)
|
|
191 |
val thy = Proof_Context.theory_of ctxt
|
|
192 |
fun enrich_ctxt (Assume (label, t)) ctxt =
|
|
193 |
Proof_Context.put_thms false
|
|
194 |
(string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
|
|
195 |
| enrich_ctxt (Prove (_, label, t, _)) ctxt =
|
|
196 |
Proof_Context.put_thms false
|
|
197 |
(string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
|
|
198 |
| enrich_ctxt _ ctxt = ctxt
|
|
199 |
val rich_ctxt = fold enrich_ctxt proof ctxt
|
|
200 |
|
|
201 |
(* Shrink case_splits and top-levl *)
|
|
202 |
val ((proof, top_level_time), lower_level_time) =
|
|
203 |
proof |> shrink_case_splits rich_ctxt
|
|
204 |
|>> shrink_top_level top_level rich_ctxt
|
|
205 |
in
|
|
206 |
(proof, ext_time_add lower_level_time top_level_time)
|
|
207 |
end
|
|
208 |
|
|
209 |
and shrink_case_splits ctxt proof =
|
|
210 |
let
|
|
211 |
fun shrink_and_collect_time shrink candidates =
|
|
212 |
let fun f_m cand time = shrink cand ||> ext_time_add time
|
|
213 |
in fold_map f_m candidates no_time end
|
|
214 |
val shrink_case_split = shrink_and_collect_time (shrink_proof' false ctxt)
|
|
215 |
fun shrink (Prove (qs, lbl, t, Case_Split (cases, facts))) =
|
|
216 |
let val (cases, time) = shrink_case_split cases
|
|
217 |
in (Prove (qs, lbl, t, Case_Split (cases, facts)), time) end
|
|
218 |
| shrink step = (step, no_time)
|
|
219 |
in
|
|
220 |
shrink_and_collect_time shrink proof
|
|
221 |
end
|
|
222 |
in
|
|
223 |
shrink_proof' true ctxt proof
|
|
224 |
end
|
|
225 |
|
|
226 |
end
|