|
1 (* Title: HOL/Matrix/cplex/fspmlp.ML |
|
2 Author: Steven Obua |
|
3 *) |
|
4 |
|
5 signature FSPMLP = |
|
6 sig |
|
7 type linprog |
|
8 type vector = FloatSparseMatrixBuilder.vector |
|
9 type matrix = FloatSparseMatrixBuilder.matrix |
|
10 |
|
11 val y : linprog -> term |
|
12 val A : linprog -> term * term |
|
13 val b : linprog -> term |
|
14 val c : linprog -> term * term |
|
15 val r12 : linprog -> term * term |
|
16 |
|
17 exception Load of string |
|
18 |
|
19 val load : string -> int -> bool -> linprog |
|
20 end |
|
21 |
|
22 structure Fspmlp : FSPMLP = |
|
23 struct |
|
24 |
|
25 type vector = FloatSparseMatrixBuilder.vector |
|
26 type matrix = FloatSparseMatrixBuilder.matrix |
|
27 |
|
28 type linprog = term * (term * term) * term * (term * term) * (term * term) |
|
29 |
|
30 fun y (c1, c2, c3, c4, _) = c1 |
|
31 fun A (c1, c2, c3, c4, _) = c2 |
|
32 fun b (c1, c2, c3, c4, _) = c3 |
|
33 fun c (c1, c2, c3, c4, _) = c4 |
|
34 fun r12 (c1, c2, c3, c4, c6) = c6 |
|
35 |
|
36 structure CplexFloatSparseMatrixConverter = |
|
37 MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder); |
|
38 |
|
39 datatype bound_type = LOWER | UPPER |
|
40 |
|
41 fun intbound_ord ((i1: int, b1),(i2,b2)) = |
|
42 if i1 < i2 then LESS |
|
43 else if i1 = i2 then |
|
44 (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER) |
|
45 else GREATER |
|
46 |
|
47 structure Inttab = Table(type key = int val ord = (rev_order o int_ord)); |
|
48 |
|
49 structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord); |
|
50 (* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *) |
|
51 (* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *) |
|
52 |
|
53 exception Internal of string; |
|
54 |
|
55 fun add_row_bound g dest_key row_index row_bound = |
|
56 let |
|
57 val x = |
|
58 case VarGraph.lookup g dest_key of |
|
59 NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty) |
|
60 | SOME (sure_bound, f) => |
|
61 (sure_bound, |
|
62 case Inttab.lookup f row_index of |
|
63 NONE => Inttab.update (row_index, (row_bound, [])) f |
|
64 | SOME _ => raise (Internal "add_row_bound")) |
|
65 in |
|
66 VarGraph.update (dest_key, x) g |
|
67 end |
|
68 |
|
69 fun update_sure_bound g (key as (_, btype)) bound = |
|
70 let |
|
71 val x = |
|
72 case VarGraph.lookup g key of |
|
73 NONE => (SOME bound, Inttab.empty) |
|
74 | SOME (NONE, f) => (SOME bound, f) |
|
75 | SOME (SOME old_bound, f) => |
|
76 (SOME ((case btype of |
|
77 UPPER => Float.min |
|
78 | LOWER => Float.max) |
|
79 old_bound bound), f) |
|
80 in |
|
81 VarGraph.update (key, x) g |
|
82 end |
|
83 |
|
84 fun get_sure_bound g key = |
|
85 case VarGraph.lookup g key of |
|
86 NONE => NONE |
|
87 | SOME (sure_bound, _) => sure_bound |
|
88 |
|
89 (*fun get_row_bound g key row_index = |
|
90 case VarGraph.lookup g key of |
|
91 NONE => NONE |
|
92 | SOME (sure_bound, f) => |
|
93 (case Inttab.lookup f row_index of |
|
94 NONE => NONE |
|
95 | SOME (row_bound, _) => (sure_bound, row_bound))*) |
|
96 |
|
97 fun add_edge g src_key dest_key row_index coeff = |
|
98 case VarGraph.lookup g dest_key of |
|
99 NONE => raise (Internal "add_edge: dest_key not found") |
|
100 | SOME (sure_bound, f) => |
|
101 (case Inttab.lookup f row_index of |
|
102 NONE => raise (Internal "add_edge: row_index not found") |
|
103 | SOME (row_bound, sources) => |
|
104 VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g) |
|
105 |
|
106 fun split_graph g = |
|
107 let |
|
108 fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound |
|
109 of NONE => (r1, r2) |
|
110 | SOME bound => (case key |
|
111 of (u, UPPER) => (r1, Inttab.update (u, bound) r2) |
|
112 | (u, LOWER) => (Inttab.update (u, bound) r1, r2)) |
|
113 in VarGraph.fold split g (Inttab.empty, Inttab.empty) end |
|
114 |
|
115 fun it2list t = Inttab.fold cons t []; |
|
116 |
|
117 (* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm). |
|
118 If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *) |
|
119 fun propagate_sure_bounds safe names g = |
|
120 let |
|
121 (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *) |
|
122 fun calc_sure_bound_from_sources g (key as (_, btype)) = |
|
123 let |
|
124 fun mult_upper x (lower, upper) = |
|
125 if Float.sign x = LESS then |
|
126 Float.mult x lower |
|
127 else |
|
128 Float.mult x upper |
|
129 |
|
130 fun mult_lower x (lower, upper) = |
|
131 if Float.sign x = LESS then |
|
132 Float.mult x upper |
|
133 else |
|
134 Float.mult x lower |
|
135 |
|
136 val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower |
|
137 |
|
138 fun calc_sure_bound (row_index, (row_bound, sources)) sure_bound = |
|
139 let |
|
140 fun add_src_bound (coeff, src_key) sum = |
|
141 case sum of |
|
142 NONE => NONE |
|
143 | SOME x => |
|
144 (case get_sure_bound g src_key of |
|
145 NONE => NONE |
|
146 | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff))) |
|
147 in |
|
148 case fold add_src_bound sources (SOME row_bound) of |
|
149 NONE => sure_bound |
|
150 | new_sure_bound as (SOME new_bound) => |
|
151 (case sure_bound of |
|
152 NONE => new_sure_bound |
|
153 | SOME old_bound => |
|
154 SOME (case btype of |
|
155 UPPER => Float.min old_bound new_bound |
|
156 | LOWER => Float.max old_bound new_bound)) |
|
157 end |
|
158 in |
|
159 case VarGraph.lookup g key of |
|
160 NONE => NONE |
|
161 | SOME (sure_bound, f) => |
|
162 let |
|
163 val x = Inttab.fold calc_sure_bound f sure_bound |
|
164 in |
|
165 if x = sure_bound then NONE else x |
|
166 end |
|
167 end |
|
168 |
|
169 fun propagate (key, _) (g, b) = |
|
170 case calc_sure_bound_from_sources g key of |
|
171 NONE => (g,b) |
|
172 | SOME bound => (update_sure_bound g key bound, |
|
173 if safe then |
|
174 case get_sure_bound g key of |
|
175 NONE => true |
|
176 | _ => b |
|
177 else |
|
178 true) |
|
179 |
|
180 val (g, b) = VarGraph.fold propagate g (g, false) |
|
181 in |
|
182 if b then propagate_sure_bounds safe names g else g |
|
183 end |
|
184 |
|
185 exception Load of string; |
|
186 |
|
187 val empty_spvec = @{term "Nil :: real spvec"}; |
|
188 fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs; |
|
189 val empty_spmat = @{term "Nil :: real spmat"}; |
|
190 fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs; |
|
191 |
|
192 fun calcr safe_propagation xlen names prec A b = |
|
193 let |
|
194 val empty = Inttab.empty |
|
195 |
|
196 fun instab t i x = Inttab.update (i, x) t |
|
197 |
|
198 fun isnegstr x = String.isPrefix "-" x |
|
199 fun negstr x = if isnegstr x then String.extract (x, 1, NONE) else "-"^x |
|
200 |
|
201 fun test_1 (lower, upper) = |
|
202 if lower = upper then |
|
203 (if Float.eq (lower, (~1, 0)) then ~1 |
|
204 else if Float.eq (lower, (1, 0)) then 1 |
|
205 else 0) |
|
206 else 0 |
|
207 |
|
208 fun calcr (row_index, a) g = |
|
209 let |
|
210 val b = FloatSparseMatrixBuilder.v_elem_at b row_index |
|
211 val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b) |
|
212 val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l => |
|
213 (i, FloatArith.approx_decstr_by_bin prec s)::l) a [] |
|
214 |
|
215 fun fold_dest_nodes (dest_index, dest_value) g = |
|
216 let |
|
217 val dest_test = test_1 dest_value |
|
218 in |
|
219 if dest_test = 0 then |
|
220 g |
|
221 else let |
|
222 val (dest_key as (_, dest_btype), row_bound) = |
|
223 if dest_test = ~1 then |
|
224 ((dest_index, LOWER), Float.neg b2) |
|
225 else |
|
226 ((dest_index, UPPER), b2) |
|
227 |
|
228 fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g = |
|
229 if src_index = dest_index then g |
|
230 else |
|
231 let |
|
232 val coeff = case dest_btype of |
|
233 UPPER => (Float.neg src_upper, Float.neg src_lower) |
|
234 | LOWER => src_value |
|
235 in |
|
236 if Float.sign src_lower = LESS then |
|
237 add_edge g (src_index, UPPER) dest_key row_index coeff |
|
238 else |
|
239 add_edge g (src_index, LOWER) dest_key row_index coeff |
|
240 end |
|
241 in |
|
242 fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound) |
|
243 end |
|
244 end |
|
245 in |
|
246 case approx_a of |
|
247 [] => g |
|
248 | [(u, a)] => |
|
249 let |
|
250 val atest = test_1 a |
|
251 in |
|
252 if atest = ~1 then |
|
253 update_sure_bound g (u, LOWER) (Float.neg b2) |
|
254 else if atest = 1 then |
|
255 update_sure_bound g (u, UPPER) b2 |
|
256 else |
|
257 g |
|
258 end |
|
259 | _ => fold fold_dest_nodes approx_a g |
|
260 end |
|
261 |
|
262 val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty |
|
263 |
|
264 val g = propagate_sure_bounds safe_propagation names g |
|
265 |
|
266 val (r1, r2) = split_graph g |
|
267 |
|
268 fun add_row_entry m index f vname value = |
|
269 let |
|
270 val v = (case value of |
|
271 SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value |
|
272 | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT)))) |
|
273 val vec = cons_spvec v empty_spvec |
|
274 in |
|
275 cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m |
|
276 end |
|
277 |
|
278 fun abs_estimate i r1 r2 = |
|
279 if i = 0 then |
|
280 let val e = empty_spmat in (e, e) end |
|
281 else |
|
282 let |
|
283 val index = xlen-i |
|
284 val (r12_1, r12_2) = abs_estimate (i-1) r1 r2 |
|
285 val b1 = Inttab.lookup r1 index |
|
286 val b2 = Inttab.lookup r2 index |
|
287 in |
|
288 (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, |
|
289 add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2) |
|
290 end |
|
291 |
|
292 val (r1, r2) = abs_estimate xlen r1 r2 |
|
293 |
|
294 in |
|
295 (r1, r2) |
|
296 end |
|
297 |
|
298 fun load filename prec safe_propagation = |
|
299 let |
|
300 val prog = Cplex.load_cplexFile filename |
|
301 val prog = Cplex.elim_nonfree_bounds prog |
|
302 val prog = Cplex.relax_strict_ineqs prog |
|
303 val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog |
|
304 val (r1, r2) = calcr safe_propagation xlen names prec A b |
|
305 val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems" |
|
306 val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b |
|
307 val results = Cplex.solve dualprog |
|
308 val (optimal,v) = CplexFloatSparseMatrixConverter.convert_results results indexof |
|
309 (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*) |
|
310 fun id x = x |
|
311 val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v |
|
312 val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b) |
|
313 val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c |
|
314 val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v |
|
315 val A = FloatSparseMatrixBuilder.approx_matrix prec id A |
|
316 val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b |
|
317 val c = FloatSparseMatrixBuilder.approx_matrix prec id c |
|
318 in |
|
319 (y1, A, b2, c, (r1, r2)) |
|
320 end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s))) |
|
321 |
|
322 end |