Enable injection of a custom NetEqFactory into PeerConnectionFactory.

Injecting both a custom NetEqFactory and an AudioDecoderFactory is not
supported, in that case the AudioDecoderFactory should be wrapped inside
the NetEqFactory.

Bug: webrtc:11005
Change-Id: I4e311eb1bfa03c91bca587d70540e81829f881c9
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/158720
Commit-Queue: Ivo Creusen <ivoc@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29673}
diff --git a/api/BUILD.gn b/api/BUILD.gn
index 3321999..2ebf6e6 100644
--- a/api/BUILD.gn
+++ b/api/BUILD.gn
@@ -180,6 +180,7 @@
     "crypto:frame_decryptor_interface",
     "crypto:frame_encryptor_interface",
     "crypto:options",
+    "neteq:neteq_api",
     "rtc_event_log",
     "task_queue",
     "transport:bitrate_settings",
diff --git a/api/peer_connection_interface.h b/api/peer_connection_interface.h
index 55cc593..7567ab1 100644
--- a/api/peer_connection_interface.h
+++ b/api/peer_connection_interface.h
@@ -85,6 +85,7 @@
 #include "api/fec_controller.h"
 #include "api/jsep.h"
 #include "api/media_stream_interface.h"
+#include "api/neteq/neteq_factory.h"
 #include "api/network_state_predictor.h"
 #include "api/packet_socket_factory.h"
 #include "api/rtc_error.h"
@@ -1318,6 +1319,7 @@
       network_state_predictor_factory;
   std::unique_ptr<NetworkControllerFactoryInterface> network_controller_factory;
   std::unique_ptr<MediaTransportFactory> media_transport_factory;
+  std::unique_ptr<NetEqFactory> neteq_factory;
 };
 
 // PeerConnectionFactoryInterface is the factory interface used for creating
diff --git a/audio/BUILD.gn b/audio/BUILD.gn
index 927c948..e64b76f 100644
--- a/audio/BUILD.gn
+++ b/audio/BUILD.gn
@@ -50,6 +50,7 @@
     "../api/crypto:frame_decryptor_interface",
     "../api/crypto:frame_encryptor_interface",
     "../api/crypto:options",
+    "../api/neteq:neteq_api",
     "../api/rtc_event_log",
     "../api/task_queue",
     "../api/transport/media:media_transport_interface",
diff --git a/audio/audio_receive_stream.cc b/audio/audio_receive_stream.cc
index c6291c7..e1041be 100644
--- a/audio/audio_receive_stream.cc
+++ b/audio/audio_receive_stream.cc
@@ -70,13 +70,15 @@
     Clock* clock,
     webrtc::AudioState* audio_state,
     ProcessThread* module_process_thread,
+    NetEqFactory* neteq_factory,
     const webrtc::AudioReceiveStream::Config& config,
     RtcEventLog* event_log) {
   RTC_DCHECK(audio_state);
   internal::AudioState* internal_audio_state =
       static_cast<internal::AudioState*>(audio_state);
   return voe::CreateChannelReceive(
-      clock, module_process_thread, internal_audio_state->audio_device_module(),
+      clock, module_process_thread, neteq_factory,
+      internal_audio_state->audio_device_module(),
       config.media_transport_config, config.rtcp_send_transport, event_log,
       config.rtp.local_ssrc, config.rtp.remote_ssrc,
       config.jitter_buffer_max_packets, config.jitter_buffer_fast_accelerate,
@@ -91,6 +93,7 @@
     RtpStreamReceiverControllerInterface* receiver_controller,
     PacketRouter* packet_router,
     ProcessThread* module_process_thread,
+    NetEqFactory* neteq_factory,
     const webrtc::AudioReceiveStream::Config& config,
     const rtc::scoped_refptr<webrtc::AudioState>& audio_state,
     webrtc::RtcEventLog* event_log)
@@ -103,6 +106,7 @@
                          CreateChannelReceive(clock,
                                               audio_state.get(),
                                               module_process_thread,
+                                              neteq_factory,
                                               config,
                                               event_log)) {}
 
