src/HOL/Matrix/fspmlp.ML
author wenzelm
Fri Dec 17 17:43:54 2010 +0100 (2010-12-17)
changeset 41229 d797baa3d57c
parent 37788 261c61fabc98
child 46533 faf233c4a404
permissions -rw-r--r--
replaced command 'nonterminals' by slightly modernized version 'nonterminal';
     1 (*  Title:      HOL/Matrix/fspmlp.ML
     2     Author:     Steven Obua
     3 *)
     4 
     5 signature FSPMLP =
     6 sig
     7     type linprog
     8     type vector = FloatSparseMatrixBuilder.vector
     9     type matrix = FloatSparseMatrixBuilder.matrix
    10 
    11     val y : linprog -> term
    12     val A : linprog -> term * term
    13     val b : linprog -> term
    14     val c : linprog -> term * term
    15     val r12 : linprog -> term * term
    16 
    17     exception Load of string
    18 
    19     val load : string -> int -> bool -> linprog
    20 end
    21 
    22 structure Fspmlp : FSPMLP =
    23 struct
    24 
    25 type vector = FloatSparseMatrixBuilder.vector
    26 type matrix = FloatSparseMatrixBuilder.matrix
    27 
    28 type linprog = term * (term * term) * term * (term * term) * (term * term)
    29 
    30 fun y (c1, c2, c3, c4, _) = c1
    31 fun A (c1, c2, c3, c4, _) = c2
    32 fun b (c1, c2, c3, c4, _) = c3
    33 fun c (c1, c2, c3, c4, _) = c4
    34 fun r12 (c1, c2, c3, c4, c6) = c6
    35 
    36 structure CplexFloatSparseMatrixConverter =
    37 MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
    38 
    39 datatype bound_type = LOWER | UPPER
    40 
    41 fun intbound_ord ((i1: int, b1),(i2,b2)) =
    42     if i1 < i2 then LESS
    43     else if i1 = i2 then
    44         (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
    45     else GREATER
    46 
    47 structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
    48 
    49 structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
    50 (* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
    51 (* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
    52 
    53 exception Internal of string;
    54 
    55 fun add_row_bound g dest_key row_index row_bound =
    56     let
    57         val x =
    58             case VarGraph.lookup g dest_key of
    59                 NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
    60               | SOME (sure_bound, f) =>
    61                 (sure_bound,
    62                  case Inttab.lookup f row_index of
    63                      NONE => Inttab.update (row_index, (row_bound, [])) f
    64                    | SOME _ => raise (Internal "add_row_bound"))
    65     in
    66         VarGraph.update (dest_key, x) g
    67     end
    68 
    69 fun update_sure_bound g (key as (_, btype)) bound =
    70     let
    71         val x =
    72             case VarGraph.lookup g key of
    73                 NONE => (SOME bound, Inttab.empty)
    74               | SOME (NONE, f) => (SOME bound, f)
    75               | SOME (SOME old_bound, f) =>
    76                 (SOME ((case btype of
    77                             UPPER => Float.min
    78                           | LOWER => Float.max)
    79                            old_bound bound), f)
    80     in
    81         VarGraph.update (key, x) g
    82     end
    83 
    84 fun get_sure_bound g key =
    85     case VarGraph.lookup g key of
    86         NONE => NONE
    87       | SOME (sure_bound, _) => sure_bound
    88 
    89 (*fun get_row_bound g key row_index =
    90     case VarGraph.lookup g key of
    91         NONE => NONE
    92       | SOME (sure_bound, f) =>
    93         (case Inttab.lookup f row_index of
    94              NONE => NONE
    95            | SOME (row_bound, _) => (sure_bound, row_bound))*)
    96 
    97 fun add_edge g src_key dest_key row_index coeff =
    98     case VarGraph.lookup g dest_key of
    99         NONE => raise (Internal "add_edge: dest_key not found")
   100       | SOME (sure_bound, f) =>
   101         (case Inttab.lookup f row_index of
   102              NONE => raise (Internal "add_edge: row_index not found")
   103            | SOME (row_bound, sources) =>
   104              VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
   105 
   106 fun split_graph g =
   107   let
   108     fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
   109      of NONE => (r1, r2)
   110       | SOME bound =>  (case key
   111          of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
   112           | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
   113   in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
   114 
   115 fun it2list t = Inttab.fold cons t [];
   116 
   117 (* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
   118    If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
   119 fun propagate_sure_bounds safe names g =
   120     let
   121         (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
   122         fun calc_sure_bound_from_sources g (key as (_, btype)) =
   123             let
   124                 fun mult_upper x (lower, upper) =
   125                     if Float.sign x = LESS then
   126                         Float.mult x lower
   127                     else
   128                         Float.mult x upper
   129 
   130                 fun mult_lower x (lower, upper) =
   131                     if Float.sign x = LESS then
   132                         Float.mult x upper
   133                     else
   134                         Float.mult x lower
   135 
   136                 val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
   137 
   138                 fun calc_sure_bound (row_index, (row_bound, sources)) sure_bound =
   139                     let
   140                         fun add_src_bound (coeff, src_key) sum =
   141                             case sum of
   142                                 NONE => NONE
   143                               | SOME x =>
   144                                 (case get_sure_bound g src_key of
   145                                      NONE => NONE
   146                                    | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
   147                     in
   148                         case fold add_src_bound sources (SOME row_bound) of
   149                             NONE => sure_bound
   150                           | new_sure_bound as (SOME new_bound) =>
   151                             (case sure_bound of
   152                                  NONE => new_sure_bound
   153                                | SOME old_bound =>
   154                                  SOME (case btype of
   155                                            UPPER => Float.min old_bound new_bound
   156                                          | LOWER => Float.max old_bound new_bound))
   157                     end
   158             in
   159                 case VarGraph.lookup g key of
   160                     NONE => NONE
   161                   | SOME (sure_bound, f) =>
   162                     let
   163                         val x = Inttab.fold calc_sure_bound f sure_bound
   164                     in
   165                         if x = sure_bound then NONE else x
   166                     end
   167                 end
   168 
   169         fun propagate (key, _) (g, b) =
   170             case calc_sure_bound_from_sources g key of
   171                 NONE => (g,b)
   172               | SOME bound => (update_sure_bound g key bound,
   173                                if safe then
   174                                    case get_sure_bound g key of
   175                                        NONE => true
   176                                      | _ => b
   177                                else
   178                                    true)
   179 
   180         val (g, b) = VarGraph.fold propagate g (g, false)
   181     in
   182         if b then propagate_sure_bounds safe names g else g
   183     end
   184 
   185 exception Load of string;
   186 
   187 val empty_spvec = @{term "Nil :: real spvec"};
   188 fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
   189 val empty_spmat = @{term "Nil :: real spmat"};
   190 fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
   191 
   192 fun calcr safe_propagation xlen names prec A b =
   193     let
   194         val empty = Inttab.empty
   195 
   196         fun instab t i x = Inttab.update (i, x) t
   197 
   198         fun isnegstr x = String.isPrefix "-" x
   199         fun negstr x = if isnegstr x then String.extract (x, 1, NONE) else "-"^x
   200 
   201         fun test_1 (lower, upper) =
   202             if lower = upper then
   203                 (if Float.eq (lower, (~1, 0)) then ~1
   204                  else if Float.eq (lower, (1, 0)) then 1
   205                  else 0)
   206             else 0
   207 
   208         fun calcr (row_index, a) g =
   209             let
   210                 val b =  FloatSparseMatrixBuilder.v_elem_at b row_index
   211                 val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
   212                 val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
   213                                                                    (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
   214 
   215                 fun fold_dest_nodes (dest_index, dest_value) g =
   216                     let
   217                         val dest_test = test_1 dest_value
   218                     in
   219                         if dest_test = 0 then
   220                             g
   221                         else let
   222                                 val (dest_key as (_, dest_btype), row_bound) =
   223                                     if dest_test = ~1 then
   224                                         ((dest_index, LOWER), Float.neg b2)
   225                                     else
   226                                         ((dest_index, UPPER), b2)
   227 
   228                                 fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
   229                                     if src_index = dest_index then g
   230                                     else
   231                                         let
   232                                             val coeff = case dest_btype of
   233                                                             UPPER => (Float.neg src_upper, Float.neg src_lower)
   234                                                           | LOWER => src_value
   235                                         in
   236                                             if Float.sign src_lower = LESS then
   237                                                 add_edge g (src_index, UPPER) dest_key row_index coeff
   238                                             else
   239                                                 add_edge g (src_index, LOWER) dest_key row_index coeff
   240                                         end
   241                             in
   242                                 fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
   243                             end
   244                     end
   245             in
   246                 case approx_a of
   247                     [] => g
   248                   | [(u, a)] =>
   249                     let
   250                         val atest = test_1 a
   251                     in
   252                         if atest = ~1 then
   253                             update_sure_bound g (u, LOWER) (Float.neg b2)
   254                         else if atest = 1 then
   255                             update_sure_bound g (u, UPPER) b2
   256                         else
   257                             g
   258                     end
   259                   | _ => fold fold_dest_nodes approx_a g
   260             end
   261 
   262         val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
   263 
   264         val g = propagate_sure_bounds safe_propagation names g
   265 
   266         val (r1, r2) = split_graph g
   267 
   268         fun add_row_entry m index f vname value =
   269             let
   270                 val v = (case value of 
   271                              SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
   272                            | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
   273                 val vec = cons_spvec v empty_spvec
   274             in
   275                 cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
   276             end
   277 
   278         fun abs_estimate i r1 r2 =
   279             if i = 0 then
   280                 let val e = empty_spmat in (e, e) end
   281             else
   282                 let
   283                     val index = xlen-i
   284                     val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
   285                     val b1 = Inttab.lookup r1 index
   286                     val b2 = Inttab.lookup r2 index
   287                 in
   288                     (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, 
   289                      add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
   290                 end
   291 
   292         val (r1, r2) = abs_estimate xlen r1 r2
   293 
   294     in
   295         (r1, r2)
   296     end
   297 
   298 fun load filename prec safe_propagation =
   299     let
   300         val prog = Cplex.load_cplexFile filename
   301         val prog = Cplex.elim_nonfree_bounds prog
   302         val prog = Cplex.relax_strict_ineqs prog
   303         val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog                       
   304         val (r1, r2) = calcr safe_propagation xlen names prec A b
   305         val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
   306         val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
   307         val results = Cplex.solve dualprog
   308         val (optimal,v) = CplexFloatSparseMatrixConverter.convert_results results indexof
   309         (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
   310         fun id x = x
   311         val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
   312         val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
   313         val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
   314         val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
   315         val A = FloatSparseMatrixBuilder.approx_matrix prec id A
   316         val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
   317         val c = FloatSparseMatrixBuilder.approx_matrix prec id c
   318     in
   319         (y1, A, b2, c, (r1, r2))
   320     end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
   321 
   322 end