LCOV - code coverage report
Current view: top level - source - Quadratic.cpp (source / functions) Hit Total Coverage
Test: LibForBES Unit Tests Lines: 107 109 98.2 %
Date: 2016-04-18 Functions: 17 17 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /* 
       2             :  * File:   Quadratic.cpp
       3             :  * Author: Pantelis Sopasakis
       4             :  * 
       5             :  * Created on July 9, 2015, 3:36 AM
       6             :  * 
       7             :  * ForBES is free software: you can redistribute it and/or modify
       8             :  * it under the terms of the GNU Lesser General Public License as published by
       9             :  * the Free Software Foundation, either version 3 of the License, or
      10             :  * (at your option) any later version.
      11             :  *  
      12             :  * ForBES is distributed in the hope that it will be useful,
      13             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      14             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
      15             :  * GNU Lesser General Public License for more details.
      16             :  * 
      17             :  * You should have received a copy of the GNU Lesser General Public License
      18             :  * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
      19             :  */
      20             : 
      21             : #include "Quadratic.h"
      22             : #include "MatrixFactory.h"
      23             : #include "CGSolver.h"
      24             : 
      25             : using namespace std;
      26             : 
      27           7 : Quadratic::Quadratic() {
      28           7 :     m_is_Q_eye = true;
      29           7 :     m_is_q_zero = true;
      30           7 :     m_solver = NULL;
      31           7 :     m_Q = NULL;
      32           7 :     m_q = NULL;
      33           7 : }
      34             : 
      35          12 : Quadratic::Quadratic(Matrix& QQ) {
      36          12 :     m_Q = &QQ;
      37          12 :     m_is_Q_eye = false;
      38          12 :     m_is_q_zero = true;
      39          12 :     m_solver = NULL;
      40          12 :     m_q = NULL;
      41          12 : }
      42             : 
      43          11 : Quadratic::Quadratic(Matrix& QQ, Matrix& qq) {
      44          11 :     m_q = &qq;
      45          11 :     m_Q = &QQ;
      46          11 :     m_solver = NULL;
      47          11 :     m_is_Q_eye = false;
      48          11 :     m_is_q_zero = false;
      49          11 : }
      50             : 
      51          83 : Quadratic::~Quadratic() {
      52          30 :     if (m_solver != NULL) {
      53           3 :         delete m_solver;
      54             :     }
      55          53 : }
      56             : 
      57           1 : void Quadratic::setQ(Matrix& Q) {
      58           1 :     m_is_Q_eye = false;
      59           1 :     this->m_Q = &Q;
      60           1 :     this->m_solver = NULL;
      61           1 : }
      62             : 
      63           3 : void Quadratic::setq(Matrix& q) {
      64           3 :     m_is_q_zero = false;
      65           3 :     this->m_q = &q;
      66           3 : }
      67             : 
      68          11 : int Quadratic::call(Matrix& x, double& f) {
      69          11 :     if (!m_is_Q_eye) {
      70          10 :         if (m_is_q_zero) {
      71           7 :             f = m_Q->quad(x);
      72             :         } else {
      73           3 :             f = m_Q->quad(x, *m_q);
      74             :         }
      75             :     } else {
      76           1 :         f = static_cast<Matrix> (x * x)[0];
      77           1 :         if (!m_is_q_zero) {
      78           0 :             f += static_cast<Matrix> ((*m_q) * x)[0];
      79             :         }
      80           1 :         f = f / 2.0;
      81             :     }
      82          11 :     return ForBESUtils::STATUS_OK;
      83             : }
      84             : 
      85       23764 : int Quadratic::call(Matrix& x, double& f, Matrix& grad) {
      86       23764 :     int statusComputeGrad = computeGradient(x, grad); // compute the gradient of f at x (grad)
      87       23764 :     if (statusComputeGrad != ForBESUtils::STATUS_OK) {
      88           0 :         return statusComputeGrad;
      89             :     }
      90             :     // f = (1/2)*(grad+q)'*x
      91       23764 :     f = static_cast<Matrix> ((m_is_q_zero ? grad : grad + (*m_q)) * x)[0] / 2;
      92       23764 :     return ForBESUtils::STATUS_OK;
      93             : }
      94             : 
      95           7 : int Quadratic::hessianProduct(Matrix& x, Matrix& z, Matrix& Hz) {
      96           7 :     Hz = (*m_Q) * z;
      97           7 :     return ForBESUtils::STATUS_OK;
      98             : }
      99             : 
     100           7 : int Quadratic::callConj(Matrix& y, double& f_star) {
     101           7 :     Matrix g;
     102           7 :     int status = callConj(y, f_star, g);
     103           7 :     return status;
     104             : }
     105             : 
     106           9 : int Quadratic::callConj(Matrix& y, double& f_star, Matrix& g) {
     107           9 :     Matrix z = (m_is_q_zero || m_q == NULL) ? y : y - *m_q; // z = y    
     108           9 :     if (m_is_Q_eye || m_Q == NULL) {
     109           1 :         g = z;
     110           1 :         f_star = static_cast<Matrix> (z * z)[0];
     111           1 :         return ForBESUtils::STATUS_OK;
     112             :     }
     113           8 :     if (m_Q != NULL && Matrix::MATRIX_DIAGONAL == m_Q->getType()) {
     114             :         /* Q is diagonal */
     115           1 :         g = z;
     116           1 :         f_star = 0.0;
     117           5 :         for (size_t i = 0; i < z.getNrows(); i++) {
     118           4 :             g[i] /= m_Q->get(i, i);
     119           4 :             f_star += z[i] * g[i];
     120             :         }
     121           1 :         return ForBESUtils::STATUS_OK;
     122             :     }
     123             : 
     124           7 :     if (m_solver == NULL) {
     125           4 :         m_solver = new CholeskyFactorization(*m_Q);
     126           4 :         int status = m_solver->factorize();
     127           4 :         if (0 != status) {
     128           1 :             return ForBESUtils::STATUS_NUMERICAL_PROBLEMS;
     129             :         }
     130             :     }
     131             : 
     132           6 :     m_solver->solve(z, g); // Q*g = z   OR  g = Q \ z
     133           6 :     f_star = static_cast<Matrix> (z * g)[0]; // fstar = z' *g 
     134           6 :     return ForBESUtils::STATUS_OK;
     135             : }
     136             : 
     137       23764 : int Quadratic::computeGradient(Matrix& x, Matrix& grad) {
     138       23764 :     if (m_is_Q_eye) {
     139         352 :         grad = x;
     140             :     } else {
     141       23412 :         grad = (*m_Q) * x;
     142             :     }
     143       23764 :     if (!m_is_q_zero) {
     144       23411 :         grad += *m_q;
     145             :     }
     146       23764 :     return ForBESUtils::STATUS_OK;
     147             : }
     148             : 
     149          14 : FunctionOntologicalClass Quadratic::category() {
     150          14 :     return FunctionOntologyRegistry::quadratic();
     151             : }
     152             : 
     153           2 : int Quadratic::callProx(Matrix& v, double gamma, Matrix& prox) {
     154             : 
     155           2 :     Matrix v_gamma_b = v;
     156           2 :     if (!m_is_q_zero) Matrix::add(v_gamma_b, -gamma, *m_q, 1.0);
     157             : 
     158             : 
     159             : 
     160             :     // (I+gamma Q)^{-1}(v-gamma q)
     161             :     int status;
     162           2 :     if (!m_is_Q_eye) {
     163             :         // If Q is not I, we need to create a CGSolver for (I + gamma Q)
     164           1 :         Matrix Q_tilde(*m_Q);
     165           1 :         size_t n = m_Q->getNrows();
     166           1 :         static Matrix Eye = MatrixFactory::MakeIdentity(n, 1.0);
     167           1 :         Q_tilde *= gamma;
     168           1 :         Q_tilde += Eye;
     169             : 
     170           2 :         MatrixOperator Q_tilde_op(Q_tilde);
     171           2 :         Matrix P(n, n, Matrix::MATRIX_DIAGONAL);
     172           4 :         for (size_t i = 0; i < n; i++) {
     173           3 :             P[i] = 1 / Q_tilde.get(i, i);
     174             :         }
     175           2 :         MatrixOperator P_op(P);
     176           2 :         CGSolver solver(Q_tilde_op, P_op, 1e-6, 1500);
     177           1 :         status = solver.solve(v_gamma_b, prox);
     178           2 :         if (!ForBESUtils::is_status_ok(status)) return status;
     179             :     } else {
     180             :         // Q = I
     181             :         // (v-gamma q)/(1+gamma)
     182           1 :         v_gamma_b *= (1. / (1. + gamma));
     183           1 :         prox = v_gamma_b;
     184           1 :         return ForBESUtils::STATUS_OK;
     185             :     }
     186           1 :     return ForBESUtils::STATUS_OK;
     187             : 
     188          27 : }
     189             : 
     190             : 

Generated by: LCOV version 1.10