Line data Source code
1 : /*
2 : * File: IndBall2.cpp
3 : * Author: Pantelis Sopasakis
4 : *
5 : * Created on November 3, 2015, 4:18 PM
6 : *
7 : * ForBES is free software: you can redistribute it and/or modify
8 : * it under the terms of the GNU Lesser General Public License as published by
9 : * the Free Software Foundation, either version 3 of the License, or
10 : * (at your option) any later version.
11 : *
12 : * ForBES is distributed in the hope that it will be useful,
13 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 : * GNU Lesser General Public License for more details.
16 : *
17 : * You should have received a copy of the GNU Lesser General Public License
18 : * along with ForBES. If not, see <http://www.gnu.org/licenses/>.
19 : */
20 :
21 : #include "IndBall2.h"
22 : #include <cmath>
23 :
24 : void check_rho(double rho);
25 :
26 2 : void check_rho(double rho) {
27 2 : if (rho < 0) {
28 1 : throw std::invalid_argument("rho must be positive");
29 : }
30 1 : }
31 :
32 1 : IndBall2::IndBall2() {
33 1 : m_rho = 1;
34 1 : m_xc = NULL;
35 1 : m_is_xc_zero = true;
36 1 : }
37 :
38 2 : IndBall2::IndBall2(double rho) {
39 1 : check_rho(rho);
40 0 : m_rho = rho;
41 0 : m_xc = NULL;
42 0 : m_is_xc_zero = true;
43 0 : }
44 :
45 1 : IndBall2::IndBall2(double rho, Matrix& c) {
46 1 : check_rho(rho);
47 1 : m_rho = rho;
48 1 : m_xc = &c;
49 1 : m_is_xc_zero = false;
50 1 : }
51 :
52 4 : IndBall2::~IndBall2() {
53 : // nothing to delete
54 4 : }
55 :
56 2 : double IndBall2::norm_div(const Matrix& x) {
57 2 : double ndiv = 0.0;
58 2 : if (m_is_xc_zero) {
59 0 : for (size_t i = 0; i < x.getNrows(); i++) {
60 0 : ndiv += std::pow(x[i], 2);
61 : }
62 : } else {
63 8 : for (size_t i = 0; i < x.getNrows(); i++) {
64 6 : ndiv += std::pow(x[i] - m_xc->get(i), 2);
65 : }
66 : }
67 2 : ndiv = std::sqrt(ndiv);
68 2 : return ndiv;
69 : }
70 :
71 2 : int IndBall2::callProx(Matrix& x, double gamma, Matrix& prox) {
72 2 : if (!x.isColumnVector()) {
73 0 : throw std::invalid_argument("x must be a vector");
74 : }
75 2 : double normDiv = norm_div(x);
76 2 : if (normDiv <= m_rho) {
77 0 : prox = x;
78 : } else {
79 2 : double alpha = m_rho / normDiv;
80 8 : for (size_t i = 0; i < x.getNrows(); i++) {
81 12 : prox[i] = (m_is_xc_zero ? 0.0 : m_xc->get(i))
82 12 : + alpha * (x[i] - (m_is_xc_zero ? 0.0 : m_xc->get(i)));
83 : }
84 : }
85 2 : return ForBESUtils::STATUS_OK;
86 : }
87 :
88 1 : int IndBall2::callProx(Matrix& x, double gamma, Matrix& prox, double& f_at_prox) {
89 1 : callProx(x, gamma, prox);
90 1 : f_at_prox = 0.0;
91 1 : return ForBESUtils::STATUS_OK;
92 : }
93 :
94 4 : FunctionOntologicalClass IndBall2::category() {
95 4 : FunctionOntologicalClass meta("IndBall2");
96 4 : meta.add_superclass(FunctionOntologyRegistry::indicator());
97 4 : meta.set_defines_prox(true);
98 4 : return meta;
99 3 : }
|