blob: 990f6a4a1b7ac6d74d907a368db6793a5fce230b [file] [log] [blame]
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +00001/*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
12#define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
13
14#include <algorithm>
Michael Graczykdfa36052015-03-25 16:37:27 -070015#include <cstring>
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +000016#include <string>
17#include <vector>
18
19#include "webrtc/base/checks.h"
20#include "webrtc/base/constructormagic.h"
kwiberg@webrtc.org00b8f6b2015-02-26 14:34:55 +000021#include "webrtc/base/scoped_ptr.h"
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +000022
23namespace {
24
25// Wrappers to get around the compiler warning resulting from the fact that
26// there's no std::sqrt overload for ints. We cast all non-complex types to
27// a double for the sqrt method.
28template <typename T>
29T sqrt_wrapper(T x) {
30 return sqrt(static_cast<double>(x));
31}
32
33template <typename S>
34std::complex<S> sqrt_wrapper(std::complex<S> x) {
35 return sqrt(x);
36}
37} // namespace
38
39namespace webrtc {
40
41// Matrix is a class for doing standard matrix operations on 2 dimensional
42// matrices of any size. Results of matrix operations are stored in the
43// calling object. Function overloads exist for both in-place (the calling
44// object is used as both an operand and the result) and out-of-place (all
45// operands are passed in as parameters) operations. If operand dimensions
46// mismatch, the program crashes. Out-of-place operations change the size of
47// the calling object, if necessary, before operating.
48//
49// 'In-place' operations that inherently change the size of the matrix (eg.
50// Transpose, Multiply on different-sized matrices) must make temporary copies
51// (|scratch_elements_| and |scratch_data_|) of existing data to complete the
52// operations.
53//
54// The data is stored contiguously. Data can be accessed internally as a flat
55// array, |data_|, or as an array of row pointers, |elements_|, but is
56// available to users only as an array of row pointers through |elements()|.
57// Memory for storage is allocated when a matrix is resized only if the new
58// size overflows capacity. Memory needed temporarily for any operations is
59// similarly resized only if the new size overflows capacity.
60//
61// If you pass in storage through the ctor, that storage is copied into the
62// matrix. TODO(claguna): albeit tricky, allow for data to be referenced
63// instead of copied, and owned by the user.
64template <typename T>
65class Matrix {
66 public:
67 Matrix() : num_rows_(0), num_columns_(0) {}
68
69 // Allocates space for the elements and initializes all values to zero.
70 Matrix(int num_rows, int num_columns)
71 : num_rows_(num_rows), num_columns_(num_columns) {
72 Resize();
73 scratch_data_.resize(num_rows_ * num_columns_);
74 scratch_elements_.resize(num_rows_);
75 }
76
77 // Copies |data| into the new Matrix.
78 Matrix(const T* data, int num_rows, int num_columns)
aluebs@webrtc.org661af502015-02-19 19:02:17 +000079 : num_rows_(0), num_columns_(0) {
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +000080 CopyFrom(data, num_rows, num_columns);
81 scratch_data_.resize(num_rows_ * num_columns_);
82 scratch_elements_.resize(num_rows_);
83 }
84
85 virtual ~Matrix() {}
86
87 // Deep copy an existing matrix.
88 void CopyFrom(const Matrix& other) {
89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_);
90 }
91
92 // Copy |data| into the Matrix. The current data is lost.
93 void CopyFrom(const T* const data, int num_rows, int num_columns) {
aluebs@webrtc.org661af502015-02-19 19:02:17 +000094 Resize(num_rows, num_columns);
95 memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0]));
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +000096 }
97
98 Matrix& CopyFromColumn(const T* const* src, int column_index, int num_rows) {
99 Resize(1, num_rows);
100 for (int i = 0; i < num_columns_; ++i) {
101 data_[i] = src[i][column_index];
102 }
103
104 return *this;
105 }
106
107 void Resize(int num_rows, int num_columns) {
aluebs@webrtc.org661af502015-02-19 19:02:17 +0000108 if (num_rows != num_rows_ || num_columns != num_columns_) {
109 num_rows_ = num_rows;
110 num_columns_ = num_columns;
111 Resize();
112 }
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +0000113 }
114
115 // Accessors and mutators.
116 int num_rows() const { return num_rows_; }
117 int num_columns() const { return num_columns_; }
118 T* const* elements() { return &elements_[0]; }
119 const T* const* elements() const { return &elements_[0]; }
120
121 T Trace() {
122 CHECK_EQ(num_rows_, num_columns_);
123
124 T trace = 0;
125 for (int i = 0; i < num_rows_; ++i) {
126 trace += elements_[i][i];
127 }
128 return trace;
129 }
130
131 // Matrix Operations. Returns *this to support method chaining.
132 Matrix& Transpose() {
133 CopyDataToScratch();
aluebs@webrtc.org661af502015-02-19 19:02:17 +0000134 Resize(num_columns_, num_rows_);
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +0000135 return Transpose(scratch_elements());
136 }
137
138 Matrix& Transpose(const Matrix& operand) {
139 CHECK_EQ(operand.num_rows_, num_columns_);
140 CHECK_EQ(operand.num_columns_, num_rows_);
141
142 return Transpose(operand.elements());
143 }
144
145 template <typename S>
146 Matrix& Scale(const S& scalar) {
147 for (size_t i = 0; i < data_.size(); ++i) {
148 data_[i] *= scalar;
149 }
150
151 return *this;
152 }
153
154 template <typename S>
155 Matrix& Scale(const Matrix& operand, const S& scalar) {
156 CopyFrom(operand);
157 return Scale(scalar);
158 }
159
160 Matrix& Add(const Matrix& operand) {
161 CHECK_EQ(num_rows_, operand.num_rows_);
162 CHECK_EQ(num_columns_, operand.num_columns_);
163
164 for (size_t i = 0; i < data_.size(); ++i) {
165 data_[i] += operand.data_[i];
166 }
167
168 return *this;
169 }
170
171 Matrix& Add(const Matrix& lhs, const Matrix& rhs) {
172 CopyFrom(lhs);
173 return Add(rhs);
174 }
175
176 Matrix& Subtract(const Matrix& operand) {
177 CHECK_EQ(num_rows_, operand.num_rows_);
178 CHECK_EQ(num_columns_, operand.num_columns_);
179
180 for (size_t i = 0; i < data_.size(); ++i) {
181 data_[i] -= operand.data_[i];
182 }
183
184 return *this;
185 }
186
187 Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) {
188 CopyFrom(lhs);
189 return Subtract(rhs);
190 }
191
192 Matrix& PointwiseMultiply(const Matrix& operand) {
193 CHECK_EQ(num_rows_, operand.num_rows_);
194 CHECK_EQ(num_columns_, operand.num_columns_);
195
196 for (size_t i = 0; i < data_.size(); ++i) {
197 data_[i] *= operand.data_[i];
198 }
199
200 return *this;
201 }
202
203 Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) {
204 CopyFrom(lhs);
205 return PointwiseMultiply(rhs);
206 }
207
208 Matrix& PointwiseDivide(const Matrix& operand) {
209 CHECK_EQ(num_rows_, operand.num_rows_);
210 CHECK_EQ(num_columns_, operand.num_columns_);
211
212 for (size_t i = 0; i < data_.size(); ++i) {
213 data_[i] /= operand.data_[i];
214 }
215
216 return *this;
217 }
218
219 Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) {
220 CopyFrom(lhs);
221 return PointwiseDivide(rhs);
222 }
223
224 Matrix& PointwiseSquareRoot() {
225 for (size_t i = 0; i < data_.size(); ++i) {
226 data_[i] = sqrt_wrapper(data_[i]);
227 }
228
229 return *this;
230 }
231
232 Matrix& PointwiseSquareRoot(const Matrix& operand) {
233 CopyFrom(operand);
234 return PointwiseSquareRoot();
235 }
236
237 Matrix& PointwiseAbsoluteValue() {
238 for (size_t i = 0; i < data_.size(); ++i) {
239 data_[i] = abs(data_[i]);
240 }
241
242 return *this;
243 }
244
245 Matrix& PointwiseAbsoluteValue(const Matrix& operand) {
246 CopyFrom(operand);
247 return PointwiseAbsoluteValue();
248 }
249
250 Matrix& PointwiseSquare() {
251 for (size_t i = 0; i < data_.size(); ++i) {
252 data_[i] *= data_[i];
253 }
254
255 return *this;
256 }
257
258 Matrix& PointwiseSquare(const Matrix& operand) {
259 CopyFrom(operand);
260 return PointwiseSquare();
261 }
262
263 Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) {
264 CHECK_EQ(lhs.num_columns_, rhs.num_rows_);
265 CHECK_EQ(num_rows_, lhs.num_rows_);
266 CHECK_EQ(num_columns_, rhs.num_columns_);
267
268 return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements());
269 }
270
271 Matrix& Multiply(const Matrix& rhs) {
272 CHECK_EQ(num_columns_, rhs.num_rows_);
273
274 CopyDataToScratch();
aluebs@webrtc.org661af502015-02-19 19:02:17 +0000275 Resize(num_rows_, rhs.num_columns_);
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +0000276 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements());
277 }
278
279 std::string ToString() const {
280 std::ostringstream ss;
281 ss << std::endl << "Matrix" << std::endl;
282
283 for (int i = 0; i < num_rows_; ++i) {
284 for (int j = 0; j < num_columns_; ++j) {
285 ss << elements_[i][j] << " ";
286 }
287 ss << std::endl;
288 }
289 ss << std::endl;
290
291 return ss.str();
292 }
293
294 protected:
295 void SetNumRows(const int num_rows) { num_rows_ = num_rows; }
296 void SetNumColumns(const int num_columns) { num_columns_ = num_columns; }
297 T* data() { return &data_[0]; }
298 const T* data() const { return &data_[0]; }
299 const T* const* scratch_elements() const { return &scratch_elements_[0]; }
300
301 // Resize the matrix. If an increase in capacity is required, the current
302 // data is lost.
303 void Resize() {
304 size_t size = num_rows_ * num_columns_;
305 data_.resize(size);
306 elements_.resize(num_rows_);
307
308 for (int i = 0; i < num_rows_; ++i) {
309 elements_[i] = &data_[i * num_columns_];
310 }
311 }
312
313 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly.
314 void CopyDataToScratch() {
315 scratch_data_ = data_;
316 scratch_elements_.resize(num_rows_);
317
318 for (int i = 0; i < num_rows_; ++i) {
319 scratch_elements_[i] = &scratch_data_[i * num_columns_];
320 }
321 }
322
323 private:
324 int num_rows_;
325 int num_columns_;
326 std::vector<T> data_;
327 std::vector<T*> elements_;
328
329 // Stores temporary copies of |data_| and |elements_| for in-place operations
330 // where referring to original data is necessary.
331 std::vector<T> scratch_data_;
332 std::vector<T*> scratch_elements_;
333
334 // Helpers for Transpose and Multiply operations that unify in-place and
335 // out-of-place solutions.
336 Matrix& Transpose(const T* const* src) {
337 for (int i = 0; i < num_rows_; ++i) {
338 for (int j = 0; j < num_columns_; ++j) {
339 elements_[i][j] = src[j][i];
340 }
341 }
342
343 return *this;
344 }
345
346 Matrix& Multiply(const T* const* lhs, int num_rows_rhs, const T* const* rhs) {
347 for (int row = 0; row < num_rows_; ++row) {
348 for (int col = 0; col < num_columns_; ++col) {
349 T cur_element = 0;
350 for (int i = 0; i < num_rows_rhs; ++i) {
351 cur_element += lhs[row][i] * rhs[i][col];
352 }
353
354 elements_[row][col] = cur_element;
355 }
356 }
357
358 return *this;
359 }
360
361 DISALLOW_COPY_AND_ASSIGN(Matrix);
362};
363
364} // namespace webrtc
365
366#endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_