blob: 7ced72266efd0ad11cb4105428f137ad98f36de8 [file] [log] [blame]
Honglin Yu7b6c1192020-09-16 10:07:17 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/process.h"
6
7#include <utility>
8
9#include <signal.h>
10#include <sysexits.h>
11#include <sys/mount.h>
12#include <sys/types.h>
13#include <sys/wait.h>
14#include <unistd.h>
15#include <pwd.h>
16
17#include <base/logging.h>
18#include <base/process/process_metrics.h>
19#include <base/strings/string_number_conversions.h>
20#include <base/strings/string_util.h>
21#include <libminijail.h>
22#include <mojo/core/embedder/embedder.h>
23#include <mojo/public/cpp/platform/platform_channel.h>
24#include <mojo/public/cpp/system/invitation.h>
25#include <scoped_minijail.h>
26
27#include "ml/daemon.h"
28#include "ml/machine_learning_service_impl.h"
29
30namespace ml {
31
32namespace {
33constexpr char kMojoBootstrapFdSwitchName[] = "mojo-bootstrap-fd";
34
35constexpr char kInternalMojoPrimordialPipeName[] = "cros_ml";
36
37constexpr char kMlServiceBinaryPath[] = "/usr/bin/ml_service";
38
39constexpr uid_t kMlServiceDBusUid = 20177;
40
41std::string GetSeccompPolicyPath(const std::string& model_name) {
42 return "/usr/share/policy/ml_service-" + model_name + "-seccomp.policy";
43}
44
45std::string GetArgumentForWorkerProcess(int fd) {
46 std::string fd_argv = kMojoBootstrapFdSwitchName;
47 return "--" + fd_argv + "=" + std::to_string(fd);
48}
49
50void InternalPrimordialMojoPipeDisconnectHandler(pid_t child_pid) {
51 Process::GetInstance()->UnregisterWorkerProcess(child_pid);
52 // Reap the worker process.
53 int status;
54 pid_t ret_pid = waitpid(child_pid, &status, 0);
55 DCHECK(ret_pid == child_pid);
56 // TODO(https://crbug.com/1202545): report WEXITSTATUS(status) to UMA.
57 DVLOG(1) << "Worker process (" << child_pid << ") exits with status "
58 << WEXITSTATUS(status);
59}
60} // namespace
61
62// static
63Process* Process::GetInstance() {
64 // This is thread-safe.
65 static base::NoDestructor<Process> instance;
66 return instance.get();
67}
68
69int Process::Run(int argc, char* argv[]) {
70 // Parses the command line and determines the process type.
71 base::CommandLine command_line(argc, argv);
72 std::string mojo_fd_string =
73 command_line.GetSwitchValueASCII(kMojoBootstrapFdSwitchName);
74
75 if (mojo_fd_string.empty()) {
76 process_type_ = Type::kControl;
77 } else {
78 process_type_ = Type::kWorker;
79 }
80
81 if (!command_line.GetArgs().empty()) {
82 LOG(ERROR) << "Unexpected command line arguments: "
83 << base::JoinString(command_line.GetArgs(), "\t");
84 return ExitCode::kUnexpectedCommandLine;
85 }
86
87 if (process_type_ == Type::kControl) {
88 ControlProcessRun();
89 } else {
90 // The process type is either "control" or "worker".
91 DCHECK(GetType() == Type::kWorker);
92 const auto is_valid_fd_str =
93 base::StringToInt(mojo_fd_string, &mojo_bootstrap_fd_);
94 DCHECK(is_valid_fd_str) << "Invalid mojo bootstrap fd";
95 WorkerProcessRun();
96 }
97
98 return ExitCode::kSuccess;
99}
100
101Process::Type Process::GetType() {
102 return process_type_;
103}
104
105bool Process::SpawnWorkerProcessAndGetPid(const mojo::PlatformChannel& channel,
106 const std::string& model_name,
107 pid_t* worker_pid) {
108 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
109 DCHECK(worker_pid != nullptr);
110 // Should only be called by the control process.
111 DCHECK(process_type_ == Type::kControl)
112 << "Should only be called by the control process";
113
114 // Start the process.
115 ScopedMinijail jail(minijail_new());
116
117 minijail_namespace_ipc(jail.get());
118 minijail_namespace_uts(jail.get());
119 minijail_namespace_net(jail.get());
120 minijail_namespace_cgroups(jail.get());
121 minijail_namespace_pids(jail.get());
122 minijail_namespace_vfs(jail.get());
123
124 std::string seccomp_policy_path = GetSeccompPolicyPath(model_name);
125 minijail_parse_seccomp_filters(jail.get(), seccomp_policy_path.c_str());
126 minijail_use_seccomp_filter(jail.get());
127
128 std::string fd_argv = kMojoBootstrapFdSwitchName;
129 // Use GetFD instead of TakeFD to non-destructively obtain the fd.
130 fd_argv = GetArgumentForWorkerProcess(
131 channel.remote_endpoint().platform_handle().GetFD().get());
132
133 std::string mlservice_binary_path(kMlServiceBinaryPath);
134
135 char* const argv[3] = {&mlservice_binary_path[0], &fd_argv[0], nullptr};
136
137 // TODO(https://crbug.com/1202545): report the failure.
138 if (minijail_run_pid(jail.get(), kMlServiceBinaryPath, argv, worker_pid) !=
139 0) {
140 LOG(DFATAL) << "Failed to spawn worker process for " << model_name;
141 return false;
142 }
143
144 return true;
145}
146
147mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>&
148Process::SendMojoInvitationAndGetRemote(pid_t worker_pid,
149 mojo::PlatformChannel channel,
150 const std::string& model_name) {
151 // Send the Mojo invitation to the worker process.
152 mojo::OutgoingInvitation invitation;
153 mojo::ScopedMessagePipeHandle pipe =
154 invitation.AttachMessagePipe(kInternalMojoPrimordialPipeName);
155
156 mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>
157 remote(mojo::PendingRemote<
158 chromeos::machine_learning::mojom::MachineLearningService>(
159 std::move(pipe), 0u /* version */));
160
161 mojo::OutgoingInvitation::Send(std::move(invitation), worker_pid,
162 channel.TakeLocalEndpoint());
163
164 remote.set_disconnect_handler(
165 base::BindOnce(InternalPrimordialMojoPipeDisconnectHandler, worker_pid));
166
167 DCHECK(worker_pid_info_map_.find(worker_pid) == worker_pid_info_map_.end())
168 << "Worker pid already exists";
169
170 WorkerInfo worker_info;
171 worker_info.remote = std::move(remote);
172 worker_info.process_metrics =
173 base::ProcessMetrics::CreateProcessMetrics(worker_pid);
174 // Baseline the CPU usage counter in `process_metrics` to be zero as of now.
175 worker_info.process_metrics->GetPlatformIndependentCPUUsage();
176
177 worker_pid_info_map_[worker_pid] = std::move(worker_info);
178
179 return worker_pid_info_map_[worker_pid].remote;
180}
181
182void Process::UnregisterWorkerProcess(pid_t pid) {
183 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
184 const auto iter = worker_pid_info_map_.find(pid);
185 DCHECK(iter != worker_pid_info_map_.end()) << "Pid is not registered";
186 worker_pid_info_map_.erase(iter);
187}
188
189Process::Process() : process_type_(Type::kUnset), mojo_bootstrap_fd_(-1) {}
190Process::~Process() = default;
191
192void Process::ControlProcessRun() {
193 // We need to set euid to kMlServiceDBusUid to bootstrap DBus. Otherwise, DBus
194 // will block us because our euid inside of the userns is 0 but is 20106
195 // outside of the userns.
196 if (seteuid(kMlServiceDBusUid) != 0) {
197 // TODO(https://crbug.com/1202545): report this error to UMA.
198 LOG(ERROR) << "Unable to change effective uid to " << kMlServiceDBusUid;
199 exit(EX_OSERR);
200 }
201
202 ml::Daemon daemon;
203 daemon.Run();
204}
205
206void Process::WorkerProcessRun() {
207 brillo::BaseMessageLoop message_loop;
208 message_loop.SetAsCurrent();
209 DETACH_FROM_SEQUENCE(sequence_checker_);
210
211 mojo::core::Init();
212 mojo::core::ScopedIPCSupport ipc_support(
213 base::ThreadTaskRunnerHandle::Get(),
214 mojo::core::ScopedIPCSupport::ShutdownPolicy::FAST);
215 mojo::IncomingInvitation invitation =
216 mojo::IncomingInvitation::Accept(mojo::PlatformChannelEndpoint(
217 mojo::PlatformHandle(base::ScopedFD(mojo_bootstrap_fd_))));
218 mojo::ScopedMessagePipeHandle pipe =
219 invitation.ExtractMessagePipe(kInternalMojoPrimordialPipeName);
220 // The worker process exits if it disconnects with the control process.
221 // This can be important because in the control process's disconnect handler
222 // function we will use waitpid to wait for this process to finish. So
223 // the exit here will make sure that the waitpid in control process
224 // won't hang.
225 MachineLearningServiceImpl machine_learning_service_impl(
226 mojo::PendingReceiver<
227 chromeos::machine_learning::mojom::MachineLearningService>(
228 std::move(pipe)),
229 message_loop.QuitClosure());
230 message_loop.Run();
231}
232
233const std::unordered_map<pid_t, Process::WorkerInfo>&
234Process::GetWorkerPidInfoMap() {
235 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
236 return worker_pid_info_map_;
237}
238
239} // namespace ml