blob: 1cb475c2ed484ae013bd7461fd9f00ed785fa0c1 [file] [log] [blame]
aluebs@webrtc.org0c39e912014-12-18 22:22:04 +00001/*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
12#define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
13
14#include <algorithm>
15#include <string>
16#include <vector>
17
18#include "webrtc/base/checks.h"
19#include "webrtc/base/constructormagic.h"
20#include "webrtc/modules/audio_processing/channel_buffer.h"
21#include "webrtc/system_wrappers/interface/scoped_ptr.h"
22
23namespace {
24
25// Wrappers to get around the compiler warning resulting from the fact that
26// there's no std::sqrt overload for ints. We cast all non-complex types to
27// a double for the sqrt method.
28template <typename T>
29T sqrt_wrapper(T x) {
30 return sqrt(static_cast<double>(x));
31}
32
33template <typename S>
34std::complex<S> sqrt_wrapper(std::complex<S> x) {
35 return sqrt(x);
36}
37} // namespace
38
39namespace webrtc {
40
41// Matrix is a class for doing standard matrix operations on 2 dimensional
42// matrices of any size. Results of matrix operations are stored in the
43// calling object. Function overloads exist for both in-place (the calling
44// object is used as both an operand and the result) and out-of-place (all
45// operands are passed in as parameters) operations. If operand dimensions
46// mismatch, the program crashes. Out-of-place operations change the size of
47// the calling object, if necessary, before operating.
48//
49// 'In-place' operations that inherently change the size of the matrix (eg.
50// Transpose, Multiply on different-sized matrices) must make temporary copies
51// (|scratch_elements_| and |scratch_data_|) of existing data to complete the
52// operations.
53//
54// The data is stored contiguously. Data can be accessed internally as a flat
55// array, |data_|, or as an array of row pointers, |elements_|, but is
56// available to users only as an array of row pointers through |elements()|.
57// Memory for storage is allocated when a matrix is resized only if the new
58// size overflows capacity. Memory needed temporarily for any operations is
59// similarly resized only if the new size overflows capacity.
60//
61// If you pass in storage through the ctor, that storage is copied into the
62// matrix. TODO(claguna): albeit tricky, allow for data to be referenced
63// instead of copied, and owned by the user.
64template <typename T>
65class Matrix {
66 public:
67 Matrix() : num_rows_(0), num_columns_(0) {}
68
69 // Allocates space for the elements and initializes all values to zero.
70 Matrix(int num_rows, int num_columns)
71 : num_rows_(num_rows), num_columns_(num_columns) {
72 Resize();
73 scratch_data_.resize(num_rows_ * num_columns_);
74 scratch_elements_.resize(num_rows_);
75 }
76
77 // Copies |data| into the new Matrix.
78 Matrix(const T* data, int num_rows, int num_columns)
79 : num_rows_(num_rows), num_columns_(num_columns) {
80 CopyFrom(data, num_rows, num_columns);
81 scratch_data_.resize(num_rows_ * num_columns_);
82 scratch_elements_.resize(num_rows_);
83 }
84
85 virtual ~Matrix() {}
86
87 // Deep copy an existing matrix.
88 void CopyFrom(const Matrix& other) {
89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_);
90 }
91
92 // Copy |data| into the Matrix. The current data is lost.
93 void CopyFrom(const T* const data, int num_rows, int num_columns) {
94 num_rows_ = num_rows;
95 num_columns_ = num_columns;
96 size_t size = num_rows_ * num_columns_;
97
98 data_.assign(data, data + size);
99 elements_.resize(num_rows_);
100 for (int i = 0; i < num_rows_; ++i) {
101 elements_[i] = &data_[i * num_columns_];
102 }
103 }
104
105 Matrix& CopyFromColumn(const T* const* src, int column_index, int num_rows) {
106 Resize(1, num_rows);
107 for (int i = 0; i < num_columns_; ++i) {
108 data_[i] = src[i][column_index];
109 }
110
111 return *this;
112 }
113
114 void Resize(int num_rows, int num_columns) {
115 num_rows_ = num_rows;
116 num_columns_ = num_columns;
117 Resize();
118 }
119
120 // Accessors and mutators.
121 int num_rows() const { return num_rows_; }
122 int num_columns() const { return num_columns_; }
123 T* const* elements() { return &elements_[0]; }
124 const T* const* elements() const { return &elements_[0]; }
125
126 T Trace() {
127 CHECK_EQ(num_rows_, num_columns_);
128
129 T trace = 0;
130 for (int i = 0; i < num_rows_; ++i) {
131 trace += elements_[i][i];
132 }
133 return trace;
134 }
135
136 // Matrix Operations. Returns *this to support method chaining.
137 Matrix& Transpose() {
138 CopyDataToScratch();
139 std::swap(num_rows_, num_columns_);
140 Resize();
141 return Transpose(scratch_elements());
142 }
143
144 Matrix& Transpose(const Matrix& operand) {
145 CHECK_EQ(operand.num_rows_, num_columns_);
146 CHECK_EQ(operand.num_columns_, num_rows_);
147
148 return Transpose(operand.elements());
149 }
150
151 template <typename S>
152 Matrix& Scale(const S& scalar) {
153 for (size_t i = 0; i < data_.size(); ++i) {
154 data_[i] *= scalar;
155 }
156
157 return *this;
158 }
159
160 template <typename S>
161 Matrix& Scale(const Matrix& operand, const S& scalar) {
162 CopyFrom(operand);
163 return Scale(scalar);
164 }
165
166 Matrix& Add(const Matrix& operand) {
167 CHECK_EQ(num_rows_, operand.num_rows_);
168 CHECK_EQ(num_columns_, operand.num_columns_);
169
170 for (size_t i = 0; i < data_.size(); ++i) {
171 data_[i] += operand.data_[i];
172 }
173
174 return *this;
175 }
176
177 Matrix& Add(const Matrix& lhs, const Matrix& rhs) {
178 CopyFrom(lhs);
179 return Add(rhs);
180 }
181
182 Matrix& Subtract(const Matrix& operand) {
183 CHECK_EQ(num_rows_, operand.num_rows_);
184 CHECK_EQ(num_columns_, operand.num_columns_);
185
186 for (size_t i = 0; i < data_.size(); ++i) {
187 data_[i] -= operand.data_[i];
188 }
189
190 return *this;
191 }
192
193 Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) {
194 CopyFrom(lhs);
195 return Subtract(rhs);
196 }
197
198 Matrix& PointwiseMultiply(const Matrix& operand) {
199 CHECK_EQ(num_rows_, operand.num_rows_);
200 CHECK_EQ(num_columns_, operand.num_columns_);
201
202 for (size_t i = 0; i < data_.size(); ++i) {
203 data_[i] *= operand.data_[i];
204 }
205
206 return *this;
207 }
208
209 Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) {
210 CopyFrom(lhs);
211 return PointwiseMultiply(rhs);
212 }
213
214 Matrix& PointwiseDivide(const Matrix& operand) {
215 CHECK_EQ(num_rows_, operand.num_rows_);
216 CHECK_EQ(num_columns_, operand.num_columns_);
217
218 for (size_t i = 0; i < data_.size(); ++i) {
219 data_[i] /= operand.data_[i];
220 }
221
222 return *this;
223 }
224
225 Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) {
226 CopyFrom(lhs);
227 return PointwiseDivide(rhs);
228 }
229
230 Matrix& PointwiseSquareRoot() {
231 for (size_t i = 0; i < data_.size(); ++i) {
232 data_[i] = sqrt_wrapper(data_[i]);
233 }
234
235 return *this;
236 }
237
238 Matrix& PointwiseSquareRoot(const Matrix& operand) {
239 CopyFrom(operand);
240 return PointwiseSquareRoot();
241 }
242
243 Matrix& PointwiseAbsoluteValue() {
244 for (size_t i = 0; i < data_.size(); ++i) {
245 data_[i] = abs(data_[i]);
246 }
247
248 return *this;
249 }
250
251 Matrix& PointwiseAbsoluteValue(const Matrix& operand) {
252 CopyFrom(operand);
253 return PointwiseAbsoluteValue();
254 }
255
256 Matrix& PointwiseSquare() {
257 for (size_t i = 0; i < data_.size(); ++i) {
258 data_[i] *= data_[i];
259 }
260
261 return *this;
262 }
263
264 Matrix& PointwiseSquare(const Matrix& operand) {
265 CopyFrom(operand);
266 return PointwiseSquare();
267 }
268
269 Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) {
270 CHECK_EQ(lhs.num_columns_, rhs.num_rows_);
271 CHECK_EQ(num_rows_, lhs.num_rows_);
272 CHECK_EQ(num_columns_, rhs.num_columns_);
273
274 return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements());
275 }
276
277 Matrix& Multiply(const Matrix& rhs) {
278 CHECK_EQ(num_columns_, rhs.num_rows_);
279
280 CopyDataToScratch();
281 num_columns_ = rhs.num_columns_;
282 Resize();
283 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements());
284 }
285
286 std::string ToString() const {
287 std::ostringstream ss;
288 ss << std::endl << "Matrix" << std::endl;
289
290 for (int i = 0; i < num_rows_; ++i) {
291 for (int j = 0; j < num_columns_; ++j) {
292 ss << elements_[i][j] << " ";
293 }
294 ss << std::endl;
295 }
296 ss << std::endl;
297
298 return ss.str();
299 }
300
301 protected:
302 void SetNumRows(const int num_rows) { num_rows_ = num_rows; }
303 void SetNumColumns(const int num_columns) { num_columns_ = num_columns; }
304 T* data() { return &data_[0]; }
305 const T* data() const { return &data_[0]; }
306 const T* const* scratch_elements() const { return &scratch_elements_[0]; }
307
308 // Resize the matrix. If an increase in capacity is required, the current
309 // data is lost.
310 void Resize() {
311 size_t size = num_rows_ * num_columns_;
312 data_.resize(size);
313 elements_.resize(num_rows_);
314
315 for (int i = 0; i < num_rows_; ++i) {
316 elements_[i] = &data_[i * num_columns_];
317 }
318 }
319
320 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly.
321 void CopyDataToScratch() {
322 scratch_data_ = data_;
323 scratch_elements_.resize(num_rows_);
324
325 for (int i = 0; i < num_rows_; ++i) {
326 scratch_elements_[i] = &scratch_data_[i * num_columns_];
327 }
328 }
329
330 private:
331 int num_rows_;
332 int num_columns_;
333 std::vector<T> data_;
334 std::vector<T*> elements_;
335
336 // Stores temporary copies of |data_| and |elements_| for in-place operations
337 // where referring to original data is necessary.
338 std::vector<T> scratch_data_;
339 std::vector<T*> scratch_elements_;
340
341 // Helpers for Transpose and Multiply operations that unify in-place and
342 // out-of-place solutions.
343 Matrix& Transpose(const T* const* src) {
344 for (int i = 0; i < num_rows_; ++i) {
345 for (int j = 0; j < num_columns_; ++j) {
346 elements_[i][j] = src[j][i];
347 }
348 }
349
350 return *this;
351 }
352
353 Matrix& Multiply(const T* const* lhs, int num_rows_rhs, const T* const* rhs) {
354 for (int row = 0; row < num_rows_; ++row) {
355 for (int col = 0; col < num_columns_; ++col) {
356 T cur_element = 0;
357 for (int i = 0; i < num_rows_rhs; ++i) {
358 cur_element += lhs[row][i] * rhs[i][col];
359 }
360
361 elements_[row][col] = cur_element;
362 }
363 }
364
365 return *this;
366 }
367
368 DISALLOW_COPY_AND_ASSIGN(Matrix);
369};
370
371} // namespace webrtc
372
373#endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_