src/HOL/Matrix/cplex/fspmlp.ML
author wenzelm
Thu Jul 09 22:01:41 2009 +0200 (2009-07-09)
changeset 31971 8c1b845ed105
parent 24630 351a308ab58d
child 32960 69916a850301
permissions -rw-r--r--
renamed functor TableFun to Table, and GraphFun to Graph;
obua@16784
     1
(*  Title:      HOL/Matrix/cplex/fspmlp.ML
obua@16784
     2
    Author:     Steven Obua
obua@16784
     3
*)
obua@16784
     4
haftmann@22951
     5
signature FSPMLP =
obua@16784
     6
sig
obua@16784
     7
    type linprog
obua@19404
     8
    type vector = FloatSparseMatrixBuilder.vector
obua@19404
     9
    type matrix = FloatSparseMatrixBuilder.matrix
obua@16784
    10
haftmann@22964
    11
    val y : linprog -> term
haftmann@22964
    12
    val A : linprog -> term * term
haftmann@22964
    13
    val b : linprog -> term
haftmann@22964
    14
    val c : linprog -> term * term
haftmann@22964
    15
    val r12 : linprog -> term * term
obua@16784
    16
obua@16784
    17
    exception Load of string
haftmann@22951
    18
obua@16784
    19
    val load : string -> int -> bool -> linprog
obua@16784
    20
end
obua@16784
    21
haftmann@22964
    22
structure Fspmlp : FSPMLP =
obua@16784
    23
struct
obua@16784
    24
obua@19404
    25
type vector = FloatSparseMatrixBuilder.vector
obua@19404
    26
type matrix = FloatSparseMatrixBuilder.matrix
obua@19404
    27
obua@24302
    28
type linprog = term * (term * term) * term * (term * term) * (term * term)
obua@16784
    29
obua@24302
    30
fun y (c1, c2, c3, c4, _) = c1
obua@24302
    31
fun A (c1, c2, c3, c4, _) = c2
obua@24302
    32
fun b (c1, c2, c3, c4, _) = c3
obua@24302
    33
fun c (c1, c2, c3, c4, _) = c4
obua@24302
    34
fun r12 (c1, c2, c3, c4, c6) = c6
obua@16784
    35
haftmann@22951
    36
structure CplexFloatSparseMatrixConverter =
obua@16784
    37
MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
obua@16784
    38
obua@16784
    39
datatype bound_type = LOWER | UPPER
obua@16784
    40
wenzelm@24135
    41
fun intbound_ord ((i1: int, b1),(i2,b2)) =
obua@16784
    42
    if i1 < i2 then LESS
haftmann@22951
    43
    else if i1 = i2 then
