blob: 518273682b071151a6aeefcfca36cf83f2befe86 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100013#include <tensorflow/lite/model.h>
Michael Martisa74af932018-08-13 16:52:36 +100014
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
Honglin Yu0ed72352019-08-27 17:42:01 +100022using ::chromeos::machine_learning::mojom::BuiltinModelId;
23using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
24using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100025using ::chromeos::machine_learning::mojom::LoadModelResult;
26using ::chromeos::machine_learning::mojom::ModelId;
27using ::chromeos::machine_learning::mojom::ModelRequest;
28using ::chromeos::machine_learning::mojom::ModelSpecPtr;
29
30constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
alanlxlcb1f8562018-11-01 15:16:11 +110031// Base name for UMA metrics related to LoadModel requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100032constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100033
34// To avoid passing a lambda as a base::Closure.
35void DeleteModelImpl(const ModelImpl* const model_impl) {
36 delete model_impl;
37}
38
39} // namespace
40
Andrew Moylanff6be512018-07-03 11:05:01 +100041MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100042 mojo::ScopedMessagePipeHandle pipe,
43 base::Closure connection_error_handler,
44 const std::string& model_dir)
45 : model_metadata_(GetModelMetadata()),
Honglin Yu0ed72352019-08-27 17:42:01 +100046 builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100047 model_dir_(model_dir),
48 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100049 binding_.set_connection_error_handler(std::move(connection_error_handler));
50}
51
Michael Martisa74af932018-08-13 16:52:36 +100052MachineLearningServiceImpl::MachineLearningServiceImpl(
53 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
54 : MachineLearningServiceImpl(std::move(pipe),
55 std::move(connection_error_handler),
56 kSystemModelDir) {}
57
Honglin Yu0ed72352019-08-27 17:42:01 +100058// TODO(crbug.com/990619): Remove this once clients migrate to
59// |LoadBuiltinModel|.
Michael Martisa74af932018-08-13 16:52:36 +100060void MachineLearningServiceImpl::LoadModel(ModelSpecPtr spec,
61 ModelRequest request,
62 const LoadModelCallback& callback) {
Andrew Moylan195a6f52019-05-16 20:57:32 +100063 // Unsupported models do not have metadata entries.
Michael Martisa74af932018-08-13 16:52:36 +100064 const auto metadata_lookup = model_metadata_.find(spec->id);
65 if (metadata_lookup == model_metadata_.end()) {
Andrew Moylan195a6f52019-05-16 20:57:32 +100066 LOG(WARNING) << "LoadModel requested for unsupported model ID " << spec->id
67 << ".";
68 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu6adafcd2019-07-22 13:48:11 +100069 RecordModelSpecificationErrorEvent();
Michael Martisa74af932018-08-13 16:52:36 +100070 return;
71 }
Honglin Yu6adafcd2019-07-22 13:48:11 +100072
Michael Martisa74af932018-08-13 16:52:36 +100073 const ModelMetadata& metadata = metadata_lookup->second;
74
Honglin Yu6adafcd2019-07-22 13:48:11 +100075 DCHECK(!metadata.metrics_model_name.empty());
76
77 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
78 kMetricsRequestName);
79 request_metrics.StartRecordingPerformanceMetrics();
80
Michael Martisa74af932018-08-13 16:52:36 +100081 // Attempt to load model.
82 const std::string model_path = model_dir_ + metadata.model_file;
83 std::unique_ptr<tflite::FlatBufferModel> model =
84 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
85 if (model == nullptr) {
86 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
87 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110088 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100089 return;
90 }
91
92 // Use a connection error handler to strongly bind |model_impl| to |request|.
Honglin Yu6adafcd2019-07-22 13:48:11 +100093 ModelImpl* const model_impl = new ModelImpl(
94 metadata.required_inputs, metadata.required_outputs, std::move(model),
95 std::move(request), metadata.metrics_model_name);
Michael Martisa74af932018-08-13 16:52:36 +100096 model_impl->set_connection_error_handler(
97 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
98 callback.Run(LoadModelResult::OK);
Honglin Yu6adafcd2019-07-22 13:48:11 +100099
alanlxlcb1f8562018-11-01 15:16:11 +1100100 request_metrics.FinishRecordingPerformanceMetrics();
101 request_metrics.RecordRequestEvent(LoadModelResult::OK);
Andrew Moylanff6be512018-07-03 11:05:01 +1000102}
103
Honglin Yu0ed72352019-08-27 17:42:01 +1000104void MachineLearningServiceImpl::LoadBuiltinModel(
105 BuiltinModelSpecPtr spec,
106 ModelRequest request,
107 const LoadBuiltinModelCallback& callback) {
108 // Unsupported models do not have metadata entries.
109 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
110 if (metadata_lookup == builtin_model_metadata_.end()) {
111 LOG(WARNING) << "LoadModel requested for unsupported model ID " << spec->id
112 << ".";
113 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
114 RecordModelSpecificationErrorEvent();
115 return;
116 }
117
118 const BuiltinModelMetadata& metadata = metadata_lookup->second;
119
120 DCHECK(!metadata.metrics_model_name.empty());
121
122 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
123 kMetricsRequestName);
124 request_metrics.StartRecordingPerformanceMetrics();
125
126 // Attempt to load model.
127 const std::string model_path = model_dir_ + metadata.model_file;
128 std::unique_ptr<tflite::FlatBufferModel> model =
129 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
130 if (model == nullptr) {
131 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
132 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
133 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
134 return;
135 }
136
137 // Use a connection error handler to strongly bind |model_impl| to |request|.
138 ModelImpl* const model_impl = new ModelImpl(
139 metadata.required_inputs, metadata.required_outputs, std::move(model),
140 std::move(request), metadata.metrics_model_name);
141
142 model_impl->set_connection_error_handler(
143 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
144 callback.Run(LoadModelResult::OK);
145
146 request_metrics.FinishRecordingPerformanceMetrics();
147 request_metrics.RecordRequestEvent(LoadModelResult::OK);
148}
149
150void MachineLearningServiceImpl::LoadFlatBufferModel(
151 FlatBufferModelSpecPtr spec,
152 ModelRequest request,
153 const LoadFlatBufferModelCallback& callback) {
154 DCHECK(!spec->metrics_model_name.empty());
155
156 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
157 kMetricsRequestName);
158 request_metrics.StartRecordingPerformanceMetrics();
159
160 // Take the ownership of the content of |model_string| because |ModelImpl| has
161 // to hold the memory.
162 auto model_string_impl =
163 std::make_unique<std::string>(std::move(spec->model_string));
164
165 std::unique_ptr<tflite::FlatBufferModel> model =
166 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
167 model_string_impl->length());
168 if (model == nullptr) {
169 LOG(ERROR) << "Failed to load model string of metric name: "
170 << spec->metrics_model_name << "'.";
171 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
172 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
173 return;
174 }
175
176 // Use a connection error handler to strongly bind |model_impl| to |request|.
177 ModelImpl* model_impl = new ModelImpl(
178 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
179 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
180 std::move(model), std::move(model_string_impl), std::move(request),
181 spec->metrics_model_name);
182
183 model_impl->set_connection_error_handler(
184 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
185 callback.Run(LoadModelResult::OK);
186
187 request_metrics.FinishRecordingPerformanceMetrics();
188 request_metrics.RecordRequestEvent(LoadModelResult::OK);
189}
190
Andrew Moylanff6be512018-07-03 11:05:01 +1000191} // namespace ml