LCOV - code coverage report
Current view: top level - source - FBCache.cpp (source / functions) Hit Total Coverage
Test: LibForBES Unit Tests Lines: 220 248 88.7 %
Date: 2016-04-18 Functions: 22 22 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :  * File:   FBCache.cpp
       3             :  * Author: Lorenzo Stella, Pantelis Sopasakis
       4             :  *
       5             :  * Created on October 2, 2015
       6             :  *
       7             :  * ForBES is free software: you can redistribute it and/or modify
       8             :  * it under the terms of the GNU Lesser General Public License as published by
       9             :  * the Free Software Foundation, either version 3 of the License, or
      10             :  * (at your option) any later version.
      11             :  *
      12             :  * ForBES is distributed in the hope that it will be useful,
      13             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      14             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
      15             :  * GNU Lesser General Public License for more details.
      16             :  *
      17             :  * You should have received a copy of the GNU Lesser General Public License
      18             :  * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
      19             :  */
      20             : 
      21             : #include "FBCache.h"
      22             : #include "LinearOperator.h"
      23             : 
      24             : #include <cmath>
      25             : #include <limits>
      26             : #include <complex>
      27             : 
      28             : #define FB_CACHE_RELATIVE_TOL 1e-6
      29             : #define FB_CACHE_ABSOLUTE_TOL 1e-14
      30             : 
      31             : /**
      32             :  * Verify whether \c a and \c b are close to one another. Checks whether 
      33             :  * \f[
      34             :  *  |a-b|\leq \max (\epsilon_r \max(a,b), \epsilon_a),
      35             :  * \f]
      36             :  * where \f$\epsilon_r = 10^{-6}\f$ is the relative tolerance and \f$\epsilon_a=10^{-14}\f$
      37             :  * is the absolute tolerance.
      38             :  * 
      39             :  * 
      40             :  * @param a scalar
      41             :  * @param b another scalar
      42             :  * @return \c true iff \c a is approximately equal to \c b.
      43             :  */
      44      189629 : inline bool is_close(const double a, const double b) {
      45      379258 :     return abs(a - b) <= std::max(
      46      379258 :             FB_CACHE_RELATIVE_TOL * std::max(std::abs(a), std::abs(b)),
      47      758516 :             FB_CACHE_ABSOLUTE_TOL);
      48             : }
      49             : 
      50       81145 : void FBCache::reset(int status) {
      51       81145 :     if (status < m_status) m_status = status;
      52       81145 : }
      53             : 
      54       26900 : void FBCache::reset() {
      55       26900 :     FBCache::reset(STATUS_NONE);
      56       26900 : }
      57             : 
      58         906 : FBCache::FBCache(FBProblem & p, Matrix & x, double gamma) :
      59             : m_prob(p),
      60             : m_x(&x),
      61         906 : m_gamma(gamma) {
      62         906 :     reset(STATUS_NONE);
      63             : 
      64             :     // get dimensions of things
      65         906 :     size_t m_x_rows = m_x->getNrows();
      66         906 :     size_t m_x_cols = m_x->getNcols();
      67             :     size_t m_res1_rows, m_res1_cols;
      68             :     size_t m_res2_rows, m_res2_cols;
      69             : 
      70         906 :     if (m_prob.d1() != NULL) {
      71         302 :         m_res1_rows = m_prob.d1()->getNrows();
      72         302 :         m_res1_cols = m_prob.d1()->getNcols();
      73         604 :     } else if (m_prob.L1() != NULL) {
      74           0 :         m_res1_rows = m_prob.L1()->dimensionOut().first;
      75           0 :         m_res1_cols = m_prob.L1()->dimensionOut().second;
      76             :     } else {
      77         604 :         m_res1_rows = m_x_rows;
      78         604 :         m_res1_cols = m_x_cols;
      79             :     }
      80         906 :     if (m_prob.d2() != NULL) {
      81         201 :         m_res2_rows = m_prob.d2()->getNrows();
      82         201 :         m_res2_cols = m_prob.d2()->getNcols();
      83         705 :     } else if (m_prob.L2() != NULL) {
      84           0 :         m_res2_rows = m_prob.L2()->dimensionOut().first;
      85           0 :         m_res2_cols = m_prob.L2()->dimensionOut().second;
      86             :     } else {
      87         705 :         m_res2_rows = m_x_rows;
      88         705 :         m_res2_cols = m_x_cols;
      89             :     }
      90             : 
      91             :     // allocate memory for residuals and gradients (where needed)
      92         906 :     if (m_prob.f1() != NULL) {
      93         704 :         m_res1x = new Matrix(m_res1_rows, m_res1_cols);
      94         704 :         m_gradf1x = new Matrix(m_res1_rows, m_res1_cols);
      95             :     } else {
      96         202 :         m_res1x = NULL;
      97         202 :         m_gradf1x = NULL;
      98             :     }
      99         906 :     if (m_prob.f2() != NULL) {
     100         202 :         m_res2x = new Matrix(m_res2_rows, m_res2_cols);
     101         202 :         m_gradf2x = new Matrix(m_res2_rows, m_res2_cols);
     102             :     } else {
     103         704 :         m_res2x = NULL;
     104         704 :         m_gradf2x = NULL;
     105             :     }
     106             : 
     107         906 :     m_gradfx = new Matrix(m_x_rows, m_x_cols);
     108         906 :     m_z = new Matrix(m_x_rows, m_x_cols);
     109         906 :     m_y = new Matrix(m_x_rows, m_x_cols);
     110         906 :     m_FPRx = new Matrix(m_x_rows, m_x_cols);
     111         906 :     m_gradFBEx = new Matrix(m_x_rows, m_x_cols);
     112             : 
     113         906 :     m_FBEx = std::numeric_limits<double>::infinity();
     114         906 :     m_sqnormFPRx = std::numeric_limits<double>::infinity();
     115             : 
     116         906 :     m_f1x = 0.0;
     117         906 :     m_f2x = 0.0;
     118         906 :     m_linx = 0.0;
     119         906 :     m_fx = 0.0;
     120         906 :     m_gz = 0.0;
     121             : 
     122         906 :     m_cached_grad_f2 = false;
     123         906 : }
     124             : 
     125       81132 : int FBCache::update_eval_f(bool order_grad_f2) {
     126             : 
     127       81132 :     if (!m_cached_grad_f2 && order_grad_f2) {
     128             :         // If gradf2x has not been computed previously, but now should be
     129             :         // computed, then set the status to STATUS_NONE, so that all 
     130             :         // computations are performed from the beginning (both f2x and gradf2x).
     131       81126 :         m_status = STATUS_NONE;
     132             :     }
     133             : 
     134       81132 :     if (m_status >= STATUS_EVALF) {
     135           0 :         return ForBESUtils::STATUS_OK;
     136             :     }
     137             : 
     138       81132 :     if (m_prob.f1() != NULL) {
     139       61530 :         if (m_prob.L1() != NULL) {
     140       38127 :             *m_res1x = m_prob.L1()->call(*m_x);
     141             :         } else {
     142       23403 :             *m_res1x = *m_x;
     143             :         }
     144       61530 :         if (m_prob.d1() != NULL) {
     145       38127 :             *m_res1x += *(m_prob.d1());
     146             :         }
     147       61530 :         int status = m_prob.f1()->call(*m_res1x, m_f1x, *m_gradf1x);
     148       61530 :         if (ForBESUtils::STATUS_OK != status) {
     149           0 :             return status;
     150             :         }
     151             :     }
     152             : 
     153       81132 :     if (m_prob.f2() != NULL) {
     154       19602 :         if (m_prob.L2() != NULL) {
     155       19601 :             *m_res2x = m_prob.L2()->call(*m_x);
     156             :         } else {
     157           1 :             *m_res2x = *m_x;
     158             :         }
     159       19602 :         if (m_prob.d2() != NULL) {
     160       19601 :             *m_res2x += *(m_prob.d2());
     161             :         }
     162             :         int status =
     163             :                 order_grad_f2
     164       19600 :                 ? m_prob.f2()->call(*m_res2x, m_f2x, *m_gradf2x)
     165       39202 :                 : m_prob.f2()->call(*m_res2x, m_f2x);
     166       19602 :         m_cached_grad_f2 = order_grad_f2;
     167       19602 :         if (ForBESUtils::STATUS_OK != status) {
     168           0 :             return status;
     169             :         }
     170             :     }
     171             : 
     172       81132 :     if (m_prob.lin() != NULL) {
     173           0 :         m_linx = ((*m_prob.lin()) * (*m_x))[0];
     174             :     }
     175             : 
     176       81132 :     m_fx = m_f1x + m_f2x + m_linx;
     177       81132 :     m_status = STATUS_EVALF;
     178             : 
     179       81132 :     return ForBESUtils::STATUS_OK;
     180             : }
     181             : 
     182       81140 : int FBCache::update_forward_step(double gamma) {
     183       81140 :     bool is_gamma_the_same = is_close(gamma, m_gamma);
     184       81140 :     if (!is_gamma_the_same) reset(STATUS_EVALF);
     185             : 
     186             : 
     187       81140 :     if (m_status >= STATUS_FORWARD) {
     188           1 :         if (is_gamma_the_same) return ForBESUtils::STATUS_OK;
     189           0 :         *m_y = *m_x;
     190           0 :         Matrix::add(*m_y, -gamma, *m_gradfx, 1.0);
     191           0 :         m_gamma = gamma;
     192           0 :         return ForBESUtils::STATUS_OK;
     193             :     }
     194             : 
     195             :     int status;
     196             : 
     197       81139 :     if (m_status < STATUS_EVALF) {
     198       81126 :         m_cached_grad_f2 = false;
     199       81126 :         status = update_eval_f(true);
     200       81126 :         if (!ForBESUtils::is_status_ok(status)) return status;
     201             :     }
     202             : 
     203       81139 :     if (m_prob.f1() != NULL) {
     204       61535 :         if (m_prob.L1()) {
     205       38129 :             Matrix d_gradfx = m_prob.L1()->callAdjoint(*m_gradf1x);
     206       38129 :             *m_gradfx = d_gradfx;
     207             :         } else {
     208       23406 :             *m_gradfx = *m_gradf1x;
     209             :         }
     210             :     }
     211             : 
     212       81139 :     if (m_prob.f2() != NULL) {
     213       19604 :         if (!m_cached_grad_f2) {
     214           2 :             status = m_prob.f2()->call(*m_res2x, m_f2x, *m_gradf2x);
     215           2 :             if (!ForBESUtils::is_status_ok(status)) return status;
     216             :             // now gradf2x has been computed:
     217           2 :             m_cached_grad_f2 = true;
     218             :         }
     219       19604 :         if (m_prob.L2() != NULL) {
     220       19602 :             Matrix d_gradfx = m_prob.L2()->callAdjoint(*m_gradf2x);
     221       19602 :             if (m_prob.f1() != NULL) *m_gradfx += d_gradfx;
     222       19602 :             else *m_gradfx = d_gradfx;
     223             :         } else {
     224           2 :             if (m_prob.f1() != NULL) *m_gradfx += *m_gradf2x;
     225           2 :             else *m_gradfx = *m_gradf2x;
     226             :         }
     227             :     }
     228             : 
     229       81139 :     if (m_prob.lin()) {
     230           0 :         if (m_prob.f1() != NULL || m_prob.f2() != NULL) {
     231           0 :             *m_gradfx += (*m_prob.lin());
     232             :         } else {
     233           0 :             *m_gradfx = *m_prob.lin();
     234             :         }
     235             :     }
     236             : 
     237       81139 :     *m_y = *m_x;
     238       81139 :     Matrix::add(*m_y, -gamma, *m_gradfx, 1.0);
     239             : 
     240       81139 :     m_gamma = gamma;
     241       81139 :     m_status = STATUS_FORWARD;
     242             : 
     243       81139 :     return ForBESUtils::STATUS_OK;
     244             : }
     245             : 
     246      108465 : int FBCache::update_forward_backward_step(double gamma) {
     247             :     int status;
     248      108465 :     if (!is_close(gamma, m_gamma)) {
     249           0 :         reset(STATUS_EVALF);
     250             :     }
     251      108465 :     if (m_status >= STATUS_FORWARDBACKWARD) {
     252       27327 :         return ForBESUtils::STATUS_OK;
     253             :     }
     254       81138 :     if (m_status < STATUS_FORWARD) {
     255       81126 :         status = update_forward_step(gamma);
     256       81126 :         if (!ForBESUtils::is_status_ok(status)) {
     257           0 :             return status;
     258             :         }
     259             :     }
     260             : 
     261       81138 :     status = m_prob.g()->callProx(*m_y, gamma, *m_z, m_gz);
     262       81138 :     if (!ForBESUtils::is_status_ok(status)) {
     263           0 :         return status;
     264             :     }
     265       81138 :     *m_FPRx = (*m_x - *m_z);
     266       81138 :     m_sqnormFPRx = std::pow(m_FPRx->norm_fro_sq(), 2);
     267       81138 :     m_gamma = gamma;
     268       81138 :     m_status = STATUS_FORWARDBACKWARD;
     269             : 
     270       81138 :     return ForBESUtils::STATUS_OK;
     271             : }
     272             : 
     273          12 : int FBCache::update_eval_FBE(double gamma) {
     274          12 :     if (!is_close(gamma, m_gamma)) {
     275           0 :         reset(STATUS_EVALF);
     276             :     }
     277             : 
     278          12 :     if (m_status >= STATUS_FBE) {
     279           0 :         return ForBESUtils::STATUS_OK;
     280             :     }
     281             : 
     282          12 :     if (m_status < STATUS_FORWARDBACKWARD) {
     283           0 :         int status = update_forward_backward_step(gamma);
     284           0 :         if (!ForBESUtils::is_status_ok(status)) {
     285           0 :             return status;
     286             :         }
     287             :     }
     288             : 
     289          12 :     Matrix innprox_mat = (*m_FPRx) * (*m_gradfx);
     290          12 :     double innprod = innprox_mat[0];
     291             : 
     292          12 :     m_FBEx = m_fx + m_gz - innprod + 0.5 / m_gamma*m_sqnormFPRx;
     293          12 :     m_gamma = gamma;
     294          12 :     m_status = STATUS_FBE;
     295             : 
     296          12 :     return ForBESUtils::STATUS_OK;
     297             : }
     298             : 
     299          12 : int FBCache::update_grad_FBE(double gamma) {
     300          12 :     if (!is_close(gamma, m_gamma)) {
     301           0 :         reset(STATUS_EVALF);
     302             :     }
     303             : 
     304          12 :     if (m_status >= STATUS_GRAD_FBE) {
     305           0 :         return ForBESUtils::STATUS_OK;
     306             :     }
     307             : 
     308          12 :     if (m_status < STATUS_FORWARDBACKWARD) {
     309           0 :         int status = update_forward_backward_step(gamma);
     310           0 :         if (!ForBESUtils::is_status_ok(status)) {
     311           0 :             return status;
     312             :         }
     313             :     }
     314             : 
     315          12 :     *m_gradFBEx = *m_FPRx;
     316             : 
     317          12 :     if (m_prob.f1() != NULL) {
     318           8 :         if (m_prob.L1() != NULL) {
     319           4 :             Matrix v1(m_prob.L1()->dimensionOut());
     320           4 :             v1 = m_prob.L1()->call(*m_FPRx);
     321           8 :             Matrix v2(m_prob.L1()->dimensionOut());
     322           4 :             m_prob.f1()->hessianProduct(*m_res1x, v1, v2);
     323           8 :             Matrix v3 = m_prob.L1()->callAdjoint(v2);
     324           8 :             Matrix::add(*m_gradFBEx, -1.0, v3, 1.0 / gamma);
     325             :         } else {
     326           4 :             Matrix v1(m_x->getNrows(), m_x->getNcols());
     327           4 :             m_prob.f1()->hessianProduct(*m_x, *m_FPRx, v1);
     328           4 :             Matrix::add(*m_gradFBEx, -1.0, v1, 1.0 / gamma);
     329             :         }
     330             :     }
     331             : 
     332          12 :     if (m_prob.f2() != NULL) {
     333           4 :         if (m_prob.L2() != NULL) {
     334           2 :             Matrix v1(m_prob.L2()->dimensionOut());
     335           2 :             v1 = m_prob.L2()->call(*m_FPRx);
     336           4 :             Matrix v2(m_prob.L2()->dimensionOut());
     337           2 :             m_prob.f2()->hessianProduct(*m_res2x, v1, v2);
     338           4 :             Matrix v3 = m_prob.L2()->callAdjoint(v2);
     339           2 :             if (m_prob.f1() != NULL) Matrix::add(*m_gradFBEx, -1.0, v3, 1.0);
     340           4 :             else Matrix::add(*m_gradFBEx, -1.0, v3, 1.0 / gamma);
     341             :         } else {
     342           2 :             Matrix v1(m_x->getNrows(), m_x->getNcols());
     343           2 :             m_prob.f2()->hessianProduct(*m_x, *m_FPRx, v1);
     344           2 :             if (m_prob.f1() != NULL) Matrix::add(*m_gradFBEx, -1.0, v1, 1.0);
     345           2 :             else Matrix::add(*m_gradFBEx, -1.0, v1, 1.0 / gamma);
     346             :         }
     347             :     }
     348             : 
     349          12 :     m_gamma = gamma;
     350          12 :     m_status = STATUS_GRAD_FBE;
     351             : 
     352          12 :     return ForBESUtils::STATUS_OK;
     353             : }
     354             : 
     355       53326 : void FBCache::set_point(Matrix& x) {
     356       53326 :     *m_x = x;
     357       53326 :     reset(STATUS_NONE);
     358       53326 : }
     359             : 
     360      108525 : Matrix * FBCache::get_point() {
     361      108525 :     return m_x;
     362             : }
     363             : 
     364          12 : double FBCache::get_eval_FBE(double gamma) {
     365          12 :     update_eval_FBE(gamma);
     366          12 :     return m_FBEx;
     367             : }
     368             : 
     369          12 : Matrix * FBCache::get_grad_FBE(double gamma) {
     370          12 :     update_grad_FBE(gamma);
     371          12 :     return m_gradFBEx;
     372             : }
     373             : 
     374           6 : double FBCache::get_eval_f() {
     375           6 :     update_eval_f(false);
     376           6 :     return m_fx;
     377             : }
     378             : 
     379          14 : Matrix* FBCache::get_forward_step(double gamma) {
     380          14 :     update_forward_step(gamma);
     381          14 :     return m_y;
     382             : }
     383             : 
     384       54238 : Matrix* FBCache::get_forward_backward_step(double gamma) {
     385       54238 :     update_forward_backward_step(gamma);
     386       54238 :     return m_z;
     387             : }
     388             : 
     389           1 : Matrix* FBCache::get_fpr() {
     390           1 :     update_forward_backward_step(m_gamma);
     391           1 :     return m_FPRx;
     392             : }
     393             : 
     394       54226 : double FBCache::get_norm_fpr() {
     395       54226 :     update_forward_backward_step(m_gamma);
     396       54226 :     return sqrt(m_sqnormFPRx);
     397             : }
     398             : 
     399         910 : FBCache::~FBCache() {
     400         905 :     if (m_z != NULL) {
     401         905 :         delete m_z;
     402         905 :         m_z = NULL;
     403             :     }
     404         905 :     if (m_y != NULL) {
     405         905 :         delete m_y;
     406         905 :         m_y = NULL;
     407             :     }
     408         905 :     if (m_res2x != NULL) {
     409         202 :         delete m_res2x;
     410         202 :         m_res2x = NULL;
     411             :     }
     412         905 :     if (m_gradf2x != NULL) {
     413         202 :         delete m_gradf2x;
     414         202 :         m_gradf2x = NULL;
     415             :     }
     416         905 :     if (m_res1x != NULL) {
     417         703 :         delete m_res1x;
     418         703 :         m_res1x = NULL;
     419             :     }
     420         905 :     if (m_gradf1x != NULL) {
     421         703 :         delete m_gradf1x;
     422         703 :         m_gradf1x = NULL;
     423             :     }
     424         905 :     if (m_gradfx != NULL) {
     425         905 :         delete m_gradfx;
     426         905 :         m_gradfx = NULL;
     427             :     }
     428         905 :     if (m_FPRx != NULL) {
     429         905 :         delete m_FPRx;
     430         905 :         m_FPRx = NULL;
     431             :     }
     432         905 :     if (m_gradFBEx != NULL) {
     433         905 :         delete m_gradFBEx;
     434         905 :         m_gradFBEx = NULL;
     435             :     }
     436         922 : }

Generated by: LCOV version 1.10