blob: 9248e2b216e6c49c08f419df6a40c95f55078666 [file] [log] [blame]
alanlxl9d26c1c2020-08-21 13:42:36 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "federated/example_database.h"
6
7#include <string>
8#include <unordered_set>
9
10#include <base/files/file_path.h>
11#include <base/files/file_util.h>
12#include <base/files/scoped_temp_dir.h>
13#include <base/strings/stringprintf.h>
14#include <base/time/time.h>
15#include <base/optional.h>
16#include <gtest/gtest.h>
17#include <sqlite3.h>
18
19#include "federated/example_database_test_utils.h"
20#include "federated/utils.h"
21
22namespace federated {
23namespace {
24const std::unordered_set<std::string> kTestClients = {"test_client_1",
25 "test_client_2"};
26} // namespace
27
28class ExampleDatabaseTest : public testing::Test {
29 public:
30 ExampleDatabaseTest(const ExampleDatabaseTest&) = delete;
31 ExampleDatabaseTest& operator=(const ExampleDatabaseTest&) = delete;
32
33 ExampleDatabase* get_db() const { return db_; }
34 const base::FilePath& temp_path() const { return temp_dir_.GetPath(); }
35
36 // Prepares a database, table test_client_1 has 100 examples (id from 1
37 // to 100), table test_client_2 is created by db_->Init() and is empty.
38 bool CreateExampleDatabaseAndInitialize() {
39 const base::FilePath db_path =
40 temp_dir_.GetPath().Append(kDatabaseFileName);
41 if (CreateDatabaseForTesting(db_path) != SQLITE_OK) {
42 LOG(ERROR) << "Failed to create database file.";
43 return false;
44 }
45
46 db_ = new ExampleDatabase(db_path, kTestClients);
47 if (!db_->Init() || !db_->IsOpen() || !db_->CheckIntegrity()) {
48 LOG(ERROR) << "Failed to initialize or check integrity of db_.";
49 return false;
50 }
51
52 return true;
53 }
54
55 protected:
56 ExampleDatabaseTest() = default;
57
58 void SetUp() override { ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); }
59 void TearDown() override {
60 ASSERT_TRUE(temp_dir_.Delete());
61 if (db_)
62 delete db_;
63 }
64
65 private:
66 base::ScopedTempDir temp_dir_;
67 ExampleDatabase* db_ = nullptr;
68};
69
70// This test runs the same steps as CreateExampleDatabaseAndInitialize, but
71// checks step by step.
72TEST_F(ExampleDatabaseTest, CreateDatabase) {
73 // Prepares a database file.
74 base::FilePath db_path = temp_path().Append(kDatabaseFileName);
75 ASSERT_EQ(CreateDatabaseForTesting(db_path), SQLITE_OK);
76 EXPECT_TRUE(base::PathExists(db_path));
77
78 // Creates the db instance.
79 ExampleDatabase db(db_path, kTestClients);
80
81 // Initializes the db and checks integrity.
82 EXPECT_TRUE(db.Init());
83 EXPECT_TRUE(db.IsOpen());
84 EXPECT_TRUE(db.CheckIntegrity());
85
86 // Closes it.
87 EXPECT_TRUE(db.Close());
88}
89
90TEST_F(ExampleDatabaseTest, DatabaseQueryFromNonEmptyTable) {
91 ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
92 ExampleDatabase& db = *get_db();
93
94 std::string client_name = "test_client_1";
95 // Table test_client_1 has 100 records, limit=50 can return 50 examples.
96 int limit = 50;
97 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, limit));
98 int64_t count = 0;
99 while (true) { // id from 1 to 50.
100 auto maybe_record = db.GetNextStreamedRecord();
101 if (maybe_record == base::nullopt)
102 break;
103 EXPECT_EQ(maybe_record.value().id, count + 1);
104 EXPECT_EQ(maybe_record.value().serialized_example,
105 base::StringPrintf("example_%zu", count + 1));
106 count++;
107 }
108 EXPECT_EQ(count, 50);
109 db.CloseStreaming();
110
111 // Limit=150 only returns 100 examples, that's all in the table.
112 limit = 150;
113 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, limit));
114 count = 0;
115 while (true) { // id from 1 to 100.
116 auto maybe_record = db.GetNextStreamedRecord();
117 if (maybe_record == base::nullopt)
118 break;
119 EXPECT_EQ(maybe_record.value().id, count + 1);
120 EXPECT_EQ(maybe_record.value().serialized_example,
121 base::StringPrintf("example_%zu", count + 1));
122 count++;
123 }
124 EXPECT_EQ(count, 100);
125 db.CloseStreaming();
126
127 EXPECT_TRUE(db.Close());
128}
129
130TEST_F(ExampleDatabaseTest, PrepareStreamingFailure) {
131 ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
132 ExampleDatabase& db = *get_db();
133
134 int limit = 100;
135 // Table test_client_2 is empty, returns false on PrepareStreamingForClient
136 std::string client_name = "test_client_2";
137 EXPECT_FALSE(db.PrepareStreamingForClient(client_name, limit));
138
139 // Table test_client_3 doesn't exist, returns false on
140 // PrepareStreamingForClient.
141 client_name = "test_client_3";
142 EXPECT_FALSE(db.PrepareStreamingForClient(client_name, limit));
143
144 EXPECT_TRUE(db.Close());
145}
146
147// Test that example_database can handle some illegal query operations.
148TEST_F(ExampleDatabaseTest, UnexpectedQuery) {
149 ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
150 ExampleDatabase& db = *get_db();
151
152 // Insert an example to table test_client_2.
153 ExampleRecord record = {-1 /*placeholder for id*/, "test_client_2",
154 "manually_inserted_example", base::Time::Now()};
155 EXPECT_TRUE(db.InsertExample(record));
156
157 std::string client_name = "test_client_1";
158
159 // Calls GetNextStreamedRecord without PrepareStreamingForClient, gets
160 // base::nullopt.
161 EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
162
163 // Calls GetNextStreamedRecord after CloseStreaming, gets base::nullopt.
164 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 10));
165 db.CloseStreaming();
166 EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
167
168 // Calls GetNextStreamedRecord after it already returned base::nullopt.
169 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 10));
170 int count = 0;
171 while (db.GetNextStreamedRecord() != base::nullopt)
172 count++;
173
174 EXPECT_EQ(count, 10);
175 EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
176 db.CloseStreaming();
177
178 // A subsequent PrepareStreamingForClient call before CloseStreaming() will
179 // fail and have no influnce on the existing streaming.
180 EXPECT_TRUE(db.PrepareStreamingForClient("test_client_1", 10));
181 count = 0;
182 for (size_t i = 0; i < 5; i++) {
183 EXPECT_NE(db.GetNextStreamedRecord(), base::nullopt);
184 count++;
185 }
186
187 EXPECT_FALSE(db.PrepareStreamingForClient("test_client_2", 50));
188
189 while (db.GetNextStreamedRecord() != base::nullopt)
190 count++;
191
192 EXPECT_EQ(count, 10);
193
194 // Subsequent PrepareStreamingForClient fails as long as the previous
195 // streaming is not closed by CloseStreaming, even if it already hit the end.
196 EXPECT_FALSE(db.PrepareStreamingForClient("test_client_2", 50));
197 db.CloseStreaming();
198 EXPECT_TRUE(db.PrepareStreamingForClient("test_client_2", 50));
199}
200
201TEST_F(ExampleDatabaseTest, InsertExample) {
202 ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
203 ExampleDatabase& db = *get_db();
204
205 ExampleRecord record;
206 record.serialized_example = "manually_inserted_example";
207 record.timestamp = base::Time::Now();
208
209 // Before inserting, table test_client_2 is empty, returns false on
210 // PrepareStreamingForClient.
211 std::string client_name = "test_client_2";
212 record.client_name = client_name;
213
214 EXPECT_FALSE(db.PrepareStreamingForClient(client_name, 100));
215
216 EXPECT_TRUE(db.InsertExample(record));
217
218 // After inserting, GetNextStreamedRecord will succeed for 1 time.
219 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 100));
220 int64_t count = 0;
221 while (db.GetNextStreamedRecord() != base::nullopt)
222 count++;
223
224 EXPECT_EQ(count, 1);
225 db.CloseStreaming();
226
227 // Fails to insert into a non-existing table;
228 client_name = "test_client_3";
229 record.client_name = client_name;
230
231 EXPECT_FALSE(db.InsertExample(record));
232
233 EXPECT_TRUE(db.Close());
234}
235
236TEST_F(ExampleDatabaseTest, DeleteExamples) {
237 ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
238 ExampleDatabase& db = *get_db();
239
240 std::string client_name = "test_client_1";
241 // Delete examples with id <= 30 from table test_client_1;
242 EXPECT_TRUE(db.DeleteExamplesWithSmallerIdForClient(client_name, 30));
243
244 EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 100));
245 int64_t count = 0;
246 while (true) { // id from 31 to 100.
247 auto maybe_record = db.GetNextStreamedRecord();
248 if (maybe_record == base::nullopt)
249 break;
250 EXPECT_EQ(maybe_record.value().id, count + 31);
251 EXPECT_EQ(maybe_record.value().serialized_example,
252 base::StringPrintf("example_%zu", count + 31));
253 count++;
254 }
255 EXPECT_EQ(count, 70);
256 db.CloseStreaming();
257
258 // No examples with id <= 20 now, returns false;
259 EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 20));
260
261 client_name = "test_client_2";
262 // Delete examples from an empty table, returns false;
263 EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 100));
264
265 client_name = "test_client_3";
266 // Delete examples from a non-existing table, returns false;
267 EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 100));
268
269 EXPECT_TRUE(db.Close());
270}
271
272} // namespace federated