blob: efcf64c9208605aa49cb88577274aa1e8b358259 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013#include <base/files/file.h>
14#include <base/files/file_util.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110016#include <unicode/putil.h>
17#include <unicode/udata.h>
18#include <utils/memory/mmap.h>
Michael Martisa74af932018-08-13 16:52:36 +100019
20#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090021#include "ml/mojom/model.mojom.h"
Honglin Yuf33dce32019-12-05 15:10:39 +110022#include "ml/text_classifier_impl.h"
Michael Martisa74af932018-08-13 16:52:36 +100023
Andrew Moylanff6be512018-07-03 11:05:01 +100024namespace ml {
25
Michael Martisa74af932018-08-13 16:52:36 +100026namespace {
27
Honglin Yu0ed72352019-08-27 17:42:01 +100028using ::chromeos::machine_learning::mojom::BuiltinModelId;
29using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
30using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100031using ::chromeos::machine_learning::mojom::LoadModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100032using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100033
34constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Honglin Yua81145a2019-09-23 15:20:13 +100035// Base name for UMA metrics related to model loading (either |LoadBuiltinModel|
36// or |LoadFlatBufferModel|) requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100037constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100038
Honglin Yuf33dce32019-12-05 15:10:39 +110039constexpr char kTextClassifierModelFile[] =
40 "mlservice-model-text_classifier_en-v706.fb";
41
42constexpr char kIcuDataFilePath[] = "/opt/google/chrome/icudtl.dat";
43
Michael Martisa74af932018-08-13 16:52:36 +100044} // namespace
45
Andrew Moylanff6be512018-07-03 11:05:01 +100046MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100047 mojo::ScopedMessagePipeHandle pipe,
48 base::Closure connection_error_handler,
49 const std::string& model_dir)
Honglin Yuf33dce32019-12-05 15:10:39 +110050 : icu_data_(nullptr),
51 text_classifier_model_filename_(kTextClassifierModelFile),
52 builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100053 model_dir_(model_dir),
hscham68867652020-01-06 11:40:47 +090054 binding_(this,
55 mojo::InterfaceRequest<
56 chromeos::machine_learning::mojom::MachineLearningService>(
Honglin Yuf33dce32019-12-05 15:10:39 +110057 std::move(pipe))) {
Andrew Moylanff6be512018-07-03 11:05:01 +100058 binding_.set_connection_error_handler(std::move(connection_error_handler));
59}
60
Michael Martisa74af932018-08-13 16:52:36 +100061MachineLearningServiceImpl::MachineLearningServiceImpl(
62 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
63 : MachineLearningServiceImpl(std::move(pipe),
64 std::move(connection_error_handler),
65 kSystemModelDir) {}
66
Honglin Yuf33dce32019-12-05 15:10:39 +110067void MachineLearningServiceImpl::SetTextClassifierModelFilenameForTesting(
68 const std::string& filename) {
69 text_classifier_model_filename_ = filename;
70}
71
Honglin Yu0ed72352019-08-27 17:42:01 +100072void MachineLearningServiceImpl::LoadBuiltinModel(
73 BuiltinModelSpecPtr spec,
74 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +090075 LoadBuiltinModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +100076 // Unsupported models do not have metadata entries.
77 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
78 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100079 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
80 << spec->id << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +090081 std::move(callback).Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +100082 RecordModelSpecificationErrorEvent();
83 return;
84 }
85
86 const BuiltinModelMetadata& metadata = metadata_lookup->second;
87
88 DCHECK(!metadata.metrics_model_name.empty());
89
90 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
91 kMetricsRequestName);
92 request_metrics.StartRecordingPerformanceMetrics();
93
94 // Attempt to load model.
95 const std::string model_path = model_dir_ + metadata.model_file;
96 std::unique_ptr<tflite::FlatBufferModel> model =
97 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
98 if (model == nullptr) {
99 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900100 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000101 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
102 return;
103 }
104
Honglin Yuc0cef102020-01-17 15:26:01 +1100105 ModelImpl::Create(metadata.required_inputs, metadata.required_outputs,
106 std::move(model), std::move(request),
107 metadata.metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +1000108
Qijiang Fan5d381a02020-04-19 23:42:37 +0900109 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000110
111 request_metrics.FinishRecordingPerformanceMetrics();
112 request_metrics.RecordRequestEvent(LoadModelResult::OK);
113}
114
115void MachineLearningServiceImpl::LoadFlatBufferModel(
116 FlatBufferModelSpecPtr spec,
117 ModelRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +0900118 LoadFlatBufferModelCallback callback) {
Honglin Yu0ed72352019-08-27 17:42:01 +1000119 DCHECK(!spec->metrics_model_name.empty());
120
121 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
122 kMetricsRequestName);
123 request_metrics.StartRecordingPerformanceMetrics();
124
125 // Take the ownership of the content of |model_string| because |ModelImpl| has
126 // to hold the memory.
127 auto model_string_impl =
128 std::make_unique<std::string>(std::move(spec->model_string));
129
130 std::unique_ptr<tflite::FlatBufferModel> model =
131 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
132 model_string_impl->length());
133 if (model == nullptr) {
134 LOG(ERROR) << "Failed to load model string of metric name: "
135 << spec->metrics_model_name << "'.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900136 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
Honglin Yu0ed72352019-08-27 17:42:01 +1000137 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
138 return;
139 }
140
Honglin Yuc0cef102020-01-17 15:26:01 +1100141 ModelImpl::Create(
Honglin Yu0ed72352019-08-27 17:42:01 +1000142 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
143 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
144 std::move(model), std::move(model_string_impl), std::move(request),
145 spec->metrics_model_name);
146
Qijiang Fan5d381a02020-04-19 23:42:37 +0900147 std::move(callback).Run(LoadModelResult::OK);
Honglin Yu0ed72352019-08-27 17:42:01 +1000148
149 request_metrics.FinishRecordingPerformanceMetrics();
150 request_metrics.RecordRequestEvent(LoadModelResult::OK);
151}
152
Honglin Yuf33dce32019-12-05 15:10:39 +1100153void MachineLearningServiceImpl::LoadTextClassifier(
154 chromeos::machine_learning::mojom::TextClassifierRequest request,
155 LoadTextClassifierCallback callback) {
156 RequestMetrics<LoadModelResult> request_metrics("TextClassifier",
157 kMetricsRequestName);
158 request_metrics.StartRecordingPerformanceMetrics();
159
160 // Attempt to load model.
161 std::string model_path = model_dir_ + text_classifier_model_filename_;
162 auto scoped_mmap =
163 std::make_unique<libtextclassifier3::ScopedMmap>(model_path);
164 if (!scoped_mmap->handle().ok()) {
165 LOG(ERROR) << "Failed to load the text classifier model file '"
166 << model_path << "'.";
167 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
168 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
169 return;
170 }
171
172 // Create the TextClassifier.
173 if (!TextClassifierImpl::Create(&scoped_mmap, std::move(request))) {
174 LOG(ERROR) << "Failed to create TextClassifierImpl object.";
175 std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
176 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
177 return;
178 }
179
180 // initialize the icu library.
181 InitIcuIfNeeded();
182
183 std::move(callback).Run(LoadModelResult::OK);
184
185 request_metrics.FinishRecordingPerformanceMetrics();
186 request_metrics.RecordRequestEvent(LoadModelResult::OK);
187}
188
189void MachineLearningServiceImpl::InitIcuIfNeeded() {
190 if (icu_data_ == nullptr) {
191 // Need to load the data file again.
192 int64_t file_size;
193 const base::FilePath icu_data_file_path(kIcuDataFilePath);
194 CHECK(base::GetFileSize(icu_data_file_path, &file_size));
195 icu_data_ = new char[file_size];
196 CHECK(base::ReadFile(icu_data_file_path, icu_data_,
197 static_cast<int>(file_size)) == file_size);
198 // Init the Icu library.
199 UErrorCode err = U_ZERO_ERROR;
200 udata_setCommonData(reinterpret_cast<void*>(icu_data_), &err);
201 DCHECK(err == U_ZERO_ERROR);
202 // Never try to load Icu data from files.
203 udata_setFileAccess(UDATA_ONLY_PACKAGES, &err);
204 }
205}
206
207
Andrew Moylanff6be512018-07-03 11:05:01 +1000208} // namespace ml