blob: 32cb481594d5ea8d743a5300693ff95584bf5c83 [file] [log] [blame]
Magnus Jedvert1927dfa2018-09-11 12:56:06 +02001/*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "rtc_tools/frame_analyzer/linear_least_squares.h"
12
Yves Gerey3e707812018-11-28 16:47:49 +010013#include <math.h>
Jonas Olssona4d87372019-07-05 19:08:33 +020014
Yves Gerey3e707812018-11-28 16:47:49 +010015#include <cstdint>
16#include <cstdlib>
17#include <functional>
Magnus Jedvert1927dfa2018-09-11 12:56:06 +020018#include <numeric>
Yves Gerey3e707812018-11-28 16:47:49 +010019#include <type_traits>
Magnus Jedvert1927dfa2018-09-11 12:56:06 +020020#include <utility>
21
22#include "rtc_base/checks.h"
23#include "rtc_base/logging.h"
24
25namespace webrtc {
26namespace test {
27
28template <class T>
29using Matrix = std::valarray<std::valarray<T>>;
30
31namespace {
32
33template <typename R, typename T>
34R DotProduct(const std::valarray<T>& a, const std::valarray<T>& b) {
35 RTC_CHECK_EQ(a.size(), b.size());
36 return std::inner_product(std::begin(a), std::end(a), std::begin(b), R(0));
37}
38
39// Calculates a^T * b.
40template <typename R, typename T>
41Matrix<R> MatrixMultiply(const Matrix<T>& a, const Matrix<T>& b) {
42 Matrix<R> result(std::valarray<R>(a.size()), b.size());
43 for (size_t i = 0; i < a.size(); ++i) {
44 for (size_t j = 0; j < b.size(); ++j)
45 result[j][i] = DotProduct<R>(a[i], b[j]);
46 }
47
48 return result;
49}
50
51template <typename T>
52Matrix<T> Transpose(const Matrix<T>& matrix) {
53 if (matrix.size() == 0)
54 return Matrix<T>();
55 const size_t rows = matrix.size();
56 const size_t columns = matrix[0].size();
57 Matrix<T> result(std::valarray<T>(rows), columns);
58
59 for (size_t i = 0; i < rows; ++i) {
60 for (size_t j = 0; j < columns; ++j)
61 result[j][i] = matrix[i][j];
62 }
63
64 return result;
65}
66
67// Convert valarray from type T to type R.
68template <typename R, typename T>
69std::valarray<R> ConvertTo(const std::valarray<T>& v) {
70 std::valarray<R> result(v.size());
71 for (size_t i = 0; i < v.size(); ++i)
72 result[i] = static_cast<R>(v[i]);
73 return result;
74}
75
76// Convert valarray Matrix from type T to type R.
77template <typename R, typename T>
78Matrix<R> ConvertTo(const Matrix<T>& mat) {
79 Matrix<R> result(mat.size());
80 for (size_t i = 0; i < mat.size(); ++i)
81 result[i] = ConvertTo<R>(mat[i]);
82 return result;
83}
84
85// Convert from valarray Matrix back to the more conventional std::vector.
86template <typename T>
87std::vector<std::vector<T>> ToVectorMatrix(const Matrix<T>& m) {
88 std::vector<std::vector<T>> result;
89 for (const std::valarray<T>& v : m)
90 result.emplace_back(std::begin(v), std::end(v));
91 return result;
92}
93
94// Create a valarray Matrix from a conventional std::vector.
95template <typename T>
96Matrix<T> FromVectorMatrix(const std::vector<std::vector<T>>& mat) {
97 Matrix<T> result(mat.size());
98 for (size_t i = 0; i < mat.size(); ++i)
99 result[i] = std::valarray<T>(mat[i].data(), mat[i].size());
100 return result;
101}
102
103// Returns |matrix_to_invert|^-1 * |right_hand_matrix|. |matrix_to_invert| must
104// have square size.
105Matrix<double> GaussianElimination(Matrix<double> matrix_to_invert,
106 Matrix<double> right_hand_matrix) {
107 // |n| is the width/height of |matrix_to_invert|.
108 const size_t n = matrix_to_invert.size();
109 // Make sure |matrix_to_invert| has square size.
110 for (const std::valarray<double>& column : matrix_to_invert)
111 RTC_CHECK_EQ(n, column.size());
112 // Make sure |right_hand_matrix| has correct size.
113 for (const std::valarray<double>& column : right_hand_matrix)
114 RTC_CHECK_EQ(n, column.size());
115
116 // Transpose the matrices before and after so that we can perform Gaussian
117 // elimination on the columns instead of the rows, since that is easier with
118 // our representation.
119 matrix_to_invert = Transpose(matrix_to_invert);
120 right_hand_matrix = Transpose(right_hand_matrix);
121
122 // Loop over the diagonal of |matrix_to_invert| and perform column reduction.
123 // Column reduction is a sequence of elementary column operations that is
124 // performed on both |matrix_to_invert| and |right_hand_matrix| until
125 // |matrix_to_invert| has been transformed to the identity matrix.
126 for (size_t diagonal_index = 0; diagonal_index < n; ++diagonal_index) {
127 // Make sure the diagonal element has the highest absolute value by
128 // swapping columns if necessary.
129 for (size_t column = diagonal_index + 1; column < n; ++column) {
130 if (std::abs(matrix_to_invert[column][diagonal_index]) >
131 std::abs(matrix_to_invert[diagonal_index][diagonal_index])) {
132 std::swap(matrix_to_invert[column], matrix_to_invert[diagonal_index]);
133 std::swap(right_hand_matrix[column], right_hand_matrix[diagonal_index]);
134 }
135 }
136
137 // Reduce the diagonal element to be 1, by dividing the column with that
138 // value. If the diagonal element is 0, it means the system of equations has
139 // many solutions, and in that case we will return an arbitrary solution.
140 if (matrix_to_invert[diagonal_index][diagonal_index] == 0.0) {
141 RTC_LOG(LS_WARNING) << "Matrix is not invertible, ignoring.";
142 continue;
143 }
144 const double diagonal_element =
145 matrix_to_invert[diagonal_index][diagonal_index];
146 matrix_to_invert[diagonal_index] /= diagonal_element;
147 right_hand_matrix[diagonal_index] /= diagonal_element;
148
149 // Eliminate the other entries in row |diagonal_index| by making them zero.
150 for (size_t column = 0; column < n; ++column) {
151 if (column == diagonal_index)
152 continue;
153 const double row_element = matrix_to_invert[column][diagonal_index];
154 matrix_to_invert[column] -=
155 row_element * matrix_to_invert[diagonal_index];
156 right_hand_matrix[column] -=
157 row_element * right_hand_matrix[diagonal_index];
158 }
159 }
160
161 // Transpose the result before returning it, explained in comment above.
162 return Transpose(right_hand_matrix);
163}
164
165} // namespace
166
167IncrementalLinearLeastSquares::IncrementalLinearLeastSquares() = default;
168IncrementalLinearLeastSquares::~IncrementalLinearLeastSquares() = default;
169
170void IncrementalLinearLeastSquares::AddObservations(
171 const std::vector<std::vector<uint8_t>>& x,
172 const std::vector<std::vector<uint8_t>>& y) {
173 if (x.empty() || y.empty())
174 return;
175 // Make sure all columns are the same size.
176 const size_t n = x[0].size();
177 for (const std::vector<uint8_t>& column : x)
178 RTC_CHECK_EQ(n, column.size());
179 for (const std::vector<uint8_t>& column : y)
180 RTC_CHECK_EQ(n, column.size());
181
182 // We will multiply the uint8_t values together, so we need to expand to a
183 // type that can safely store those values, i.e. uint16_t.
184 const Matrix<uint16_t> unpacked_x = ConvertTo<uint16_t>(FromVectorMatrix(x));
185 const Matrix<uint16_t> unpacked_y = ConvertTo<uint16_t>(FromVectorMatrix(y));
186
187 const Matrix<uint64_t> xx = MatrixMultiply<uint64_t>(unpacked_x, unpacked_x);
188 const Matrix<uint64_t> xy = MatrixMultiply<uint64_t>(unpacked_x, unpacked_y);
189 if (sum_xx && sum_xy) {
190 *sum_xx += xx;
191 *sum_xy += xy;
192 } else {
193 sum_xx = xx;
194 sum_xy = xy;
195 }
196}
197
198std::vector<std::vector<double>>
199IncrementalLinearLeastSquares::GetBestSolution() const {
200 RTC_CHECK(sum_xx && sum_xy) << "No observations have been added";
201 return ToVectorMatrix(GaussianElimination(ConvertTo<double>(*sum_xx),
202 ConvertTo<double>(*sum_xy)));
203}
204
205} // namespace test
206} // namespace webrtc