Line data Source code
1 : /*
2 : * File: QuadraticLoss.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on October 29, 2015, 5:47 PM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "QuadraticLoss.h"
22 :
23 5 : QuadraticLoss::QuadraticLoss() {
24 5 : m_is_uniform_weights = true;
25 5 : m_is_zero_p = true;
26 5 : m_uniform_w = 1.0;
27 5 : m_w = NULL;
28 5 : m_p = NULL;
29 5 : }
30 :
31 1 : QuadraticLoss::QuadraticLoss(double w) :
32 1 : Function(), m_uniform_w(w) {
33 1 : m_is_uniform_weights = true;
34 1 : m_is_zero_p = true;
35 1 : m_w = NULL;
36 1 : m_p = NULL;
37 1 : }
38 :
39 2 : QuadraticLoss::QuadraticLoss(Matrix& w, Matrix& p) :
40 2 : Function() {
41 2 : if (!w.isColumnVector() || !p.isColumnVector()) {
42 0 : throw std::invalid_argument("Arguments w and p must be column-vectors");
43 : }
44 2 : if (w.getNrows() != p.getNrows()) {
45 0 : throw std::invalid_argument("w and p must be of equal size");
46 : }
47 2 : m_is_uniform_weights = false;
48 2 : m_is_zero_p = false;
49 2 : m_w = &w;
50 2 : m_p = &p;
51 2 : m_uniform_w = 0.0;
52 2 : }
53 :
54 13 : QuadraticLoss::~QuadraticLoss() {
55 13 : }
56 :
57 1 : int QuadraticLoss::call(Matrix& x, double& f) {
58 1 : f = 0.0;
59 11 : for (size_t j = 0; j < x.getNrows(); j++) {
60 : double fi;
61 10 : fi = x[j];
62 10 : if (!m_is_zero_p) {
63 10 : fi -= m_p->getData()[j];
64 : }
65 10 : fi *= fi;
66 10 : if (!m_is_uniform_weights) {
67 10 : fi *= m_w->getData()[j];
68 : }
69 10 : f += fi;
70 : }
71 1 : if (m_is_uniform_weights) {
72 0 : f *= m_uniform_w;
73 : }
74 1 : f /= 2.0;
75 1 : return ForBESUtils::STATUS_OK;
76 : }
77 :
78 38129 : int QuadraticLoss::call(Matrix& x, double& f, Matrix& grad) {
79 38129 : f = 0.0;
80 190659 : for (size_t j = 0; j < x.getNrows(); j++) {
81 : double fi;
82 : double gi;
83 152530 : fi = x[j];
84 152530 : if (!m_is_zero_p) {
85 20 : fi -= m_p->getData()[j];
86 : }
87 152530 : gi = fi;
88 152530 : fi *= fi;
89 152530 : if (!m_is_uniform_weights) {
90 20 : double w = m_w->getData()[j];
91 20 : fi *= w;
92 20 : gi *= w;
93 : }
94 152530 : f += fi;
95 152530 : grad[j] = gi;
96 : }
97 38129 : if (m_is_uniform_weights) {
98 38127 : f *= m_uniform_w;
99 : }
100 38129 : f /= 2.0;
101 38129 : return ForBESUtils::STATUS_OK;
102 : }
103 :
104 3 : int QuadraticLoss::callConj(Matrix& x, double& f_star) {
105 3 : f_star = 0.0;
106 33 : for (size_t i = 0; i < x.getNrows(); i++) {
107 30 : f_star += x[i]*(2.0 * (m_is_zero_p ? 0.0 : m_p->get(i))
108 30 : + (x[i] / (m_is_uniform_weights ? m_uniform_w : m_w->get(i))));
109 : }
110 3 : f_star /= 2.0;
111 3 : return ForBESUtils::STATUS_OK;
112 : }
113 :
114 1 : int QuadraticLoss::callConj(Matrix& x, double& f_star, Matrix& grad) {
115 1 : f_star = 0.0;
116 11 : for (size_t i = 0; i < x.getNrows(); i++) {
117 : double gradi;
118 : double pi;
119 10 : pi = m_is_zero_p ? 0.0 : m_p->get(i);
120 10 : gradi = pi + x[i] / (m_is_uniform_weights ? m_uniform_w : m_w->get(i));
121 10 : grad.set(i, 0, gradi);
122 10 : f_star += x[i]*(gradi + pi);
123 : }
124 1 : f_star /= 2.0;
125 1 : return ForBESUtils::STATUS_OK;
126 : }
127 :
128 4 : int QuadraticLoss::hessianProduct(Matrix& x, Matrix& z, Matrix& Hz) {
129 4 : if (m_is_uniform_weights) {
130 4 : Hz = m_uniform_w * z;
131 : } else {
132 0 : m_w->toggle_diagonal();
133 0 : Hz = (*m_w)*z;
134 0 : m_w->toggle_diagonal();
135 : }
136 4 : return ForBESUtils::STATUS_OK;
137 : }
138 :
139 18 : FunctionOntologicalClass QuadraticLoss::category() {
140 18 : FunctionOntologicalClass quadLoss("QuadraticLoss");
141 18 : quadLoss.set_defines_f(true);
142 18 : quadLoss.set_defines_conjugate(true);
143 18 : quadLoss.set_defines_conjugate_grad(true);
144 18 : quadLoss.set_defines_grad(true);
145 18 : quadLoss.add_superclass(FunctionOntologyRegistry::loss());
146 18 : quadLoss.add_superclass(FunctionOntologyRegistry::quadratic());
147 18 : return quadLoss;
148 18 : }
|