blob: 84d992b4193518eced14fb2bb35036fa6d1ff28a [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013#include <base/files/file.h>
14#include <base/files/file_util.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110016#include <unicode/putil.h>
17#include <unicode/udata.h>
18#include <utils/memory/mmap.h>
Michael Martisa74af932018-08-13 16:52:36 +100019
charleszhao17777f92020-04-23 12:53:11 +100020#include "ml/handwriting.h"
charleszhao05c5a4a2020-06-09 16:49:54 +100021#include "ml/handwriting_path.h"
charleszhao17777f92020-04-23 12:53:11 +100022#include "ml/handwriting_recognizer_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100023#include "ml/model_impl.h"
charleszhao17777f92020-04-23 12:53:11 +100024#include "ml/mojom/handwriting_recognizer.mojom.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090025#include "ml/mojom/model.mojom.h"
Honglin Yuf33dce32019-12-05 15:10:39 +110026#include "ml/text_classifier_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100027
Andrew Moylanff6be512018-07-03 11:05:01 +100028namespace ml {
29
Michael Martisa74af932018-08-13 16:52:36 +100030namespace {
31
Honglin Yu0ed72352019-08-27 17:42:01 +100032using ::chromeos::machine_learning::mojom::BuiltinModelId;
33using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
34using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
charleszhao17777f92020-04-23 12:53:11 +100035using ::chromeos::machine_learning::mojom::HandwritingRecognizerRequest;
charleszhao05c5a4a2020-06-09 16:49:54 +100036using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpec;
37using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100038using ::chromeos::machine_learning::mojom::LoadModelResult;
Andrew Moylan2fb80af2020-07-08 10:52:08 +100039using ::chromeos::machine_learning::mojom::MachineLearningServiceRequest;
Michael Martisa74af932018-08-13 16:52:36 +100040using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100041
42constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Andrew Moylan79b34a42020-07-08 11:13:11 +100043// Base name for UMA metrics related to model loading (`LoadBuiltinModel`,
44// `LoadFlatBufferModel`, `LoadTextClassifier` or LoadHandwritingModel).
Honglin Yu6adafcd2019-07-22 13:48:11 +100045constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100046
Honglin Yuf33dce32019-12-05 15:10:39 +110047constexpr char kTextClassifierModelFile[] =
48 "mlservice-model-text_classifier_en-v706.fb";
49
50constexpr char kIcuDataFilePath[] = "/opt/google/chrome/icudtl.dat";
51
Michael Martisa74af932018-08-13 16:52:36 +100052} // namespace
53
Andrew Moylanff6be512018-07-03 11:05:01 +100054MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100055 mojo::ScopedMessagePipeHandle pipe,
56 base::Closure connection_error_handler,
57 const std::string& model_dir)
Honglin Yuf33dce32019-12-05 15:10:39 +110058 : icu_data_(nullptr),
59 text_classifier_model_filename_(kTextClassifierModelFile),
60 builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100061 model_dir_(model_dir),
hscham68867652020-01-06 11:40:47 +090062 binding_(this,
63 mojo::InterfaceRequest<
64 chromeos::machine_learning::mojom::MachineLearningService>(
Honglin Yuf33dce32019-12-05 15:10:39 +110065 std::move(pipe))) {
Andrew Moylanff6be512018-07-03 11:05:01 +100066 binding_.set_connection_error_handler(std::move(connection_error_handler));
67}
68
Michael Martisa74af932018-08-13 16:52:36 +100069MachineLearningServiceImpl::MachineLearningServiceImpl(
70 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
71 : MachineLearningServiceImpl(std::move(pipe),
72 std::move(connection_error_handler),
73 kSystemModelDir) {}
74
Honglin Yuf33dce32019-12-05 15:10:39 +110075void MachineLearningServiceImpl::SetTextClassifierModelFilenameForTesting(
76 const std::string& filename) {
77 text_classifier_model_filename_ = filename;
78}
79
Andrew Moylan2fb80af2020-07-08 10:52:08 +100080void MachineLearningServiceImpl::Clone(MachineLearningServiceRequest request) {
81 clone_bindings_.AddBinding(this, std::move(request));
82}
83
Honglin Yu0ed72352019-08-27 17:42:01 +100084void MachineLearningServiceImpl::LoadBuiltinModel(
85 BuiltinModelSpecPtr spec,
86 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +090087 LoadBuiltinModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +100088 // Unsupported models do not have metadata entries.
89 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
90 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100091 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
92 << spec->id << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +090093 std::move(callback).Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +100094 RecordModelSpecificationErrorEvent();
95 return;
96 }
97
98 const BuiltinModelMetadata& metadata = metadata_lookup->second;
99
100 DCHECK(!metadata.metrics_model_name.empty());
101
102 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
103 kMetricsRequestName);
104 request_metrics.StartRecordingPerformanceMetrics();
105
106 // Attempt to load model.
107 const std::string model_path = model_dir_ + metadata.model_file;
108 std::unique_ptr<tflite::FlatBufferModel> model =
109 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
110 if (model == nullptr) {
111 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900112 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000113 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
114 return;
115 }
116
Honglin Yuc0cef102020-01-17 15:26:01 +1100117 ModelImpl::Create(metadata.required_inputs, metadata.required_outputs,
118 std::move(model), std::move(request),
119 metadata.metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +1000120
Qijiang Fan5d381a02020-04-19 23:42:37 +0900121 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000122
123 request_metrics.FinishRecordingPerformanceMetrics();
124 request_metrics.RecordRequestEvent(LoadModelResult::OK);
125}
126
127void MachineLearningServiceImpl::LoadFlatBufferModel(
128 FlatBufferModelSpecPtr spec,
129 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +0900130 LoadFlatBufferModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +1000131 DCHECK(!spec->metrics_model_name.empty());
132
133 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
134 kMetricsRequestName);
135 request_metrics.StartRecordingPerformanceMetrics();
136
Andrew Moylan79b34a42020-07-08 11:13:11 +1000137 // Take the ownership of the content of `model_string` because `ModelImpl` has
Honglin Yu0ed72352019-08-27 17:42:01 +1000138 // to hold the memory.
139 auto model_string_impl =
140 std::make_unique<std::string>(std::move(spec->model_string));
141
142 std::unique_ptr<tflite::FlatBufferModel> model =
143 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
144 model_string_impl->length());
145 if (model == nullptr) {
146 LOG(ERROR) << "Failed to load model string of metric name: "
147 << spec->metrics_model_name << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900148 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000149 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
150 return;
151 }
152
Honglin Yuc0cef102020-01-17 15:26:01 +1100153 ModelImpl::Create(
Honglin Yu0ed72352019-08-27 17:42:01 +1000154 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
155 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
156 std::move(model), std::move(model_string_impl), std::move(request),
157 spec->metrics_model_name);
158
Qijiang Fan5d381a02020-04-19 23:42:37 +0900159 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000160
161 request_metrics.FinishRecordingPerformanceMetrics();
162 request_metrics.RecordRequestEvent(LoadModelResult::OK);
163}
164
Honglin Yuf33dce32019-12-05 15:10:39 +1100165void MachineLearningServiceImpl::LoadTextClassifier(
166 chromeos::machine_learning::mojom::TextClassifierRequest request,
167 LoadTextClassifierCallback callback) {
168 RequestMetrics<LoadModelResult> request_metrics("TextClassifier",
169 kMetricsRequestName);
170 request_metrics.StartRecordingPerformanceMetrics();
171
172 // Attempt to load model.
173 std::string model_path = model_dir_ + text_classifier_model_filename_;
174 auto scoped_mmap =
175 std::make_unique<libtextclassifier3::ScopedMmap>(model_path);
176 if (!scoped_mmap->handle().ok()) {
177 LOG(ERROR) << "Failed to load the text classifier model file '"
178 << model_path << "'.";
179 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
180 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
181 return;
182 }
183
184 // Create the TextClassifier.
185 if (!TextClassifierImpl::Create(&scoped_mmap, std::move(request))) {
186 LOG(ERROR) << "Failed to create TextClassifierImpl object.";
187 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
188 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
189 return;
190 }
191
192 // initialize the icu library.
193 InitIcuIfNeeded();
194
195 std::move(callback).Run(LoadModelResult::OK);
196
197 request_metrics.FinishRecordingPerformanceMetrics();
198 request_metrics.RecordRequestEvent(LoadModelResult::OK);
199}
200
charleszhao17777f92020-04-23 12:53:11 +1000201void MachineLearningServiceImpl::LoadHandwritingModel(
202 HandwritingRecognizerRequest request,
203 LoadHandwritingModelCallback callback) {
charleszhao05c5a4a2020-06-09 16:49:54 +1000204 // Use english as default language.
205 LoadHandwritingModelWithSpec(HandwritingRecognizerSpec::New("en"),
206 std::move(request), std::move(callback));
207}
208
209void MachineLearningServiceImpl::LoadHandwritingModelWithSpec(
210 HandwritingRecognizerSpecPtr spec,
211 HandwritingRecognizerRequest request,
212 LoadHandwritingModelCallback callback) {
charleszhao17777f92020-04-23 12:53:11 +1000213 RequestMetrics<LoadModelResult> request_metrics("HandwritingModel",
214 kMetricsRequestName);
215 request_metrics.StartRecordingPerformanceMetrics();
216
217 // Load HandwritingLibrary.
218 auto* const hwr_library = ml::HandwritingLibrary::GetInstance();
219
220 if (hwr_library->GetStatus() ==
221 ml::HandwritingLibrary::Status::kNotSupported) {
222 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
223 << static_cast<int>(hwr_library->GetStatus());
224
225 std::move(callback).Run(LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
226 request_metrics.RecordRequestEvent(
227 LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
228 return;
229 }
230
231 if (hwr_library->GetStatus() != ml::HandwritingLibrary::Status::kOk) {
232 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
233 << static_cast<int>(hwr_library->GetStatus());
234
235 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
236 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
237 return;
238 }
239
charleszhao05c5a4a2020-06-09 16:49:54 +1000240 if (!GetModelPaths(spec.Clone()).has_value()) {
241 LOG(ERROR) << "LoadHandwritingRecognizer is not called because language "
242 "code is not supported.";
243
244 std::move(callback).Run(LoadModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
245 request_metrics.RecordRequestEvent(
246 LoadModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
247 return;
248 }
249
charleszhao17777f92020-04-23 12:53:11 +1000250 // Create HandwritingRecognizer.
charleszhao05c5a4a2020-06-09 16:49:54 +1000251 if (!HandwritingRecognizerImpl::Create(std::move(spec), std::move(request))) {
charleszhao17777f92020-04-23 12:53:11 +1000252 LOG(ERROR) << "LoadHandwritingRecognizer returned false.";
253 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
254 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
255 return;
256 }
257
258 std::move(callback).Run(LoadModelResult::OK);
259 request_metrics.FinishRecordingPerformanceMetrics();
260 request_metrics.RecordRequestEvent(LoadModelResult::OK);
261}
262
Honglin Yuf33dce32019-12-05 15:10:39 +1100263void MachineLearningServiceImpl::InitIcuIfNeeded() {
264 if (icu_data_ == nullptr) {
265 // Need to load the data file again.
266 int64_t file_size;
267 const base::FilePath icu_data_file_path(kIcuDataFilePath);
268 CHECK(base::GetFileSize(icu_data_file_path, &file_size));
269 icu_data_ = new char[file_size];
270 CHECK(base::ReadFile(icu_data_file_path, icu_data_,
271 static_cast<int>(file_size)) == file_size);
272 // Init the Icu library.
273 UErrorCode err = U_ZERO_ERROR;
274 udata_setCommonData(reinterpret_cast<void*>(icu_data_), &err);
275 DCHECK(err == U_ZERO_ERROR);
276 // Never try to load Icu data from files.
277 udata_setFileAccess(UDATA_ONLY_PACKAGES, &err);
278 }
279}
280
Andrew Moylanff6be512018-07-03 11:05:01 +1000281} // namespace ml