src/HOL/Matrix/fspmlp.ML
changeset 37764 3489daf839d5
parent 32960 69916a850301
child 37788 261c61fabc98
equal deleted inserted replaced
37759:00ff97087ab5 37764:3489daf839d5
       
     1 (*  Title:      HOL/Matrix/cplex/fspmlp.ML
       
     2     Author:     Steven Obua
       
     3 *)
       
     4 
       
     5 signature FSPMLP =
       
     6 sig
       
     7     type linprog
       
     8     type vector = FloatSparseMatrixBuilder.vector
       
     9     type matrix = FloatSparseMatrixBuilder.matrix
       
    10 
       
    11     val y : linprog -> term
       
    12     val A : linprog -> term * term
       
    13     val b : linprog -> term
       
    14     val c : linprog -> term * term
       
    15     val r12 : linprog -> term * term
       
    16 
       
    17     exception Load of string
       
    18 
       
    19     val load : string -> int -> bool -> linprog
       
    20 end
       
    21 
       
    22 structure Fspmlp : FSPMLP =
       
    23 struct
       
    24 
       
    25 type vector = FloatSparseMatrixBuilder.vector
       
    26 type matrix = FloatSparseMatrixBuilder.matrix
       
    27 
       
    28 type linprog = term * (term * term) * term * (term * term) * (term * term)
       
    29 
       
    30 fun y (c1, c2, c3, c4, _) = c1
       
    31 fun A (c1, c2, c3, c4, _) = c2
       
    32 fun b (c1, c2, c3, c4, _) = c3
       
    33 fun c (c1, c2, c3, c4, _) = c4
       
    34 fun r12 (c1, c2, c3, c4, c6) = c6
       
    35 
       
    36 structure CplexFloatSparseMatrixConverter =
       
    37 MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
       
    38 
       
    39 datatype bound_type = LOWER | UPPER
       
    40 
       
    41 fun intbound_ord ((i1: int, b1),(i2,b2)) =
       
    42     if i1 < i2 then LESS
       
    43     else if i1 = i2 then
       
    44         (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
       
    45     else GREATER
       
    46 
       
    47 structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
       
    48 
       
    49 structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
       
    50 (* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
       
    51 (* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
       
    52 
       
    53 exception Internal of string;
       
    54 
       
    55 fun add_row_bound g dest_key row_index row_bound =
       
    56     let
       
    57         val x =
       
    58             case VarGraph.lookup g dest_key of
       
    59                 NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
       
    60               | SOME (sure_bound, f) =>
       
    61                 (sure_bound,
       
    62                  case Inttab.lookup f row_index of
       
    63                      NONE => Inttab.update (row_index, (row_bound, [])) f
       
    64                    | SOME _ => raise (Internal "add_row_bound"))
       
    65     in
       
    66         VarGraph.update (dest_key, x) g
       
    67     end
       
    68 
       
    69 fun update_sure_bound g (key as (_, btype)) bound =
       
    70     let
       
    71         val x =
       
    72             case VarGraph.lookup g key of
       
    73                 NONE => (SOME bound, Inttab.empty)
       
    74               | SOME (NONE, f) => (SOME bound, f)
       
    75               | SOME (SOME old_bound, f) =>
       
    76                 (SOME ((case btype of
       
    77                             UPPER => Float.min
       
    78                           | LOWER => Float.max)
       
    79                            old_bound bound), f)
       
    80     in
       
    81         VarGraph.update (key, x) g
       
    82     end
       
    83 
       
    84 fun get_sure_bound g key =
       
    85     case VarGraph.lookup g key of
       
    86         NONE => NONE
       
    87       | SOME (sure_bound, _) => sure_bound
       
    88 
       
    89 (*fun get_row_bound g key row_index =
       
    90     case VarGraph.lookup g key of
       
    91         NONE => NONE
       
    92       | SOME (sure_bound, f) =>
       
    93         (case Inttab.lookup f row_index of
       
    94              NONE => NONE
       
    95            | SOME (row_bound, _) => (sure_bound, row_bound))*)
       
    96 
       
    97 fun add_edge g src_key dest_key row_index coeff =
       
    98     case VarGraph.lookup g dest_key of
       
    99         NONE => raise (Internal "add_edge: dest_key not found")
       
   100       | SOME (sure_bound, f) =>
       
   101         (case Inttab.lookup f row_index of
       
   102              NONE => raise (Internal "add_edge: row_index not found")
       
   103            | SOME (row_bound, sources) =>
       
   104              VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
       
   105 
       
   106 fun split_graph g =
       
   107   let
       
   108     fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
       
   109      of NONE => (r1, r2)
       
   110       | SOME bound =>  (case key
       
   111          of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
       
   112           | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
       
   113   in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
       
   114 
       
   115 fun it2list t = Inttab.fold cons t [];
       
   116 
       
   117 (* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
       
   118    If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
       
   119 fun propagate_sure_bounds safe names g =
       
   120     let
       
   121         (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
       
   122         fun calc_sure_bound_from_sources g (key as (_, btype)) =
       
   123             let
       
   124                 fun mult_upper x (lower, upper) =
       
   125                     if Float.sign x = LESS then
       
   126                         Float.mult x lower
       
   127                     else
       
   128                         Float.mult x upper
       
   129 
       
   130                 fun mult_lower x (lower, upper) =
       
   131                     if Float.sign x = LESS then
       
   132                         Float.mult x upper
       
   133                     else
       
   134                         Float.mult x lower
       
   135 
       
   136                 val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
       
   137 
       
   138                 fun calc_sure_bound (row_index, (row_bound, sources)) sure_bound =
       
   139                     let
       
   140                         fun add_src_bound (coeff, src_key) sum =
       
   141                             case sum of
       
   142                                 NONE => NONE
       
   143                               | SOME x =>
       
   144                                 (case get_sure_bound g src_key of
       
   145                                      NONE => NONE
       
   146                                    | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
       
   147                     in
       
   148                         case fold add_src_bound sources (SOME row_bound) of
       
   149                             NONE => sure_bound
       
   150                           | new_sure_bound as (SOME new_bound) =>
       
   151                             (case sure_bound of
       
   152                                  NONE => new_sure_bound
       
   153                                | SOME old_bound =>
       
   154                                  SOME (case btype of
       
   155                                            UPPER => Float.min old_bound new_bound
       
   156                                          | LOWER => Float.max old_bound new_bound))
       
   157                     end
       
   158             in
       
   159                 case VarGraph.lookup g key of
       
   160                     NONE => NONE
       
   161                   | SOME (sure_bound, f) =>
       
   162                     let
       
   163                         val x = Inttab.fold calc_sure_bound f sure_bound
       
   164                     in
       
   165                         if x = sure_bound then NONE else x
       
   166                     end
       
   167                 end
       
   168 
       
   169         fun propagate (key, _) (g, b) =
       
   170             case calc_sure_bound_from_sources g key of
       
   171                 NONE => (g,b)
       
   172               | SOME bound => (update_sure_bound g key bound,
       
   173                                if safe then
       
   174                                    case get_sure_bound g key of
       
   175                                        NONE => true
       
   176                                      | _ => b
       
   177                                else
       
   178                                    true)
       
   179 
       
   180         val (g, b) = VarGraph.fold propagate g (g, false)
       
   181     in
       
   182         if b then propagate_sure_bounds safe names g else g
       
   183     end
       
   184 
       
   185 exception Load of string;
       
   186 
       
   187 val empty_spvec = @{term "Nil :: real spvec"};
       
   188 fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
       
   189 val empty_spmat = @{term "Nil :: real spmat"};
       
   190 fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
       
   191 
       
   192 fun calcr safe_propagation xlen names prec A b =
       
   193     let
       
   194         val empty = Inttab.empty
       
   195 
       
   196         fun instab t i x = Inttab.update (i, x) t
       
   197 
       
   198         fun isnegstr x = String.isPrefix "-" x
       
   199         fun negstr x = if isnegstr x then String.extract (x, 1, NONE) else "-"^x
       
   200 
       
   201         fun test_1 (lower, upper) =
       
   202             if lower = upper then
       
   203                 (if Float.eq (lower, (~1, 0)) then ~1
       
   204                  else if Float.eq (lower, (1, 0)) then 1
       
   205                  else 0)
       
   206             else 0
       
   207 
       
   208         fun calcr (row_index, a) g =
       
   209             let
       
   210                 val b =  FloatSparseMatrixBuilder.v_elem_at b row_index
       
   211                 val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
       
   212                 val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
       
   213                                                                    (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
       
   214 
       
   215                 fun fold_dest_nodes (dest_index, dest_value) g =
       
   216                     let
       
   217                         val dest_test = test_1 dest_value
       
   218                     in
       
   219                         if dest_test = 0 then
       
   220                             g
       
   221                         else let
       
   222                                 val (dest_key as (_, dest_btype), row_bound) =
       
   223                                     if dest_test = ~1 then
       
   224                                         ((dest_index, LOWER), Float.neg b2)
       
   225                                     else
       
   226                                         ((dest_index, UPPER), b2)
       
   227 
       
   228                                 fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
       
   229                                     if src_index = dest_index then g
       
   230                                     else
       
   231                                         let
       
   232                                             val coeff = case dest_btype of
       
   233                                                             UPPER => (Float.neg src_upper, Float.neg src_lower)
       
   234                                                           | LOWER => src_value
       
   235                                         in
       
   236                                             if Float.sign src_lower = LESS then
       
   237                                                 add_edge g (src_index, UPPER) dest_key row_index coeff
       
   238                                             else
       
   239                                                 add_edge g (src_index, LOWER) dest_key row_index coeff
       
   240                                         end
       
   241                             in
       
   242                                 fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
       
   243                             end
       
   244                     end
       
   245             in
       
   246                 case approx_a of
       
   247                     [] => g
       
   248                   | [(u, a)] =>
       
   249                     let
       
   250                         val atest = test_1 a
       
   251                     in
       
   252                         if atest = ~1 then
       
   253                             update_sure_bound g (u, LOWER) (Float.neg b2)
       
   254                         else if atest = 1 then
       
   255                             update_sure_bound g (u, UPPER) b2
       
   256                         else
       
   257                             g
       
   258                     end
       
   259                   | _ => fold fold_dest_nodes approx_a g
       
   260             end
       
   261 
       
   262         val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
       
   263 
       
   264         val g = propagate_sure_bounds safe_propagation names g
       
   265 
       
   266         val (r1, r2) = split_graph g
       
   267 
       
   268         fun add_row_entry m index f vname value =
       
   269             let
       
   270                 val v = (case value of 
       
   271                              SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
       
   272                            | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
       
   273                 val vec = cons_spvec v empty_spvec
       
   274             in
       
   275                 cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
       
   276             end
       
   277 
       
   278         fun abs_estimate i r1 r2 =
       
   279             if i = 0 then
       
   280                 let val e = empty_spmat in (e, e) end
       
   281             else
       
   282                 let
       
   283                     val index = xlen-i
       
   284                     val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
       
   285                     val b1 = Inttab.lookup r1 index
       
   286                     val b2 = Inttab.lookup r2 index
       
   287                 in
       
   288                     (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, 
       
   289                      add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
       
   290                 end
       
   291 
       
   292         val (r1, r2) = abs_estimate xlen r1 r2
       
   293 
       
   294     in
       
   295         (r1, r2)
       
   296     end
       
   297 
       
   298 fun load filename prec safe_propagation =
       
   299     let
       
   300         val prog = Cplex.load_cplexFile filename
       
   301         val prog = Cplex.elim_nonfree_bounds prog
       
   302         val prog = Cplex.relax_strict_ineqs prog
       
   303         val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog                       
       
   304         val (r1, r2) = calcr safe_propagation xlen names prec A b
       
   305         val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
       
   306         val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
       
   307         val results = Cplex.solve dualprog
       
   308         val (optimal,v) = CplexFloatSparseMatrixConverter.convert_results results indexof
       
   309         (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
       
   310         fun id x = x
       
   311         val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
       
   312         val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
       
   313         val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
       
   314         val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
       
   315         val A = FloatSparseMatrixBuilder.approx_matrix prec id A
       
   316         val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
       
   317         val c = FloatSparseMatrixBuilder.approx_matrix prec id c
       
   318     in
       
   319         (y1, A, b2, c, (r1, r2))
       
   320     end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
       
   321 
       
   322 end