Line data Source code
1 : /*
2 : * File: TestConjugateFunction.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on Nov 7, 2015, 4:10:56 PM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "TestConjugateFunction.h"
22 : #include "ConjugateFunction.h"
23 : #include <cmath>
24 :
25 : #define FORBES_TEST_UTILS
26 : #include "ForBES.h"
27 :
28 1 : CPPUNIT_TEST_SUITE_REGISTRATION(TestConjugateFunction);
29 :
30 7 : TestConjugateFunction::TestConjugateFunction() {
31 7 : }
32 :
33 14 : TestConjugateFunction::~TestConjugateFunction() {
34 14 : }
35 :
36 7 : void TestConjugateFunction::setUp() {
37 7 : }
38 :
39 7 : void TestConjugateFunction::tearDown() {
40 7 : }
41 :
42 1 : void TestConjugateFunction::testCall() {
43 1 : Function *f = new QuadraticLoss(1.55628554);
44 1 : Function *f_conj = new ConjugateFunction(*f);
45 :
46 1 : Matrix x = MatrixFactory::MakeRandomMatrix(10, 1, 3.0, 5.0);
47 :
48 : double f_star;
49 : double f_conj_val;
50 :
51 : int status;
52 :
53 1 : _ASSERT(f->category().defines_f());
54 :
55 1 : status = f->callConj(x, f_star);
56 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
57 :
58 1 : status = f_conj->call(x, f_conj_val);
59 1 : _ASSERT(f_conj->category().defines_f());
60 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
61 :
62 1 : _ASSERT_NUM_EQ(f_star, f_conj_val, 1e-8);
63 :
64 1 : delete f;
65 1 : delete f_conj;
66 1 : }
67 :
68 1 : void TestConjugateFunction::testCall2() {
69 :
70 1 : int n = 8;
71 1 : int s = 4;
72 1 : Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
73 2 : Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
74 :
75 2 : Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
76 2 : Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
77 :
78 : QuadOverAffine * qoa;
79 1 : qoa = new QuadOverAffine(Q, q, A, b);
80 1 : _ASSERT_NEQ(NULL, qoa);
81 :
82 2 : Matrix y = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
83 1 : double fstar = 0.0;
84 2 : Matrix grad;
85 1 : int status = qoa->callConj(y, fstar, grad);
86 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
87 1 : _ASSERT_NOT(std::abs(fstar) < 1e-7);
88 :
89 1 : Function * f_conjugate = new ConjugateFunction(*qoa);
90 1 : double fstar2 = 0.0;
91 2 : Matrix grad2;
92 1 : _ASSERT(f_conjugate->category().defines_grad());
93 1 : f_conjugate -> call(y, fstar2, grad2);
94 :
95 1 : _ASSERT_NUM_EQ(fstar, fstar2, 1e-8);
96 1 : _ASSERT_EQ(grad, grad2);
97 :
98 1 : _ASSERT_OK(delete qoa);
99 2 : _ASSERT_OK(delete f_conjugate);
100 :
101 :
102 1 : }
103 :
104 1 : void TestConjugateFunction::testCallConj() {
105 1 : size_t n = 10;
106 1 : size_t nnz_Q = 20;
107 1 : Matrix Qsp = MatrixFactory::MakeRandomSparse(n, n, nnz_Q, 0.0, 1.0);
108 :
109 1 : Function *F = new Quadratic(Qsp);
110 :
111 2 : Matrix x = MatrixFactory::MakeRandomMatrix(n, 1, 3.0, 1.5, Matrix::MATRIX_DENSE);
112 1 : double fval = -1.0;
113 1 : double fval2 = -1.0;
114 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, F->call(x, fval));
115 1 : _ASSERT(fval > 0);
116 :
117 1 : double f_exp = Qsp.quad(x);
118 1 : const double tol = 1e-10;
119 1 : _ASSERT_NUM_EQ(f_exp, fval, tol);
120 :
121 :
122 1 : Function * F_conj = new ConjugateFunction(*F);
123 1 : _ASSERT(F_conj->category().defines_conjugate());
124 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, F_conj->callConj(x, fval2));
125 1 : _ASSERT(fval2 > 0);
126 1 : _ASSERT_NUM_EQ(fval, fval2, tol);
127 :
128 1 : _ASSERT_OK(delete F);
129 2 : _ASSERT_OK(delete F_conj);
130 1 : }
131 :
132 1 : void TestConjugateFunction::testCallConj2() {
133 1 : const size_t n = 10;
134 1 : Matrix x = MatrixFactory::MakeRandomMatrix(n, 1, 0.5, 2.0);
135 1 : const double delta = 0.2;
136 1 : Function * huber = new HuberLoss(delta);
137 :
138 : double f;
139 : double f2;
140 2 : Matrix grad(n, 1);
141 2 : Matrix grad2(n, 1);
142 :
143 1 : _ASSERT(huber->category().defines_f());
144 1 : int status = huber->call(x, f, grad);
145 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
146 :
147 1 : Function * huber_conj = new ConjugateFunction(*huber);
148 1 : _ASSERT(huber_conj->category().defines_conjugate());
149 1 : _ASSERT(huber_conj->category().defines_conjugate_grad());
150 1 : status = huber_conj -> callConj(x, f2, grad2);
151 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
152 :
153 1 : const double tol = 1e-9;
154 2 : _ASSERT_NUM_EQ(f, f2, tol);
155 1 : }
156 :
157 1 : void TestConjugateFunction::testCallProx() {
158 1 : Function * elastic = new ElasticNet(2.5, 1.3);
159 1 : Function * elastic_conj = new ConjugateFunction(*elastic);
160 :
161 1 : const size_t n = 9;
162 1 : double xdata[n] = {-1.0, -3.0, 7.5, 2.0, -1.0, -1.0, 5.0, 2.0, -5.0};
163 1 : const double gamma = 1.6;
164 1 : const double prox_expected_data[n] = {0.0, -0.1840, 1.0840, 0.0, 0.0, 0.0, 0.5840, 0.0, -0.5840};
165 :
166 1 : Matrix x(n, 1, xdata);
167 2 : Matrix prox_expected(n, 1, prox_expected_data);
168 2 : Matrix prox(n, 1);
169 :
170 : double f_at_prox;
171 1 : const double f_at_prox_expected = 5.5305800;
172 1 : const double tol = 1e-12;
173 1 : _ASSERT(elastic->category().defines_prox());
174 1 : int status = elastic->callProx(x, gamma, prox, f_at_prox);
175 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
176 1 : _ASSERT_NUM_EQ(f_at_prox_expected, f_at_prox, tol);
177 1 : _ASSERT_EQ(prox_expected, prox);
178 :
179 1 : status = elastic->callProx(x, gamma, prox);
180 1 : _ASSERT_EQ(prox_expected, prox);
181 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
182 :
183 2 : Matrix prox_conj(n, 1);
184 1 : _ASSERT(elastic_conj->category().defines_prox());
185 1 : status = elastic_conj -> callProx(x, gamma, prox_conj);
186 1 : _ASSERT_EQ(ForBESUtils::STATUS_OK, status);
187 :
188 :
189 :
190 1 : delete elastic;
191 2 : delete elastic_conj;
192 1 : }
193 :
194 1 : void TestConjugateFunction::testCategory() {
195 1 : int n = 8;
196 1 : int s = 4;
197 1 : Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
198 2 : Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
199 :
200 2 : Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
201 2 : Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
202 :
203 1 : QuadOverAffine * qoa = new QuadOverAffine(Q, q, A, b);
204 1 : ConjugateFunction * qoa_conj = new ConjugateFunction(*qoa);
205 :
206 :
207 2 : FunctionOntologicalClass meta = qoa_conj->category();
208 1 : _ASSERT(meta.defines_f());
209 1 : _ASSERT(meta.defines_grad());
210 1 : _ASSERT_NOT(meta.defines_prox());
211 :
212 1 : delete qoa;
213 2 : delete qoa_conj;
214 1 : }
215 :
216 1 : void TestConjugateFunction::testCategory2() {
217 1 : Quadratic f;
218 2 : ConjugateFunction f_star(f);
219 1 : _ASSERT(f_star.category().is_conjugate_quadratic());
220 :
221 : // and now the converse...
222 1 : int n = 8;
223 1 : int s = 4;
224 2 : Matrix Q = MatrixFactory::MakeRandomMatrix(n, n, 0.0, 1.0, Matrix::MATRIX_DENSE);
225 2 : Matrix A = MatrixFactory::MakeRandomMatrix(s, n, 0.0, -5.0, Matrix::MATRIX_DENSE);
226 :
227 2 : Matrix q = MatrixFactory::MakeRandomMatrix(n, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
228 2 : Matrix b = MatrixFactory::MakeRandomMatrix(s, 1, 0.0, 1.0, Matrix::MATRIX_DENSE);
229 :
230 2 : QuadOverAffine g(Q, q, A, b);
231 2 : ConjugateFunction g_star(g);
232 2 : _ASSERT(g_star.category().is_quadratic());
233 4 : }
234 :
|