blob: 6c84ae94982200e595d2e11cafa2b9457fdff428 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013#include <base/files/file.h>
14#include <base/files/file_util.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110016#include <unicode/putil.h>
17#include <unicode/udata.h>
18#include <utils/memory/mmap.h>
Michael Martisa74af932018-08-13 16:52:36 +100019
charleszhao17777f92020-04-23 12:53:11 +100020#include "ml/handwriting.h"
21#include "ml/handwriting_recognizer_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100022#include "ml/model_impl.h"
charleszhao17777f92020-04-23 12:53:11 +100023#include "ml/mojom/handwriting_recognizer.mojom.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090024#include "ml/mojom/model.mojom.h"
Honglin Yuf33dce32019-12-05 15:10:39 +110025#include "ml/text_classifier_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100026
Andrew Moylanff6be512018-07-03 11:05:01 +100027namespace ml {
28
Michael Martisa74af932018-08-13 16:52:36 +100029namespace {
30
Honglin Yu0ed72352019-08-27 17:42:01 +100031using ::chromeos::machine_learning::mojom::BuiltinModelId;
32using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
33using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
charleszhao17777f92020-04-23 12:53:11 +100034using ::chromeos::machine_learning::mojom::HandwritingRecognizerRequest;
Michael Martisa74af932018-08-13 16:52:36 +100035using ::chromeos::machine_learning::mojom::LoadModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100036using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100037
38constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
charleszhao17777f92020-04-23 12:53:11 +100039// Base name for UMA metrics related to model loading (|LoadBuiltinModel|,
40// |LoadFlatBufferModel|, |LoadTextClassifier| or LoadHandwritingModel).
Honglin Yu6adafcd2019-07-22 13:48:11 +100041constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100042
Honglin Yuf33dce32019-12-05 15:10:39 +110043constexpr char kTextClassifierModelFile[] =
44 "mlservice-model-text_classifier_en-v706.fb";
45
46constexpr char kIcuDataFilePath[] = "/opt/google/chrome/icudtl.dat";
47
Michael Martisa74af932018-08-13 16:52:36 +100048} // namespace
49
Andrew Moylanff6be512018-07-03 11:05:01 +100050MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100051 mojo::ScopedMessagePipeHandle pipe,
52 base::Closure connection_error_handler,
53 const std::string& model_dir)
Honglin Yuf33dce32019-12-05 15:10:39 +110054 : icu_data_(nullptr),
55 text_classifier_model_filename_(kTextClassifierModelFile),
56 builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100057 model_dir_(model_dir),
hscham68867652020-01-06 11:40:47 +090058 binding_(this,
59 mojo::InterfaceRequest<
60 chromeos::machine_learning::mojom::MachineLearningService>(
Honglin Yuf33dce32019-12-05 15:10:39 +110061 std::move(pipe))) {
Andrew Moylanff6be512018-07-03 11:05:01 +100062 binding_.set_connection_error_handler(std::move(connection_error_handler));
63}
64
Michael Martisa74af932018-08-13 16:52:36 +100065MachineLearningServiceImpl::MachineLearningServiceImpl(
66 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
67 : MachineLearningServiceImpl(std::move(pipe),
68 std::move(connection_error_handler),
69 kSystemModelDir) {}
70
Honglin Yuf33dce32019-12-05 15:10:39 +110071void MachineLearningServiceImpl::SetTextClassifierModelFilenameForTesting(
72 const std::string& filename) {
73 text_classifier_model_filename_ = filename;
74}
75
Honglin Yu0ed72352019-08-27 17:42:01 +100076void MachineLearningServiceImpl::LoadBuiltinModel(
77 BuiltinModelSpecPtr spec,
78 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +090079 LoadBuiltinModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +100080 // Unsupported models do not have metadata entries.
81 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
82 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100083 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
84 << spec->id << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +090085 std::move(callback).Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +100086 RecordModelSpecificationErrorEvent();
87 return;
88 }
89
90 const BuiltinModelMetadata& metadata = metadata_lookup->second;
91
92 DCHECK(!metadata.metrics_model_name.empty());
93
94 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
95 kMetricsRequestName);
96 request_metrics.StartRecordingPerformanceMetrics();
97
98 // Attempt to load model.
99 const std::string model_path = model_dir_ + metadata.model_file;
100 std::unique_ptr<tflite::FlatBufferModel> model =
101 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
102 if (model == nullptr) {
103 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900104 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000105 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
106 return;
107 }
108
Honglin Yuc0cef102020-01-17 15:26:01 +1100109 ModelImpl::Create(metadata.required_inputs, metadata.required_outputs,
110 std::move(model), std::move(request),
111 metadata.metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +1000112
Qijiang Fan5d381a02020-04-19 23:42:37 +0900113 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000114
115 request_metrics.FinishRecordingPerformanceMetrics();
116 request_metrics.RecordRequestEvent(LoadModelResult::OK);
117}
118
119void MachineLearningServiceImpl::LoadFlatBufferModel(
120 FlatBufferModelSpecPtr spec,
121 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +0900122 LoadFlatBufferModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +1000123 DCHECK(!spec->metrics_model_name.empty());
124
125 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
126 kMetricsRequestName);
127 request_metrics.StartRecordingPerformanceMetrics();
128
129 // Take the ownership of the content of |model_string| because |ModelImpl| has
130 // to hold the memory.
131 auto model_string_impl =
132 std::make_unique<std::string>(std::move(spec->model_string));
133
134 std::unique_ptr<tflite::FlatBufferModel> model =
135 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
136 model_string_impl->length());
137 if (model == nullptr) {
138 LOG(ERROR) << "Failed to load model string of metric name: "
139 << spec->metrics_model_name << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900140 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000141 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
142 return;
143 }
144
Honglin Yuc0cef102020-01-17 15:26:01 +1100145 ModelImpl::Create(
Honglin Yu0ed72352019-08-27 17:42:01 +1000146 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
147 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
148 std::move(model), std::move(model_string_impl), std::move(request),
149 spec->metrics_model_name);
150
Qijiang Fan5d381a02020-04-19 23:42:37 +0900151 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000152
153 request_metrics.FinishRecordingPerformanceMetrics();
154 request_metrics.RecordRequestEvent(LoadModelResult::OK);
155}
156
Honglin Yuf33dce32019-12-05 15:10:39 +1100157void MachineLearningServiceImpl::LoadTextClassifier(
158 chromeos::machine_learning::mojom::TextClassifierRequest request,
159 LoadTextClassifierCallback callback) {
160 RequestMetrics<LoadModelResult> request_metrics("TextClassifier",
161 kMetricsRequestName);
162 request_metrics.StartRecordingPerformanceMetrics();
163
164 // Attempt to load model.
165 std::string model_path = model_dir_ + text_classifier_model_filename_;
166 auto scoped_mmap =
167 std::make_unique<libtextclassifier3::ScopedMmap>(model_path);
168 if (!scoped_mmap->handle().ok()) {
169 LOG(ERROR) << "Failed to load the text classifier model file '"
170 << model_path << "'.";
171 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
172 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
173 return;
174 }
175
176 // Create the TextClassifier.
177 if (!TextClassifierImpl::Create(&scoped_mmap, std::move(request))) {
178 LOG(ERROR) << "Failed to create TextClassifierImpl object.";
179 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
180 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
181 return;
182 }
183
184 // initialize the icu library.
185 InitIcuIfNeeded();
186
187 std::move(callback).Run(LoadModelResult::OK);
188
189 request_metrics.FinishRecordingPerformanceMetrics();
190 request_metrics.RecordRequestEvent(LoadModelResult::OK);
191}
192
charleszhao17777f92020-04-23 12:53:11 +1000193void MachineLearningServiceImpl::LoadHandwritingModel(
194 HandwritingRecognizerRequest request,
195 LoadHandwritingModelCallback callback) {
196 RequestMetrics<LoadModelResult> request_metrics("HandwritingModel",
197 kMetricsRequestName);
198 request_metrics.StartRecordingPerformanceMetrics();
199
200 // Load HandwritingLibrary.
201 auto* const hwr_library = ml::HandwritingLibrary::GetInstance();
202
203 if (hwr_library->GetStatus() ==
204 ml::HandwritingLibrary::Status::kNotSupported) {
205 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
206 << static_cast<int>(hwr_library->GetStatus());
207
208 std::move(callback).Run(LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
209 request_metrics.RecordRequestEvent(
210 LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
211 return;
212 }
213
214 if (hwr_library->GetStatus() != ml::HandwritingLibrary::Status::kOk) {
215 LOG(ERROR) << "Initialize ml::HandwritingLibrary with error "
216 << static_cast<int>(hwr_library->GetStatus());
217
218 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
219 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
220 return;
221 }
222
223 // Create HandwritingRecognizer.
224 if (!HandwritingRecognizerImpl::Create(std::move(request))) {
225 LOG(ERROR) << "LoadHandwritingRecognizer returned false.";
226 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
227 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
228 return;
229 }
230
231 std::move(callback).Run(LoadModelResult::OK);
232 request_metrics.FinishRecordingPerformanceMetrics();
233 request_metrics.RecordRequestEvent(LoadModelResult::OK);
234}
235
Honglin Yuf33dce32019-12-05 15:10:39 +1100236void MachineLearningServiceImpl::InitIcuIfNeeded() {
237 if (icu_data_ == nullptr) {
238 // Need to load the data file again.
239 int64_t file_size;
240 const base::FilePath icu_data_file_path(kIcuDataFilePath);
241 CHECK(base::GetFileSize(icu_data_file_path, &file_size));
242 icu_data_ = new char[file_size];
243 CHECK(base::ReadFile(icu_data_file_path, icu_data_,
244 static_cast<int>(file_size)) == file_size);
245 // Init the Icu library.
246 UErrorCode err = U_ZERO_ERROR;
247 udata_setCommonData(reinterpret_cast<void*>(icu_data_), &err);
248 DCHECK(err == U_ZERO_ERROR);
249 // Never try to load Icu data from files.
250 udata_setFileAccess(UDATA_ONLY_PACKAGES, &err);
251 }
252}
253
Andrew Moylanff6be512018-07-03 11:05:01 +1000254} // namespace ml