blob: 4ec70b8887093c8370b6c071b76a7532b84f72d1 [file] [log] [blame]
Honglin Yu7b6c1192020-09-16 10:07:17 +10001// Copyright 2021 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_PROCESS_H_
6#define ML_PROCESS_H_
7
8#include <memory>
9#include <string>
10#include <unordered_map>
11
12#include <unistd.h>
13
14#include <base/macros.h>
15#include <base/no_destructor.h>
16#include <base/process/process_metrics.h>
17#include <mojo/public/cpp/bindings/remote.h>
18#include <base/sequence_checker.h>
19
20#include "ml/machine_learning_service_impl.h"
21
22namespace ml {
23
24// A singleton class to store the global process information and provide
25// process management functions.
26// Usage: access the global instance by calling `Process::GetInstance()`.
27class Process {
28 public:
29 // The type of a process.
30 enum class Type {
31 kUnset = 0,
32 kControl = 1,
33 kWorker = 2,
34 };
35
36 // The exit code of a process.
37 enum ExitCode : int {
38 kSuccess = 0,
39 // Only for worker process used when its mojo connection with the control
40 // process breaks.
41 kWorkerDisconnectWithControl = 1,
42 kInvalidProcessType = 2,
43 kUnexpectedCommandLine = 3,
44 };
45
46 // The worker process info, containing object to contact and measure worker
47 // process in the control process.
48 struct WorkerInfo {
49 // The Mojo remote to call the worker process's `MachineLearningService`
50 // bindings.
51 mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>
52 remote;
53 // The process metrics object of the worker process.
54 std::unique_ptr<base::ProcessMetrics> process_metrics;
55 };
56
57 static Process* GetInstance();
58
59 int Run(int argc, char* argv[]);
60
61 // Gets the process type of current process.
62 Type GetType();
63
64 // Returns true if the worker process has been started successfully and the
65 // worker's pid is stored in `worker_pid`. Otherwise returns false and
66 // `worker_pid` is unchanged.
67 // The argument `model_name` has two usages:
68 // - it used in logging (like `metrics_model_name`).
69 // - it also determines which seccomp policy list to use in sandboxing the
70 // worker process.
71 bool SpawnWorkerProcessAndGetPid(const mojo::PlatformChannel& channel,
72 const std::string& model_name,
73 pid_t* worker_pid);
74
75 // Returns a reference of the remote of the worker process. The remote is hold
76 // in the `worker_pid_info_map_` object.
77 mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>&
78 SendMojoInvitationAndGetRemote(pid_t worker_pid,
79 mojo::PlatformChannel channel,
80 const std::string& model_name);
81
82 // Removes a worker process from metadata. This does not terminate the
83 // worker process.
84 void UnregisterWorkerProcess(pid_t pid);
85
86 const std::unordered_map<pid_t, WorkerInfo>& GetWorkerPidInfoMap();
87
88 private:
89 friend base::NoDestructor<Process>;
90
91 Process();
92 ~Process();
93
94 // Can only be called by the control process.
95 void ControlProcessRun();
96
97 // Can only be called by the worker process.
98 // Input: the file descriptor used to bootstrap Mojo connection.
99 void WorkerProcessRun();
100
101 // The type of current process.
102 Type process_type_;
103
104 // The file descriptor to bootstrap the mojo connection of current process.
105 // Only meaningful for worker process.
106 int mojo_bootstrap_fd_;
107
108 // The map from pid to the info of worker processes. Only meaningful for
109 // control process.
110 std::unordered_map<pid_t, WorkerInfo> worker_pid_info_map_;
111
112 // Mainly used for guarding `worker_pid_info_map_`.
113 SEQUENCE_CHECKER(sequence_checker_);
114};
115
116} // namespace ml
117
118#endif // ML_PROCESS_H_