src/HOL/Matrix_LP/fspmlp.ML
author blanchet
Tue Nov 07 15:16:42 2017 +0100 (21 months ago)
changeset 67022 49309fe530fd
parent 47455 26315a545e26
child 69597 ff784d5a5bfb
permissions -rw-r--r--
more robust parsing for THF proofs (esp. polymorphic Leo-III proofs)
     1 (*  Title:      HOL/Matrix_LP/fspmlp.ML
     2     Author:     Steven Obua
     3 *)
     4 
     5 signature FSPMLP =
     6 sig
     7     type linprog
     8     type vector = FloatSparseMatrixBuilder.vector
     9     type matrix = FloatSparseMatrixBuilder.matrix
    10 
    11     val y : linprog -> term
    12     val A : linprog -> term * term
    13     val b : linprog -> term
    14     val c : linprog -> term * term
    15     val r12 : linprog -> term * term
    16 
    17     exception Load of string
    18 
    19     val load : string -> int -> bool -> linprog
    20 end
    21 
    22 structure Fspmlp : FSPMLP =
    23 struct
    24 
    25 type vector = FloatSparseMatrixBuilder.vector
    26 type matrix = FloatSparseMatrixBuilder.matrix
    27 
    28 type linprog = term * (term * term) * term * (term * term) * (term * term)
    29 
    30 fun y (c1, _, _, _, _) = c1
    31 fun A (_, c2, _, _, _) = c2
    32 fun b (_, _, c3, _, _) = c3
    33 fun c (_, _, _, c4, _) = c4
    34 fun r12 (_, _, _, _, c6) = c6
    35 
    36 structure CplexFloatSparseMatrixConverter =
    37 MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
    38 
    39 datatype bound_type = LOWER | UPPER
    40 
    41 fun intbound_ord ((i1: int, b1),(i2,b2)) =
    42     if i1 < i2 then LESS
    43     else if i1 = i2 then
    44         (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
    45     else GREATER
    46 
    47 structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
    48 
    49 structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
    50 (* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
    51 (* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
    52 
    53 exception Internal of string;
    54 
    55 fun add_row_bound g dest_key row_index row_bound =
    56     let
    57         val x =
    58             case VarGraph.lookup g dest_key of
    59                 NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
    60               | SOME (sure_bound, f) =>
    61                 (sure_bound,
    62                  case Inttab.lookup f row_index of
    63                      NONE => Inttab.update (row_index, (row_bound, [])) f
    64                    | SOME _ => raise (Internal "add_row_bound"))
    65     in
    66         VarGraph.update (dest_key, x) g
    67     end
    68 
    69 fun update_sure_bound g (key as (_, btype)) bound =
    70     let
    71         val x =
    72             case VarGraph.lookup g key of
    73                 NONE => (SOME bound, Inttab.empty)
    74               | SOME (NONE, f) => (SOME bound, f)
    75               | SOME (SOME old_bound, f) =>
    76                 (SOME ((case btype of
    77                             UPPER => Float.min
    78                           | LOWER => Float.max)
    79                            old_bound bound), f)
    80     in
    81         VarGraph.update (key, x) g
    82     end
    83 
    84 fun get_sure_bound g key =
    85     case VarGraph.lookup g key of
    86         NONE => NONE
    87       | SOME (sure_bound, _) => sure_bound
    88 
    89 (*fun get_row_bound g key row_index =
    90     case VarGraph.lookup g key of
    91         NONE => NONE
    92       | SOME (sure_bound, f) =>
    93         (case Inttab.lookup f row_index of
    94              NONE => NONE
    95            | SOME (row_bound, _) => (sure_bound, row_bound))*)
    96 
    97 fun add_edge g src_key dest_key row_index coeff =
    98     case VarGraph.lookup g dest_key of
    99         NONE => raise (Internal "add_edge: dest_key not found")
   100       | SOME (sure_bound, f) =>
   101         (case Inttab.lookup f row_index of
   102              NONE => raise (Internal "add_edge: row_index not found")
   103            | SOME (row_bound, sources) =>
   104              VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
   105 
   106 fun split_graph g =
   107   let
   108     fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
   109      of NONE => (r1, r2)
   110       | SOME bound =>  (case key
   111          of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
   112           | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
   113   in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
   114 
   115 (* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
   116    If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
   117 fun propagate_sure_bounds safe names g =
   118     let
   119         (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
   120         fun calc_sure_bound_from_sources g (key as (_, btype)) =
   121             let
   122                 fun mult_upper x (lower, upper) =
   123                     if Float.sign x = LESS then
   124                         Float.mult x lower
   125                     else
   126                         Float.mult x upper
   127 
   128                 fun mult_lower x (lower, upper) =
   129                     if Float.sign x = LESS then
   130                         Float.mult x upper
   131                     else
   132                         Float.mult x lower
   133 
   134                 val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
   135 
   136                 fun calc_sure_bound (_, (row_bound, sources)) sure_bound =
   137                     let
   138                         fun add_src_bound (coeff, src_key) sum =
   139                             case sum of
   140                                 NONE => NONE
   141                               | SOME x =>
   142                                 (case get_sure_bound g src_key of
   143                                      NONE => NONE
   144                                    | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
   145                     in
   146                         case fold add_src_bound sources (SOME row_bound) of
   147                             NONE => sure_bound
   148                           | new_sure_bound as (SOME new_bound) =>
   149                             (case sure_bound of
   150                                  NONE => new_sure_bound
   151                                | SOME old_bound =>
   152                                  SOME (case btype of
   153                                            UPPER => Float.min old_bound new_bound
   154                                          | LOWER => Float.max old_bound new_bound))
   155                     end
   156             in
   157                 case VarGraph.lookup g key of
   158                     NONE => NONE
   159                   | SOME (sure_bound, f) =>
   160                     let
   161                         val x = Inttab.fold calc_sure_bound f sure_bound
   162                     in
   163                         if x = sure_bound then NONE else x
   164                     end
   165                 end
   166 
   167         fun propagate (key, _) (g, b) =
   168             case calc_sure_bound_from_sources g key of
   169                 NONE => (g,b)
   170               | SOME bound => (update_sure_bound g key bound,
   171                                if safe then
   172                                    case get_sure_bound g key of
   173                                        NONE => true
   174                                      | _ => b
   175                                else
   176                                    true)
   177 
   178         val (g, b) = VarGraph.fold propagate g (g, false)
   179     in
   180         if b then propagate_sure_bounds safe names g else g
   181     end
   182 
   183 exception Load of string;
   184 
   185 val empty_spvec = @{term "Nil :: real spvec"};
   186 fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
   187 val empty_spmat = @{term "Nil :: real spmat"};
   188 fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
   189 
   190 fun calcr safe_propagation xlen names prec A b =
   191     let
   192         fun test_1 (lower, upper) =
   193             if lower = upper then
   194                 (if Float.eq (lower, (~1, 0)) then ~1
   195                  else if Float.eq (lower, (1, 0)) then 1
   196                  else 0)
   197             else 0
   198 
   199         fun calcr (row_index, a) g =
   200             let
   201                 val b =  FloatSparseMatrixBuilder.v_elem_at b row_index
   202                 val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
   203                 val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
   204                                                                    (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
   205 
   206                 fun fold_dest_nodes (dest_index, dest_value) g =
   207                     let
   208                         val dest_test = test_1 dest_value
   209                     in
   210                         if dest_test = 0 then
   211                             g
   212                         else let
   213                                 val (dest_key as (_, dest_btype), row_bound) =
   214                                     if dest_test = ~1 then
   215                                         ((dest_index, LOWER), Float.neg b2)
   216                                     else
   217                                         ((dest_index, UPPER), b2)
   218 
   219                                 fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
   220                                     if src_index = dest_index then g
   221                                     else
   222                                         let
   223                                             val coeff = case dest_btype of
   224                                                             UPPER => (Float.neg src_upper, Float.neg src_lower)
   225                                                           | LOWER => src_value
   226                                         in
   227                                             if Float.sign src_lower = LESS then
   228                                                 add_edge g (src_index, UPPER) dest_key row_index coeff
   229                                             else
   230                                                 add_edge g (src_index, LOWER) dest_key row_index coeff
   231                                         end
   232                             in
   233                                 fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
   234                             end
   235                     end
   236             in
   237                 case approx_a of
   238                     [] => g
   239                   | [(u, a)] =>
   240                     let
   241                         val atest = test_1 a
   242                     in
   243                         if atest = ~1 then
   244                             update_sure_bound g (u, LOWER) (Float.neg b2)
   245                         else if atest = 1 then
   246                             update_sure_bound g (u, UPPER) b2
   247                         else
   248                             g
   249                     end
   250                   | _ => fold fold_dest_nodes approx_a g
   251             end
   252 
   253         val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
   254 
   255         val g = propagate_sure_bounds safe_propagation names g
   256 
   257         val (r1, r2) = split_graph g
   258 
   259         fun add_row_entry m index f vname value =
   260             let
   261                 val v = (case value of 
   262                              SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
   263                            | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
   264                 val vec = cons_spvec v empty_spvec
   265             in
   266                 cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
   267             end
   268 
   269         fun abs_estimate i r1 r2 =
   270             if i = 0 then
   271                 let val e = empty_spmat in (e, e) end
   272             else
   273                 let
   274                     val index = xlen-i
   275                     val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
   276                     val b1 = Inttab.lookup r1 index
   277                     val b2 = Inttab.lookup r2 index
   278                 in
   279                     (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, 
   280                      add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
   281                 end
   282 
   283         val (r1, r2) = abs_estimate xlen r1 r2
   284 
   285     in
   286         (r1, r2)
   287     end
   288 
   289 fun load filename prec safe_propagation =
   290     let
   291         val prog = Cplex.load_cplexFile filename
   292         val prog = Cplex.elim_nonfree_bounds prog
   293         val prog = Cplex.relax_strict_ineqs prog
   294         val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog                       
   295         val (r1, r2) = calcr safe_propagation xlen names prec A b
   296         val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
   297         val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
   298         val results = Cplex.solve dualprog
   299         val (_, v) = CplexFloatSparseMatrixConverter.convert_results results indexof
   300         (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
   301         fun id x = x
   302         val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
   303         val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
   304         val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
   305         val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
   306         val A = FloatSparseMatrixBuilder.approx_matrix prec id A
   307         val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
   308         val c = FloatSparseMatrixBuilder.approx_matrix prec id c
   309     in
   310         (y1, A, b2, c, (r1, r2))
   311     end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
   312 
   313 end