blob: 4abee4ca3ebde3f65efb6c1b2254e41225e5245a [file] [log] [blame]
alanlxl9d26c1c2020-08-21 13:42:36 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef FEDERATED_EXAMPLE_DATABASE_H_
6#define FEDERATED_EXAMPLE_DATABASE_H_
7
8#include <bits/stdint-intn.h>
9#include <stdint.h>
10
11#include <memory>
12#include <string>
13#include <unordered_map>
14#include <unordered_set>
15
16#include <base/files/file_path.h>
17#include <base/optional.h>
18#include <base/time/time.h>
19#include <sqlite3.h>
20
21namespace federated {
22
23// Example objects stored in corresponding `client_name` tables.
24// An example represents a training example of federated computation.
25struct ExampleRecord {
26 int64_t id;
27 std::string client_name;
28 std::string serialized_example;
29 base::Time timestamp;
30};
31
32// Provides access to example database.
33// Example usage:
34// Construct and initialize:
35// ExampleDatabase db(db_path, kTestClients);
36// if(!db.Init() || !db.IsOpen() || !db.CheckIntegrity()) {
37// // Error handling
38// }
39//
40// Insert an example:
41// ExampleRecord example_record;
42// example_record.client_name = client_name;
43// example_record.serialized_example = serialized_example;
44// example_record.timestamp = base::Time::Now();
45//
46// db.InsertExample(example_record);
47//
48// Query examples:
49// int limit = 100;
50// if (db.PrepareStreamingForClient(client_name, limit)) {
51// // Error handling
52// } else {
53// // Call GetNextStreamedRecord() repeatedly until it returns
54// // a base::nullopt, then CloseStreaming();
55// while (true) {
56// auto maybe_example_record = db.GetNextStreamedRecord();
57// if (maybe_example_record == base::nullopt) {
58// // end of iterator
59// break;
60// } else {
61// ExampleRecord record = maybe_example_record.value();
62// }
63// }
64// db.CloseStreaming();
65// }
66//
67// Delete examples with id smaller than the given id from table `client_name`:
68// // Keeps track of last_seen_id when querying examples.
69// if (!db.DeleteExamplesWithSmallerIdForClient(client_name, id)) {
70// // Error handling
71// }
72// See example_database_test.cc and storage_manager_impl.cc for more details.
73
74class ExampleDatabase {
75 public:
76 struct StmtGroup {
77 sqlite3_stmt* stmt_for_streaming = nullptr;
78 sqlite3_stmt* stmt_for_insert = nullptr;
79 sqlite3_stmt* stmt_for_delete = nullptr;
80 sqlite3_stmt* stmt_for_check = nullptr;
81 // Finalizes the stmts, required before disconnecting the db.
82 void Finalize();
83 };
84
85 // Creates an instance to talk to the database file at `db_path`. Init() must
86 // be called to establish connection.
87 explicit ExampleDatabase(const base::FilePath& db_path,
88 const std::unordered_set<std::string>& clients);
89 ExampleDatabase(const ExampleDatabase&) = delete;
90 ExampleDatabase& operator=(const ExampleDatabase&) = delete;
91
92 virtual ~ExampleDatabase();
93
94 // Initializes database connection. Must be called before any other queries.
95 // Returns true if no error occurred.
96 virtual bool Init();
97 // Returns true if the database connection is open.
98 virtual bool IsOpen() const;
99 // Closes database connection. Returns true if no error occurred.
100 virtual bool Close();
101 // Runs sqlite built-in integrity check. Returns true if no error is found.
102 virtual bool CheckIntegrity() const;
103 // Inserts example into database. Returns true if no error occurred.
104 virtual bool InsertExample(const ExampleRecord& example_record);
105
106 // Streaming examples with sqlite3_step, return true if table of client_name
107 // has more than a minimum number of examples and the stmt binds values
108 // successfully. The minimum number now is kMinExampleCount = 1 defined in
109 // utils.h/cc.
110 virtual bool PrepareStreamingForClient(const std::string& client_name,
111 const int32_t limit);
112 virtual base::Optional<ExampleRecord> GetNextStreamedRecord();
113 virtual void CloseStreaming();
114
115 // Deletes examples with id <= given id from client table. Returns true if no
116 // error occurred.
117 virtual bool DeleteExamplesWithSmallerIdForClient(
118 const std::string& client_name, const int64_t id);
119
120 private:
121 // Typedef of sqlite3_exec callback, see sqlite doc:
122 // https://sqlite.org/c3ref/exec.html.
123 using SqliteCallback = int (*)(void* /*data*/,
124 int /*count*/,
125 char** /*row*/,
126 char** /*names*/);
127
128 // Sqlite error code and error message.
129 struct ExecResult {
130 int code;
131 std::string error_msg;
132 };
133
134 // Returns true if the client's table exists.
135 bool ClientTableExists(const std::string& client_name) const;
136 // Returns true if the client's table is created without error.
137 bool CreateClientTable(const std::string& client_name) const;
138
139 // Returns the count of examples in the client's table.
140 int32_t ExampleCountOfClientTable(const std::string& client_name);
141
142 // Executes sql.
143 ExecResult ExecSql(const std::string& sql) const;
144 ExecResult ExecSql(const std::string& sql,
145 SqliteCallback callback,
146 void* data) const;
147
148 const base::FilePath db_path_;
149 std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_;
150
151 // The set of registered client names.
152 std::unordered_set<std::string> clients_;
153 // Mapping client_name to sqlite prepared statement objects for streaming,
154 // inserting and deleting examples.
155 std::unordered_map<std::string, StmtGroup> stmts_;
156 // The client with open example stream.
157 std::string current_streaming_client_;
158 // Whether there is an open streaming.
159 bool streaming_open_ = false;
160 // Whether current streaming hits SQLITE_DONE, to early return in
161 // GetNextStreamedRecord if current streaming already ends but is not closed.
162 // Relationship between streaming_open_ and end_of_streaming_:
163 // streaming_open_ && !end_of_streaming_: safe to call GetNextStreamedRecord
164 // streaming_open_ && end_of_streaming_: should call CloseStreaming
165 // !streaming_open_ && !end_of_streaming_: ready to PrepareStreamingForClient
166 // !streaming_open_ && end_of_streaming_: invalid, should never happen
167 bool end_of_streaming_ = false;
168};
169
170} // namespace federated
171
172#endif // FEDERATED_EXAMPLE_DATABASE_H_