blob: 0672b513f3c0fb6a3645ea26bfbe69e9ad41413b [file] [log] [blame]
peah5e79b292017-04-12 01:20:45 -07001/*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020011#ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
12#define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
peah5e79b292017-04-12 01:20:45 -070013
Mirko Bonadei71207422017-09-15 13:58:09 +020014#include "typedefs.h" // NOLINT(build/include)
peah5d153c72017-05-03 06:45:44 -070015#if defined(WEBRTC_HAS_NEON)
16#include <arm_neon.h>
17#endif
peah5e79b292017-04-12 01:20:45 -070018#if defined(WEBRTC_ARCH_X86_FAMILY)
19#include <emmintrin.h>
20#endif
21#include <math.h>
22#include <algorithm>
23#include <array>
24#include <functional>
25
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020026#include "api/array_view.h"
27#include "modules/audio_processing/aec3/aec3_common.h"
28#include "rtc_base/checks.h"
peah5e79b292017-04-12 01:20:45 -070029
30namespace webrtc {
31namespace aec3 {
32
33// Provides optimizations for mathematical operations based on vectors.
34class VectorMath {
35 public:
36 explicit VectorMath(Aec3Optimization optimization)
37 : optimization_(optimization) {}
38
39 // Elementwise square root.
40 void Sqrt(rtc::ArrayView<float> x) {
41 switch (optimization_) {
42#if defined(WEBRTC_ARCH_X86_FAMILY)
43 case Aec3Optimization::kSse2: {
44 const int x_size = static_cast<int>(x.size());
45 const int vector_limit = x_size >> 2;
46
47 int j = 0;
48 for (; j < vector_limit * 4; j += 4) {
49 __m128 g = _mm_loadu_ps(&x[j]);
50 g = _mm_sqrt_ps(g);
51 _mm_storeu_ps(&x[j], g);
52 }
53
54 for (; j < x_size; ++j) {
55 x[j] = sqrtf(x[j]);
56 }
57 } break;
58#endif
peah5d153c72017-05-03 06:45:44 -070059#if defined(WEBRTC_HAS_NEON)
60 case Aec3Optimization::kNeon: {
61 const int x_size = static_cast<int>(x.size());
62 const int vector_limit = x_size >> 2;
63
64 int j = 0;
65 for (; j < vector_limit * 4; j += 4) {
66 float32x4_t g = vld1q_f32(&x[j]);
67#if !defined(WEBRTC_ARCH_ARM64)
68 float32x4_t y = vrsqrteq_f32(g);
69
70 // Code to handle sqrt(0).
71 // If the input to sqrtf() is zero, a zero will be returned.
72 // If the input to vrsqrteq_f32() is zero, positive infinity is
73 // returned.
74 const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
75 // check for divide by zero
76 const uint32x4_t div_by_zero =
77 vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y));
78 // zero out the positive infinity results
79 y = vreinterpretq_f32_u32(
80 vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y)));
81 // from arm documentation
82 // The Newton-Raphson iteration:
83 // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2)
84 // converges to (1/√d) if y0 is the result of VRSQRTE applied to d.
85 //
86 // Note: The precision did not improve after 2 iterations.
87 for (int i = 0; i < 2; i++) {
88 y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y);
89 }
90 // sqrt(g) = g * 1/sqrt(g)
91 g = vmulq_f32(g, y);
92#else
93 g = vsqrtq_f32(g);
94#endif
95 vst1q_f32(&x[j], g);
96 }
97
98 for (; j < x_size; ++j) {
99 x[j] = sqrtf(x[j]);
100 }
101 }
102#endif
103 break;
peah5e79b292017-04-12 01:20:45 -0700104 default:
105 std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); });
106 }
107 }
108
109 // Elementwise vector multiplication z = x * y.
110 void Multiply(rtc::ArrayView<const float> x,
111 rtc::ArrayView<const float> y,
112 rtc::ArrayView<float> z) {
113 RTC_DCHECK_EQ(z.size(), x.size());
114 RTC_DCHECK_EQ(z.size(), y.size());
115 switch (optimization_) {
116#if defined(WEBRTC_ARCH_X86_FAMILY)
117 case Aec3Optimization::kSse2: {
118 const int x_size = static_cast<int>(x.size());
119 const int vector_limit = x_size >> 2;
120
121 int j = 0;
122 for (; j < vector_limit * 4; j += 4) {
123 const __m128 x_j = _mm_loadu_ps(&x[j]);
124 const __m128 y_j = _mm_loadu_ps(&y[j]);
125 const __m128 z_j = _mm_mul_ps(x_j, y_j);
126 _mm_storeu_ps(&z[j], z_j);
127 }
128
129 for (; j < x_size; ++j) {
130 z[j] = x[j] * y[j];
131 }
132 } break;
133#endif
peah5d153c72017-05-03 06:45:44 -0700134#if defined(WEBRTC_HAS_NEON)
135 case Aec3Optimization::kNeon: {
136 const int x_size = static_cast<int>(x.size());
137 const int vector_limit = x_size >> 2;
138
139 int j = 0;
140 for (; j < vector_limit * 4; j += 4) {
141 const float32x4_t x_j = vld1q_f32(&x[j]);
142 const float32x4_t y_j = vld1q_f32(&y[j]);
143 const float32x4_t z_j = vmulq_f32(x_j, y_j);
144 vst1q_f32(&z[j], z_j);
145 }
146
147 for (; j < x_size; ++j) {
148 z[j] = x[j] * y[j];
149 }
150 } break;
151#endif
peah5e79b292017-04-12 01:20:45 -0700152 default:
153 std::transform(x.begin(), x.end(), y.begin(), z.begin(),
154 std::multiplies<float>());
155 }
156 }
157
158 // Elementwise vector accumulation z += x.
159 void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) {
160 RTC_DCHECK_EQ(z.size(), x.size());
161 switch (optimization_) {
162#if defined(WEBRTC_ARCH_X86_FAMILY)
163 case Aec3Optimization::kSse2: {
164 const int x_size = static_cast<int>(x.size());
165 const int vector_limit = x_size >> 2;
166
167 int j = 0;
168 for (; j < vector_limit * 4; j += 4) {
169 const __m128 x_j = _mm_loadu_ps(&x[j]);
170 __m128 z_j = _mm_loadu_ps(&z[j]);
171 z_j = _mm_add_ps(x_j, z_j);
172 _mm_storeu_ps(&z[j], z_j);
173 }
174
175 for (; j < x_size; ++j) {
176 z[j] += x[j];
177 }
178 } break;
179#endif
peah5d153c72017-05-03 06:45:44 -0700180#if defined(WEBRTC_HAS_NEON)
181 case Aec3Optimization::kNeon: {
182 const int x_size = static_cast<int>(x.size());
183 const int vector_limit = x_size >> 2;
184
185 int j = 0;
186 for (; j < vector_limit * 4; j += 4) {
187 const float32x4_t x_j = vld1q_f32(&x[j]);
188 float32x4_t z_j = vld1q_f32(&z[j]);
189 z_j = vaddq_f32(z_j, x_j);
190 vst1q_f32(&z[j], z_j);
191 }
192
193 for (; j < x_size; ++j) {
194 z[j] += x[j];
195 }
196 } break;
197#endif
peah5e79b292017-04-12 01:20:45 -0700198 default:
199 std::transform(x.begin(), x.end(), z.begin(), z.begin(),
200 std::plus<float>());
201 }
202 }
203
204 private:
205 Aec3Optimization optimization_;
206};
207
208} // namespace aec3
209
210} // namespace webrtc
211
Mirko Bonadei92ea95e2017-09-15 06:47:31 +0200212#endif // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_