blob: bf9778b3b0d12f26ae498cfde8009a71c5bed96d [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013#include <base/files/file.h>
14#include <base/files/file_util.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110016#include <unicode/putil.h>
17#include <unicode/udata.h>
18#include <utils/memory/mmap.h>
Michael Martisa74af932018-08-13 16:52:36 +100019
charleszhao17777f92020-04-23 12:53:11 +100020#include "ml/handwriting.h"
21#include "ml/handwriting_recognizer_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100022#include "ml/model_impl.h"
charleszhao17777f92020-04-23 12:53:11 +100023#include "ml/mojom/handwriting_recognizer.mojom.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090024#include "ml/mojom/model.mojom.h"
Honglin Yud2204272020-08-26 14:21:37 +100025#include "ml/mojom/soda.mojom.h"
26#include "ml/soda_recognizer_impl.h"
Honglin Yuf33dce32019-12-05 15:10:39 +110027#include "ml/text_classifier_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100028
Andrew Moylanff6be512018-07-03 11:05:01 +100029namespace ml {
30
Michael Martisa74af932018-08-13 16:52:36 +100031namespace {
32
Honglin Yu0ed72352019-08-27 17:42:01 +100033using ::chromeos::machine_learning::mojom::BuiltinModelId;
34using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
35using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100036using ::chromeos::machine_learning::mojom::HandwritingRecognizer;
charleszhao05c5a4a2020-06-09 16:49:54 +100037using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpec;
38using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr;
Charles Zhao6d467e62020-08-31 10:02:03 +100039using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100040using ::chromeos::machine_learning::mojom::LoadModelResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100041using ::chromeos::machine_learning::mojom::MachineLearningService;
42using ::chromeos::machine_learning::mojom::Model;
Honglin Yud2204272020-08-26 14:21:37 +100043using ::chromeos::machine_learning::mojom::SodaClient;
44using ::chromeos::machine_learning::mojom::SodaConfigPtr;
45using ::chromeos::machine_learning::mojom::SodaRecognizer;
Andrew Moylanb481af72020-07-09 15:22:00 +100046using ::chromeos::machine_learning::mojom::TextClassifier;
Michael Martisa74af932018-08-13 16:52:36 +100047
48constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Andrew Moylan79b34a42020-07-08 11:13:11 +100049// Base name for UMA metrics related to model loading (`LoadBuiltinModel`,
50// `LoadFlatBufferModel`, `LoadTextClassifier` or LoadHandwritingModel).
Honglin Yu6adafcd2019-07-22 13:48:11 +100051constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100052
Honglin Yuf33dce32019-12-05 15:10:39 +110053constexpr char kIcuDataFilePath[] = "/opt/google/chrome/icudtl.dat";
54
Michael Martisa74af932018-08-13 16:52:36 +100055} // namespace
56
Andrew Moylanff6be512018-07-03 11:05:01 +100057MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100058 mojo::ScopedMessagePipeHandle pipe,
Andrew Moylanb481af72020-07-09 15:22:00 +100059 base::Closure disconnect_handler,
Michael Martisa74af932018-08-13 16:52:36 +100060 const std::string& model_dir)
Honglin Yuf33dce32019-12-05 15:10:39 +110061 : icu_data_(nullptr),
Honglin Yuf33dce32019-12-05 15:10:39 +110062 builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100063 model_dir_(model_dir),
Andrew Moylanb481af72020-07-09 15:22:00 +100064 receiver_(this,
65 mojo::InterfaceRequest<
66 chromeos::machine_learning::mojom::MachineLearningService>(
67 std::move(pipe))) {
68 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Andrew Moylanff6be512018-07-03 11:05:01 +100069}
70
Michael Martisa74af932018-08-13 16:52:36 +100071MachineLearningServiceImpl::MachineLearningServiceImpl(
Charles Zhaod4fb7b62020-08-25 17:21:58 +100072 mojo::ScopedMessagePipeHandle pipe,
73 base::Closure disconnect_handler,
74 dbus::Bus* bus)
Andrew Moylanb481af72020-07-09 15:22:00 +100075 : MachineLearningServiceImpl(
Charles Zhaod4fb7b62020-08-25 17:21:58 +100076 std::move(pipe), std::move(disconnect_handler), kSystemModelDir) {
77 if (bus) {
78 dlcservice_client_ = std::make_unique<DlcserviceClient>(bus);
79 }
80}
Michael Martisa74af932018-08-13 16:52:36 +100081
Andrew Moylanb481af72020-07-09 15:22:00 +100082void MachineLearningServiceImpl::Clone(
83 mojo::PendingReceiver<MachineLearningService> receiver) {
84 clone_receivers_.Add(this, std::move(receiver));
Andrew Moylan2fb80af2020-07-08 10:52:08 +100085}
86
Honglin Yu0ed72352019-08-27 17:42:01 +100087void MachineLearningServiceImpl::LoadBuiltinModel(
88 BuiltinModelSpecPtr spec,
Andrew Moylanb481af72020-07-09 15:22:00 +100089 mojo::PendingReceiver<Model> receiver,
Qijiang Fan5d381a02020-04-19 23:42:37 +090090 LoadBuiltinModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +100091 // Unsupported models do not have metadata entries.
92 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
93 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100094 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
95 << spec->id << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +090096 std::move(callback).Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +100097 RecordModelSpecificationErrorEvent();
98 return;
99 }
100
101 const BuiltinModelMetadata& metadata = metadata_lookup->second;
102
103 DCHECK(!metadata.metrics_model_name.empty());
104
charleszhao5a7050e2020-07-14 15:21:41 +1000105 RequestMetrics request_metrics(metadata.metrics_model_name,
106 kMetricsRequestName);
Honglin Yu0ed72352019-08-27 17:42:01 +1000107 request_metrics.StartRecordingPerformanceMetrics();
108
109 // Attempt to load model.
110 const std::string model_path = model_dir_ + metadata.model_file;
111 std::unique_ptr<tflite::FlatBufferModel> model =
112 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
113 if (model == nullptr) {
114 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900115 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000116 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
117 return;
118 }
119
Honglin Yuc0cef102020-01-17 15:26:01 +1100120 ModelImpl::Create(metadata.required_inputs, metadata.required_outputs,
Andrew Moylanb481af72020-07-09 15:22:00 +1000121 std::move(model), std::move(receiver),
Honglin Yuc0cef102020-01-17 15:26:01 +1100122 metadata.metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +1000123
Qijiang Fan5d381a02020-04-19 23:42:37 +0900124 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000125
126 request_metrics.FinishRecordingPerformanceMetrics();
127 request_metrics.RecordRequestEvent(LoadModelResult::OK);
128}
129
130void MachineLearningServiceImpl::LoadFlatBufferModel(
131 FlatBufferModelSpecPtr spec,
Andrew Moylanb481af72020-07-09 15:22:00 +1000132 mojo::PendingReceiver<Model> receiver,
Qijiang Fan5d381a02020-04-19 23:42:37 +0900133 LoadFlatBufferModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +1000134 DCHECK(!spec->metrics_model_name.empty());
135
charleszhao5a7050e2020-07-14 15:21:41 +1000136 RequestMetrics request_metrics(spec->metrics_model_name, kMetricsRequestName);
Honglin Yu0ed72352019-08-27 17:42:01 +1000137 request_metrics.StartRecordingPerformanceMetrics();
138
Andrew Moylan79b34a42020-07-08 11:13:11 +1000139 // Take the ownership of the content of `model_string` because `ModelImpl` has
Honglin Yu0ed72352019-08-27 17:42:01 +1000140 // to hold the memory.
141 auto model_string_impl =
142 std::make_unique<std::string>(std::move(spec->model_string));
143
144 std::unique_ptr<tflite::FlatBufferModel> model =
145 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
146 model_string_impl->length());
147 if (model == nullptr) {
148 LOG(ERROR) << "Failed to load model string of metric name: "
149 << spec->metrics_model_name << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900150 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000151 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
152 return;
153 }
154
Honglin Yuc0cef102020-01-17 15:26:01 +1100155 ModelImpl::Create(
Honglin Yu0ed72352019-08-27 17:42:01 +1000156 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
157 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
Andrew Moylanb481af72020-07-09 15:22:00 +1000158 std::move(model), std::move(model_string_impl), std::move(receiver),
Honglin Yu0ed72352019-08-27 17:42:01 +1000159 spec->metrics_model_name);
160
Qijiang Fan5d381a02020-04-19 23:42:37 +0900161 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000162
163 request_metrics.FinishRecordingPerformanceMetrics();
164 request_metrics.RecordRequestEvent(LoadModelResult::OK);
165}
166
Honglin Yuf33dce32019-12-05 15:10:39 +1100167void MachineLearningServiceImpl::LoadTextClassifier(
Andrew Moylanb481af72020-07-09 15:22:00 +1000168 mojo::PendingReceiver<TextClassifier> receiver,
Honglin Yuf33dce32019-12-05 15:10:39 +1100169 LoadTextClassifierCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000170 RequestMetrics request_metrics("TextClassifier", kMetricsRequestName);
Honglin Yuf33dce32019-12-05 15:10:39 +1100171 request_metrics.StartRecordingPerformanceMetrics();
172
Honglin Yuf33dce32019-12-05 15:10:39 +1100173 // Create the TextClassifier.
Honglin Yu3f99ff12020-10-15 00:40:11 +1100174 if (!TextClassifierImpl::Create(std::move(receiver))) {
Honglin Yuf33dce32019-12-05 15:10:39 +1100175 LOG(ERROR) << "Failed to create TextClassifierImpl object.";
176 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
177 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
178 return;
179 }
180
181 // initialize the icu library.
182 InitIcuIfNeeded();
183
184 std::move(callback).Run(LoadModelResult::OK);
185
186 request_metrics.FinishRecordingPerformanceMetrics();
187 request_metrics.RecordRequestEvent(LoadModelResult::OK);
188}
189
Charles Zhao6d467e62020-08-31 10:02:03 +1000190void LoadHandwritingModelFromDir(
191 HandwritingRecognizerSpecPtr spec,
192 mojo::PendingReceiver<HandwritingRecognizer> receiver,
193 MachineLearningServiceImpl::LoadHandwritingModelCallback callback,
194 const std::string& root_path) {
195 RequestMetrics request_metrics("HandwritingModel", kMetricsRequestName);
196 request_metrics.StartRecordingPerformanceMetrics();
197
198 // Returns error if root_path is empty.
199 if (root_path.empty()) {
200 std::move(callback).Run(LoadHandwritingModelResult::DLC_GET_PATH_ERROR);
201 request_metrics.RecordRequestEvent(
202 LoadHandwritingModelResult::DLC_GET_PATH_ERROR);
203 return;
204 }
205
206 // Load HandwritingLibrary.
207 auto* const hwr_library = ml::HandwritingLibrary::GetInstance(root_path);
208
209 if (hwr_library->GetStatus() != ml::HandwritingLibrary::Status::kOk) {
210 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
211 << static_cast<int>(hwr_library->GetStatus());
212
213 switch (hwr_library->GetStatus()) {
214 case ml::HandwritingLibrary::Status::kLoadLibraryFailed: {
215 std::move(callback).Run(
216 LoadHandwritingModelResult::LOAD_NATIVE_LIB_ERROR);
217 request_metrics.RecordRequestEvent(
218 LoadHandwritingModelResult::LOAD_NATIVE_LIB_ERROR);
219 return;
220 }
221 case ml::HandwritingLibrary::Status::kFunctionLookupFailed: {
222 std::move(callback).Run(
223 LoadHandwritingModelResult::LOAD_FUNC_PTR_ERROR);
224 request_metrics.RecordRequestEvent(
225 LoadHandwritingModelResult::LOAD_FUNC_PTR_ERROR);
226 return;
227 }
228 default: {
229 std::move(callback).Run(LoadHandwritingModelResult::LOAD_MODEL_ERROR);
230 request_metrics.RecordRequestEvent(
231 LoadHandwritingModelResult::LOAD_MODEL_ERROR);
232 return;
233 }
234 }
235 }
236
237 // Create HandwritingRecognizer.
238 if (!HandwritingRecognizerImpl::Create(std::move(spec),
239 std::move(receiver))) {
240 LOG(ERROR) << "LoadHandwritingRecognizer returned false.";
241 std::move(callback).Run(LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR);
242 request_metrics.RecordRequestEvent(
243 LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR);
244 return;
245 }
246
247 std::move(callback).Run(LoadHandwritingModelResult::OK);
248 request_metrics.FinishRecordingPerformanceMetrics();
249 request_metrics.RecordRequestEvent(LoadHandwritingModelResult::OK);
250}
251
252void MachineLearningServiceImpl::LoadHandwritingModel(
253 chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr spec,
254 mojo::PendingReceiver<
255 chromeos::machine_learning::mojom::HandwritingRecognizer> receiver,
256 LoadHandwritingModelCallback callback) {
257 // If handwriting is installed on rootfs, load it from there.
258 if (ml::HandwritingLibrary::IsUseLibHandwritingEnabled()) {
259 LoadHandwritingModelFromDir(
260 std::move(spec), std::move(receiver), std::move(callback),
261 ml::HandwritingLibrary::kHandwritingDefaultModelDir);
262 return;
263 }
264
265 // If handwriting is installed as DLC, get the dir and subsequently load it
266 // from there.
267 if (ml::HandwritingLibrary::IsUseLibHandwritingDlcEnabled()) {
268 dlcservice_client_->GetDlcRootPath(
269 "libhandwriting",
270 base::BindOnce(&LoadHandwritingModelFromDir, std::move(spec),
271 std::move(receiver), std::move(callback)));
272 return;
273 }
274
275 // If handwriting is not on rootfs and not in DLC, this function should not
276 // be called.
277 LOG(ERROR) << "Calling LoadHandwritingModel without Handwriting enabled "
278 "should never happen.";
279 std::move(callback).Run(LoadHandwritingModelResult::LOAD_MODEL_ERROR);
charleszhao05c5a4a2020-06-09 16:49:54 +1000280}
281
282void MachineLearningServiceImpl::LoadHandwritingModelWithSpec(
283 HandwritingRecognizerSpecPtr spec,
Andrew Moylanb481af72020-07-09 15:22:00 +1000284 mojo::PendingReceiver<HandwritingRecognizer> receiver,
Charles Zhaoc882eb02020-07-27 10:02:35 +1000285 LoadHandwritingModelWithSpecCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000286 RequestMetrics request_metrics("HandwritingModel", kMetricsRequestName);
charleszhao17777f92020-04-23 12:53:11 +1000287 request_metrics.StartRecordingPerformanceMetrics();
288
289 // Load HandwritingLibrary.
290 auto* const hwr_library = ml::HandwritingLibrary::GetInstance();
291
292 if (hwr_library->GetStatus() ==
293 ml::HandwritingLibrary::Status::kNotSupported) {
294 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
295 << static_cast<int>(hwr_library->GetStatus());
296
297 std::move(callback).Run(LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
298 request_metrics.RecordRequestEvent(
299 LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
300 return;
301 }
302
303 if (hwr_library->GetStatus() != ml::HandwritingLibrary::Status::kOk) {
304 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
305 << static_cast<int>(hwr_library->GetStatus());
306
307 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
308 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
309 return;
310 }
311
312 // Create HandwritingRecognizer.
Andrew Moylanb481af72020-07-09 15:22:00 +1000313 if (!HandwritingRecognizerImpl::Create(std::move(spec),
314 std::move(receiver))) {
charleszhao17777f92020-04-23 12:53:11 +1000315 LOG(ERROR) << "LoadHandwritingRecognizer returned false.";
316 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
317 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
318 return;
319 }
320
321 std::move(callback).Run(LoadModelResult::OK);
322 request_metrics.FinishRecordingPerformanceMetrics();
323 request_metrics.RecordRequestEvent(LoadModelResult::OK);
324}
325
Honglin Yud2204272020-08-26 14:21:37 +1000326void MachineLearningServiceImpl::LoadSpeechRecognizer(
327 SodaConfigPtr config,
328 mojo::PendingRemote<SodaClient> soda_client,
329 mojo::PendingReceiver<SodaRecognizer> soda_recognizer,
330 LoadSpeechRecognizerCallback callback) {
331 RequestMetrics request_metrics("Soda", kMetricsRequestName);
332 request_metrics.StartRecordingPerformanceMetrics();
333
334 // Create the SodaRecognizer.
335 if (!SodaRecognizerImpl::Create(std::move(config), std::move(soda_client),
336 std::move(soda_recognizer))) {
337 LOG(ERROR) << "Failed to create SodaRecognizerImpl object.";
338 // TODO(robsc): it may be better that SODA has its specific enum values to
339 // return, similar to handwriting. So before we finalize the impl of SODA
340 // Mojo API, we may revisit this return value.
341 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
342 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
343 return;
344 }
345
346 std::move(callback).Run(LoadModelResult::OK);
347
348 request_metrics.FinishRecordingPerformanceMetrics();
349 request_metrics.RecordRequestEvent(LoadModelResult::OK);
350}
351
Honglin Yuf33dce32019-12-05 15:10:39 +1100352void MachineLearningServiceImpl::InitIcuIfNeeded() {
353 if (icu_data_ == nullptr) {
354 // Need to load the data file again.
355 int64_t file_size;
356 const base::FilePath icu_data_file_path(kIcuDataFilePath);
357 CHECK(base::GetFileSize(icu_data_file_path, &file_size));
358 icu_data_ = new char[file_size];
359 CHECK(base::ReadFile(icu_data_file_path, icu_data_,
360 static_cast<int>(file_size)) == file_size);
361 // Init the Icu library.
362 UErrorCode err = U_ZERO_ERROR;
363 udata_setCommonData(reinterpret_cast<void*>(icu_data_), &err);
364 DCHECK(err == U_ZERO_ERROR);
365 // Never try to load Icu data from files.
366 udata_setFileAccess(UDATA_ONLY_PACKAGES, &err);
367 }
368}
369
Andrew Moylanff6be512018-07-03 11:05:01 +1000370} // namespace ml