diff --git a/audio/audio_receive_stream.h b/audio/audio_receive_stream.h
index 26bcf63..24dcbf2 100644
--- a/audio/audio_receive_stream.h
+++ b/audio/audio_receive_stream.h
@@ -15,6 +15,7 @@
 #include <vector>
 
 #include "api/audio/audio_mixer.h"
+#include "api/neteq/neteq_factory.h"
 #include "api/rtp_headers.h"
 #include "audio/audio_state.h"
 #include "call/audio_receive_stream.h"
@@ -47,6 +48,7 @@
                      RtpStreamReceiverControllerInterface* receiver_controller,
                      PacketRouter* packet_router,
                      ProcessThread* module_process_thread,
+                     NetEqFactory* neteq_factory,
                      const webrtc::AudioReceiveStream::Config& config,
                      const rtc::scoped_refptr<webrtc::AudioState>& audio_state,
                      webrtc::RtcEventLog* event_log);
diff --git a/audio/channel_receive.cc b/audio/channel_receive.cc
index 7fe41a1..e19a49d 100644
--- a/audio/channel_receive.cc
+++ b/audio/channel_receive.cc
@@ -72,11 +72,13 @@
 }
 
 AudioCodingModule::Config AcmConfig(
+    NetEqFactory* neteq_factory,
     rtc::scoped_refptr<AudioDecoderFactory> decoder_factory,
     absl::optional<AudioCodecPairId> codec_pair_id,
     size_t jitter_buffer_max_packets,
     bool jitter_buffer_fast_playout) {
   AudioCodingModule::Config acm_config;
+  acm_config.neteq_factory = neteq_factory;
   acm_config.decoder_factory = decoder_factory;
   acm_config.neteq_config.codec_pair_id = codec_pair_id;
   acm_config.neteq_config.max_packets_in_buffer = jitter_buffer_max_packets;
@@ -92,6 +94,7 @@
   // Used for receive streams.
   ChannelReceive(Clock* clock,
                  ProcessThread* module_process_thread,
+                 NetEqFactory* neteq_factory,
                  AudioDeviceModule* audio_device_module,
                  const MediaTransportConfig& media_transport_config,
                  Transport* rtcp_send_transport,
@@ -453,6 +456,7 @@
 ChannelReceive::ChannelReceive(
     Clock* clock,
     ProcessThread* module_process_thread,
+    NetEqFactory* neteq_factory,
     AudioDeviceModule* audio_device_module,
     const MediaTransportConfig& media_transport_config,
     Transport* rtcp_send_transport,
@@ -470,7 +474,8 @@
     : event_log_(rtc_event_log),
       rtp_receive_statistics_(ReceiveStatistics::Create(clock)),
       remote_ssrc_(remote_ssrc),
-      acm_receiver_(AcmConfig(decoder_factory,
+      acm_receiver_(AcmConfig(neteq_factory,
+                              decoder_factory,
                               codec_pair_id,
                               jitter_buffer_max_packets,
                               jitter_buffer_fast_playout)),
@@ -964,6 +969,7 @@
 std::unique_ptr<ChannelReceiveInterface> CreateChannelReceive(
     Clock* clock,
     ProcessThread* module_process_thread,
+    NetEqFactory* neteq_factory,
     AudioDeviceModule* audio_device_module,
     const MediaTransportConfig& media_transport_config,
     Transport* rtcp_send_transport,
@@ -979,9 +985,9 @@
     rtc::scoped_refptr<FrameDecryptorInterface> frame_decryptor,
     const webrtc::CryptoOptions& crypto_options) {
   return std::make_unique<ChannelReceive>(
-      clock, module_process_thread, audio_device_module, media_transport_config,
-      rtcp_send_transport, rtc_event_log, local_ssrc, remote_ssrc,
-      jitter_buffer_max_packets, jitter_buffer_fast_playout,
+      clock, module_process_thread, neteq_factory, audio_device_module,
+      media_transport_config, rtcp_send_transport, rtc_event_log, local_ssrc,
+      remote_ssrc, jitter_buffer_max_packets, jitter_buffer_fast_playout,
       jitter_buffer_min_delay_ms, jitter_buffer_enable_rtx_handling,
       decoder_factory, codec_pair_id, frame_decryptor, crypto_options);
 }
diff --git a/audio/channel_receive.h b/audio/channel_receive.h
index fb79dc2..3cab489 100644
--- a/audio/channel_receive.h
+++ b/audio/channel_receive.h
@@ -22,6 +22,7 @@
 #include "api/call/audio_sink.h"
 #include "api/call/transport.h"
 #include "api/crypto/crypto_options.h"
+#include "api/neteq/neteq_factory.h"
 #include "api/transport/media/media_transport_config.h"
 #include "api/transport/media/media_transport_interface.h"
 #include "api/transport/rtp/rtp_source.h"
@@ -143,6 +144,7 @@
 std::unique_ptr<ChannelReceiveInterface> CreateChannelReceive(
     Clock* clock,
     ProcessThread* module_process_thread,
+    NetEqFactory* neteq_factory,
     AudioDeviceModule* audio_device_module,
     const MediaTransportConfig& media_transport_config,
     Transport* rtcp_send_transport,
diff --git a/audio/test/media_transport_test.cc b/audio/test/media_transport_test.cc
index 9646039..134a37b 100644
--- a/audio/test/media_transport_test.cc
+++ b/audio/test/media_transport_test.cc
@@ -117,8 +117,8 @@
   webrtc::internal::AudioReceiveStream receive_stream(
       Clock::GetRealTimeClock(),
       /*receiver_controller=*/nullptr,
-      /*packet_router=*/nullptr, receive_process_thread.get(), receive_config,
-      audio_state, &null_event_log);
+      /*packet_router=*/nullptr, receive_process_thread.get(),
+      /*neteq_factory=*/nullptr, receive_config, audio_state, &null_event_log);
 
   // TODO(nisse): Update AudioSendStream to not require send_transport when a
   // MediaTransport is provided.
diff --git a/call/BUILD.gn b/call/BUILD.gn
index 07e3645..94bb6ce 100644
--- a/call/BUILD.gn
+++ b/call/BUILD.gn
@@ -42,6 +42,7 @@
     "../api/crypto:frame_decryptor_interface",
     "../api/crypto:frame_encryptor_interface",
     "../api/crypto:options",
+    "../api/neteq:neteq_api",
     "../api/task_queue",
     "../api/transport:bitrate_settings",
     "../api/transport:network_control",
diff --git a/call/call.cc b/call/call.cc
index 971ebbd..4402f18 100644
--- a/call/call.cc
+++ b/call/call.cc
@@ -684,7 +684,8 @@
       CreateRtcLogStreamConfig(config)));
   AudioReceiveStream* receive_stream = new AudioReceiveStream(
       clock_, &audio_receiver_controller_, transport_send_ptr_->packet_router(),
-      module_process_thread_.get(), config, config_.audio_state, event_log_);
+      module_process_thread_.get(), config_.neteq_factory, config,
+      config_.audio_state, event_log_);
   {
     WriteLockScoped write_lock(*receive_crit_);
     receive_rtp_config_.emplace(config.rtp.remote_ssrc,
diff --git a/call/call_config.h b/call/call_config.h
index 3129530..69d9e59 100644
--- a/call/call_config.h
+++ b/call/call_config.h
@@ -11,6 +11,7 @@
 #define CALL_CALL_CONFIG_H_
 
 #include "api/fec_controller.h"
+#include "api/neteq/neteq_factory.h"
 #include "api/network_state_predictor.h"
 #include "api/rtc_error.h"
 #include "api/task_queue/task_queue_factory.h"
@@ -56,6 +57,9 @@
 
   // Network controller factory to use for this call.
   NetworkControllerFactoryInterface* network_controller_factory = nullptr;
+
+  // NetEq factory to use for this call.
+  NetEqFactory* neteq_factory = nullptr;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_coding/acm2/acm_receiver.cc b/modules/audio_coding/acm2/acm_receiver.cc
index 2723937..9783fc8 100644
--- a/modules/audio_coding/acm2/acm_receiver.cc
+++ b/modules/audio_coding/acm2/acm_receiver.cc
@@ -37,19 +37,28 @@
 namespace {
 
 std::unique_ptr<NetEq> CreateNetEq(
+    NetEqFactory* neteq_factory,
     const NetEq::Config& config,
     Clock* clock,
     const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory) {
-  CustomNetEqFactory neteq_factory(
+  RTC_CHECK((neteq_factory == nullptr) || (decoder_factory.get() == nullptr))
+      << "Either a NetEqFactory or a AudioDecoderFactory should be injected, "
+         "supplying both is not supported. Please wrap the AudioDecoderFactory "
+         "inside the NetEqFactory when using both.";
+  if (neteq_factory) {
+    return neteq_factory->CreateNetEq(config, clock);
+  }
+  CustomNetEqFactory custom_factory(
       decoder_factory, std::make_unique<DefaultNetEqControllerFactory>());
-  return neteq_factory.CreateNetEq(config, clock);
+  return custom_factory.CreateNetEq(config, clock);
 }
 
 }  // namespace
 
 AcmReceiver::AcmReceiver(const AudioCodingModule::Config& config)
     : last_audio_buffer_(new int16_t[AudioFrame::kMaxDataSizeSamples]),
-      neteq_(CreateNetEq(config.neteq_config,
+      neteq_(CreateNetEq(config.neteq_factory,
+                         config.neteq_config,
                          config.clock,
                          config.decoder_factory)),
       clock_(config.clock),
diff --git a/modules/audio_coding/include/audio_coding_module.h b/modules/audio_coding/include/audio_coding_module.h
index 05d9380..d8c9260 100644
--- a/modules/audio_coding/include/audio_coding_module.h
+++ b/modules/audio_coding/include/audio_coding_module.h
@@ -21,6 +21,7 @@
 #include "api/audio_codecs/audio_encoder.h"
 #include "api/function_view.h"
 #include "api/neteq/neteq.h"
+#include "api/neteq/neteq_factory.h"
 #include "modules/audio_coding/include/audio_coding_module_typedefs.h"
 #include "system_wrappers/include/clock.h"
 
@@ -68,6 +69,7 @@
     NetEq::Config neteq_config;
     Clock* clock;
     rtc::scoped_refptr<AudioDecoderFactory> decoder_factory;
+    NetEqFactory* neteq_factory = nullptr;
   };
 
   static AudioCodingModule* Create(const Config& config);
diff --git a/pc/peer_connection_factory.cc b/pc/peer_connection_factory.cc
index 0800718..a1a9f04 100644
--- a/pc/peer_connection_factory.cc
+++ b/pc/peer_connection_factory.cc
@@ -79,8 +79,8 @@
           std::move(dependencies.network_state_predictor_factory)),
       injected_network_controller_factory_(
           std::move(dependencies.network_controller_factory)),
-      media_transport_factory_(
-          std::move(dependencies.media_transport_factory)) {
+      media_transport_factory_(std::move(dependencies.media_transport_factory)),
+      neteq_factory_(std::move(dependencies.neteq_factory)) {
   if (!network_thread_) {
     owned_network_thread_ = rtc::Thread::CreateWithSocketServer();
     owned_network_thread_->SetName("pc_network_thread", nullptr);
@@ -371,6 +371,7 @@
   call_config.task_queue_factory = task_queue_factory_.get();
   call_config.network_state_predictor_factory =
       network_state_predictor_factory_.get();
+  call_config.neteq_factory = neteq_factory_.get();
 
   if (field_trial::IsEnabled("WebRTC-Bwe-InjectedCongestionController")) {
     RTC_LOG(LS_INFO) << "Using injected network controller factory";
diff --git a/pc/peer_connection_factory.h b/pc/peer_connection_factory.h
index 648a3af..5886dee 100644
--- a/pc/peer_connection_factory.h
+++ b/pc/peer_connection_factory.h
@@ -127,6 +127,7 @@
   std::unique_ptr<NetworkControllerFactoryInterface>
       injected_network_controller_factory_;
   std::unique_ptr<MediaTransportFactory> media_transport_factory_;
+  std::unique_ptr<NetEqFactory> neteq_factory_;
 };
 
 }  // namespace webrtc