Line data Source code
1 : #include "FBSplittingFast.h"
2 : #include "FBProblem.h"
3 : #include "FBStoppingRelative.h"
4 : #include "MatrixFactory.h"
5 : #include "MatrixOperator.h"
6 : #include "TestFBSplittingFast.h"
7 :
8 : // #include <iostream>
9 :
10 : #define DOUBLES_EQUAL_DELTA 1e-4
11 : #define MAXIT 1000
12 : #define TOLERANCE 1e-6
13 :
14 1 : CPPUNIT_TEST_SUITE_REGISTRATION(TestFBSplittingFast);
15 :
16 3 : TestFBSplittingFast::TestFBSplittingFast() {
17 3 : }
18 :
19 6 : TestFBSplittingFast::~TestFBSplittingFast() {
20 6 : }
21 :
22 3 : void TestFBSplittingFast::setUp() {
23 3 : }
24 :
25 3 : void TestFBSplittingFast::tearDown() {
26 3 : }
27 :
28 1 : void TestFBSplittingFast::testBoxQP_small() {
29 1 : size_t n = 4;
30 : // problem data
31 : double data_Q[] = {
32 : 7, 2, -2, -1,
33 : 2, 3, 0, -1,
34 : -2, 0, 3, -1,
35 : -1, -1, -1, 1
36 1 : };
37 : double data_q[] = {
38 : 1, 2, 3, 4
39 1 : };
40 1 : double gamma = 0.1;
41 1 : double lb = -1;
42 1 : double ub = +1;
43 : // starting points
44 1 : double data_x1[] = {+0.5, +1.2, -0.7, -1.1};
45 1 : double data_x2[] = {-1.0, -1.0, -1.0, -1.0};
46 : // reference results
47 1 : double ref_xstar[] = {-0.352941176470588, -0.764705882352941, -1.000000000000000, -1.000000000000000};
48 :
49 1 : Matrix * Q = new Matrix(n, n, data_Q);
50 1 : Matrix * q = new Matrix(n, 1, data_q);
51 : Matrix * x0;
52 1 : Matrix xstar;
53 1 : Function * f = new Quadratic(*Q, *q);
54 1 : Function * g = new IndBox(lb, ub);
55 2 : FBProblem prob = FBProblem(*f, *g);
56 2 : FBStoppingRelative sc = FBStoppingRelative(TOLERANCE);
57 : FBSplitting * solver;
58 :
59 : // test FB operations starting from x1
60 1 : size_t repeat = 100;
61 101 : for (size_t r = 0; r < repeat; r++) {
62 100 : x0 = new Matrix(n, 1, data_x1);
63 100 : solver = new FBSplittingFast(prob, *x0, gamma, sc, MAXIT);
64 100 : solver->run();
65 100 : xstar = solver->getSolution();
66 : //cout << "*** iters (fast) : " << solver->getIt() << endl;
67 100 : _ASSERT(solver->getIt() < MAXIT);
68 500 : for (int i=0; i < n; i++) {
69 400 : CPPUNIT_ASSERT_DOUBLES_EQUAL(ref_xstar[i], xstar.get(i, 0), DOUBLES_EQUAL_DELTA);
70 : }
71 100 : delete x0;
72 100 : delete solver;
73 :
74 : // test FB operations starting from x2
75 100 : x0 = new Matrix(n, 1, data_x2);
76 100 : solver = new FBSplittingFast(prob, *x0, gamma, sc, MAXIT);
77 100 : solver->run();
78 100 : xstar = solver->getSolution();
79 : //cout << "*** iters (fast) : " << solver->getIt() << endl;
80 100 : _ASSERT(solver->getIt() < MAXIT);
81 500 : for (int i=0; i < n; i++) {
82 400 : CPPUNIT_ASSERT_DOUBLES_EQUAL(ref_xstar[i], xstar.get(i, 0), DOUBLES_EQUAL_DELTA);
83 : }
84 100 : delete x0;
85 100 : delete solver;
86 : }
87 :
88 1 : delete Q;
89 1 : delete q;
90 1 : delete f;
91 2 : delete g;
92 1 : }
93 :
94 1 : void TestFBSplittingFast::testLasso_small() {
95 1 : size_t n = 5;
96 1 : size_t m = 4;
97 : // problem data
98 : double data_A[] = {
99 : 1, 2, -1, -1,
100 : -2, -1, 0, -1,
101 : 3, 0, 4, -1,
102 : -4, -1, -3, 1,
103 : 5, 3, 2, 3
104 1 : };
105 : double data_minusb[] = {
106 : -1, -2, -3, -4
107 1 : };
108 : /*
109 : * WARNING: data_w is not used anywhere...
110 : */
111 : double data_w[] = {
112 : 1, 1, 1, 1
113 1 : };
114 1 : double gamma = 0.01;
115 : // starting points
116 1 : double data_x1[] = {0, 0, 0, 0, 0};
117 : // reference results
118 1 : double ref_xstar[] = {-0.010238907849511, 0, 0, 0, 0.511945392491421};
119 :
120 1 : Matrix * A = new Matrix(m, n, data_A);
121 1 : Matrix * minusb = new Matrix(m, 1, data_minusb);
122 : Matrix * x0;
123 1 : Matrix xstar;
124 1 : Function * f = new QuadraticLoss();
125 1 : LinearOperator * OpA = new MatrixOperator(*A);
126 1 : Function * g = new Norm1(5.0);
127 2 : FBProblem prob = FBProblem(*f, *OpA, *minusb, *g);
128 2 : FBStoppingRelative sc = FBStoppingRelative(TOLERANCE);
129 : FBSplitting * solver;
130 :
131 1 : size_t repeat = 200;
132 201 : for (size_t r = 0; r < repeat; r++) {
133 : // test FB operations starting from x1
134 200 : x0 = new Matrix(n, 1, data_x1);
135 200 : solver = new FBSplittingFast(prob, *x0, gamma, sc, MAXIT);
136 200 : solver->run();
137 200 : xstar = solver->getSolution();
138 : //cout << "*** iters (fast) : " << solver->getIt() << endl;
139 200 : _ASSERT(solver->getIt() < MAXIT);
140 1200 : for (int i=0; i < n; i++) {
141 1000 : CPPUNIT_ASSERT_DOUBLES_EQUAL(ref_xstar[i], xstar.get(i, 0), DOUBLES_EQUAL_DELTA);
142 : }
143 200 : delete x0;
144 200 : delete solver;
145 : }
146 :
147 1 : delete A;
148 1 : delete minusb;
149 1 : delete OpA;
150 1 : delete f;
151 2 : delete g;
152 1 : }
153 :
154 1 : void TestFBSplittingFast::testSparseLogReg_small() {
155 1 : size_t n = 5;
156 1 : size_t m = 4;
157 : // problem data
158 : double data_A[] = {
159 : 1, 2, -1, -1,
160 : -2, -1, 0, -1,
161 : 3, 0, 4, -1,
162 : -4, -1, -3, 1,
163 : 5, 3, 2, 3
164 1 : };
165 : double data_minusb[] = {
166 : -1, 1, -1, 1
167 1 : };
168 1 : double gamma = 0.1;
169 : // starting points
170 1 : double data_x1[] = {0, 0, 0, 0, 0};
171 : // reference results
172 1 : double ref_xstar[] = {0.0, 0.0, 0.215341883018748, 0.0, 0.675253988559914};
173 :
174 1 : Matrix * A = new Matrix(m, n, data_A);
175 1 : Matrix * minusb = new Matrix(m, 1, data_minusb);
176 :
177 1 : Function * f = new LogLogisticLoss();
178 1 : LinearOperator * OpA = new MatrixOperator(*A);
179 1 : Function * g = new Norm1(1.0);
180 1 : FBProblem prob(*f, *OpA, *minusb, *g);
181 2 : FBStoppingRelative sc(TOLERANCE);
182 : FBSplitting * solver;
183 :
184 1 : size_t repeat = 100;
185 101 : for (size_t r = 0; r < repeat; r++) {
186 : // test FB operations starting from x1
187 100 : Matrix * x0 = new Matrix(n, 1, data_x1);
188 100 : solver = new FBSplittingFast(prob, *x0, gamma, sc, MAXIT);
189 100 : solver->run();
190 100 : Matrix xstar = solver->getSolution();
191 : //cout << "*** iters (fast) : " << solver->getIt() << endl << flush;
192 100 : _ASSERT(solver->getIt() < MAXIT);
193 600 : for (int i=0; i < n; i++) {
194 500 : CPPUNIT_ASSERT_DOUBLES_EQUAL(ref_xstar[i], xstar.get(i, 0), DOUBLES_EQUAL_DELTA);
195 : }
196 100 : delete x0;
197 100 : delete solver;
198 100 : }
199 :
200 1 : delete A;
201 1 : delete minusb;
202 1 : delete OpA;
203 1 : delete f;
204 2 : delete g;
205 4 : }
|