blob: 7d8e3d77e86b9d8fc41d63a4249b525304013e82 [file] [log] [blame]
deadbeefcbecd352015-09-23 11:50:27 -07001/*
2 * Copyright 2009 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#ifndef WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_
12#define WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_
13
14#include <map>
15#include <string>
16#include <vector>
17
18#include "webrtc/p2p/base/transport.h"
19#include "webrtc/p2p/base/transportchannel.h"
20#include "webrtc/p2p/base/transportcontroller.h"
21#include "webrtc/p2p/base/transportchannelimpl.h"
22#include "webrtc/base/bind.h"
23#include "webrtc/base/buffer.h"
24#include "webrtc/base/fakesslidentity.h"
25#include "webrtc/base/messagequeue.h"
26#include "webrtc/base/sigslot.h"
27#include "webrtc/base/sslfingerprint.h"
28#include "webrtc/base/thread.h"
29
30namespace cricket {
31
32class FakeTransport;
33
34struct PacketMessageData : public rtc::MessageData {
35 PacketMessageData(const char* data, size_t len) : packet(data, len) {}
36 rtc::Buffer packet;
37};
38
39// Fake transport channel class, which can be passed to anything that needs a
40// transport channel. Can be informed of another FakeTransportChannel via
41// SetDestination.
42// TODO(hbos): Move implementation to .cc file, this and other classes in file.
43class FakeTransportChannel : public TransportChannelImpl,
44 public rtc::MessageHandler {
45 public:
46 explicit FakeTransportChannel(Transport* transport,
47 const std::string& name,
48 int component)
49 : TransportChannelImpl(name, component),
50 transport_(transport),
51 dtls_fingerprint_("", nullptr, 0) {}
52 ~FakeTransportChannel() { Reset(); }
53
Peter Boström0c4e06b2015-10-07 12:23:21 +020054 uint64_t IceTiebreaker() const { return tiebreaker_; }
deadbeefcbecd352015-09-23 11:50:27 -070055 IceMode remote_ice_mode() const { return remote_ice_mode_; }
56 const std::string& ice_ufrag() const { return ice_ufrag_; }
57 const std::string& ice_pwd() const { return ice_pwd_; }
58 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
59 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
60 const rtc::SSLFingerprint& dtls_fingerprint() const {
61 return dtls_fingerprint_;
62 }
63
64 // If async, will send packets by "Post"-ing to message queue instead of
65 // synchronously "Send"-ing.
66 void SetAsync(bool async) { async_ = async; }
67
68 Transport* GetTransport() override { return transport_; }
69
70 TransportChannelState GetState() const override {
71 if (connection_count_ == 0) {
72 return had_connection_ ? TransportChannelState::STATE_FAILED
73 : TransportChannelState::STATE_INIT;
74 }
75
76 if (connection_count_ == 1) {
77 return TransportChannelState::STATE_COMPLETED;
78 }
79
80 return TransportChannelState::STATE_CONNECTING;
81 }
82
83 void SetIceRole(IceRole role) override { role_ = role; }
84 IceRole GetIceRole() const override { return role_; }
Peter Boström0c4e06b2015-10-07 12:23:21 +020085 void SetIceTiebreaker(uint64_t tiebreaker) override {
deadbeefcbecd352015-09-23 11:50:27 -070086 tiebreaker_ = tiebreaker;
87 }
88 void SetIceCredentials(const std::string& ice_ufrag,
89 const std::string& ice_pwd) override {
90 ice_ufrag_ = ice_ufrag;
91 ice_pwd_ = ice_pwd;
92 }
93 void SetRemoteIceCredentials(const std::string& ice_ufrag,
94 const std::string& ice_pwd) override {
95 remote_ice_ufrag_ = ice_ufrag;
96 remote_ice_pwd_ = ice_pwd;
97 }
98
99 void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; }
100 bool SetRemoteFingerprint(const std::string& alg,
Peter Boström0c4e06b2015-10-07 12:23:21 +0200101 const uint8_t* digest,
deadbeefcbecd352015-09-23 11:50:27 -0700102 size_t digest_len) override {
103 dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
104 return true;
105 }
106 bool SetSslRole(rtc::SSLRole role) override {
107 ssl_role_ = role;
108 return true;
109 }
110 bool GetSslRole(rtc::SSLRole* role) const override {
111 *role = ssl_role_;
112 return true;
113 }
114
115 void Connect() override {
116 if (state_ == STATE_INIT) {
117 state_ = STATE_CONNECTING;
118 }
119 }
120
121 void MaybeStartGathering() override {
122 if (gathering_state_ == kIceGatheringNew) {
123 gathering_state_ = kIceGatheringGathering;
124 SignalGatheringState(this);
125 }
126 }
127
128 IceGatheringState gathering_state() const override {
129 return gathering_state_;
130 }
131
132 void Reset() {
133 if (state_ != STATE_INIT) {
134 state_ = STATE_INIT;
135 if (dest_) {
136 dest_->state_ = STATE_INIT;
137 dest_->dest_ = nullptr;
138 dest_ = nullptr;
139 }
140 }
141 }
142
143 void SetWritable(bool writable) { set_writable(writable); }
144
145 void SetDestination(FakeTransportChannel* dest) {
146 if (state_ == STATE_CONNECTING && dest) {
147 // This simulates the delivery of candidates.
148 dest_ = dest;
149 dest_->dest_ = this;
150 if (local_cert_ && dest_->local_cert_) {
151 do_dtls_ = true;
152 dest_->do_dtls_ = true;
153 NegotiateSrtpCiphers();
154 }
155 state_ = STATE_CONNECTED;
156 dest_->state_ = STATE_CONNECTED;
157 set_writable(true);
158 dest_->set_writable(true);
159 } else if (state_ == STATE_CONNECTED && !dest) {
160 // Simulates loss of connectivity, by asymmetrically forgetting dest_.
161 dest_ = nullptr;
162 state_ = STATE_CONNECTING;
163 set_writable(false);
164 }
165 }
166
167 void SetConnectionCount(size_t connection_count) {
168 size_t old_connection_count = connection_count_;
169 connection_count_ = connection_count;
170 if (connection_count)
171 had_connection_ = true;
172 if (connection_count_ < old_connection_count)
173 SignalConnectionRemoved(this);
174 }
175
176 void SetCandidatesGatheringComplete() {
177 if (gathering_state_ != kIceGatheringComplete) {
178 gathering_state_ = kIceGatheringComplete;
179 SignalGatheringState(this);
180 }
181 }
182
183 void SetReceiving(bool receiving) { set_receiving(receiving); }
184
honghaiz1f429e32015-09-28 07:57:34 -0700185 void SetIceConfig(const IceConfig& config) override {
186 receiving_timeout_ = config.receiving_timeout_ms;
187 gather_continually_ = config.gather_continually;
deadbeefcbecd352015-09-23 11:50:27 -0700188 }
189
190 int receiving_timeout() const { return receiving_timeout_; }
honghaiz1f429e32015-09-28 07:57:34 -0700191 bool gather_continually() const { return gather_continually_; }
deadbeefcbecd352015-09-23 11:50:27 -0700192
193 int SendPacket(const char* data,
194 size_t len,
195 const rtc::PacketOptions& options,
196 int flags) override {
197 if (state_ != STATE_CONNECTED) {
198 return -1;
199 }
200
201 if (flags != PF_SRTP_BYPASS && flags != 0) {
202 return -1;
203 }
204
205 PacketMessageData* packet = new PacketMessageData(data, len);
206 if (async_) {
207 rtc::Thread::Current()->Post(this, 0, packet);
208 } else {
209 rtc::Thread::Current()->Send(this, 0, packet);
210 }
211 return static_cast<int>(len);
212 }
213 int SetOption(rtc::Socket::Option opt, int value) override { return true; }
214 bool GetOption(rtc::Socket::Option opt, int* value) override { return true; }
215 int GetError() override { return 0; }
216
217 void AddRemoteCandidate(const Candidate& candidate) override {
218 remote_candidates_.push_back(candidate);
219 }
220 const Candidates& remote_candidates() const { return remote_candidates_; }
221
222 void OnMessage(rtc::Message* msg) override {
223 PacketMessageData* data = static_cast<PacketMessageData*>(msg->pdata);
224 dest_->SignalReadPacket(dest_, data->packet.data<char>(),
225 data->packet.size(), rtc::CreatePacketTime(0), 0);
226 delete data;
227 }
228
229 bool SetLocalCertificate(
230 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
231 local_cert_ = certificate;
232 return true;
233 }
234
235 void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) {
236 remote_cert_ = cert;
237 }
238
239 bool IsDtlsActive() const override { return do_dtls_; }
240
241 bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override {
242 srtp_ciphers_ = ciphers;
243 return true;
244 }
245
Guo-wei Shieh456696a2015-09-30 21:48:54 -0700246 bool GetSrtpCryptoSuite(std::string* cipher) override {
deadbeefcbecd352015-09-23 11:50:27 -0700247 if (!chosen_srtp_cipher_.empty()) {
248 *cipher = chosen_srtp_cipher_;
249 return true;
250 }
251 return false;
252 }
253
Guo-wei Shieh6caafbe2015-10-05 12:43:27 -0700254 bool GetSslCipherSuite(int* cipher) override { return false; }
deadbeefcbecd352015-09-23 11:50:27 -0700255
256 rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const {
257 return local_cert_;
258 }
259
260 bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override {
261 if (!remote_cert_)
262 return false;
263
264 *cert = remote_cert_->GetReference();
265 return true;
266 }
267
268 bool ExportKeyingMaterial(const std::string& label,
Peter Boström0c4e06b2015-10-07 12:23:21 +0200269 const uint8_t* context,
deadbeefcbecd352015-09-23 11:50:27 -0700270 size_t context_len,
271 bool use_context,
Peter Boström0c4e06b2015-10-07 12:23:21 +0200272 uint8_t* result,
deadbeefcbecd352015-09-23 11:50:27 -0700273 size_t result_len) override {
274 if (!chosen_srtp_cipher_.empty()) {
275 memset(result, 0xff, result_len);
276 return true;
277 }
278
279 return false;
280 }
281
282 void NegotiateSrtpCiphers() {
283 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
284 it1 != srtp_ciphers_.end(); ++it1) {
285 for (std::vector<std::string>::const_iterator it2 =
286 dest_->srtp_ciphers_.begin();
287 it2 != dest_->srtp_ciphers_.end(); ++it2) {
288 if (*it1 == *it2) {
289 chosen_srtp_cipher_ = *it1;
290 dest_->chosen_srtp_cipher_ = *it2;
291 return;
292 }
293 }
294 }
295 }
296
297 bool GetStats(ConnectionInfos* infos) override {
298 ConnectionInfo info;
299 infos->clear();
300 infos->push_back(info);
301 return true;
302 }
303
304 void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) {
305 ssl_max_version_ = version;
306 }
307 rtc::SSLProtocolVersion ssl_max_protocol_version() const {
308 return ssl_max_version_;
309 }
310
311 private:
312 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
313 Transport* transport_;
314 FakeTransportChannel* dest_ = nullptr;
315 State state_ = STATE_INIT;
316 bool async_ = false;
317 Candidates remote_candidates_;
318 rtc::scoped_refptr<rtc::RTCCertificate> local_cert_;
319 rtc::FakeSSLCertificate* remote_cert_ = nullptr;
320 bool do_dtls_ = false;
321 std::vector<std::string> srtp_ciphers_;
322 std::string chosen_srtp_cipher_;
323 int receiving_timeout_ = -1;
honghaiz1f429e32015-09-28 07:57:34 -0700324 bool gather_continually_ = false;
deadbeefcbecd352015-09-23 11:50:27 -0700325 IceRole role_ = ICEROLE_UNKNOWN;
Peter Boström0c4e06b2015-10-07 12:23:21 +0200326 uint64_t tiebreaker_ = 0;
deadbeefcbecd352015-09-23 11:50:27 -0700327 std::string ice_ufrag_;
328 std::string ice_pwd_;
329 std::string remote_ice_ufrag_;
330 std::string remote_ice_pwd_;
331 IceMode remote_ice_mode_ = ICEMODE_FULL;
332 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10;
333 rtc::SSLFingerprint dtls_fingerprint_;
334 rtc::SSLRole ssl_role_ = rtc::SSL_CLIENT;
335 size_t connection_count_ = 0;
336 IceGatheringState gathering_state_ = kIceGatheringNew;
337 bool had_connection_ = false;
338};
339
340// Fake transport class, which can be passed to anything that needs a Transport.
341// Can be informed of another FakeTransport via SetDestination (low-tech way
342// of doing candidates)
343class FakeTransport : public Transport {
344 public:
345 typedef std::map<int, FakeTransportChannel*> ChannelMap;
346
347 explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {}
348
349 // Note that we only have a constructor with the allocator parameter so it can
350 // be wrapped by a DtlsTransport.
351 FakeTransport(const std::string& name, PortAllocator* allocator)
352 : Transport(name, nullptr) {}
353
354 ~FakeTransport() { DestroyAllChannels(); }
355
356 const ChannelMap& channels() const { return channels_; }
357
358 // If async, will send packets by "Post"-ing to message queue instead of
359 // synchronously "Send"-ing.
360 void SetAsync(bool async) { async_ = async; }
361 void SetDestination(FakeTransport* dest) {
362 dest_ = dest;
363 for (const auto& kv : channels_) {
364 kv.second->SetLocalCertificate(certificate_);
365 SetChannelDestination(kv.first, kv.second);
366 }
367 }
368
369 void SetWritable(bool writable) {
370 for (const auto& kv : channels_) {
371 kv.second->SetWritable(writable);
372 }
373 }
374
375 void SetLocalCertificate(
376 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
377 certificate_ = certificate;
378 }
379 bool GetLocalCertificate(
380 rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override {
381 if (!certificate_)
382 return false;
383
384 *certificate = certificate_;
385 return true;
386 }
387
388 bool GetSslRole(rtc::SSLRole* role) const override {
389 if (channels_.empty()) {
390 return false;
391 }
392 return channels_.begin()->second->GetSslRole(role);
393 }
394
395 bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override {
396 ssl_max_version_ = version;
397 for (const auto& kv : channels_) {
398 kv.second->set_ssl_max_protocol_version(ssl_max_version_);
399 }
400 return true;
401 }
402 rtc::SSLProtocolVersion ssl_max_protocol_version() const {
403 return ssl_max_version_;
404 }
405
406 using Transport::local_description;
407 using Transport::remote_description;
408
409 protected:
410 TransportChannelImpl* CreateTransportChannel(int component) override {
411 if (channels_.find(component) != channels_.end()) {
412 return nullptr;
413 }
414 FakeTransportChannel* channel =
415 new FakeTransportChannel(this, name(), component);
416 channel->set_ssl_max_protocol_version(ssl_max_version_);
417 channel->SetAsync(async_);
418 SetChannelDestination(component, channel);
419 channels_[component] = channel;
420 return channel;
421 }
422
423 void DestroyTransportChannel(TransportChannelImpl* channel) override {
424 channels_.erase(channel->component());
425 delete channel;
426 }
427
428 private:
429 FakeTransportChannel* GetFakeChannel(int component) {
430 auto it = channels_.find(component);
431 return (it != channels_.end()) ? it->second : nullptr;
432 }
433
434 void SetChannelDestination(int component, FakeTransportChannel* channel) {
435 FakeTransportChannel* dest_channel = nullptr;
436 if (dest_) {
437 dest_channel = dest_->GetFakeChannel(component);
438 if (dest_channel) {
439 dest_channel->SetLocalCertificate(dest_->certificate_);
440 }
441 }
442 channel->SetDestination(dest_channel);
443 }
444
445 // Note, this is distinct from the Channel map owned by Transport.
446 // This map just tracks the FakeTransportChannels created by this class.
447 // It's mainly needed so that we can access a FakeTransportChannel directly,
448 // even if wrapped by a DtlsTransportChannelWrapper.
449 ChannelMap channels_;
450 FakeTransport* dest_ = nullptr;
451 bool async_ = false;
452 rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
453 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10;
454};
455
456// Fake TransportController class, which can be passed into a BaseChannel object
457// for test purposes. Can be connected to other FakeTransportControllers via
458// Connect().
459//
460// This fake is unusual in that for the most part, it's implemented with the
461// real TransportController code, but with fake TransportChannels underneath.
462class FakeTransportController : public TransportController {
463 public:
464 FakeTransportController()
465 : TransportController(rtc::Thread::Current(),
466 rtc::Thread::Current(),
467 nullptr),
468 fail_create_channel_(false) {}
469
470 explicit FakeTransportController(IceRole role)
471 : TransportController(rtc::Thread::Current(),
472 rtc::Thread::Current(),
473 nullptr),
474 fail_create_channel_(false) {
475 SetIceRole(role);
476 }
477
478 explicit FakeTransportController(rtc::Thread* worker_thread)
479 : TransportController(rtc::Thread::Current(), worker_thread, nullptr),
480 fail_create_channel_(false) {}
481
482 FakeTransportController(rtc::Thread* worker_thread, IceRole role)
483 : TransportController(rtc::Thread::Current(), worker_thread, nullptr),
484 fail_create_channel_(false) {
485 SetIceRole(role);
486 }
487
488 FakeTransport* GetTransport_w(const std::string& transport_name) {
489 return static_cast<FakeTransport*>(
490 TransportController::GetTransport_w(transport_name));
491 }
492
493 void Connect(FakeTransportController* dest) {
494 worker_thread()->Invoke<void>(
495 rtc::Bind(&FakeTransportController::Connect_w, this, dest));
496 }
497
498 TransportChannel* CreateTransportChannel_w(const std::string& transport_name,
499 int component) override {
500 if (fail_create_channel_) {
501 return nullptr;
502 }
503 return TransportController::CreateTransportChannel_w(transport_name,
504 component);
505 }
506
507 void set_fail_channel_creation(bool fail_channel_creation) {
508 fail_create_channel_ = fail_channel_creation;
509 }
510
511 protected:
512 Transport* CreateTransport_w(const std::string& transport_name) override {
513 return new FakeTransport(transport_name);
514 }
515
516 void Connect_w(FakeTransportController* dest) {
517 // Simulate the exchange of candidates.
518 ConnectChannels_w();
519 dest->ConnectChannels_w();
520 for (auto& kv : transports()) {
521 FakeTransport* transport = static_cast<FakeTransport*>(kv.second);
522 transport->SetDestination(dest->GetTransport_w(kv.first));
523 }
524 }
525
526 void ConnectChannels_w() {
527 for (auto& kv : transports()) {
528 FakeTransport* transport = static_cast<FakeTransport*>(kv.second);
529 transport->ConnectChannels();
530 transport->MaybeStartGathering();
531 }
532 }
533
534 private:
535 bool fail_create_channel_;
536};
537
538} // namespace cricket
539
540#endif // WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_