haftmann@22951
    44
        (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
obua@16784
    45
    else GREATER
obua@16784
    46
wenzelm@31971
    47
structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
obua@16784
    48
wenzelm@31971
    49
structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
obua@16784
    50
(* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
obua@16784
    51
(* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
obua@16784
    52
obua@16784
    53
exception Internal of string;
obua@16784
    54
haftmann@22951
    55
fun add_row_bound g dest_key row_index row_bound =
haftmann@22951
    56
    let
haftmann@22951
    57
        val x =
haftmann@22951
    58
            case VarGraph.lookup g dest_key of
haftmann@22951
    59
                NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
haftmann@22951
    60
              | SOME (sure_bound, f) =>
haftmann@22951
    61
                (sure_bound,
haftmann@22951
    62
                 case Inttab.lookup f row_index of
haftmann@22951
    63
                     NONE => Inttab.update (row_index, (row_bound, [])) f
haftmann@22951
    64
                   | SOME _ => raise (Internal "add_row_bound"))
obua@16784
    65
    in
haftmann@22951
    66
        VarGraph.update (dest_key, x) g
obua@16784
    67
    end
obua@16784
    68
haftmann@22951
    69
fun update_sure_bound g (key as (_, btype)) bound =
haftmann@22951
    70
    let
haftmann@22951
    71
        val x =
haftmann@22951
    72
            case VarGraph.lookup g key of
haftmann@22951
    73
                NONE => (SOME bound, Inttab.empty)
haftmann@22951
    74
              | SOME (NONE, f) => (SOME bound, f)
haftmann@22951
    75
              | SOME (SOME old_bound, f) =>
haftmann@22951
    76
                (SOME ((case btype of
haftmann@22964
    77
                            UPPER => Float.min
haftmann@22964
    78
                          | LOWER => Float.max)
haftmann@22951
    79
                           old_bound bound), f)
haftmann@22951
    80
    in
haftmann@22951
    81
        VarGraph.update (key, x) g
haftmann@22951
    82
    end
haftmann@22951
    83
haftmann@22951
    84
fun get_sure_bound g key =
haftmann@22951
    85
    case VarGraph.lookup g key of
haftmann@22951
    86
        NONE => NONE
obua@16784
    87
      | SOME (sure_bound, _) => sure_bound
obua@16784
    88
haftmann@22951
    89
(*fun get_row_bound g key row_index =
wenzelm@17412
    90
    case VarGraph.lookup g key of
haftmann@22951
    91
        NONE => NONE
obua@16784
    92
      | SOME (sure_bound, f) =>
haftmann@22951
    93
        (case Inttab.lookup f row_index of
haftmann@22951
    94
             NONE => NONE
haftmann@22951
    95
           | SOME (row_bound, _) => (sure_bound, row_bound))*)
haftmann@22951
    96
haftmann@22951
    97
fun add_edge g src_key dest_key row_index coeff =
wenzelm@17412
    98
    case VarGraph.lookup g dest_key of
haftmann@22951
    99
        NONE => raise (Internal "add_edge: dest_key not found")
obua@16784
   100
      | SOME (sure_bound, f) =>
haftmann@22951
   101
        (case Inttab.lookup f row_index of
haftmann@22951
   102
             NONE => raise (Internal "add_edge: row_index not found")
haftmann@22951
   103
           | SOME (row_bound, sources) =>
haftmann@22951
   104
             VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
obua@16784
   105
haftmann@22951
   106
fun split_graph g =
haftmann@21056
   107
  let
haftmann@21056
   108
    fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
haftmann@21056
   109
     of NONE => (r1, r2)
haftmann@21056
   110
      | SOME bound =>  (case key
haftmann@21056
   111
         of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
haftmann@21056
   112
          | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
haftmann@21056
   113
  in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
obua@16784
   114
haftmann@21056
   115
fun it2list t = Inttab.fold cons t [];
obua@16784
   116
obua@16784
   117
(* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
obua@16784
   118
   If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
haftmann@22951
   119
fun propagate_sure_bounds safe names g =
haftmann@22951
   120
    let
haftmann@22951
   121
        (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
haftmann@22951
   122
        fun calc_sure_bound_from_sources g (key as (_, btype)) =
haftmann@22951
   123
            let
haftmann@22951
   124
                fun mult_upper x (lower, upper) =
haftmann@23520
   125
                    if Float.sign x = LESS then
haftmann@22964
   126
                        Float.mult x lower
haftmann@22951
   127
                    else
haftmann@22964
   128
                        Float.mult x upper
obua@16784
   129
haftmann@22951
   130
                fun mult_lower x (lower, upper) =
haftmann@23520
   131
                    if Float.sign x = LESS then
haftmann@22964
   132
                        Float.mult x upper
haftmann@22951
   133
                    else
haftmann@22964
   134
                        Float.mult x lower
haftmann@22951
   135
haftmann@22951
   136
                val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
obua@16784
   137
haftmann@22951
   138
                fun calc_sure_bound (row_index, (row_bound, sources)) sure_bound =
haftmann@22951
   139
                    let
haftmann@22951
   140
                        fun add_src_bound (coeff, src_key) sum =
haftmann@22951
   141
                            case sum of
haftmann@22951
   142
                                NONE => NONE
haftmann@22951
   143
                              | SOME x =>
haftmann@22951
   144
                                (case get_sure_bound g src_key of
haftmann@22951
   145
                                     NONE => NONE
haftmann@22964
   146
                                   | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
haftmann@22951
   147
                    in
haftmann@22951
   148
                        case fold add_src_bound sources (SOME row_bound) of
haftmann@22951
   149
                            NONE => sure_bound
haftmann@22951
   150
                          | new_sure_bound as (SOME new_bound) =>
haftmann@22951
   151
                            (case sure_bound of
haftmann@22951
   152
                                 NONE => new_sure_bound
haftmann@22951
   153
                               | SOME old_bound =>
haftmann@22951
   154
                                 SOME (case btype of
haftmann@22964
   155
                                           UPPER => Float.min old_bound new_bound
haftmann@22964
   156
                                         | LOWER => Float.max old_bound new_bound))
haftmann@22951
   157
                    end
haftmann@22951
   158
            in
haftmann@22951
   159
                case VarGraph.lookup g key of
haftmann@22951
   160
                    NONE => NONE
haftmann@22951
   161
                  | SOME (sure_bound, f) =>
haftmann@22951
   162
                    let
haftmann@22951
   163
                        val x = Inttab.fold calc_sure_bound f sure_bound
haftmann@22951
   164
                    in
haftmann@22951
   165
                        if x = sure_bound then NONE else x
haftmann@22951
   166
                    end
haftmann@22951
   167
                end
obua@16784
   168
haftmann@22951
   169
        fun propagate (key, _) (g, b) =
haftmann@22951
   170
            case calc_sure_bound_from_sources g key of
haftmann@22951
   171
                NONE => (g,b)
haftmann@22951
   172
              | SOME bound => (update_sure_bound g key bound,
haftmann@22951
   173
                               if safe then
haftmann@22951
   174
                                   case get_sure_bound g key of
haftmann@22951
   175
                                       NONE => true
haftmann@22951
   176
                                     | _ => b
haftmann@22951
   177
                               else
haftmann@22951
   178
                                   true)
obua@16784
   179
haftmann@22951
   180
        val (g, b) = VarGraph.fold propagate g (g, false)
obua@16784
   181
    in
haftmann@22951
   182
        if b then propagate_sure_bounds safe names g else g
haftmann@22951
   183
    end
haftmann@22951
   184
obua@16784
   185
exception Load of string;
obua@16784
   186
obua@23665
   187
val empty_spvec = @{term "Nil :: real spvec"};
obua@23665
   188
fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
obua@23665
   189
val empty_spmat = @{term "Nil :: real spmat"};
obua@23665
   190
fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
haftmann@22964
   191
haftmann@22951
   192
fun calcr safe_propagation xlen names prec A b =
obua@16784
   193
    let
haftmann@22951
   194
        val empty = Inttab.empty
haftmann@22951
   195
haftmann@22951
   196
        fun instab t i x = Inttab.update (i, x) t
haftmann@22951
   197
haftmann@22951
   198
        fun isnegstr x = String.isPrefix "-" x
haftmann@22951
   199
        fun negstr x = if isnegstr x then String.extract (x, 1, NONE) else "-"^x
obua@16784
   200
haftmann@22951
   201
        fun test_1 (lower, upper) =
haftmann@22951
   202
            if lower = upper then
wenzelm@23297
   203
                (if Float.eq (lower, (~1, 0)) then ~1
wenzelm@23297
   204
                 else if Float.eq (lower, (1, 0)) then 1
haftmann@22951
   205
                 else 0)
haftmann@22951
   206
            else 0
obua@16784
   207
haftmann@22951
   208
        fun calcr (row_index, a) g =
haftmann@22951
   209
            let
haftmann@22951
   210
                val b =  FloatSparseMatrixBuilder.v_elem_at b row_index
haftmann@22964
   211
                val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
haftmann@22951
   212
                val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
haftmann@22964
   213
                                                                   (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
obua@16784
   214
haftmann@22951
   215
                fun fold_dest_nodes (dest_index, dest_value) g =
haftmann@22951
   216
                    let
haftmann@22951
   217
                        val dest_test = test_1 dest_value
haftmann@22951
   218
                    in
haftmann@22951
   219
                        if dest_test = 0 then
haftmann@22951
   220
                            g
haftmann@22951
   221
                        else let
haftmann@22951
   222
                                val (dest_key as (_, dest_btype), row_bound) =
haftmann@22951
   223
                                    if dest_test = ~1 then
haftmann@22964
   224
                                        ((dest_index, LOWER), Float.neg b2)
haftmann@22951
   225
                                    else
haftmann@22951
   226
                                        ((dest_index, UPPER), b2)
obua@16784
   227
haftmann@22951
   228
                                fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
haftmann@22951
   229
                                    if src_index = dest_index then g
haftmann@22951
   230
                                    else
haftmann@22951
   231
                                        let
haftmann@22951
   232
                                            val coeff = case dest_btype of
haftmann@22964
   233
                                                            UPPER => (Float.neg src_upper, Float.neg src_lower)
haftmann@22951
   234
                                                          | LOWER => src_value
haftmann@22951
   235
                                        in
haftmann@23520
   236
                                            if Float.sign src_lower = LESS then
haftmann@22951
   237
                                                add_edge g (src_index, UPPER) dest_key row_index coeff
haftmann@22951
   238
                                            else
haftmann@22951
   239
                                                add_edge g (src_index, LOWER) dest_key row_index coeff
haftmann@22951
   240
                                        end
haftmann@22951
   241
                            in
haftmann@22951
   242
                                fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
haftmann@22951
   243
                            end
haftmann@22951
   244
                    end
haftmann@22951
   245
            in
haftmann@22951
   246
                case approx_a of
haftmann@22951
   247
                    [] => g
haftmann@22951
   248
                  | [(u, a)] =>
haftmann@22951
   249
                    let
haftmann@22951
   250
                        val atest = test_1 a
haftmann@22951
   251
                    in
haftmann@22951
   252
                        if atest = ~1 then
haftmann@22964
   253
                            update_sure_bound g (u, LOWER) (Float.neg b2)
haftmann@22951
   254
                        else if atest = 1 then
haftmann@22951
   255
                            update_sure_bound g (u, UPPER) b2
haftmann@22951
   256
                        else
haftmann@22951
   257
                            g
haftmann@22951
   258
                    end
haftmann@22951
   259
                  | _ => fold fold_dest_nodes approx_a g
haftmann@22951
   260
            end
obua@16784
   261
haftmann@22951
   262
        val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
obua@23665
   263
haftmann@22951
   264
        val g = propagate_sure_bounds safe_propagation names g
haftmann@22951
   265
obua@23665
   266
	val (r1, r2) = split_graph g
haftmann@22951
   267
obua@24302
   268
        fun add_row_entry m index f vname value =
haftmann@22951
   269
            let
obua@24302
   270
		val v = (case value of 
obua@24302
   271
			     SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
obua@24302
   272
			   | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
obua@24302
   273
                val vec = cons_spvec v empty_spvec
haftmann@22951
   274
            in
haftmann@22964
   275
                cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
haftmann@22951
   276
            end
obua@16784
   277
haftmann@22951
   278
        fun abs_estimate i r1 r2 =
haftmann@22951
   279
            if i = 0 then
obua@24302
   280
                let val e = empty_spmat in (e, e) end
haftmann@22951
   281
            else
haftmann@22951
   282
                let
haftmann@22951
   283
                    val index = xlen-i
obua@24302
   284
                    val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
obua@24302
   285
                    val b1 = Inttab.lookup r1 index
wenzelm@24630
   286
                    val b2 = Inttab.lookup r2 index
haftmann@22951
   287
                in
wenzelm@24630
   288
                    (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, 
wenzelm@24630
   289
		     add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
haftmann@22951
   290
                end
obua@23665
   291
obua@24302
   292
        val (r1, r2) = abs_estimate xlen r1 r2
obua@23665
   293
obua@16784
   294
    in
obua@24302
   295
        (r1, r2)
obua@16784
   296
    end
haftmann@22951
   297
obua@16784
   298
fun load filename prec safe_propagation =
obua@16784
   299
    let
haftmann@22951
   300
        val prog = Cplex.load_cplexFile filename
haftmann@22951
   301
        val prog = Cplex.elim_nonfree_bounds prog
haftmann@22951
   302
        val prog = Cplex.relax_strict_ineqs prog
obua@23665
   303
        val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog			    
obua@24302
   304
        val (r1, r2) = calcr safe_propagation xlen names prec A b
haftmann@22951
   305
        val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
haftmann@22951
   306
        val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
haftmann@22951
   307
        val results = Cplex.solve dualprog
haftmann@22951
   308
        val (optimal,v) = CplexFloatSparseMatrixConverter.convert_results results indexof
obua@24302
   309
        (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
haftmann@22951
   310
        fun id x = x
haftmann@22951
   311
        val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
haftmann@22951
   312
        val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
haftmann@22951
   313
        val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
haftmann@22964
   314
        val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
haftmann@22951
   315
        val A = FloatSparseMatrixBuilder.approx_matrix prec id A
haftmann@22951
   316
        val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
haftmann@22951
   317
        val c = FloatSparseMatrixBuilder.approx_matrix prec id c
obua@16784
   318
    in
obua@24302
   319
        (y1, A, b2, c, (r1, r2))
haftmann@22951
   320
    end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
obua@16784
   321
obua@16784
   322
end