blob: d2f259192a1e5998fe391f43b8c03f783ef8b04f [file] [log] [blame]
alanlxl9d26c1c2020-08-21 13:42:36 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "federated/example_database.h"
6
7#include <cinttypes>
8#include <string>
9#include <unordered_set>
10
11#include <base/files/file_path.h>
12#include <base/logging.h>
13#include <base/optional.h>
14#include <base/strings/stringprintf.h>
15#include <base/strings/string_number_conversions.h>
16#include <base/strings/string_util.h>
17#include <bits/stdint-intn.h>
18#include <sqlite3.h>
19
20#include "federated/utils.h"
21
22namespace federated {
23
24namespace {
25
26// Used in CheckIntegrity to extract state code and result string from sql exec.
27int IntegrityCheckCallback(void* data, int count, char** row, char** names) {
28 CHECK(data);
29 CHECK(row);
30 auto* integrity_result = static_cast<std::string*>(data);
31 if (!row[0]) {
32 LOG(ERROR) << "Integrity check returned null";
33 return SQLITE_ERROR;
34 }
35 integrity_result->assign(row[0]);
36 return SQLITE_OK;
37}
38
39// Used in ClientTableExists to extract state code and table_count from SQL
40// exec.
41int ClientTableExistsCallback(void* data, int count, char** row, char** names) {
42 CHECK(data);
43 CHECK(row);
44 auto* table_count = static_cast<int*>(data);
45 if (!row[0] || !base::StringToInt(row[0], table_count)) {
46 LOG(ERROR) << "TableExist check returned invalid data";
47 return SQLITE_ERROR;
48 }
49 return SQLITE_OK;
50}
51
52// Prepare sqlite statement group for the given table. Statements (stmt)
53// are compiled sql that can bind values to its parameters (`?` in the
54// sql string). Table name must be assigned in stmt (not configurable),
55// so we must prepare stmt group for each client.
56bool PrepareStatements(sqlite3* const db,
57 const std::string& client_name,
58 ExampleDatabase::StmtGroup* stmt_group) {
59 std::string sql = base::StringPrintf(
60 "SELECT id, example FROM %s ORDER BY id LIMIT ?;", client_name.c_str());
61 int result = sqlite3_prepare_v2(db, sql.c_str(), -1,
62 &stmt_group->stmt_for_streaming, nullptr);
63 if (result != SQLITE_OK) {
64 LOG(ERROR)
65 << "Failed to prepare sqlite statement stmt_for_streaming for client "
66 << client_name << " with error message:" << sqlite3_errmsg(db);
67 stmt_group->stmt_for_streaming = nullptr;
68 return false;
69 }
70
71 sql = base::StringPrintf("INSERT INTO %s (example, timestamp) VALUES (?, ?);",
72 client_name.c_str());
73 result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_insert,
74 nullptr);
75 if (result != SQLITE_OK) {
76 LOG(ERROR)
77 << "Failed to prepare sqlite statement stmt_for_insert for client "
78 << client_name << " with error message:" << sqlite3_errmsg(db);
79 stmt_group->stmt_for_insert = nullptr;
80 return false;
81 }
82
83 sql =
84 base::StringPrintf("DELETE FROM %s WHERE id <= ?;", client_name.c_str());
85 result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_delete,
86 nullptr);
87 if (result != SQLITE_OK) {
88 LOG(ERROR)
89 << "Failed to prepare sqlite statement stmt_for_delete for client "
90 << client_name << " with error message:" << sqlite3_errmsg(db);
91 stmt_group->stmt_for_delete = nullptr;
92 return false;
93 }
94
95 sql = base::StringPrintf("SELECT COUNT(*) FROM %s;", client_name.c_str());
96 result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_check,
97 nullptr);
98 if (result != SQLITE_OK) {
99 LOG(ERROR)
100 << "Failed to prepare sqlite statement stmt_for_check for client "
101 << client_name << " with error message:" << sqlite3_errmsg(db);
102 stmt_group->stmt_for_check = nullptr;
103 return false;
104 }
105
106 return true;
107}
108
109} // namespace
110
111void ExampleDatabase::StmtGroup::Finalize() {
112 // Per https://www.sqlite.org/c3ref/finalize.html, it's harmless to finalize a
113 // nullptr.
114 sqlite3_finalize(stmt_for_streaming);
115 sqlite3_finalize(stmt_for_insert);
116 sqlite3_finalize(stmt_for_delete);
117 sqlite3_finalize(stmt_for_check);
118}
119
120ExampleDatabase::ExampleDatabase(const base::FilePath& db_path,
121 const std::unordered_set<std::string>& clients)
122 : db_path_(db_path), db_(nullptr, nullptr), clients_(clients) {
123 for (const auto& client : clients_) {
124 DCHECK(!client.empty()) << "Client name cannot be empty";
125 stmts_.emplace(client, StmtGroup());
126 }
127}
128
129ExampleDatabase::~ExampleDatabase() {
130 Close();
131}
132
133bool ExampleDatabase::Init() {
134 sqlite3* db_ptr;
135 int result = sqlite3_open(db_path_.MaybeAsASCII().c_str(), &db_ptr);
136 db_ = std::unique_ptr<sqlite3, decltype(&sqlite3_close)>(db_ptr,
137 &sqlite3_close);
138 if (result != SQLITE_OK) {
139 LOG(ERROR) << "Failed to connect to database: " << result;
140 db_ = nullptr;
141 return false;
142 }
143
144 for (const auto& client : clients_) {
145 if ((!ClientTableExists(client) && !CreateClientTable(client)) ||
146 !PrepareStatements(db_.get(), client, &stmts_[client])) {
147 LOG(ERROR) << "Failed to prepare table for client " << client;
148 Close();
149
150 return false;
151 }
152 }
153
154 return true;
155}
156
157bool ExampleDatabase::IsOpen() const {
158 return db_.get() != nullptr;
159}
160
161bool ExampleDatabase::Close() {
162 if (!db_)
163 return true;
164
165 for (const auto& client : clients_) {
166 stmts_[client].Finalize();
167 }
168
169 // If the database is successfully closed, db_ pointer must be released.
170 // Otherwise sqlite3_close will be called again on already released db_
171 // pointer by the destructor, which will result in undefined behavior.
172 int result = sqlite3_close(db_.get());
173 if (result != SQLITE_OK) {
174 // This should never happen
175 LOG(ERROR) << "sqlite3_close returns error code: " << result;
176 return false;
177 }
178
179 db_.release();
180 return true;
181}
182
183bool ExampleDatabase::CheckIntegrity() const {
184 // Integrity_check(N) returns a single row and a single column with string
185 // "ok" if there is no error. Otherwise a maximum of N rows are returned
186 // with each row representing a single error.
187 std::string integrity_result;
188 ExecResult result = ExecSql("PRAGMA integrity_check(1)",
189 IntegrityCheckCallback, &integrity_result);
190 if (result.code != SQLITE_OK) {
191 LOG(ERROR) << "Failed to check integrity: (" << result.code << ") "
192 << result.error_msg;
193 return false;
194 }
195
196 return integrity_result == "ok";
197}
198
199bool ExampleDatabase::InsertExample(const ExampleRecord& example_record) {
200 // The table for example_record.client_name must exist.
201 const auto& client_name = example_record.client_name;
202 if (clients_.find(client_name) == clients_.end()) {
203 LOG(ERROR) << "Unregistered client_name '" << client_name << "'.";
204 return false;
205 }
206
207 auto* stmt = stmts_[client_name].stmt_for_insert;
208 DCHECK(stmt);
209
210 sqlite3_clear_bindings(stmt);
211 if (sqlite3_bind_blob(stmt, 1, example_record.serialized_example.c_str(),
212 example_record.serialized_example.length(),
213 nullptr) == SQLITE_OK &&
214 sqlite3_bind_int64(stmt, 2, example_record.timestamp.ToJavaTime()) ==
215 SQLITE_OK &&
216 sqlite3_step(stmt) == SQLITE_DONE) {
217 sqlite3_reset(stmt);
218 return true;
219 }
220
221 LOG(ERROR) << "Failed to insert an example to table "
222 << example_record.client_name;
223 sqlite3_reset(stmt);
224
225 return false;
226}
227
228// Streaming examples with sqlite3_step.
229bool ExampleDatabase::PrepareStreamingForClient(const std::string& client_name,
230 const int32_t limit) {
231 if (streaming_open_) {
232 LOG(ERROR) << "The previous streaming for client "
233 << current_streaming_client_
234 << "is still open, call CloseStreaming() first.";
235 return false;
236 }
237
238 if (clients_.find(client_name) == clients_.end()) {
239 LOG(ERROR) << "Unregistered client_name '" << client_name << "'.";
240 return false;
241 }
242
243 int32_t example_count = ExampleCountOfClientTable(client_name);
244 if (example_count < kMinExampleCount) {
245 DVLOG(1) << "Client '" << client_name << "' example_count " << example_count
246 << " doesn't meet the minimum requirement " << kMinExampleCount;
247 return false;
248 }
249
250 auto* stmt = stmts_[client_name].stmt_for_streaming;
251 DCHECK(stmt);
252
253 if (sqlite3_stmt_busy(stmt)) {
254 LOG(WARNING) << "An unexpected streaming already exists with sql='"
255 << sqlite3_expanded_sql(stmt) << "', cancelling it now.";
256 }
257 // Resets the prepared statement anyway.
258 sqlite3_reset(stmt);
259
260 sqlite3_clear_bindings(stmt);
261 if (sqlite3_bind_int(stmt, 1, limit) != SQLITE_OK) {
262 LOG(ERROR) << "Failed to bind limit to stmt_for_streaming of client "
263 << client_name;
264 sqlite3_reset(stmt);
265 return false;
266 }
267
268 streaming_open_ = true;
269 end_of_streaming_ = false;
270 current_streaming_client_ = client_name;
271 return true;
272}
273
274base::Optional<ExampleRecord> ExampleDatabase::GetNextStreamedRecord() {
275 if (!streaming_open_) {
276 LOG(ERROR) << "No open streaming, call PrepareStreamingForClient first";
277 return base::nullopt;
278 }
279
280 if (clients_.find(current_streaming_client_) == clients_.end()) {
281 LOG(ERROR) << "Unregistered client_name '" << current_streaming_client_
282 << "'.";
283 return base::nullopt;
284 }
285
286 if (end_of_streaming_) {
287 LOG(ERROR) << "The streaming already hit SQLITE_DONE but not closed "
288 "properly, please call CloseStreaming() first.";
289 return base::nullopt;
290 }
291
292 auto* stmt = stmts_[current_streaming_client_].stmt_for_streaming;
293 DCHECK(stmt);
294
295 int code = sqlite3_step(stmt);
296 if (code == SQLITE_DONE) {
297 end_of_streaming_ = true;
298 return base::nullopt;
299 }
300 if (code != SQLITE_ROW) {
301 LOG(ERROR) << "Error when executing sqlite3_step.";
302 return base::nullopt;
303 }
304
305 int64_t id = sqlite3_column_int64(stmt, 0);
306 const unsigned char* example_buffer =
307 reinterpret_cast<const unsigned char*>(sqlite3_column_blob(stmt, 1));
308 const int example_buffer_len = sqlite3_column_bytes(stmt, 1);
309 if (id <= 0 || !example_buffer || example_buffer_len <= 0) {
310 LOG(ERROR) << "Failed to extract example from stmt_for_streaming";
311 return base::nullopt;
312 }
313 ExampleRecord example_record;
314 example_record.id = id;
315 example_record.serialized_example =
316 std::string(example_buffer, example_buffer + example_buffer_len);
317 return example_record;
318}
319
320void ExampleDatabase::CloseStreaming() {
321 if (!streaming_open_) {
322 LOG(ERROR) << "No open streaming to close";
323 return;
324 }
325 if (clients_.find(current_streaming_client_) == clients_.end()) {
326 LOG(ERROR) << "Unregistered client_name '" << current_streaming_client_
327 << "'.";
328 return;
329 }
330
331 auto* stmt = stmts_[current_streaming_client_].stmt_for_streaming;
332 DCHECK(stmt);
333 sqlite3_reset(stmt);
334 current_streaming_client_ = std::string();
335 streaming_open_ = false;
336 end_of_streaming_ = false;
337}
338
339bool ExampleDatabase::DeleteExamplesWithSmallerIdForClient(
340 const std::string& client_name, const int64_t id) {
341 if (clients_.find(client_name) == clients_.end()) {
342 LOG(ERROR) << "Unregistered client_name '" << client_name << "'.";
343 return false;
344 }
345
346 auto* stmt = stmts_[client_name].stmt_for_delete;
347 DCHECK(stmt);
348
349 sqlite3_clear_bindings(stmt);
350 if (sqlite3_bind_int64(stmt, 1, id) == SQLITE_OK &&
351 sqlite3_step(stmt) == SQLITE_DONE) {
352 int delete_count = sqlite3_changes(db_.get());
353
354 sqlite3_reset(stmt);
355
356 if (delete_count <= 0) {
357 LOG(ERROR) << "Client " << client_name
358 << " does not have examples with id <= " << id;
359 return false;
360 }
361 return true;
362 }
363 LOG(ERROR) << "Error in delete examples from table " << client_name
364 << " with id <= " << id;
365
366 sqlite3_reset(stmt);
367
368 return false;
369}
370
371bool ExampleDatabase::ClientTableExists(const std::string& client_name) const {
372 int table_count = 0;
373 const std::string sql = base::StringPrintf(
374 "SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = "
375 "'%s';",
376 client_name.c_str());
377 ExecResult result = ExecSql(sql, ClientTableExistsCallback, &table_count);
378
379 if (result.code != SQLITE_OK) {
380 LOG(ERROR) << "Failed to call ClientTableExists for client " << client_name
381 << " with ExecResult: (" << result.code << ") "
382 << result.error_msg;
383 return false;
384 }
385
386 if (table_count <= 0)
387 return false;
388
389 DCHECK(table_count == 1) << "There should be only one table with name "
390 << client_name;
391
392 return true;
393}
394
395bool ExampleDatabase::CreateClientTable(const std::string& client_name) const {
396 const std::string sql = base::StringPrintf(
397 "CREATE TABLE %s ("
398 " id INTEGER PRIMARY KEY AUTOINCREMENT"
399 " NOT NULL,"
400 " example BLOB NOT NULL,"
401 " timestamp INTEGER NOT NULL"
402 ")",
403 client_name.c_str());
404 ExecResult result = ExecSql(sql);
405 if (result.code != SQLITE_OK) {
406 LOG(ERROR) << "Failed to create table " << client_name << ": ("
407 << result.code << ") " << result.error_msg;
408 return false;
409 }
410 return true;
411}
412
413int32_t ExampleDatabase::ExampleCountOfClientTable(
414 const std::string& client_name) {
415 auto* stmt = stmts_[client_name].stmt_for_check;
416 DCHECK(stmt);
417 sqlite3_reset(stmt);
418
419 int code = sqlite3_step(stmt);
420 if (code != SQLITE_ROW) {
421 LOG(ERROR)
422 << "Error when executing sqlite3_step in ExampleCountOfClientTable.";
423 return 0;
424 }
425
426 int count = sqlite3_column_int(stmt, 0);
427 return count;
428}
429
430ExampleDatabase::ExecResult ExampleDatabase::ExecSql(
431 const std::string& sql) const {
432 return ExecSql(sql, nullptr, nullptr);
433}
434
435ExampleDatabase::ExecResult ExampleDatabase::ExecSql(const std::string& sql,
436 SqliteCallback callback,
437 void* data) const {
438 char* error_msg = nullptr;
439 int result = sqlite3_exec(db_.get(), sql.c_str(), callback, data, &error_msg);
440 // According to sqlite3_exec() documentation, error_msg points to memory
441 // allocated by sqlite3_malloc(), which must be freed by sqlite3_free().
442 std::string error_msg_str;
443 if (error_msg) {
444 error_msg_str.assign(error_msg);
445 sqlite3_free(error_msg);
446 }
447 return {result, error_msg_str};
448}
449
450} // namespace federated