@@ -70,39 +70,50 @@ template <typename T> static auto makeB() {
70
70
const index_t m = 8 ;
71
71
const index_t n = 2 ;
72
72
tensor_t <T, 2 > B = make_tensor<T>({m, n});
73
- B (0 , 0 ) = static_cast <T>(1 ); B (0 , 1 ) = static_cast <T>(2 );
74
- B (1 , 0 ) = static_cast <T>(3 ); B (1 , 1 ) = static_cast <T>(4 );
75
- B (2 , 0 ) = static_cast <T>(5 ); B (2 , 1 ) = static_cast <T>(6 );
76
- B (3 , 0 ) = static_cast <T>(7 ); B (3 , 1 ) = static_cast <T>(8 );
77
- B (4 , 0 ) = static_cast <T>(9 ); B (4 , 1 ) = static_cast <T>(10 );
78
- B (5 , 0 ) = static_cast <T>(11 ); B (5 , 1 ) = static_cast <T>(12 );
79
- B (6 , 0 ) = static_cast <T>(13 ); B (6 , 1 ) = static_cast <T>(14 );
80
- B (7 , 0 ) = static_cast <T>(15 ); B (7 , 1 ) = static_cast <T>(16 );
73
+ B (0 , 0 ) = static_cast <T>(1 );
74
+ B (0 , 1 ) = static_cast <T>(2 );
75
+ B (1 , 0 ) = static_cast <T>(3 );
76
+ B (1 , 1 ) = static_cast <T>(4 );
77
+ B (2 , 0 ) = static_cast <T>(5 );
78
+ B (2 , 1 ) = static_cast <T>(6 );
79
+ B (3 , 0 ) = static_cast <T>(7 );
80
+ B (3 , 1 ) = static_cast <T>(8 );
81
+ B (4 , 0 ) = static_cast <T>(9 );
82
+ B (4 , 1 ) = static_cast <T>(10 );
83
+ B (5 , 0 ) = static_cast <T>(11 );
84
+ B (5 , 1 ) = static_cast <T>(12 );
85
+ B (6 , 0 ) = static_cast <T>(13 );
86
+ B (6 , 1 ) = static_cast <T>(14 );
87
+ B (7 , 0 ) = static_cast <T>(15 );
88
+ B (7 , 1 ) = static_cast <T>(16 );
81
89
return B;
82
90
}
83
91
84
92
template <typename T> static auto makeE () {
85
93
const index_t m = 4 ;
86
94
const index_t n = 2 ;
87
95
tensor_t <T, 2 > E = make_tensor<T>({m, n});
88
- E (0 , 0 ) = static_cast <T>(7 ); E (0 , 1 ) = static_cast <T>(10 );
89
- E (1 , 0 ) = static_cast <T>(45 ); E (1 , 1 ) = static_cast <T>(48 );
90
- E (2 , 0 ) = static_cast <T>(52 ); E (2 , 1 ) = static_cast <T>(56 );
91
- E (3 , 0 ) = static_cast <T>(144 ); E (3 , 1 ) = static_cast <T>(162 );
96
+ E (0 , 0 ) = static_cast <T>(7 );
97
+ E (0 , 1 ) = static_cast <T>(10 );
98
+ E (1 , 0 ) = static_cast <T>(45 );
99
+ E (1 , 1 ) = static_cast <T>(48 );
100
+ E (2 , 0 ) = static_cast <T>(52 );
101
+ E (2 , 1 ) = static_cast <T>(56 );
102
+ E (3 , 0 ) = static_cast <T>(144 );
103
+ E (3 , 1 ) = static_cast <T>(162 );
92
104
return E;
93
105
}
94
106
95
107
template <typename T> class MatmulSparseTest : public ::testing::Test {
96
108
protected:
97
109
using GTestType = cuda::std::tuple_element_t <0 , T>;
98
110
using GExecType = cuda::std::tuple_element_t <1 , T>;
99
- void SetUp () override {
100
- CheckTestTypeSupport<GTestType>();
101
- }
111
+ void SetUp () override { CheckTestTypeSupport<GTestType>(); }
102
112
float thresh = 0 .001f ;
103
113
};
104
114
105
- template <typename T> class MatmulSparseTestsAll : public MatmulSparseTest <T> { };
115
+ template <typename T>
116
+ class MatmulSparseTestsAll : public MatmulSparseTest <T> {};
106
117
107
118
TYPED_TEST_SUITE (MatmulSparseTestsAll, MatXFloatNonComplexHalfTypesCUDAExec);
108
119
@@ -136,9 +147,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCOO) {
136
147
for (index_t j = 0 ; j < n; j++) {
137
148
if constexpr (is_complex_v<TestType>) {
138
149
ASSERT_NEAR (O (i, j).real (), E (i, j).real (), this ->thresh );
139
- ASSERT_NEAR (O (i, j).imag (), E (i,j ).imag (), this ->thresh );
140
- }
141
- else {
150
+ ASSERT_NEAR (O (i, j).imag (), E (i, j).imag (), this ->thresh );
151
+ } else {
142
152
ASSERT_NEAR (O (i, j), E (i, j), this ->thresh );
143
153
}
144
154
}
@@ -154,9 +164,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCOO) {
154
164
for (index_t j = 0 ; j < n; j++) {
155
165
if constexpr (is_complex_v<TestType>) {
156
166
ASSERT_NEAR (TO (j, i).real (), E (i, j).real (), this ->thresh );
157
- ASSERT_NEAR (TO (j, i).imag (), E (i,j ).imag (), this ->thresh );
158
- }
159
- else {
167
+ ASSERT_NEAR (TO (j, i).imag (), E (i, j).imag (), this ->thresh );
168
+ } else {
160
169
ASSERT_NEAR (TO (j, i), E (i, j), this ->thresh );
161
170
}
162
171
}
@@ -180,7 +189,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSR) {
180
189
const auto n = B.Size (1 );
181
190
182
191
// Convert dense A to sparse S.
183
- auto S = experimental::make_zero_tensor_csr<TestType, index_t , index_t >({m, k});
192
+ auto S =
193
+ experimental::make_zero_tensor_csr<TestType, index_t , index_t >({m, k});
184
194
(S = dense2sparse (A)).run (exec);
185
195
ASSERT_EQ (S.Nse (), 7 );
186
196
ASSERT_EQ (S.posSize (1 ), m + 1 );
@@ -195,9 +205,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSR) {
195
205
for (index_t j = 0 ; j < n; j++) {
196
206
if constexpr (is_complex_v<TestType>) {
197
207
ASSERT_NEAR (O (i, j).real (), E (i, j).real (), this ->thresh );
198
- ASSERT_NEAR (O (i, j).imag (), E (i,j ).imag (), this ->thresh );
199
- }
200
- else {
208
+ ASSERT_NEAR (O (i, j).imag (), E (i, j).imag (), this ->thresh );
209
+ } else {
201
210
ASSERT_NEAR (O (i, j), E (i, j), this ->thresh );
202
211
}
203
212
}
@@ -221,7 +230,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSC) {
221
230
const auto n = B.Size (1 );
222
231
223
232
// Convert dense A to sparse S.
224
- auto S = experimental::make_zero_tensor_csc<TestType, index_t , index_t >({m, k});
233
+ auto S =
234
+ experimental::make_zero_tensor_csc<TestType, index_t , index_t >({m, k});
225
235
(S = dense2sparse (A)).run (exec);
226
236
ASSERT_EQ (S.Nse (), 7 );
227
237
ASSERT_EQ (S.posSize (1 ), k + 1 );
@@ -236,9 +246,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSC) {
236
246
for (index_t j = 0 ; j < n; j++) {
237
247
if constexpr (is_complex_v<TestType>) {
238
248
ASSERT_NEAR (O (i, j).real (), E (i, j).real (), this ->thresh );
239
- ASSERT_NEAR (O (i, j).imag (), E (i,j ).imag (), this ->thresh );
240
- }
241
- else {
249
+ ASSERT_NEAR (O (i, j).imag (), E (i, j).imag (), this ->thresh );
250
+ } else {
242
251
ASSERT_NEAR (O (i, j), E (i, j), this ->thresh );
243
252
}
244
253
}
@@ -256,9 +265,8 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSC) {
256
265
for (index_t j = 0 ; j < n; j++) {
257
266
if constexpr (is_complex_v<TestType>) {
258
267
ASSERT_NEAR ((O (i, j) - C5).real (), E (i, j).real (), this ->thresh );
259
- ASSERT_NEAR ((O (i, j) - C5).imag (), E (i,j ).imag (), this ->thresh );
260
- }
261
- else {
268
+ ASSERT_NEAR ((O (i, j) - C5).imag (), E (i, j).imag (), this ->thresh );
269
+ } else {
262
270
ASSERT_NEAR (O (i, j) - C5, E (i, j), this ->thresh );
263
271
}
264
272
}
0 commit comments