Line data Source code
1 : /*
2 : * File: Quadratic.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on July 9, 2015, 3:36 AM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "Quadratic.h"
22 : #include "MatrixFactory.h"
23 : #include "CGSolver.h"
24 :
25 : using namespace std;
26 :
27 7 : Quadratic::Quadratic() {
28 7 : m_is_Q_eye = true;
29 7 : m_is_q_zero = true;
30 7 : m_solver = NULL;
31 7 : m_Q = NULL;
32 7 : m_q = NULL;
33 7 : }
34 :
35 12 : Quadratic::Quadratic(Matrix& QQ) {
36 12 : m_Q = &QQ;
37 12 : m_is_Q_eye = false;
38 12 : m_is_q_zero = true;
39 12 : m_solver = NULL;
40 12 : m_q = NULL;
41 12 : }
42 :
43 11 : Quadratic::Quadratic(Matrix& QQ, Matrix& qq) {
44 11 : m_q = &qq;
45 11 : m_Q = &QQ;
46 11 : m_solver = NULL;
47 11 : m_is_Q_eye = false;
48 11 : m_is_q_zero = false;
49 11 : }
50 :
51 83 : Quadratic::~Quadratic() {
52 30 : if (m_solver != NULL) {
53 3 : delete m_solver;
54 : }
55 53 : }
56 :
57 1 : void Quadratic::setQ(Matrix& Q) {
58 1 : m_is_Q_eye = false;
59 1 : this->m_Q = &Q;
60 1 : this->m_solver = NULL;
61 1 : }
62 :
63 3 : void Quadratic::setq(Matrix& q) {
64 3 : m_is_q_zero = false;
65 3 : this->m_q = &q;
66 3 : }
67 :
68 11 : int Quadratic::call(Matrix& x, double& f) {
69 11 : if (!m_is_Q_eye) {
70 10 : if (m_is_q_zero) {
71 7 : f = m_Q->quad(x);
72 : } else {
73 3 : f = m_Q->quad(x, *m_q);
74 : }
75 : } else {
76 1 : f = static_cast<Matrix> (x * x)[0];
77 1 : if (!m_is_q_zero) {
78 0 : f += static_cast<Matrix> ((*m_q) * x)[0];
79 : }
80 1 : f = f / 2.0;
81 : }
82 11 : return ForBESUtils::STATUS_OK;
83 : }
84 :
85 23764 : int Quadratic::call(Matrix& x, double& f, Matrix& grad) {
86 23764 : int statusComputeGrad = computeGradient(x, grad); // compute the gradient of f at x (grad)
87 23764 : if (statusComputeGrad != ForBESUtils::STATUS_OK) {
88 0 : return statusComputeGrad;
89 : }
90 : // f = (1/2)*(grad+q)'*x
91 23764 : f = static_cast<Matrix> ((m_is_q_zero ? grad : grad + (*m_q)) * x)[0] / 2;
92 23764 : return ForBESUtils::STATUS_OK;
93 : }
94 :
95 7 : int Quadratic::hessianProduct(Matrix& x, Matrix& z, Matrix& Hz) {
96 7 : Hz = (*m_Q) * z;
97 7 : return ForBESUtils::STATUS_OK;
98 : }
99 :
100 7 : int Quadratic::callConj(Matrix& y, double& f_star) {
101 7 : Matrix g;
102 7 : int status = callConj(y, f_star, g);
103 7 : return status;
104 : }
105 :
106 9 : int Quadratic::callConj(Matrix& y, double& f_star, Matrix& g) {
107 9 : Matrix z = (m_is_q_zero || m_q == NULL) ? y : y - *m_q; // z = y
108 9 : if (m_is_Q_eye || m_Q == NULL) {
109 1 : g = z;
110 1 : f_star = static_cast<Matrix> (z * z)[0];
111 1 : return ForBESUtils::STATUS_OK;
112 : }
113 8 : if (m_Q != NULL && Matrix::MATRIX_DIAGONAL == m_Q->getType()) {
114 : /* Q is diagonal */
115 1 : g = z;
116 1 : f_star = 0.0;
117 5 : for (size_t i = 0; i < z.getNrows(); i++) {
118 4 : g[i] /= m_Q->get(i, i);
119 4 : f_star += z[i] * g[i];
120 : }
121 1 : return ForBESUtils::STATUS_OK;
122 : }
123 :
124 7 : if (m_solver == NULL) {
125 4 : m_solver = new CholeskyFactorization(*m_Q);
126 4 : int status = m_solver->factorize();
127 4 : if (0 != status) {
128 1 : return ForBESUtils::STATUS_NUMERICAL_PROBLEMS;
129 : }
130 : }
131 :
132 6 : m_solver->solve(z, g); // Q*g = z OR g = Q \ z
133 6 : f_star = static_cast<Matrix> (z * g)[0]; // fstar = z' *g
134 6 : return ForBESUtils::STATUS_OK;
135 : }
136 :
137 23764 : int Quadratic::computeGradient(Matrix& x, Matrix& grad) {
138 23764 : if (m_is_Q_eye) {
139 352 : grad = x;
140 : } else {
141 23412 : grad = (*m_Q) * x;
142 : }
143 23764 : if (!m_is_q_zero) {
144 23411 : grad += *m_q;
145 : }
146 23764 : return ForBESUtils::STATUS_OK;
147 : }
148 :
149 14 : FunctionOntologicalClass Quadratic::category() {
150 14 : return FunctionOntologyRegistry::quadratic();
151 : }
152 :
153 2 : int Quadratic::callProx(Matrix& v, double gamma, Matrix& prox) {
154 :
155 2 : Matrix v_gamma_b = v;
156 2 : if (!m_is_q_zero) Matrix::add(v_gamma_b, -gamma, *m_q, 1.0);
157 :
158 :
159 :
160 : // (I+gamma Q)^{-1}(v-gamma q)
161 : int status;
162 2 : if (!m_is_Q_eye) {
163 : // If Q is not I, we need to create a CGSolver for (I + gamma Q)
164 1 : Matrix Q_tilde(*m_Q);
165 1 : size_t n = m_Q->getNrows();
166 1 : static Matrix Eye = MatrixFactory::MakeIdentity(n, 1.0);
167 1 : Q_tilde *= gamma;
168 1 : Q_tilde += Eye;
169 :
170 2 : MatrixOperator Q_tilde_op(Q_tilde);
171 2 : Matrix P(n, n, Matrix::MATRIX_DIAGONAL);
172 4 : for (size_t i = 0; i < n; i++) {
173 3 : P[i] = 1 / Q_tilde.get(i, i);
174 : }
175 2 : MatrixOperator P_op(P);
176 2 : CGSolver solver(Q_tilde_op, P_op, 1e-6, 1500);
177 1 : status = solver.solve(v_gamma_b, prox);
178 2 : if (!ForBESUtils::is_status_ok(status)) return status;
179 : } else {
180 : // Q = I
181 : // (v-gamma q)/(1+gamma)
182 1 : v_gamma_b *= (1. / (1. + gamma));
183 1 : prox = v_gamma_b;
184 1 : return ForBESUtils::STATUS_OK;
185 : }
186 1 : return ForBESUtils::STATUS_OK;
187 :
188 27 : }
189 :
190 :
|