Line data Source code
1 : /*
2 : * File: FBCache.cpp
3 : * Author: Lorenzo Stella, Pantelis Sopasakis
4 : *
5 : * Created on October 2, 2015
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "FBCache.h"
22 : #include "LinearOperator.h"
23 :
24 : #include <cmath>
25 : #include <limits>
26 : #include <complex>
27 :
28 : #define FB_CACHE_RELATIVE_TOL 1e-6
29 : #define FB_CACHE_ABSOLUTE_TOL 1e-14
30 :
31 : /**
32 : * Verify whether \c a and \c b are close to one another. Checks whether
33 : * \f[
34 : * |a-b|\leq \max (\epsilon_r \max(a,b), \epsilon_a),
35 : * \f]
36 : * where \f$\epsilon_r = 10^{-6}\f$ is the relative tolerance and \f$\epsilon_a=10^{-14}\f$
37 : * is the absolute tolerance.
38 : *
39 : *
40 : * @param a scalar
41 : * @param b another scalar
42 : * @return \c true iff \c a is approximately equal to \c b.
43 : */
44 189629 : inline bool is_close(const double a, const double b) {
45 379258 : return abs(a - b) <= std::max(
46 379258 : FB_CACHE_RELATIVE_TOL * std::max(std::abs(a), std::abs(b)),
47 758516 : FB_CACHE_ABSOLUTE_TOL);
48 : }
49 :
50 81145 : void FBCache::reset(int status) {
51 81145 : if (status < m_status) m_status = status;
52 81145 : }
53 :
54 26900 : void FBCache::reset() {
55 26900 : FBCache::reset(STATUS_NONE);
56 26900 : }
57 :
58 906 : FBCache::FBCache(FBProblem & p, Matrix & x, double gamma) :
59 : m_prob(p),
60 : m_x(&x),
61 906 : m_gamma(gamma) {
62 906 : reset(STATUS_NONE);
63 :
64 : // get dimensions of things
65 906 : size_t m_x_rows = m_x->getNrows();
66 906 : size_t m_x_cols = m_x->getNcols();
67 : size_t m_res1_rows, m_res1_cols;
68 : size_t m_res2_rows, m_res2_cols;
69 :
70 906 : if (m_prob.d1() != NULL) {
71 302 : m_res1_rows = m_prob.d1()->getNrows();
72 302 : m_res1_cols = m_prob.d1()->getNcols();
73 604 : } else if (m_prob.L1() != NULL) {
74 0 : m_res1_rows = m_prob.L1()->dimensionOut().first;
75 0 : m_res1_cols = m_prob.L1()->dimensionOut().second;
76 : } else {
77 604 : m_res1_rows = m_x_rows;
78 604 : m_res1_cols = m_x_cols;
79 : }
80 906 : if (m_prob.d2() != NULL) {
81 201 : m_res2_rows = m_prob.d2()->getNrows();
82 201 : m_res2_cols = m_prob.d2()->getNcols();
83 705 : } else if (m_prob.L2() != NULL) {
84 0 : m_res2_rows = m_prob.L2()->dimensionOut().first;
85 0 : m_res2_cols = m_prob.L2()->dimensionOut().second;
86 : } else {
87 705 : m_res2_rows = m_x_rows;
88 705 : m_res2_cols = m_x_cols;
89 : }
90 :
91 : // allocate memory for residuals and gradients (where needed)
92 906 : if (m_prob.f1() != NULL) {
93 704 : m_res1x = new Matrix(m_res1_rows, m_res1_cols);
94 704 : m_gradf1x = new Matrix(m_res1_rows, m_res1_cols);
95 : } else {
96 202 : m_res1x = NULL;
97 202 : m_gradf1x = NULL;
98 : }
99 906 : if (m_prob.f2() != NULL) {
100 202 : m_res2x = new Matrix(m_res2_rows, m_res2_cols);
101 202 : m_gradf2x = new Matrix(m_res2_rows, m_res2_cols);
102 : } else {
103 704 : m_res2x = NULL;
104 704 : m_gradf2x = NULL;
105 : }
106 :
107 906 : m_gradfx = new Matrix(m_x_rows, m_x_cols);
108 906 : m_z = new Matrix(m_x_rows, m_x_cols);
109 906 : m_y = new Matrix(m_x_rows, m_x_cols);
110 906 : m_FPRx = new Matrix(m_x_rows, m_x_cols);
111 906 : m_gradFBEx = new Matrix(m_x_rows, m_x_cols);
112 :
113 906 : m_FBEx = std::numeric_limits<double>::infinity();
114 906 : m_sqnormFPRx = std::numeric_limits<double>::infinity();
115 :
116 906 : m_f1x = 0.0;
117 906 : m_f2x = 0.0;
118 906 : m_linx = 0.0;
119 906 : m_fx = 0.0;
120 906 : m_gz = 0.0;
121 :
122 906 : m_cached_grad_f2 = false;
123 906 : }
124 :
125 81132 : int FBCache::update_eval_f(bool order_grad_f2) {
126 :
127 81132 : if (!m_cached_grad_f2 && order_grad_f2) {
128 : // If gradf2x has not been computed previously, but now should be
129 : // computed, then set the status to STATUS_NONE, so that all
130 : // computations are performed from the beginning (both f2x and gradf2x).
131 81126 : m_status = STATUS_NONE;
132 : }
133 :
134 81132 : if (m_status >= STATUS_EVALF) {
135 0 : return ForBESUtils::STATUS_OK;
136 : }
137 :
138 81132 : if (m_prob.f1() != NULL) {
139 61530 : if (m_prob.L1() != NULL) {
140 38127 : *m_res1x = m_prob.L1()->call(*m_x);
141 : } else {
142 23403 : *m_res1x = *m_x;
143 : }
144 61530 : if (m_prob.d1() != NULL) {
145 38127 : *m_res1x += *(m_prob.d1());
146 : }
147 61530 : int status = m_prob.f1()->call(*m_res1x, m_f1x, *m_gradf1x);
148 61530 : if (ForBESUtils::STATUS_OK != status) {
149 0 : return status;
150 : }
151 : }
152 :
153 81132 : if (m_prob.f2() != NULL) {
154 19602 : if (m_prob.L2() != NULL) {
155 19601 : *m_res2x = m_prob.L2()->call(*m_x);
156 : } else {
157 1 : *m_res2x = *m_x;
158 : }
159 19602 : if (m_prob.d2() != NULL) {
160 19601 : *m_res2x += *(m_prob.d2());
161 : }
162 : int status =
163 : order_grad_f2
164 19600 : ? m_prob.f2()->call(*m_res2x, m_f2x, *m_gradf2x)
165 39202 : : m_prob.f2()->call(*m_res2x, m_f2x);
166 19602 : m_cached_grad_f2 = order_grad_f2;
167 19602 : if (ForBESUtils::STATUS_OK != status) {
168 0 : return status;
169 : }
170 : }
171 :
172 81132 : if (m_prob.lin() != NULL) {
173 0 : m_linx = ((*m_prob.lin()) * (*m_x))[0];
174 : }
175 :
176 81132 : m_fx = m_f1x + m_f2x + m_linx;
177 81132 : m_status = STATUS_EVALF;
178 :
179 81132 : return ForBESUtils::STATUS_OK;
180 : }
181 :
182 81140 : int FBCache::update_forward_step(double gamma) {
183 81140 : bool is_gamma_the_same = is_close(gamma, m_gamma);
184 81140 : if (!is_gamma_the_same) reset(STATUS_EVALF);
185 :
186 :
187 81140 : if (m_status >= STATUS_FORWARD) {
188 1 : if (is_gamma_the_same) return ForBESUtils::STATUS_OK;
189 0 : *m_y = *m_x;
190 0 : Matrix::add(*m_y, -gamma, *m_gradfx, 1.0);
191 0 : m_gamma = gamma;
192 0 : return ForBESUtils::STATUS_OK;
193 : }
194 :
195 : int status;
196 :
197 81139 : if (m_status < STATUS_EVALF) {
198 81126 : m_cached_grad_f2 = false;
199 81126 : status = update_eval_f(true);
200 81126 : if (!ForBESUtils::is_status_ok(status)) return status;
201 : }
202 :
203 81139 : if (m_prob.f1() != NULL) {
204 61535 : if (m_prob.L1()) {
205 38129 : Matrix d_gradfx = m_prob.L1()->callAdjoint(*m_gradf1x);
206 38129 : *m_gradfx = d_gradfx;
207 : } else {
208 23406 : *m_gradfx = *m_gradf1x;
209 : }
210 : }
211 :
212 81139 : if (m_prob.f2() != NULL) {
213 19604 : if (!m_cached_grad_f2) {
214 2 : status = m_prob.f2()->call(*m_res2x, m_f2x, *m_gradf2x);
215 2 : if (!ForBESUtils::is_status_ok(status)) return status;
216 : // now gradf2x has been computed:
217 2 : m_cached_grad_f2 = true;
218 : }
219 19604 : if (m_prob.L2() != NULL) {
220 19602 : Matrix d_gradfx = m_prob.L2()->callAdjoint(*m_gradf2x);
221 19602 : if (m_prob.f1() != NULL) *m_gradfx += d_gradfx;
222 19602 : else *m_gradfx = d_gradfx;
223 : } else {
224 2 : if (m_prob.f1() != NULL) *m_gradfx += *m_gradf2x;
225 2 : else *m_gradfx = *m_gradf2x;
226 : }
227 : }
228 :
229 81139 : if (m_prob.lin()) {
230 0 : if (m_prob.f1() != NULL || m_prob.f2() != NULL) {
231 0 : *m_gradfx += (*m_prob.lin());
232 : } else {
233 0 : *m_gradfx = *m_prob.lin();
234 : }
235 : }
236 :
237 81139 : *m_y = *m_x;
238 81139 : Matrix::add(*m_y, -gamma, *m_gradfx, 1.0);
239 :
240 81139 : m_gamma = gamma;
241 81139 : m_status = STATUS_FORWARD;
242 :
243 81139 : return ForBESUtils::STATUS_OK;
244 : }
245 :
246 108465 : int FBCache::update_forward_backward_step(double gamma) {
247 : int status;
248 108465 : if (!is_close(gamma, m_gamma)) {
249 0 : reset(STATUS_EVALF);
250 : }
251 108465 : if (m_status >= STATUS_FORWARDBACKWARD) {
252 27327 : return ForBESUtils::STATUS_OK;
253 : }
254 81138 : if (m_status < STATUS_FORWARD) {
255 81126 : status = update_forward_step(gamma);
256 81126 : if (!ForBESUtils::is_status_ok(status)) {
257 0 : return status;
258 : }
259 : }
260 :
261 81138 : status = m_prob.g()->callProx(*m_y, gamma, *m_z, m_gz);
262 81138 : if (!ForBESUtils::is_status_ok(status)) {
263 0 : return status;
264 : }
265 81138 : *m_FPRx = (*m_x - *m_z);
266 81138 : m_sqnormFPRx = std::pow(m_FPRx->norm_fro_sq(), 2);
267 81138 : m_gamma = gamma;
268 81138 : m_status = STATUS_FORWARDBACKWARD;
269 :
270 81138 : return ForBESUtils::STATUS_OK;
271 : }
272 :
273 12 : int FBCache::update_eval_FBE(double gamma) {
274 12 : if (!is_close(gamma, m_gamma)) {
275 0 : reset(STATUS_EVALF);
276 : }
277 :
278 12 : if (m_status >= STATUS_FBE) {
279 0 : return ForBESUtils::STATUS_OK;
280 : }
281 :
282 12 : if (m_status < STATUS_FORWARDBACKWARD) {
283 0 : int status = update_forward_backward_step(gamma);
284 0 : if (!ForBESUtils::is_status_ok(status)) {
285 0 : return status;
286 : }
287 : }
288 :
289 12 : Matrix innprox_mat = (*m_FPRx) * (*m_gradfx);
290 12 : double innprod = innprox_mat[0];
291 :
292 12 : m_FBEx = m_fx + m_gz - innprod + 0.5 / m_gamma*m_sqnormFPRx;
293 12 : m_gamma = gamma;
294 12 : m_status = STATUS_FBE;
295 :
296 12 : return ForBESUtils::STATUS_OK;
297 : }
298 :
299 12 : int FBCache::update_grad_FBE(double gamma) {
300 12 : if (!is_close(gamma, m_gamma)) {
301 0 : reset(STATUS_EVALF);
302 : }
303 :
304 12 : if (m_status >= STATUS_GRAD_FBE) {
305 0 : return ForBESUtils::STATUS_OK;
306 : }
307 :
308 12 : if (m_status < STATUS_FORWARDBACKWARD) {
309 0 : int status = update_forward_backward_step(gamma);
310 0 : if (!ForBESUtils::is_status_ok(status)) {
311 0 : return status;
312 : }
313 : }
314 :
315 12 : *m_gradFBEx = *m_FPRx;
316 :
317 12 : if (m_prob.f1() != NULL) {
318 8 : if (m_prob.L1() != NULL) {
319 4 : Matrix v1(m_prob.L1()->dimensionOut());
320 4 : v1 = m_prob.L1()->call(*m_FPRx);
321 8 : Matrix v2(m_prob.L1()->dimensionOut());
322 4 : m_prob.f1()->hessianProduct(*m_res1x, v1, v2);
323 8 : Matrix v3 = m_prob.L1()->callAdjoint(v2);
324 8 : Matrix::add(*m_gradFBEx, -1.0, v3, 1.0 / gamma);
325 : } else {
326 4 : Matrix v1(m_x->getNrows(), m_x->getNcols());
327 4 : m_prob.f1()->hessianProduct(*m_x, *m_FPRx, v1);
328 4 : Matrix::add(*m_gradFBEx, -1.0, v1, 1.0 / gamma);
329 : }
330 : }
331 :
332 12 : if (m_prob.f2() != NULL) {
333 4 : if (m_prob.L2() != NULL) {
334 2 : Matrix v1(m_prob.L2()->dimensionOut());
335 2 : v1 = m_prob.L2()->call(*m_FPRx);
336 4 : Matrix v2(m_prob.L2()->dimensionOut());
337 2 : m_prob.f2()->hessianProduct(*m_res2x, v1, v2);
338 4 : Matrix v3 = m_prob.L2()->callAdjoint(v2);
339 2 : if (m_prob.f1() != NULL) Matrix::add(*m_gradFBEx, -1.0, v3, 1.0);
340 4 : else Matrix::add(*m_gradFBEx, -1.0, v3, 1.0 / gamma);
341 : } else {
342 2 : Matrix v1(m_x->getNrows(), m_x->getNcols());
343 2 : m_prob.f2()->hessianProduct(*m_x, *m_FPRx, v1);
344 2 : if (m_prob.f1() != NULL) Matrix::add(*m_gradFBEx, -1.0, v1, 1.0);
345 2 : else Matrix::add(*m_gradFBEx, -1.0, v1, 1.0 / gamma);
346 : }
347 : }
348 :
349 12 : m_gamma = gamma;
350 12 : m_status = STATUS_GRAD_FBE;
351 :
352 12 : return ForBESUtils::STATUS_OK;
353 : }
354 :
355 53326 : void FBCache::set_point(Matrix& x) {
356 53326 : *m_x = x;
357 53326 : reset(STATUS_NONE);
358 53326 : }
359 :
360 108525 : Matrix * FBCache::get_point() {
361 108525 : return m_x;
362 : }
363 :
364 12 : double FBCache::get_eval_FBE(double gamma) {
365 12 : update_eval_FBE(gamma);
366 12 : return m_FBEx;
367 : }
368 :
369 12 : Matrix * FBCache::get_grad_FBE(double gamma) {
370 12 : update_grad_FBE(gamma);
371 12 : return m_gradFBEx;
372 : }
373 :
374 6 : double FBCache::get_eval_f() {
375 6 : update_eval_f(false);
376 6 : return m_fx;
377 : }
378 :
379 14 : Matrix* FBCache::get_forward_step(double gamma) {
380 14 : update_forward_step(gamma);
381 14 : return m_y;
382 : }
383 :
384 54238 : Matrix* FBCache::get_forward_backward_step(double gamma) {
385 54238 : update_forward_backward_step(gamma);
386 54238 : return m_z;
387 : }
388 :
389 1 : Matrix* FBCache::get_fpr() {
390 1 : update_forward_backward_step(m_gamma);
391 1 : return m_FPRx;
392 : }
393 :
394 54226 : double FBCache::get_norm_fpr() {
395 54226 : update_forward_backward_step(m_gamma);
396 54226 : return sqrt(m_sqnormFPRx);
397 : }
398 :
399 910 : FBCache::~FBCache() {
400 905 : if (m_z != NULL) {
401 905 : delete m_z;
402 905 : m_z = NULL;
403 : }
404 905 : if (m_y != NULL) {
405 905 : delete m_y;
406 905 : m_y = NULL;
407 : }
408 905 : if (m_res2x != NULL) {
409 202 : delete m_res2x;
410 202 : m_res2x = NULL;
411 : }
412 905 : if (m_gradf2x != NULL) {
413 202 : delete m_gradf2x;
414 202 : m_gradf2x = NULL;
415 : }
416 905 : if (m_res1x != NULL) {
417 703 : delete m_res1x;
418 703 : m_res1x = NULL;
419 : }
420 905 : if (m_gradf1x != NULL) {
421 703 : delete m_gradf1x;
422 703 : m_gradf1x = NULL;
423 : }
424 905 : if (m_gradfx != NULL) {
425 905 : delete m_gradfx;
426 905 : m_gradfx = NULL;
427 : }
428 905 : if (m_FPRx != NULL) {
429 905 : delete m_FPRx;
430 905 : m_FPRx = NULL;
431 : }
432 905 : if (m_gradFBEx != NULL) {
433 905 : delete m_gradFBEx;
434 905 : m_gradFBEx = NULL;
435 : }
436 922 : }
|