Line data Source code
1 : /*
2 : * File: HingeLoss.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on October 29, 2015, 10:49 PM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "HingeLoss.h"
22 :
23 6 : HingeLoss::HingeLoss(Matrix& b, double mu) :
24 6 : Function(), m_mu(mu) {
25 6 : m_b = &b;
26 6 : }
27 :
28 1 : HingeLoss::HingeLoss(Matrix& b) :
29 1 : Function() {
30 1 : m_b = &b;
31 1 : m_mu = 1.0;
32 1 : }
33 :
34 14 : HingeLoss::~HingeLoss() {
35 14 : }
36 :
37 6 : int HingeLoss::call(Matrix& x, double& f) {
38 6 : if (!x.isColumnVector()) {
39 0 : throw std::invalid_argument("x must be a column-vector");
40 : }
41 6 : f = 0.0;
42 435 : for (size_t i = 0; i < x.getNrows(); i++) {
43 429 : double si = 1 - m_b->get(i) * x[i];
44 429 : if (si > 0) {
45 21 : f += si;
46 : }
47 : }
48 6 : f *= m_mu;
49 6 : return ForBESUtils::STATUS_OK;
50 : }
51 :
52 3 : int HingeLoss::callProx(Matrix& x, double gamma, Matrix& prox) {
53 3 : if (!x.isColumnVector()) {
54 0 : throw std::invalid_argument("x must be a column-vector");
55 : }
56 3 : double gm = gamma*m_mu;
57 12 : for (size_t i = 0; i < x.getNrows(); i++) {
58 : double bi;
59 : double bxi;
60 9 : bi = m_b->get(i);
61 9 : bxi = bi * x[i];
62 9 : if (bxi < 1) {
63 2 : prox[i] = bi * std::min(1.0, bxi + gm);
64 : } else {
65 7 : prox[i] = x[i];
66 : }
67 : }
68 3 : return ForBESUtils::STATUS_OK;
69 : }
70 :
71 3 : int HingeLoss::callProx(Matrix& x, double gamma, Matrix& prox, double& f_at_prox) {
72 3 : if (!x.isColumnVector()) {
73 0 : throw std::invalid_argument("x must be a column-vector");
74 : }
75 3 : double gm = gamma*m_mu;
76 3 : f_at_prox = 0.0;
77 12 : for (size_t i = 0; i < x.getNrows(); i++) {
78 : double si;
79 : double pi;
80 : double bi;
81 : double bxi;
82 9 : bi = m_b->get(i);
83 9 : bxi = bi * x[i];
84 9 : if (bxi < 1) {
85 2 : pi = bi * std::min(1.0, bxi + gm);
86 : } else {
87 7 : pi = x[i];
88 : }
89 9 : si = 1 - m_b->get(i) * pi;
90 9 : if (si > 0) {
91 2 : f_at_prox += si;
92 : }
93 9 : prox[i] = pi;
94 : }
95 3 : f_at_prox *= m_mu;
96 3 : return ForBESUtils::STATUS_OK;
97 : }
98 :
99 3 : FunctionOntologicalClass HingeLoss::category() {
100 3 : FunctionOntologicalClass hingeLoss("HingeLoss");
101 3 : hingeLoss.set_defines_f(true);
102 3 : hingeLoss.set_defines_grad(true);
103 3 : hingeLoss.set_defines_prox(true);
104 3 : hingeLoss.add_superclass(FunctionOntologyRegistry::loss());
105 3 : return hingeLoss;
106 6 : }
|