LCOV - code coverage report
Current view: top level - source/tests - TestConjugateFunction.cpp (source / functions) Hit Total Coverage
Test: LibForBES Unit Tests Lines: 144 144 100.0 %
Date: 2016-04-18 Functions: 14 14 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :  * File:   TestConjugateFunction.cpp
       3             :  * Author: Pantelis Sopasakis
       4             :  *
       5             :  * Created on Nov 7, 2015, 4:10:56 PM
       6             :  * 
       7             :  * ForBES is free software: you can redistribute it and/or modify
       8             :  * it under the terms of the GNU Lesser General Public License as published by
       9             :  * the Free Software Foundation, either version 3 of the License, or
      10             :  * (at your option) any later version.
      11             :  *  
      12             :  * ForBES is distributed in the hope that it will be useful,
      13             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      14             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
      15             :  * GNU Lesser General Public License for more details.
      16             :  * 
      17             :  * You should have received a copy of the GNU Lesser General Public License
      18             :  * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
      19             :  */
      20             : 
      21             : #include "TestConjugateFunction.h"
      22             : #include "ConjugateFunction.h"
      23             : #include <cmath>
      24             : 
      25             : #define FORBES_TEST_UTILS
      26             : #include "ForBES.h"
      27             : 
      28           1 : CPPUNIT_TEST_SUITE_REGISTRATION(TestConjugateFunction);
      29             : 
      30           7 : TestConjugateFunction::TestConjugateFunction() {
      31           7 : }
      32             : 
      33          14 : TestConjugateFunction::~TestConjugateFunction() {
      34          14 : }
      35             : 
      36           7 : void TestConjugateFunction::setUp() {
      37           7 : }
      38             : 
      39           7 : void TestConjugateFunction::tearDown() {
      40           7 : }
      41             : 
      42           1 : void TestConjugateFunction::testCall() {
      43           1 :     Function *f = new QuadraticLoss(1.55628554);
      44           1 :     Function *f_conj = new ConjugateFunction(*f);
      45             : 
      46           1 :     Matrix x = MatrixFactory::MakeRandomMatrix(10, 1, 3.0, 5.0);
      47             : 
      48             :     double f_star;
      49             :     double f_conj_val;
      50             : 
      51             :     int status;
      52             : 
      53           1 :     _ASSERT(f->category().defines_f());
      54             : 
      55           1 :     status = f->callConj(x, f_star);
      56           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
      57             : 
      58           1 :     status = f_conj->call(x, f_conj_val);
      59           1 :     _ASSERT(f_conj->category().defines_f());
      60           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
      61             : 
      62           1 :     _ASSERT_NUM_EQ(f_star, f_conj_val, 1e-8);
      63             : 
      64           1 :     delete f;
      65           1 :     delete f_conj;
      66           1 : }
      67             : 
      68           1 : void TestConjugateFunction::testCall2() {
      69             : 
      70           1 :     int n = 8;
      71           1 :     int s = 4;
      72           1 :     Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
      73           2 :     Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
      74             : 
      75           2 :     Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
      76           2 :     Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
      77             : 
      78             :     QuadOverAffine * qoa;
      79           1 :     qoa = new QuadOverAffine(Q, q, A, b);
      80           1 :     _ASSERT_NEQ(NULL, qoa);
      81             : 
      82           2 :     Matrix y = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
      83           1 :     double fstar = 0.0;
      84           2 :     Matrix grad;
      85           1 :     int status = qoa->callConj(y, fstar, grad);
      86           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
      87           1 :     _ASSERT_NOT(std::abs(fstar) < 1e-7);
      88             : 
      89           1 :     Function * f_conjugate = new ConjugateFunction(*qoa);
      90           1 :     double fstar2 = 0.0;
      91           2 :     Matrix grad2;
      92           1 :     _ASSERT(f_conjugate->category().defines_grad());
      93           1 :     f_conjugate -> call(y, fstar2, grad2);
      94             : 
      95           1 :     _ASSERT_NUM_EQ(fstar, fstar2, 1e-8);
      96           1 :     _ASSERT_EQ(grad, grad2);
      97             : 
      98           1 :     _ASSERT_OK(delete qoa);
      99           2 :     _ASSERT_OK(delete f_conjugate);
     100             : 
     101             : 
     102           1 : }
     103             : 
     104           1 : void TestConjugateFunction::testCallConj() {
     105           1 :     size_t n = 10;
     106           1 :     size_t nnz_Q = 20;
     107           1 :     Matrix Qsp = MatrixFactory::MakeRandomSparse(n, n, nnz_Q, 0.0, 1.0);
     108             : 
     109           1 :     Function *F = new Quadratic(Qsp);
     110             : 
     111           2 :     Matrix x = MatrixFactory::MakeRandomMatrix(n, 1, 3.0, 1.5, Matrix::MATRIX_DENSE);
     112           1 :     double fval = -1.0;
     113           1 :     double fval2 = -1.0;
     114           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, F->call(x, fval));
     115           1 :     _ASSERT(fval > 0);
     116             : 
     117           1 :     double f_exp = Qsp.quad(x);
     118           1 :     const double tol = 1e-10;
     119           1 :     _ASSERT_NUM_EQ(f_exp, fval, tol);
     120             : 
     121             : 
     122           1 :     Function * F_conj = new ConjugateFunction(*F);
     123           1 :     _ASSERT(F_conj->category().defines_conjugate());
     124           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, F_conj->callConj(x, fval2));
     125           1 :     _ASSERT(fval2 > 0);
     126           1 :     _ASSERT_NUM_EQ(fval, fval2, tol);
     127             : 
     128           1 :     _ASSERT_OK(delete F);
     129           2 :     _ASSERT_OK(delete F_conj);
     130           1 : }
     131             : 
     132           1 : void TestConjugateFunction::testCallConj2() {
     133           1 :     const size_t n = 10;
     134           1 :     Matrix x = MatrixFactory::MakeRandomMatrix(n, 1, 0.5, 2.0);
     135           1 :     const double delta = 0.2;
     136           1 :     Function * huber = new HuberLoss(delta);
     137             : 
     138             :     double f;
     139             :     double f2;
     140           2 :     Matrix grad(n, 1);
     141           2 :     Matrix grad2(n, 1);
     142             : 
     143           1 :     _ASSERT(huber->category().defines_f());
     144           1 :     int status = huber->call(x, f, grad);
     145           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
     146             : 
     147           1 :     Function * huber_conj = new ConjugateFunction(*huber);
     148           1 :     _ASSERT(huber_conj->category().defines_conjugate());
     149           1 :     _ASSERT(huber_conj->category().defines_conjugate_grad());
     150           1 :     status = huber_conj -> callConj(x, f2, grad2);
     151           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
     152             : 
     153           1 :     const double tol = 1e-9;
     154           2 :     _ASSERT_NUM_EQ(f, f2, tol);
     155           1 : }
     156             : 
     157           1 : void TestConjugateFunction::testCallProx() {
     158           1 :     Function * elastic = new ElasticNet(2.5, 1.3);
     159           1 :     Function * elastic_conj = new ConjugateFunction(*elastic);
     160             : 
     161           1 :     const size_t n = 9;
     162           1 :     double xdata[n] = {-1.0, -3.0, 7.5, 2.0, -1.0, -1.0, 5.0, 2.0, -5.0};
     163           1 :     const double gamma = 1.6;
     164           1 :     const double prox_expected_data[n] = {0.0, -0.1840, 1.0840, 0.0, 0.0, 0.0, 0.5840, 0.0, -0.5840};
     165             : 
     166           1 :     Matrix x(n, 1, xdata);
     167           2 :     Matrix prox_expected(n, 1, prox_expected_data);
     168           2 :     Matrix prox(n, 1);
     169             : 
     170             :     double f_at_prox;
     171           1 :     const double f_at_prox_expected = 5.5305800;
     172           1 :     const double tol = 1e-12;
     173           1 :     _ASSERT(elastic->category().defines_prox());
     174           1 :     int status = elastic->callProx(x, gamma, prox, f_at_prox);
     175           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
     176           1 :     _ASSERT_NUM_EQ(f_at_prox_expected, f_at_prox, tol);
     177           1 :     _ASSERT_EQ(prox_expected, prox);
     178             : 
     179           1 :     status = elastic->callProx(x, gamma, prox);
     180           1 :     _ASSERT_EQ(prox_expected, prox);
     181           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
     182             : 
     183           2 :     Matrix prox_conj(n, 1);
     184           1 :     _ASSERT(elastic_conj->category().defines_prox());
     185           1 :     status = elastic_conj -> callProx(x, gamma, prox_conj);
     186           1 :     _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
     187             : 
     188             : 
     189             : 
     190           1 :     delete elastic;
     191           2 :     delete elastic_conj;
     192           1 : }
     193             : 
     194           1 : void TestConjugateFunction::testCategory() {
     195           1 :     int n = 8;
     196           1 :     int s = 4;
     197           1 :     Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
     198           2 :     Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
     199             : 
     200           2 :     Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
     201           2 :     Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
     202             : 
     203           1 :     QuadOverAffine * qoa = new QuadOverAffine(Q, q, A, b);
     204           1 :     ConjugateFunction * qoa_conj = new ConjugateFunction(*qoa);
     205             : 
     206             : 
     207           2 :     FunctionOntologicalClass meta = qoa_conj->category();
     208           1 :     _ASSERT(meta.defines_f());
     209           1 :     _ASSERT(meta.defines_grad());
     210           1 :     _ASSERT_NOT(meta.defines_prox());
     211             : 
     212           1 :     delete qoa;
     213           2 :     delete qoa_conj;
     214           1 : }
     215             : 
     216           1 : void TestConjugateFunction::testCategory2() {
     217           1 :     Quadratic f;
     218           2 :     ConjugateFunction f_star(f);
     219           1 :     _ASSERT(f_star.category().is_conjugate_quadratic());
     220             : 
     221             :     // and now the converse...
     222           1 :     int n = 8;
     223           1 :     int s = 4;
     224           2 :     Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
     225           2 :     Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
     226             : 
     227           2 :     Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
     228           2 :     Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
     229             : 
     230           2 :     QuadOverAffine g(Q, q, A, b);
     231           2 :     ConjugateFunction g_star(g);
     232           2 :     _ASSERT(g_star.category().is_quadratic());
     233           4 : }
     234             : 

Generated by: LCOV version 1.10