blob: 706d212d174abe9d76ce20039a0ba83e3fe842dd [file] [log] [blame]
Magnus Jedvert1927dfa2018-09-11 12:56:06 +02001/*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "rtc_tools/frame_analyzer/linear_least_squares.h"
12
13#include <numeric>
14#include <utility>
15
16#include "rtc_base/checks.h"
17#include "rtc_base/logging.h"
18
19namespace webrtc {
20namespace test {
21
22template <class T>
23using Matrix = std::valarray<std::valarray<T>>;
24
25namespace {
26
27template <typename R, typename T>
28R DotProduct(const std::valarray<T>& a, const std::valarray<T>& b) {
29 RTC_CHECK_EQ(a.size(), b.size());
30 return std::inner_product(std::begin(a), std::end(a), std::begin(b), R(0));
31}
32
33// Calculates a^T * b.
34template <typename R, typename T>
35Matrix<R> MatrixMultiply(const Matrix<T>& a, const Matrix<T>& b) {
36 Matrix<R> result(std::valarray<R>(a.size()), b.size());
37 for (size_t i = 0; i < a.size(); ++i) {
38 for (size_t j = 0; j < b.size(); ++j)
39 result[j][i] = DotProduct<R>(a[i], b[j]);
40 }
41
42 return result;
43}
44
45template <typename T>
46Matrix<T> Transpose(const Matrix<T>& matrix) {
47 if (matrix.size() == 0)
48 return Matrix<T>();
49 const size_t rows = matrix.size();
50 const size_t columns = matrix[0].size();
51 Matrix<T> result(std::valarray<T>(rows), columns);
52
53 for (size_t i = 0; i < rows; ++i) {
54 for (size_t j = 0; j < columns; ++j)
55 result[j][i] = matrix[i][j];
56 }
57
58 return result;
59}
60
61// Convert valarray from type T to type R.
62template <typename R, typename T>
63std::valarray<R> ConvertTo(const std::valarray<T>& v) {
64 std::valarray<R> result(v.size());
65 for (size_t i = 0; i < v.size(); ++i)
66 result[i] = static_cast<R>(v[i]);
67 return result;
68}
69
70// Convert valarray Matrix from type T to type R.
71template <typename R, typename T>
72Matrix<R> ConvertTo(const Matrix<T>& mat) {
73 Matrix<R> result(mat.size());
74 for (size_t i = 0; i < mat.size(); ++i)
75 result[i] = ConvertTo<R>(mat[i]);
76 return result;
77}
78
79// Convert from valarray Matrix back to the more conventional std::vector.
80template <typename T>
81std::vector<std::vector<T>> ToVectorMatrix(const Matrix<T>& m) {
82 std::vector<std::vector<T>> result;
83 for (const std::valarray<T>& v : m)
84 result.emplace_back(std::begin(v), std::end(v));
85 return result;
86}
87
88// Create a valarray Matrix from a conventional std::vector.
89template <typename T>
90Matrix<T> FromVectorMatrix(const std::vector<std::vector<T>>& mat) {
91 Matrix<T> result(mat.size());
92 for (size_t i = 0; i < mat.size(); ++i)
93 result[i] = std::valarray<T>(mat[i].data(), mat[i].size());
94 return result;
95}
96
97// Returns |matrix_to_invert|^-1 * |right_hand_matrix|. |matrix_to_invert| must
98// have square size.
99Matrix<double> GaussianElimination(Matrix<double> matrix_to_invert,
100 Matrix<double> right_hand_matrix) {
101 // |n| is the width/height of |matrix_to_invert|.
102 const size_t n = matrix_to_invert.size();
103 // Make sure |matrix_to_invert| has square size.
104 for (const std::valarray<double>& column : matrix_to_invert)
105 RTC_CHECK_EQ(n, column.size());
106 // Make sure |right_hand_matrix| has correct size.
107 for (const std::valarray<double>& column : right_hand_matrix)
108 RTC_CHECK_EQ(n, column.size());
109
110 // Transpose the matrices before and after so that we can perform Gaussian
111 // elimination on the columns instead of the rows, since that is easier with
112 // our representation.
113 matrix_to_invert = Transpose(matrix_to_invert);
114 right_hand_matrix = Transpose(right_hand_matrix);
115
116 // Loop over the diagonal of |matrix_to_invert| and perform column reduction.
117 // Column reduction is a sequence of elementary column operations that is
118 // performed on both |matrix_to_invert| and |right_hand_matrix| until
119 // |matrix_to_invert| has been transformed to the identity matrix.
120 for (size_t diagonal_index = 0; diagonal_index < n; ++diagonal_index) {
121 // Make sure the diagonal element has the highest absolute value by
122 // swapping columns if necessary.
123 for (size_t column = diagonal_index + 1; column < n; ++column) {
124 if (std::abs(matrix_to_invert[column][diagonal_index]) >
125 std::abs(matrix_to_invert[diagonal_index][diagonal_index])) {
126 std::swap(matrix_to_invert[column], matrix_to_invert[diagonal_index]);
127 std::swap(right_hand_matrix[column], right_hand_matrix[diagonal_index]);
128 }
129 }
130
131 // Reduce the diagonal element to be 1, by dividing the column with that
132 // value. If the diagonal element is 0, it means the system of equations has
133 // many solutions, and in that case we will return an arbitrary solution.
134 if (matrix_to_invert[diagonal_index][diagonal_index] == 0.0) {
135 RTC_LOG(LS_WARNING) << "Matrix is not invertible, ignoring.";
136 continue;
137 }
138 const double diagonal_element =
139 matrix_to_invert[diagonal_index][diagonal_index];
140 matrix_to_invert[diagonal_index] /= diagonal_element;
141 right_hand_matrix[diagonal_index] /= diagonal_element;
142
143 // Eliminate the other entries in row |diagonal_index| by making them zero.
144 for (size_t column = 0; column < n; ++column) {
145 if (column == diagonal_index)
146 continue;
147 const double row_element = matrix_to_invert[column][diagonal_index];
148 matrix_to_invert[column] -=
149 row_element * matrix_to_invert[diagonal_index];
150 right_hand_matrix[column] -=
151 row_element * right_hand_matrix[diagonal_index];
152 }
153 }
154
155 // Transpose the result before returning it, explained in comment above.
156 return Transpose(right_hand_matrix);
157}
158
159} // namespace
160
161IncrementalLinearLeastSquares::IncrementalLinearLeastSquares() = default;
162IncrementalLinearLeastSquares::~IncrementalLinearLeastSquares() = default;
163
164void IncrementalLinearLeastSquares::AddObservations(
165 const std::vector<std::vector<uint8_t>>& x,
166 const std::vector<std::vector<uint8_t>>& y) {
167 if (x.empty() || y.empty())
168 return;
169 // Make sure all columns are the same size.
170 const size_t n = x[0].size();
171 for (const std::vector<uint8_t>& column : x)
172 RTC_CHECK_EQ(n, column.size());
173 for (const std::vector<uint8_t>& column : y)
174 RTC_CHECK_EQ(n, column.size());
175
176 // We will multiply the uint8_t values together, so we need to expand to a
177 // type that can safely store those values, i.e. uint16_t.
178 const Matrix<uint16_t> unpacked_x = ConvertTo<uint16_t>(FromVectorMatrix(x));
179 const Matrix<uint16_t> unpacked_y = ConvertTo<uint16_t>(FromVectorMatrix(y));
180
181 const Matrix<uint64_t> xx = MatrixMultiply<uint64_t>(unpacked_x, unpacked_x);
182 const Matrix<uint64_t> xy = MatrixMultiply<uint64_t>(unpacked_x, unpacked_y);
183 if (sum_xx && sum_xy) {
184 *sum_xx += xx;
185 *sum_xy += xy;
186 } else {
187 sum_xx = xx;
188 sum_xy = xy;
189 }
190}
191
192std::vector<std::vector<double>>
193IncrementalLinearLeastSquares::GetBestSolution() const {
194 RTC_CHECK(sum_xx && sum_xy) << "No observations have been added";
195 return ToVectorMatrix(GaussianElimination(ConvertTo<double>(*sum_xx),
196 ConvertTo<double>(*sum_xy)));
197}
198
199} // namespace test
200} // namespace webrtc