blob: 0599ba474b6f7111fbb60fba4de047523ffb4460 [file] [log] [blame]
Honglin Yu7b6c1192020-09-16 10:07:17 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/process.h"
6
7#include <utility>
8
9#include <signal.h>
10#include <sysexits.h>
11#include <sys/mount.h>
12#include <sys/types.h>
13#include <sys/wait.h>
14#include <unistd.h>
15#include <pwd.h>
16
17#include <base/logging.h>
18#include <base/process/process_metrics.h>
19#include <base/strings/string_number_conversions.h>
20#include <base/strings/string_util.h>
21#include <libminijail.h>
22#include <mojo/core/embedder/embedder.h>
23#include <mojo/public/cpp/platform/platform_channel.h>
24#include <mojo/public/cpp/system/invitation.h>
25#include <scoped_minijail.h>
26
27#include "ml/daemon.h"
28#include "ml/machine_learning_service_impl.h"
Honglin Yu21616692021-05-14 11:20:22 +100029#include "ml/request_metrics.h"
30#include "ml/time_metrics.h"
Honglin Yu7b6c1192020-09-16 10:07:17 +100031
32namespace ml {
33
34namespace {
35constexpr char kMojoBootstrapFdSwitchName[] = "mojo-bootstrap-fd";
36
37constexpr char kInternalMojoPrimordialPipeName[] = "cros_ml";
38
Honglin Yuf3c47b32021-05-06 18:38:37 +100039constexpr char kDefaultMlServiceBinaryPath[] = "/usr/bin/ml_service";
Honglin Yu7b6c1192020-09-16 10:07:17 +100040
41constexpr uid_t kMlServiceDBusUid = 20177;
42
43std::string GetSeccompPolicyPath(const std::string& model_name) {
44 return "/usr/share/policy/ml_service-" + model_name + "-seccomp.policy";
45}
46
47std::string GetArgumentForWorkerProcess(int fd) {
48 std::string fd_argv = kMojoBootstrapFdSwitchName;
49 return "--" + fd_argv + "=" + std::to_string(fd);
50}
51
Honglin Yu7b6c1192020-09-16 10:07:17 +100052} // namespace
53
54// static
55Process* Process::GetInstance() {
56 // This is thread-safe.
57 static base::NoDestructor<Process> instance;
58 return instance.get();
59}
60
61int Process::Run(int argc, char* argv[]) {
62 // Parses the command line and determines the process type.
63 base::CommandLine command_line(argc, argv);
64 std::string mojo_fd_string =
65 command_line.GetSwitchValueASCII(kMojoBootstrapFdSwitchName);
66
67 if (mojo_fd_string.empty()) {
68 process_type_ = Type::kControl;
69 } else {
70 process_type_ = Type::kWorker;
71 }
72
73 if (!command_line.GetArgs().empty()) {
74 LOG(ERROR) << "Unexpected command line arguments: "
75 << base::JoinString(command_line.GetArgs(), "\t");
76 return ExitCode::kUnexpectedCommandLine;
77 }
78
79 if (process_type_ == Type::kControl) {
80 ControlProcessRun();
81 } else {
82 // The process type is either "control" or "worker".
83 DCHECK(GetType() == Type::kWorker);
84 const auto is_valid_fd_str =
85 base::StringToInt(mojo_fd_string, &mojo_bootstrap_fd_);
86 DCHECK(is_valid_fd_str) << "Invalid mojo bootstrap fd";
87 WorkerProcessRun();
88 }
89
90 return ExitCode::kSuccess;
91}
92
93Process::Type Process::GetType() {
94 return process_type_;
95}
96
97bool Process::SpawnWorkerProcessAndGetPid(const mojo::PlatformChannel& channel,
98 const std::string& model_name,
99 pid_t* worker_pid) {
100 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
101 DCHECK(worker_pid != nullptr);
102 // Should only be called by the control process.
Honglin Yuf3c47b32021-05-06 18:38:37 +1000103 DCHECK(IsControlProcess()) << "Should only be called by the control process";
Honglin Yu7b6c1192020-09-16 10:07:17 +1000104
105 // Start the process.
106 ScopedMinijail jail(minijail_new());
107
108 minijail_namespace_ipc(jail.get());
109 minijail_namespace_uts(jail.get());
110 minijail_namespace_net(jail.get());
111 minijail_namespace_cgroups(jail.get());
Honglin Yu7b6c1192020-09-16 10:07:17 +1000112
Honglin Yuf3c47b32021-05-06 18:38:37 +1000113 // The following sandboxing makes unit test crash so we do not use them in
114 // unit tests.
115 if (process_type_ != Type::kControlForTest) {
116 minijail_namespace_pids(jail.get());
117 minijail_namespace_vfs(jail.get());
118 std::string seccomp_policy_path = GetSeccompPolicyPath(model_name);
119 minijail_parse_seccomp_filters(jail.get(), seccomp_policy_path.c_str());
120 minijail_use_seccomp_filter(jail.get());
121 }
Honglin Yu7b6c1192020-09-16 10:07:17 +1000122
123 std::string fd_argv = kMojoBootstrapFdSwitchName;
124 // Use GetFD instead of TakeFD to non-destructively obtain the fd.
125 fd_argv = GetArgumentForWorkerProcess(
126 channel.remote_endpoint().platform_handle().GetFD().get());
Honglin Yuf3c47b32021-05-06 18:38:37 +1000127 char* const argv[3] = {&ml_service_path_[0], &fd_argv[0], nullptr};
Honglin Yu7b6c1192020-09-16 10:07:17 +1000128
Honglin Yuf3c47b32021-05-06 18:38:37 +1000129 if (minijail_run_pid(jail.get(), &ml_service_path_[0], argv, worker_pid) !=
Honglin Yu7b6c1192020-09-16 10:07:17 +1000130 0) {
Honglin Yu21616692021-05-14 11:20:22 +1000131 RecordProcessErrorEvent(ProcessError::kSpawnWorkerProcessFailed);
Honglin Yu7b6c1192020-09-16 10:07:17 +1000132 LOG(DFATAL) << "Failed to spawn worker process for " << model_name;
133 return false;
134 }
135
136 return true;
137}
138
139mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>&
140Process::SendMojoInvitationAndGetRemote(pid_t worker_pid,
141 mojo::PlatformChannel channel,
142 const std::string& model_name) {
143 // Send the Mojo invitation to the worker process.
144 mojo::OutgoingInvitation invitation;
145 mojo::ScopedMessagePipeHandle pipe =
146 invitation.AttachMessagePipe(kInternalMojoPrimordialPipeName);
147
148 mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>
149 remote(mojo::PendingRemote<
150 chromeos::machine_learning::mojom::MachineLearningService>(
151 std::move(pipe), 0u /* version */));
152
153 mojo::OutgoingInvitation::Send(std::move(invitation), worker_pid,
154 channel.TakeLocalEndpoint());
155
156 remote.set_disconnect_handler(
Honglin Yuf3c47b32021-05-06 18:38:37 +1000157 base::BindOnce(&Process::InternalPrimordialMojoPipeDisconnectHandler,
158 base::Unretained(this), worker_pid));
Honglin Yu7b6c1192020-09-16 10:07:17 +1000159
160 DCHECK(worker_pid_info_map_.find(worker_pid) == worker_pid_info_map_.end())
161 << "Worker pid already exists";
162
163 WorkerInfo worker_info;
164 worker_info.remote = std::move(remote);
165 worker_info.process_metrics =
166 base::ProcessMetrics::CreateProcessMetrics(worker_pid);
167 // Baseline the CPU usage counter in `process_metrics` to be zero as of now.
168 worker_info.process_metrics->GetPlatformIndependentCPUUsage();
169
170 worker_pid_info_map_[worker_pid] = std::move(worker_info);
171
172 return worker_pid_info_map_[worker_pid].remote;
173}
174
175void Process::UnregisterWorkerProcess(pid_t pid) {
176 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
177 const auto iter = worker_pid_info_map_.find(pid);
178 DCHECK(iter != worker_pid_info_map_.end()) << "Pid is not registered";
179 worker_pid_info_map_.erase(iter);
180}
181
Honglin Yuf3c47b32021-05-06 18:38:37 +1000182Process::Process()
183 : process_type_(Type::kUnset),
184 mojo_bootstrap_fd_(-1),
185 ml_service_path_(kDefaultMlServiceBinaryPath) {}
Honglin Yu7b6c1192020-09-16 10:07:17 +1000186Process::~Process() = default;
187
188void Process::ControlProcessRun() {
189 // We need to set euid to kMlServiceDBusUid to bootstrap DBus. Otherwise, DBus
190 // will block us because our euid inside of the userns is 0 but is 20106
191 // outside of the userns.
192 if (seteuid(kMlServiceDBusUid) != 0) {
Honglin Yu21616692021-05-14 11:20:22 +1000193 RecordProcessErrorEvent(ProcessError::kChangeEuidToMlServiceDBusFailed);
Honglin Yu7b6c1192020-09-16 10:07:17 +1000194 LOG(ERROR) << "Unable to change effective uid to " << kMlServiceDBusUid;
195 exit(EX_OSERR);
196 }
197
198 ml::Daemon daemon;
199 daemon.Run();
200}
201
202void Process::WorkerProcessRun() {
203 brillo::BaseMessageLoop message_loop;
204 message_loop.SetAsCurrent();
205 DETACH_FROM_SEQUENCE(sequence_checker_);
Honglin Yu7b6c1192020-09-16 10:07:17 +1000206 mojo::core::Init();
207 mojo::core::ScopedIPCSupport ipc_support(
208 base::ThreadTaskRunnerHandle::Get(),
209 mojo::core::ScopedIPCSupport::ShutdownPolicy::FAST);
Honglin Yu21616692021-05-14 11:20:22 +1000210 mojo::IncomingInvitation invitation;
211 {
Honglin Yuc083b1a2021-05-26 09:24:24 +1000212 WallTimeMetric walltime_metric(
Honglin Yu21616692021-05-14 11:20:22 +1000213 "MachineLearningService.WorkerProcessAcceptMojoConnectionTime");
214 invitation = mojo::IncomingInvitation::Accept(mojo::PlatformChannelEndpoint(
215 mojo::PlatformHandle(base::ScopedFD(mojo_bootstrap_fd_))));
216 }
Honglin Yu7b6c1192020-09-16 10:07:17 +1000217 mojo::ScopedMessagePipeHandle pipe =
218 invitation.ExtractMessagePipe(kInternalMojoPrimordialPipeName);
219 // The worker process exits if it disconnects with the control process.
220 // This can be important because in the control process's disconnect handler
221 // function we will use waitpid to wait for this process to finish. So
222 // the exit here will make sure that the waitpid in control process
223 // won't hang.
224 MachineLearningServiceImpl machine_learning_service_impl(
225 mojo::PendingReceiver<
226 chromeos::machine_learning::mojom::MachineLearningService>(
227 std::move(pipe)),
228 message_loop.QuitClosure());
229 message_loop.Run();
230}
231
232const std::unordered_map<pid_t, Process::WorkerInfo>&
233Process::GetWorkerPidInfoMap() {
234 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
235 return worker_pid_info_map_;
236}
237
Honglin Yuf3c47b32021-05-06 18:38:37 +1000238void Process::SetTypeForTesting(Type type) {
239 process_type_ = type;
240}
241
242void Process::SetMlServicePathForTesting(const std::string& path) {
243 ml_service_path_ = path;
244}
245
246void Process::SetBeforeExitWorkerDisconnectHandlerHookForTesting(
247 base::RepeatingClosure hook) {
248 before_exit_worker_disconnect_handler_hook_ = std::move(hook);
249}
250
251bool Process::IsControlProcess() {
252 return process_type_ == Type::kControl ||
253 process_type_ == Type::kControlForTest;
254}
255
256bool Process::IsWorkerProcess() {
257 return process_type_ == Type::kWorker ||
258 process_type_ == Type::kSingleProcessForTest;
259}
260
261void Process::InternalPrimordialMojoPipeDisconnectHandler(pid_t child_pid) {
Honglin Yuc083b1a2021-05-26 09:24:24 +1000262 WallTimeMetric walltime_metric(
263 "MachineLearningService.WorkerProcessCleanUpTime");
Honglin Yu21616692021-05-14 11:20:22 +1000264
Honglin Yuf3c47b32021-05-06 18:38:37 +1000265 UnregisterWorkerProcess(child_pid);
266 // Reap the worker process.
267 int status;
268 pid_t ret_pid = waitpid(child_pid, &status, 0);
269 DCHECK(ret_pid == child_pid);
Honglin Yu21616692021-05-14 11:20:22 +1000270 int exit_status = WEXITSTATUS(status);
271 if (exit_status != 0) {
272 RecordWorkerProcessExitStatus(WEXITSTATUS(status));
273 }
Honglin Yuf3c47b32021-05-06 18:38:37 +1000274
275 // Call the hooks used in testing.
276 if (process_type_ == Type::kControlForTest &&
277 !before_exit_worker_disconnect_handler_hook_.is_null()) {
278 before_exit_worker_disconnect_handler_hook_.Run();
279 }
280}
281
Honglin Yu7b6c1192020-09-16 10:07:17 +1000282} // namespace ml