blob: 4022261863fb6f7c0d2bd2f463ff61f460c56f24 [file] [log] [blame]
Jason Jeremy Iman63fd8152021-02-01 05:28:04 +09001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "dns-proxy/resolver.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/test/task_environment.h>
11#include <base/time/time.h>
12#include <gmock/gmock.h>
13#include <gtest/gtest.h>
14
15#include "dns-proxy/ares_client.h"
16#include "dns-proxy/doh_curl_client.h"
17
18using testing::_;
19using testing::Return;
20
21namespace dns_proxy {
22namespace {
23const std::vector<std::string> kTestNameServers{"8.8.8.8"};
24const std::vector<std::string> kTestDoHProviders{
25 "https://dns.google/dns-query"};
26constexpr base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(3);
27constexpr int32_t kMaxNumRetries = 1;
28
29class MockDoHCurlClient : public DoHCurlClient {
30 public:
31 MockDoHCurlClient() : DoHCurlClient(kTimeout, kDefaultMaxConcurrentQueries) {}
32 ~MockDoHCurlClient() = default;
33
34 MOCK_METHOD4(
35 Resolve,
36 bool(const char* msg, int len, const QueryCallback& callback, void* ctx));
37
38 MOCK_METHOD1(SetNameServers,
39 void(const std::vector<std::string>& name_servers));
40 MOCK_METHOD1(SetDoHProviders,
41 void(const std::vector<std::string>& doh_providers));
42};
43
44class MockAresClient : public AresClient {
45 public:
46 MockAresClient()
47 : AresClient(kTimeout, kMaxNumRetries, kDefaultMaxConcurrentQueries) {}
48 ~MockAresClient() = default;
49
50 MOCK_METHOD4(Resolve,
51 bool(const unsigned char* msg,
52 size_t len,
53 const QueryCallback& callback,
54 void* ctx));
55
56 MOCK_METHOD1(SetNameServers,
57 void(const std::vector<std::string>& name_servers));
58};
59
60} // namespace
61
62class ResolverTest : public testing::Test {
63 protected:
64 void SetUp() override {
65 std::unique_ptr<MockAresClient> scoped_ares_client(new MockAresClient());
66 std::unique_ptr<MockDoHCurlClient> scoped_curl_client(
67 new MockDoHCurlClient());
68 ares_client_ = scoped_ares_client.get();
69 curl_client_ = scoped_curl_client.get();
70 resolver_ = std::make_unique<Resolver>(std::move(scoped_ares_client),
71 std::move(scoped_curl_client));
72 }
73
74 base::test::TaskEnvironment task_environment_;
75
76 MockAresClient* ares_client_;
77 MockDoHCurlClient* curl_client_;
78 std::unique_ptr<Resolver> resolver_;
79};
80
81TEST_F(ResolverTest, SetNameServers) {
82 EXPECT_CALL(*ares_client_, SetNameServers(kTestNameServers)).Times(1);
83 EXPECT_CALL(*curl_client_, SetNameServers(kTestNameServers)).Times(1);
84 EXPECT_CALL(*curl_client_, SetDoHProviders(kTestDoHProviders)).Times(1);
85
86 resolver_->SetNameServers(kTestNameServers);
87 resolver_->SetDoHProviders(kTestDoHProviders);
88}
89
90TEST_F(ResolverTest, Resolve_DNSDoHServers) {
91 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
92 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
93
94 resolver_->SetNameServers(kTestNameServers);
95 resolver_->SetDoHProviders(kTestDoHProviders);
96
97 Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
98 resolver_->Resolve(&sock_fd);
99}
100
101TEST_F(ResolverTest, Resolve_DNSServers) {
102 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
103 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
104
105 resolver_->SetNameServers(kTestNameServers);
106
107 Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
108 resolver_->Resolve(&sock_fd);
109}
110
111TEST_F(ResolverTest, Resolve_DNSDoHServersFallback) {
112 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
113 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
114
115 resolver_->SetNameServers(kTestNameServers);
116 resolver_->SetDoHProviders(kTestDoHProviders);
117
118 Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
119 resolver_->Resolve(&sock_fd, true);
120}
121
122TEST_F(ResolverTest, CurlResult_CURLFail) {
123 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
124 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
125
126 resolver_->SetNameServers(kTestNameServers);
127 resolver_->SetDoHProviders(kTestDoHProviders);
128
129 Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
130 DoHCurlClient::CurlResult res(CURLE_COULDNT_CONNECT, 0 /* http_code */,
131 0 /* timeout */);
132 resolver_->HandleCurlResult(static_cast<void*>(&sock_fd), res, nullptr, 0);
133 task_environment_.RunUntilIdle();
134}
135
136TEST_F(ResolverTest, CurlResult_HTTPError) {
137 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
138 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
139
140 resolver_->SetNameServers(kTestNameServers);
141 resolver_->SetDoHProviders(kTestDoHProviders);
142
143 Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
144 DoHCurlClient::CurlResult res(CURLE_OK, 403 /* http_code */, 0 /* timeout */);
145 resolver_->HandleCurlResult(static_cast<void*>(&sock_fd), res, nullptr, 0);
146 task_environment_.RunUntilIdle();
147}
148
149TEST_F(ResolverTest, CurlResult_SuccessNoRetry) {
150 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
151 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
152
153 resolver_->SetNameServers(kTestNameServers);
154 resolver_->SetDoHProviders(kTestDoHProviders);
155
156 Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
157 DoHCurlClient::CurlResult res(CURLE_OK, 200 /* http_code */, 0 /* timeout */);
158 resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
159 task_environment_.RunUntilIdle();
160}
161
162TEST_F(ResolverTest, CurlResult_FailNoRetry) {
163 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
164 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
165
166 resolver_->SetNameServers(kTestNameServers);
167 resolver_->SetDoHProviders(kTestDoHProviders, true /* always_on */);
168
169 Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
170 DoHCurlClient::CurlResult res1(CURLE_OUT_OF_MEMORY, 200 /* http_code */,
171 0 /* timeout */);
172 resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res1, nullptr, 0);
173 task_environment_.RunUntilIdle();
174
175 // |sock_fd| should be freed by now.
176 sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
177 DoHCurlClient::CurlResult res2(CURLE_OK, 403 /* http_code */,
178 0 /* timeout */);
179 resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res2, nullptr, 0);
180 task_environment_.RunUntilIdle();
181}
182
183TEST_F(ResolverTest, CurlResult_FailTooManyRetries) {
184 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
185 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
186
187 resolver_->SetNameServers(kTestNameServers);
188 resolver_->SetDoHProviders(kTestDoHProviders);
189
190 Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
191 sock_fd->num_retries = INT_MAX;
192 DoHCurlClient::CurlResult res(CURLE_OK, 429 /* http_code */, 0 /* timeout */);
193 resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
194 task_environment_.RunUntilIdle();
195}
196
197TEST_F(ResolverTest, HandleAresResult_Success) {
198 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
199 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
200
201 resolver_->SetNameServers(kTestNameServers);
202
203 Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
204 resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
205 nullptr, 0);
206}
207
208TEST_F(ResolverTest, HandleAresResult_Fail) {
209 EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
210 EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
211
212 resolver_->SetNameServers(kTestNameServers);
213
214 Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
215 resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
216 nullptr, 0);
217}
218} // namespace dns_proxy