Line data Source code
1 : /*
2 : * File: ElasticNet.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on October 28, 2015, 7:43 PM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "ElasticNet.h"
22 : #include <cmath>
23 :
24 6 : ElasticNet::ElasticNet(double lambda, double mu) : Function(), m_mu(mu), m_lambda(lambda) {
25 6 : }
26 :
27 12 : ElasticNet::~ElasticNet() {
28 12 : }
29 :
30 3 : int ElasticNet::call(Matrix& x, double& f) {
31 : //LCOV_EXCL_START
32 : if (!x.isColumnVector()) {
33 : throw std::invalid_argument("x must be a column-vector");
34 : }
35 : //LCOV_EXCL_STOP
36 3 : f = 0.0;
37 18 : for (size_t i = 0; i < x.getNrows(); i++) {
38 : double xi;
39 15 : xi = x[i];
40 15 : f += m_mu * std::abs(xi) + (m_lambda / 2.0) * std::pow(xi, 2);
41 : }
42 3 : return ForBESUtils::STATUS_OK;
43 : }
44 :
45 4 : int ElasticNet::callProx(Matrix& x, double gamma, Matrix& prox, double& g_at_prox) {
46 : //LCOV_EXCL_START
47 : if (!x.isColumnVector()) {
48 : throw std::invalid_argument("x must be a column-vector");
49 : }
50 : //LCOV_EXCL_STOP
51 4 : double gm = gamma * m_mu;
52 4 : double alpha = 1 + m_lambda * gamma; // alpha > 0 [assuming gamma>0 and lambda>0].
53 4 : g_at_prox = 0.0;
54 28 : for (size_t i = 0; i < x.getNrows(); i++) {
55 : double xi;
56 : double yi;
57 24 : xi = x[i];
58 24 : yi = max(0.0, abs(xi) - gm) / alpha;
59 24 : prox[i] = (xi < 0 ? -1 : 1) * yi;
60 24 : g_at_prox += m_mu * yi + (m_lambda / 2.0) * std::pow(yi, 2);
61 : }
62 4 : return ForBESUtils::STATUS_OK;
63 : }
64 :
65 5 : int ElasticNet::callProx(Matrix& x, double gamma, Matrix& prox) {
66 : //LCOV_EXCL_START
67 : if (!x.isColumnVector()) {
68 : throw std::invalid_argument("x must be a column-vector");
69 : }
70 : //LCOV_EXCL_STOP
71 5 : double gm = gamma * m_mu;
72 5 : double alpha = 1 + m_lambda * gamma; // alpha > 0 [assuming gamma>0 and lambda>0].
73 38 : for (size_t i = 0; i < x.getNrows(); i++) {
74 : double xi;
75 33 : xi = x[i];
76 33 : prox[i] = (xi < 0 ? -1 : 1) * max(0.0, abs(xi) - gm) / alpha;
77 : }
78 5 : return ForBESUtils::STATUS_OK;
79 : }
80 :
81 6 : FunctionOntologicalClass ElasticNet::category() {
82 6 : FunctionOntologicalClass ont("ElasticNet");
83 6 : ont.set_defines_conjugate(false);
84 6 : ont.set_defines_conjugate_grad(false);
85 6 : ont.set_defines_f(true);
86 6 : ont.set_defines_grad(false);
87 6 : ont.set_defines_prox(true);
88 6 : ont.add_superclass(FunctionOntologyRegistry::function());
89 6 : return ont;
90 9 : }
|