blob: 255331b672ed4ecd3c2f95dd981afc51ac87ef4f [file] [log] [blame]
peah5e79b292017-04-12 01:20:45 -07001/*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020011#ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
12#define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
peah5e79b292017-04-12 01:20:45 -070013
Niels Möllera12c42a2018-07-25 16:05:48 +020014// Defines WEBRTC_ARCH_X86_FAMILY, used below.
15#include "rtc_base/system/arch.h"
16
peah5d153c72017-05-03 06:45:44 -070017#if defined(WEBRTC_HAS_NEON)
18#include <arm_neon.h>
19#endif
peah5e79b292017-04-12 01:20:45 -070020#if defined(WEBRTC_ARCH_X86_FAMILY)
21#include <emmintrin.h>
22#endif
23#include <math.h>
24#include <algorithm>
25#include <array>
26#include <functional>
27
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020028#include "api/array_view.h"
29#include "modules/audio_processing/aec3/aec3_common.h"
30#include "rtc_base/checks.h"
peah5e79b292017-04-12 01:20:45 -070031
32namespace webrtc {
33namespace aec3 {
34
35// Provides optimizations for mathematical operations based on vectors.
36class VectorMath {
37 public:
38 explicit VectorMath(Aec3Optimization optimization)
39 : optimization_(optimization) {}
40
41 // Elementwise square root.
42 void Sqrt(rtc::ArrayView<float> x) {
43 switch (optimization_) {
44#if defined(WEBRTC_ARCH_X86_FAMILY)
45 case Aec3Optimization::kSse2: {
46 const int x_size = static_cast<int>(x.size());
47 const int vector_limit = x_size >> 2;
48
49 int j = 0;
50 for (; j < vector_limit * 4; j += 4) {
51 __m128 g = _mm_loadu_ps(&x[j]);
52 g = _mm_sqrt_ps(g);
53 _mm_storeu_ps(&x[j], g);
54 }
55
56 for (; j < x_size; ++j) {
57 x[j] = sqrtf(x[j]);
58 }
59 } break;
60#endif
peah5d153c72017-05-03 06:45:44 -070061#if defined(WEBRTC_HAS_NEON)
62 case Aec3Optimization::kNeon: {
63 const int x_size = static_cast<int>(x.size());
64 const int vector_limit = x_size >> 2;
65
66 int j = 0;
67 for (; j < vector_limit * 4; j += 4) {
68 float32x4_t g = vld1q_f32(&x[j]);
69#if !defined(WEBRTC_ARCH_ARM64)
70 float32x4_t y = vrsqrteq_f32(g);
71
72 // Code to handle sqrt(0).
73 // If the input to sqrtf() is zero, a zero will be returned.
74 // If the input to vrsqrteq_f32() is zero, positive infinity is
75 // returned.
76 const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
77 // check for divide by zero
78 const uint32x4_t div_by_zero =
79 vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y));
80 // zero out the positive infinity results
81 y = vreinterpretq_f32_u32(
82 vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y)));
83 // from arm documentation
84 // The Newton-Raphson iteration:
85 // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2)
86 // converges to (1/√d) if y0 is the result of VRSQRTE applied to d.
87 //
88 // Note: The precision did not improve after 2 iterations.
89 for (int i = 0; i < 2; i++) {
90 y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y);
91 }
92 // sqrt(g) = g * 1/sqrt(g)
93 g = vmulq_f32(g, y);
94#else
95 g = vsqrtq_f32(g);
96#endif
97 vst1q_f32(&x[j], g);
98 }
99
100 for (; j < x_size; ++j) {
101 x[j] = sqrtf(x[j]);
102 }
103 }
104#endif
105 break;
peah5e79b292017-04-12 01:20:45 -0700106 default:
107 std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); });
108 }
109 }
110
111 // Elementwise vector multiplication z = x * y.
112 void Multiply(rtc::ArrayView<const float> x,
113 rtc::ArrayView<const float> y,
114 rtc::ArrayView<float> z) {
115 RTC_DCHECK_EQ(z.size(), x.size());
116 RTC_DCHECK_EQ(z.size(), y.size());
117 switch (optimization_) {
118#if defined(WEBRTC_ARCH_X86_FAMILY)
119 case Aec3Optimization::kSse2: {
120 const int x_size = static_cast<int>(x.size());
121 const int vector_limit = x_size >> 2;
122
123 int j = 0;
124 for (; j < vector_limit * 4; j += 4) {
125 const __m128 x_j = _mm_loadu_ps(&x[j]);
126 const __m128 y_j = _mm_loadu_ps(&y[j]);
127 const __m128 z_j = _mm_mul_ps(x_j, y_j);
128 _mm_storeu_ps(&z[j], z_j);
129 }
130
131 for (; j < x_size; ++j) {
132 z[j] = x[j] * y[j];
133 }
134 } break;
135#endif
peah5d153c72017-05-03 06:45:44 -0700136#if defined(WEBRTC_HAS_NEON)
137 case Aec3Optimization::kNeon: {
138 const int x_size = static_cast<int>(x.size());
139 const int vector_limit = x_size >> 2;
140
141 int j = 0;
142 for (; j < vector_limit * 4; j += 4) {
143 const float32x4_t x_j = vld1q_f32(&x[j]);
144 const float32x4_t y_j = vld1q_f32(&y[j]);
145 const float32x4_t z_j = vmulq_f32(x_j, y_j);
146 vst1q_f32(&z[j], z_j);
147 }
148
149 for (; j < x_size; ++j) {
150 z[j] = x[j] * y[j];
151 }
152 } break;
153#endif
peah5e79b292017-04-12 01:20:45 -0700154 default:
155 std::transform(x.begin(), x.end(), y.begin(), z.begin(),
156 std::multiplies<float>());
157 }
158 }
159
160 // Elementwise vector accumulation z += x.
161 void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) {
162 RTC_DCHECK_EQ(z.size(), x.size());
163 switch (optimization_) {
164#if defined(WEBRTC_ARCH_X86_FAMILY)
165 case Aec3Optimization::kSse2: {
166 const int x_size = static_cast<int>(x.size());
167 const int vector_limit = x_size >> 2;
168
169 int j = 0;
170 for (; j < vector_limit * 4; j += 4) {
171 const __m128 x_j = _mm_loadu_ps(&x[j]);
172 __m128 z_j = _mm_loadu_ps(&z[j]);
173 z_j = _mm_add_ps(x_j, z_j);
174 _mm_storeu_ps(&z[j], z_j);
175 }
176
177 for (; j < x_size; ++j) {
178 z[j] += x[j];
179 }
180 } break;
181#endif
peah5d153c72017-05-03 06:45:44 -0700182#if defined(WEBRTC_HAS_NEON)
183 case Aec3Optimization::kNeon: {
184 const int x_size = static_cast<int>(x.size());
185 const int vector_limit = x_size >> 2;
186
187 int j = 0;
188 for (; j < vector_limit * 4; j += 4) {
189 const float32x4_t x_j = vld1q_f32(&x[j]);
190 float32x4_t z_j = vld1q_f32(&z[j]);
191 z_j = vaddq_f32(z_j, x_j);
192 vst1q_f32(&z[j], z_j);
193 }
194
195 for (; j < x_size; ++j) {
196 z[j] += x[j];
197 }
198 } break;
199#endif
peah5e79b292017-04-12 01:20:45 -0700200 default:
201 std::transform(x.begin(), x.end(), z.begin(), z.begin(),
202 std::plus<float>());
203 }
204 }
205
206 private:
207 Aec3Optimization optimization_;
208};
209
210} // namespace aec3
211
212} // namespace webrtc
213
Mirko Bonadei92ea95e2017-09-15 06:47:31 +0200214#endif // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_