alanlxl | 9d26c1c | 2020-08-21 13:42:36 +1000 | [diff] [blame] | 1 | // Copyright 2021 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "federated/example_database.h" |
| 6 | |
| 7 | #include <cinttypes> |
| 8 | #include <string> |
| 9 | #include <unordered_set> |
| 10 | |
| 11 | #include <base/files/file_path.h> |
| 12 | #include <base/logging.h> |
| 13 | #include <base/optional.h> |
| 14 | #include <base/strings/stringprintf.h> |
| 15 | #include <base/strings/string_number_conversions.h> |
| 16 | #include <base/strings/string_util.h> |
| 17 | #include <bits/stdint-intn.h> |
| 18 | #include <sqlite3.h> |
| 19 | |
| 20 | #include "federated/utils.h" |
| 21 | |
| 22 | namespace federated { |
| 23 | |
| 24 | namespace { |
| 25 | |
| 26 | // Used in CheckIntegrity to extract state code and result string from sql exec. |
| 27 | int IntegrityCheckCallback(void* data, int count, char** row, char** names) { |
| 28 | CHECK(data); |
| 29 | CHECK(row); |
| 30 | auto* integrity_result = static_cast<std::string*>(data); |
| 31 | if (!row[0]) { |
| 32 | LOG(ERROR) << "Integrity check returned null"; |
| 33 | return SQLITE_ERROR; |
| 34 | } |
| 35 | integrity_result->assign(row[0]); |
| 36 | return SQLITE_OK; |
| 37 | } |
| 38 | |
| 39 | // Used in ClientTableExists to extract state code and table_count from SQL |
| 40 | // exec. |
| 41 | int ClientTableExistsCallback(void* data, int count, char** row, char** names) { |
| 42 | CHECK(data); |
| 43 | CHECK(row); |
| 44 | auto* table_count = static_cast<int*>(data); |
| 45 | if (!row[0] || !base::StringToInt(row[0], table_count)) { |
| 46 | LOG(ERROR) << "TableExist check returned invalid data"; |
| 47 | return SQLITE_ERROR; |
| 48 | } |
| 49 | return SQLITE_OK; |
| 50 | } |
| 51 | |
| 52 | // Prepare sqlite statement group for the given table. Statements (stmt) |
| 53 | // are compiled sql that can bind values to its parameters (`?` in the |
| 54 | // sql string). Table name must be assigned in stmt (not configurable), |
| 55 | // so we must prepare stmt group for each client. |
| 56 | bool PrepareStatements(sqlite3* const db, |
| 57 | const std::string& client_name, |
| 58 | ExampleDatabase::StmtGroup* stmt_group) { |
| 59 | std::string sql = base::StringPrintf( |
| 60 | "SELECT id, example FROM %s ORDER BY id LIMIT ?;", client_name.c_str()); |
| 61 | int result = sqlite3_prepare_v2(db, sql.c_str(), -1, |
| 62 | &stmt_group->stmt_for_streaming, nullptr); |
| 63 | if (result != SQLITE_OK) { |
| 64 | LOG(ERROR) |
| 65 | << "Failed to prepare sqlite statement stmt_for_streaming for client " |
| 66 | << client_name << " with error message:" << sqlite3_errmsg(db); |
| 67 | stmt_group->stmt_for_streaming = nullptr; |
| 68 | return false; |
| 69 | } |
| 70 | |
| 71 | sql = base::StringPrintf("INSERT INTO %s (example, timestamp) VALUES (?, ?);", |
| 72 | client_name.c_str()); |
| 73 | result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_insert, |
| 74 | nullptr); |
| 75 | if (result != SQLITE_OK) { |
| 76 | LOG(ERROR) |
| 77 | << "Failed to prepare sqlite statement stmt_for_insert for client " |
| 78 | << client_name << " with error message:" << sqlite3_errmsg(db); |
| 79 | stmt_group->stmt_for_insert = nullptr; |
| 80 | return false; |
| 81 | } |
| 82 | |
| 83 | sql = |
| 84 | base::StringPrintf("DELETE FROM %s WHERE id <= ?;", client_name.c_str()); |
| 85 | result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_delete, |
| 86 | nullptr); |
| 87 | if (result != SQLITE_OK) { |
| 88 | LOG(ERROR) |
| 89 | << "Failed to prepare sqlite statement stmt_for_delete for client " |
| 90 | << client_name << " with error message:" << sqlite3_errmsg(db); |
| 91 | stmt_group->stmt_for_delete = nullptr; |
| 92 | return false; |
| 93 | } |
| 94 | |
| 95 | sql = base::StringPrintf("SELECT COUNT(*) FROM %s;", client_name.c_str()); |
| 96 | result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt_group->stmt_for_check, |
| 97 | nullptr); |
| 98 | if (result != SQLITE_OK) { |
| 99 | LOG(ERROR) |
| 100 | << "Failed to prepare sqlite statement stmt_for_check for client " |
| 101 | << client_name << " with error message:" << sqlite3_errmsg(db); |
| 102 | stmt_group->stmt_for_check = nullptr; |
| 103 | return false; |
| 104 | } |
| 105 | |
| 106 | return true; |
| 107 | } |
| 108 | |
| 109 | } // namespace |
| 110 | |
| 111 | void ExampleDatabase::StmtGroup::Finalize() { |
| 112 | // Per https://www.sqlite.org/c3ref/finalize.html, it's harmless to finalize a |
| 113 | // nullptr. |
| 114 | sqlite3_finalize(stmt_for_streaming); |
| 115 | sqlite3_finalize(stmt_for_insert); |
| 116 | sqlite3_finalize(stmt_for_delete); |
| 117 | sqlite3_finalize(stmt_for_check); |
| 118 | } |
| 119 | |
| 120 | ExampleDatabase::ExampleDatabase(const base::FilePath& db_path, |
| 121 | const std::unordered_set<std::string>& clients) |
| 122 | : db_path_(db_path), db_(nullptr, nullptr), clients_(clients) { |
| 123 | for (const auto& client : clients_) { |
| 124 | DCHECK(!client.empty()) << "Client name cannot be empty"; |
| 125 | stmts_.emplace(client, StmtGroup()); |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | ExampleDatabase::~ExampleDatabase() { |
| 130 | Close(); |
| 131 | } |
| 132 | |
| 133 | bool ExampleDatabase::Init() { |
| 134 | sqlite3* db_ptr; |
| 135 | int result = sqlite3_open(db_path_.MaybeAsASCII().c_str(), &db_ptr); |
| 136 | db_ = std::unique_ptr<sqlite3, decltype(&sqlite3_close)>(db_ptr, |
| 137 | &sqlite3_close); |
| 138 | if (result != SQLITE_OK) { |
| 139 | LOG(ERROR) << "Failed to connect to database: " << result; |
| 140 | db_ = nullptr; |
| 141 | return false; |
| 142 | } |
| 143 | |
| 144 | for (const auto& client : clients_) { |
| 145 | if ((!ClientTableExists(client) && !CreateClientTable(client)) || |
| 146 | !PrepareStatements(db_.get(), client, &stmts_[client])) { |
| 147 | LOG(ERROR) << "Failed to prepare table for client " << client; |
| 148 | Close(); |
| 149 | |
| 150 | return false; |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | return true; |
| 155 | } |
| 156 | |
| 157 | bool ExampleDatabase::IsOpen() const { |
| 158 | return db_.get() != nullptr; |
| 159 | } |
| 160 | |
| 161 | bool ExampleDatabase::Close() { |
| 162 | if (!db_) |
| 163 | return true; |
| 164 | |
| 165 | for (const auto& client : clients_) { |
| 166 | stmts_[client].Finalize(); |
| 167 | } |
| 168 | |
| 169 | // If the database is successfully closed, db_ pointer must be released. |
| 170 | // Otherwise sqlite3_close will be called again on already released db_ |
| 171 | // pointer by the destructor, which will result in undefined behavior. |
| 172 | int result = sqlite3_close(db_.get()); |
| 173 | if (result != SQLITE_OK) { |
| 174 | // This should never happen |
| 175 | LOG(ERROR) << "sqlite3_close returns error code: " << result; |
| 176 | return false; |
| 177 | } |
| 178 | |
| 179 | db_.release(); |
| 180 | return true; |
| 181 | } |
| 182 | |
| 183 | bool ExampleDatabase::CheckIntegrity() const { |
| 184 | // Integrity_check(N) returns a single row and a single column with string |
| 185 | // "ok" if there is no error. Otherwise a maximum of N rows are returned |
| 186 | // with each row representing a single error. |
| 187 | std::string integrity_result; |
| 188 | ExecResult result = ExecSql("PRAGMA integrity_check(1)", |
| 189 | IntegrityCheckCallback, &integrity_result); |
| 190 | if (result.code != SQLITE_OK) { |
| 191 | LOG(ERROR) << "Failed to check integrity: (" << result.code << ") " |
| 192 | << result.error_msg; |
| 193 | return false; |
| 194 | } |
| 195 | |
| 196 | return integrity_result == "ok"; |
| 197 | } |
| 198 | |
| 199 | bool ExampleDatabase::InsertExample(const ExampleRecord& example_record) { |
| 200 | // The table for example_record.client_name must exist. |
| 201 | const auto& client_name = example_record.client_name; |
| 202 | if (clients_.find(client_name) == clients_.end()) { |
| 203 | LOG(ERROR) << "Unregistered client_name '" << client_name << "'."; |
| 204 | return false; |
| 205 | } |
| 206 | |
| 207 | auto* stmt = stmts_[client_name].stmt_for_insert; |
| 208 | DCHECK(stmt); |
| 209 | |
| 210 | sqlite3_clear_bindings(stmt); |
| 211 | if (sqlite3_bind_blob(stmt, 1, example_record.serialized_example.c_str(), |
| 212 | example_record.serialized_example.length(), |
| 213 | nullptr) == SQLITE_OK && |
| 214 | sqlite3_bind_int64(stmt, 2, example_record.timestamp.ToJavaTime()) == |
| 215 | SQLITE_OK && |
| 216 | sqlite3_step(stmt) == SQLITE_DONE) { |
| 217 | sqlite3_reset(stmt); |
| 218 | return true; |
| 219 | } |
| 220 | |
| 221 | LOG(ERROR) << "Failed to insert an example to table " |
| 222 | << example_record.client_name; |
| 223 | sqlite3_reset(stmt); |
| 224 | |
| 225 | return false; |
| 226 | } |
| 227 | |
| 228 | // Streaming examples with sqlite3_step. |
| 229 | bool ExampleDatabase::PrepareStreamingForClient(const std::string& client_name, |
| 230 | const int32_t limit) { |
| 231 | if (streaming_open_) { |
| 232 | LOG(ERROR) << "The previous streaming for client " |
| 233 | << current_streaming_client_ |
| 234 | << "is still open, call CloseStreaming() first."; |
| 235 | return false; |
| 236 | } |
| 237 | |
| 238 | if (clients_.find(client_name) == clients_.end()) { |
| 239 | LOG(ERROR) << "Unregistered client_name '" << client_name << "'."; |
| 240 | return false; |
| 241 | } |
| 242 | |
| 243 | int32_t example_count = ExampleCountOfClientTable(client_name); |
| 244 | if (example_count < kMinExampleCount) { |
| 245 | DVLOG(1) << "Client '" << client_name << "' example_count " << example_count |
| 246 | << " doesn't meet the minimum requirement " << kMinExampleCount; |
| 247 | return false; |
| 248 | } |
| 249 | |
| 250 | auto* stmt = stmts_[client_name].stmt_for_streaming; |
| 251 | DCHECK(stmt); |
| 252 | |
| 253 | if (sqlite3_stmt_busy(stmt)) { |
| 254 | LOG(WARNING) << "An unexpected streaming already exists with sql='" |
| 255 | << sqlite3_expanded_sql(stmt) << "', cancelling it now."; |
| 256 | } |
| 257 | // Resets the prepared statement anyway. |
| 258 | sqlite3_reset(stmt); |
| 259 | |
| 260 | sqlite3_clear_bindings(stmt); |
| 261 | if (sqlite3_bind_int(stmt, 1, limit) != SQLITE_OK) { |
| 262 | LOG(ERROR) << "Failed to bind limit to stmt_for_streaming of client " |
| 263 | << client_name; |
| 264 | sqlite3_reset(stmt); |
| 265 | return false; |
| 266 | } |
| 267 | |
| 268 | streaming_open_ = true; |
| 269 | end_of_streaming_ = false; |
| 270 | current_streaming_client_ = client_name; |
| 271 | return true; |
| 272 | } |
| 273 | |
| 274 | base::Optional<ExampleRecord> ExampleDatabase::GetNextStreamedRecord() { |
| 275 | if (!streaming_open_) { |
| 276 | LOG(ERROR) << "No open streaming, call PrepareStreamingForClient first"; |
| 277 | return base::nullopt; |
| 278 | } |
| 279 | |
| 280 | if (clients_.find(current_streaming_client_) == clients_.end()) { |
| 281 | LOG(ERROR) << "Unregistered client_name '" << current_streaming_client_ |
| 282 | << "'."; |
| 283 | return base::nullopt; |
| 284 | } |
| 285 | |
| 286 | if (end_of_streaming_) { |
| 287 | LOG(ERROR) << "The streaming already hit SQLITE_DONE but not closed " |
| 288 | "properly, please call CloseStreaming() first."; |
| 289 | return base::nullopt; |
| 290 | } |
| 291 | |
| 292 | auto* stmt = stmts_[current_streaming_client_].stmt_for_streaming; |
| 293 | DCHECK(stmt); |
| 294 | |
| 295 | int code = sqlite3_step(stmt); |
| 296 | if (code == SQLITE_DONE) { |
| 297 | end_of_streaming_ = true; |
| 298 | return base::nullopt; |
| 299 | } |
| 300 | if (code != SQLITE_ROW) { |
| 301 | LOG(ERROR) << "Error when executing sqlite3_step."; |
| 302 | return base::nullopt; |
| 303 | } |
| 304 | |
| 305 | int64_t id = sqlite3_column_int64(stmt, 0); |
| 306 | const unsigned char* example_buffer = |
| 307 | reinterpret_cast<const unsigned char*>(sqlite3_column_blob(stmt, 1)); |
| 308 | const int example_buffer_len = sqlite3_column_bytes(stmt, 1); |
| 309 | if (id <= 0 || !example_buffer || example_buffer_len <= 0) { |
| 310 | LOG(ERROR) << "Failed to extract example from stmt_for_streaming"; |
| 311 | return base::nullopt; |
| 312 | } |
| 313 | ExampleRecord example_record; |
| 314 | example_record.id = id; |
| 315 | example_record.serialized_example = |
| 316 | std::string(example_buffer, example_buffer + example_buffer_len); |
| 317 | return example_record; |
| 318 | } |
| 319 | |
| 320 | void ExampleDatabase::CloseStreaming() { |
| 321 | if (!streaming_open_) { |
| 322 | LOG(ERROR) << "No open streaming to close"; |
| 323 | return; |
| 324 | } |
| 325 | if (clients_.find(current_streaming_client_) == clients_.end()) { |
| 326 | LOG(ERROR) << "Unregistered client_name '" << current_streaming_client_ |
| 327 | << "'."; |
| 328 | return; |
| 329 | } |
| 330 | |
| 331 | auto* stmt = stmts_[current_streaming_client_].stmt_for_streaming; |
| 332 | DCHECK(stmt); |
| 333 | sqlite3_reset(stmt); |
| 334 | current_streaming_client_ = std::string(); |
| 335 | streaming_open_ = false; |
| 336 | end_of_streaming_ = false; |
| 337 | } |
| 338 | |
| 339 | bool ExampleDatabase::DeleteExamplesWithSmallerIdForClient( |
| 340 | const std::string& client_name, const int64_t id) { |
| 341 | if (clients_.find(client_name) == clients_.end()) { |
| 342 | LOG(ERROR) << "Unregistered client_name '" << client_name << "'."; |
| 343 | return false; |
| 344 | } |
| 345 | |
| 346 | auto* stmt = stmts_[client_name].stmt_for_delete; |
| 347 | DCHECK(stmt); |
| 348 | |
| 349 | sqlite3_clear_bindings(stmt); |
| 350 | if (sqlite3_bind_int64(stmt, 1, id) == SQLITE_OK && |
| 351 | sqlite3_step(stmt) == SQLITE_DONE) { |
| 352 | int delete_count = sqlite3_changes(db_.get()); |
| 353 | |
| 354 | sqlite3_reset(stmt); |
| 355 | |
| 356 | if (delete_count <= 0) { |
| 357 | LOG(ERROR) << "Client " << client_name |
| 358 | << " does not have examples with id <= " << id; |
| 359 | return false; |
| 360 | } |
| 361 | return true; |
| 362 | } |
| 363 | LOG(ERROR) << "Error in delete examples from table " << client_name |
| 364 | << " with id <= " << id; |
| 365 | |
| 366 | sqlite3_reset(stmt); |
| 367 | |
| 368 | return false; |
| 369 | } |
| 370 | |
| 371 | bool ExampleDatabase::ClientTableExists(const std::string& client_name) const { |
| 372 | int table_count = 0; |
| 373 | const std::string sql = base::StringPrintf( |
| 374 | "SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = " |
| 375 | "'%s';", |
| 376 | client_name.c_str()); |
| 377 | ExecResult result = ExecSql(sql, ClientTableExistsCallback, &table_count); |
| 378 | |
| 379 | if (result.code != SQLITE_OK) { |
| 380 | LOG(ERROR) << "Failed to call ClientTableExists for client " << client_name |
| 381 | << " with ExecResult: (" << result.code << ") " |
| 382 | << result.error_msg; |
| 383 | return false; |
| 384 | } |
| 385 | |
| 386 | if (table_count <= 0) |
| 387 | return false; |
| 388 | |
| 389 | DCHECK(table_count == 1) << "There should be only one table with name " |
| 390 | << client_name; |
| 391 | |
| 392 | return true; |
| 393 | } |
| 394 | |
| 395 | bool ExampleDatabase::CreateClientTable(const std::string& client_name) const { |
| 396 | const std::string sql = base::StringPrintf( |
| 397 | "CREATE TABLE %s (" |
| 398 | " id INTEGER PRIMARY KEY AUTOINCREMENT" |
| 399 | " NOT NULL," |
| 400 | " example BLOB NOT NULL," |
| 401 | " timestamp INTEGER NOT NULL" |
| 402 | ")", |
| 403 | client_name.c_str()); |
| 404 | ExecResult result = ExecSql(sql); |
| 405 | if (result.code != SQLITE_OK) { |
| 406 | LOG(ERROR) << "Failed to create table " << client_name << ": (" |
| 407 | << result.code << ") " << result.error_msg; |
| 408 | return false; |
| 409 | } |
| 410 | return true; |
| 411 | } |
| 412 | |
| 413 | int32_t ExampleDatabase::ExampleCountOfClientTable( |
| 414 | const std::string& client_name) { |
| 415 | auto* stmt = stmts_[client_name].stmt_for_check; |
| 416 | DCHECK(stmt); |
| 417 | sqlite3_reset(stmt); |
| 418 | |
| 419 | int code = sqlite3_step(stmt); |
| 420 | if (code != SQLITE_ROW) { |
| 421 | LOG(ERROR) |
| 422 | << "Error when executing sqlite3_step in ExampleCountOfClientTable."; |
| 423 | return 0; |
| 424 | } |
| 425 | |
| 426 | int count = sqlite3_column_int(stmt, 0); |
| 427 | return count; |
| 428 | } |
| 429 | |
| 430 | ExampleDatabase::ExecResult ExampleDatabase::ExecSql( |
| 431 | const std::string& sql) const { |
| 432 | return ExecSql(sql, nullptr, nullptr); |
| 433 | } |
| 434 | |
| 435 | ExampleDatabase::ExecResult ExampleDatabase::ExecSql(const std::string& sql, |
| 436 | SqliteCallback callback, |
| 437 | void* data) const { |
| 438 | char* error_msg = nullptr; |
| 439 | int result = sqlite3_exec(db_.get(), sql.c_str(), callback, data, &error_msg); |
| 440 | // According to sqlite3_exec() documentation, error_msg points to memory |
| 441 | // allocated by sqlite3_malloc(), which must be freed by sqlite3_free(). |
| 442 | std::string error_msg_str; |
| 443 | if (error_msg) { |
| 444 | error_msg_str.assign(error_msg); |
| 445 | sqlite3_free(error_msg); |
| 446 | } |
| 447 | return {result, error_msg_str}; |
| 448 | } |
| 449 | |
| 450 | } // namespace federated |