blob: 96681bda64948ead3a2dc254085ab8202bd4f310 [file] [log] [blame]
Honglin Yu7b6c1192020-09-16 10:07:17 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/process.h"
6
7#include <utility>
8
9#include <signal.h>
10#include <sysexits.h>
11#include <sys/mount.h>
12#include <sys/types.h>
13#include <sys/wait.h>
14#include <unistd.h>
15#include <pwd.h>
16
17#include <base/logging.h>
18#include <base/process/process_metrics.h>
19#include <base/strings/string_number_conversions.h>
20#include <base/strings/string_util.h>
21#include <libminijail.h>
22#include <mojo/core/embedder/embedder.h>
23#include <mojo/public/cpp/platform/platform_channel.h>
24#include <mojo/public/cpp/system/invitation.h>
25#include <scoped_minijail.h>
26
27#include "ml/daemon.h"
28#include "ml/machine_learning_service_impl.h"
29
30namespace ml {
31
32namespace {
33constexpr char kMojoBootstrapFdSwitchName[] = "mojo-bootstrap-fd";
34
35constexpr char kInternalMojoPrimordialPipeName[] = "cros_ml";
36
Honglin Yuf3c47b32021-05-06 18:38:37 +100037constexpr char kDefaultMlServiceBinaryPath[] = "/usr/bin/ml_service";
Honglin Yu7b6c1192020-09-16 10:07:17 +100038
39constexpr uid_t kMlServiceDBusUid = 20177;
40
41std::string GetSeccompPolicyPath(const std::string& model_name) {
42 return "/usr/share/policy/ml_service-" + model_name + "-seccomp.policy";
43}
44
45std::string GetArgumentForWorkerProcess(int fd) {
46 std::string fd_argv = kMojoBootstrapFdSwitchName;
47 return "--" + fd_argv + "=" + std::to_string(fd);
48}
49
Honglin Yu7b6c1192020-09-16 10:07:17 +100050} // namespace
51
52// static
53Process* Process::GetInstance() {
54 // This is thread-safe.
55 static base::NoDestructor<Process> instance;
56 return instance.get();
57}
58
59int Process::Run(int argc, char* argv[]) {
60 // Parses the command line and determines the process type.
61 base::CommandLine command_line(argc, argv);
62 std::string mojo_fd_string =
63 command_line.GetSwitchValueASCII(kMojoBootstrapFdSwitchName);
64
65 if (mojo_fd_string.empty()) {
66 process_type_ = Type::kControl;
67 } else {
68 process_type_ = Type::kWorker;
69 }
70
71 if (!command_line.GetArgs().empty()) {
72 LOG(ERROR) << "Unexpected command line arguments: "
73 << base::JoinString(command_line.GetArgs(), "\t");
74 return ExitCode::kUnexpectedCommandLine;
75 }
76
77 if (process_type_ == Type::kControl) {
78 ControlProcessRun();
79 } else {
80 // The process type is either "control" or "worker".
81 DCHECK(GetType() == Type::kWorker);
82 const auto is_valid_fd_str =
83 base::StringToInt(mojo_fd_string, &mojo_bootstrap_fd_);
84 DCHECK(is_valid_fd_str) << "Invalid mojo bootstrap fd";
85 WorkerProcessRun();
86 }
87
88 return ExitCode::kSuccess;
89}
90
91Process::Type Process::GetType() {
92 return process_type_;
93}
94
95bool Process::SpawnWorkerProcessAndGetPid(const mojo::PlatformChannel& channel,
96 const std::string& model_name,
97 pid_t* worker_pid) {
98 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
99 DCHECK(worker_pid != nullptr);
100 // Should only be called by the control process.
Honglin Yuf3c47b32021-05-06 18:38:37 +1000101 DCHECK(IsControlProcess()) << "Should only be called by the control process";
Honglin Yu7b6c1192020-09-16 10:07:17 +1000102
103 // Start the process.
104 ScopedMinijail jail(minijail_new());
105
106 minijail_namespace_ipc(jail.get());
107 minijail_namespace_uts(jail.get());
108 minijail_namespace_net(jail.get());
109 minijail_namespace_cgroups(jail.get());
Honglin Yu7b6c1192020-09-16 10:07:17 +1000110
Honglin Yuf3c47b32021-05-06 18:38:37 +1000111 // The following sandboxing makes unit test crash so we do not use them in
112 // unit tests.
113 if (process_type_ != Type::kControlForTest) {
114 minijail_namespace_pids(jail.get());
115 minijail_namespace_vfs(jail.get());
116 std::string seccomp_policy_path = GetSeccompPolicyPath(model_name);
117 minijail_parse_seccomp_filters(jail.get(), seccomp_policy_path.c_str());
118 minijail_use_seccomp_filter(jail.get());
119 }
Honglin Yu7b6c1192020-09-16 10:07:17 +1000120
121 std::string fd_argv = kMojoBootstrapFdSwitchName;
122 // Use GetFD instead of TakeFD to non-destructively obtain the fd.
123 fd_argv = GetArgumentForWorkerProcess(
124 channel.remote_endpoint().platform_handle().GetFD().get());
Honglin Yuf3c47b32021-05-06 18:38:37 +1000125 char* const argv[3] = {&ml_service_path_[0], &fd_argv[0], nullptr};
Honglin Yu7b6c1192020-09-16 10:07:17 +1000126
127 // TODO(https://crbug.com/1202545): report the failure.
Honglin Yuf3c47b32021-05-06 18:38:37 +1000128 if (minijail_run_pid(jail.get(), &ml_service_path_[0], argv, worker_pid) !=
Honglin Yu7b6c1192020-09-16 10:07:17 +1000129 0) {
130 LOG(DFATAL) << "Failed to spawn worker process for " << model_name;
131 return false;
132 }
133
134 return true;
135}
136
137mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>&
138Process::SendMojoInvitationAndGetRemote(pid_t worker_pid,
139 mojo::PlatformChannel channel,
140 const std::string& model_name) {
141 // Send the Mojo invitation to the worker process.
142 mojo::OutgoingInvitation invitation;
143 mojo::ScopedMessagePipeHandle pipe =
144 invitation.AttachMessagePipe(kInternalMojoPrimordialPipeName);
145
146 mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>
147 remote(mojo::PendingRemote<
148 chromeos::machine_learning::mojom::MachineLearningService>(
149 std::move(pipe), 0u /* version */));
150
151 mojo::OutgoingInvitation::Send(std::move(invitation), worker_pid,
152 channel.TakeLocalEndpoint());
153
154 remote.set_disconnect_handler(
Honglin Yuf3c47b32021-05-06 18:38:37 +1000155 base::BindOnce(&Process::InternalPrimordialMojoPipeDisconnectHandler,
156 base::Unretained(this), worker_pid));
Honglin Yu7b6c1192020-09-16 10:07:17 +1000157
158 DCHECK(worker_pid_info_map_.find(worker_pid) == worker_pid_info_map_.end())
159 << "Worker pid already exists";
160
161 WorkerInfo worker_info;
162 worker_info.remote = std::move(remote);
163 worker_info.process_metrics =
164 base::ProcessMetrics::CreateProcessMetrics(worker_pid);
165 // Baseline the CPU usage counter in `process_metrics` to be zero as of now.
166 worker_info.process_metrics->GetPlatformIndependentCPUUsage();
167
168 worker_pid_info_map_[worker_pid] = std::move(worker_info);
169
170 return worker_pid_info_map_[worker_pid].remote;
171}
172
173void Process::UnregisterWorkerProcess(pid_t pid) {
174 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
175 const auto iter = worker_pid_info_map_.find(pid);
176 DCHECK(iter != worker_pid_info_map_.end()) << "Pid is not registered";
177 worker_pid_info_map_.erase(iter);
178}
179
Honglin Yuf3c47b32021-05-06 18:38:37 +1000180Process::Process()
181 : process_type_(Type::kUnset),
182 mojo_bootstrap_fd_(-1),
183 ml_service_path_(kDefaultMlServiceBinaryPath) {}
Honglin Yu7b6c1192020-09-16 10:07:17 +1000184Process::~Process() = default;
185
186void Process::ControlProcessRun() {
187 // We need to set euid to kMlServiceDBusUid to bootstrap DBus. Otherwise, DBus
188 // will block us because our euid inside of the userns is 0 but is 20106
189 // outside of the userns.
190 if (seteuid(kMlServiceDBusUid) != 0) {
191 // TODO(https://crbug.com/1202545): report this error to UMA.
192 LOG(ERROR) << "Unable to change effective uid to " << kMlServiceDBusUid;
193 exit(EX_OSERR);
194 }
195
196 ml::Daemon daemon;
197 daemon.Run();
198}
199
200void Process::WorkerProcessRun() {
201 brillo::BaseMessageLoop message_loop;
202 message_loop.SetAsCurrent();
203 DETACH_FROM_SEQUENCE(sequence_checker_);
Honglin Yu7b6c1192020-09-16 10:07:17 +1000204 mojo::core::Init();
205 mojo::core::ScopedIPCSupport ipc_support(
206 base::ThreadTaskRunnerHandle::Get(),
207 mojo::core::ScopedIPCSupport::ShutdownPolicy::FAST);
208 mojo::IncomingInvitation invitation =
209 mojo::IncomingInvitation::Accept(mojo::PlatformChannelEndpoint(
210 mojo::PlatformHandle(base::ScopedFD(mojo_bootstrap_fd_))));
211 mojo::ScopedMessagePipeHandle pipe =
212 invitation.ExtractMessagePipe(kInternalMojoPrimordialPipeName);
213 // The worker process exits if it disconnects with the control process.
214 // This can be important because in the control process's disconnect handler
215 // function we will use waitpid to wait for this process to finish. So
216 // the exit here will make sure that the waitpid in control process
217 // won't hang.
218 MachineLearningServiceImpl machine_learning_service_impl(
219 mojo::PendingReceiver<
220 chromeos::machine_learning::mojom::MachineLearningService>(
221 std::move(pipe)),
222 message_loop.QuitClosure());
223 message_loop.Run();
224}
225
226const std::unordered_map<pid_t, Process::WorkerInfo>&
227Process::GetWorkerPidInfoMap() {
228 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
229 return worker_pid_info_map_;
230}
231
Honglin Yuf3c47b32021-05-06 18:38:37 +1000232void Process::SetTypeForTesting(Type type) {
233 process_type_ = type;
234}
235
236void Process::SetMlServicePathForTesting(const std::string& path) {
237 ml_service_path_ = path;
238}
239
240void Process::SetBeforeExitWorkerDisconnectHandlerHookForTesting(
241 base::RepeatingClosure hook) {
242 before_exit_worker_disconnect_handler_hook_ = std::move(hook);
243}
244
245bool Process::IsControlProcess() {
246 return process_type_ == Type::kControl ||
247 process_type_ == Type::kControlForTest;
248}
249
250bool Process::IsWorkerProcess() {
251 return process_type_ == Type::kWorker ||
252 process_type_ == Type::kSingleProcessForTest;
253}
254
255void Process::InternalPrimordialMojoPipeDisconnectHandler(pid_t child_pid) {
256 UnregisterWorkerProcess(child_pid);
257 // Reap the worker process.
258 int status;
259 pid_t ret_pid = waitpid(child_pid, &status, 0);
260 DCHECK(ret_pid == child_pid);
261 // TODO(https://crbug.com/1202545): report WEXITSTATUS(status) to UMA.
262 DVLOG(1) << "Worker process (" << child_pid << ") exits with status "
263 << WEXITSTATUS(status);
264
265 // Call the hooks used in testing.
266 if (process_type_ == Type::kControlForTest &&
267 !before_exit_worker_disconnect_handler_hook_.is_null()) {
268 before_exit_worker_disconnect_handler_hook_.Run();
269 }
270}
271
Honglin Yu7b6c1192020-09-16 10:07:17 +1000272} // namespace ml