blob: fbe135ca10edc915ece267a08483f1abd73dbf91 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
6
Andrew Moylan44c352f2020-11-04 15:19:46 +11007#include <algorithm>
Michael Martisa967f632018-08-10 10:39:00 +10008#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
Alan Green55e16542020-05-11 14:06:46 +100013#include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100014#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100016
Andrew Moylan44c352f2020-11-04 15:19:46 +110017#include "ml/machine_learning_service_impl.h"
18#include "ml/request_metrics.h"
19
Honglin Yuc0cef102020-01-17 15:26:01 +110020namespace {
21
Andrew Moylanb481af72020-07-09 15:22:00 +100022// Callback for self-owned ModelImpl's to delete themselves upon disconnection.
Honglin Yuc0cef102020-01-17 15:26:01 +110023void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
24 delete model_impl;
25}
26
27} // namespace
28
Michael Martisa967f632018-08-10 10:39:00 +100029namespace ml {
30
31using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
32using ::chromeos::machine_learning::mojom::GraphExecutor;
Alan Green55e16542020-05-11 14:06:46 +100033using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
34using ::chromeos::machine_learning::mojom::GraphExecutorOptionsPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100035using ::chromeos::machine_learning::mojom::Model;
Michael Martisa967f632018-08-10 10:39:00 +100036
alanlxlcb1f8562018-11-01 15:16:11 +110037// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100038constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110039
Andrew Moylan44c352f2020-11-04 15:19:46 +110040AlignedModelData::AlignedModelData(std::string model_str) {
41 if (reinterpret_cast<std::uintptr_t>(model_str.c_str()) % 4 == 0) {
42 // `model_str` is aligned. Keep it.
43 original_model_str_ = std::make_unique<std::string>(std::move(model_str));
44 aligned_copy_ = nullptr;
45 aligned_copy_size_ = 0;
46 } else {
47 // `model_str` is unaligned. Discard it and make an aligned copy.
48 aligned_copy_.reset(new char[model_str.size()]);
49 std::copy(model_str.begin(), model_str.end(), aligned_copy_.get());
50 aligned_copy_size_ = model_str.size();
51 }
52}
53
54const char* AlignedModelData::data() const {
55 return aligned_copy_ ? aligned_copy_.get() : original_model_str_->c_str();
56}
57
58size_t AlignedModelData::size() const {
59 return aligned_copy_ ? aligned_copy_size_ : original_model_str_->size();
60}
61
62AlignedModelData::~AlignedModelData() = default;
63
Andrew Moylanb481af72020-07-09 15:22:00 +100064ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
65 std::map<std::string, int> required_outputs,
66 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110067 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +100068 mojo::PendingReceiver<Model> receiver,
69 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110070 auto model_impl = new ModelImpl(
71 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylan44c352f2020-11-04 15:19:46 +110072 std::move(model_data), std::move(receiver), metrics_model_name);
Andrew Moylanb481af72020-07-09 15:22:00 +100073 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
74 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110075 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
76
77 return model_impl;
78}
79
Andrew Moylanb481af72020-07-09 15:22:00 +100080ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
81 std::map<std::string, int> required_outputs,
82 std::unique_ptr<tflite::FlatBufferModel> model,
83 mojo::PendingReceiver<Model> receiver,
84 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110085 auto model_impl = new ModelImpl(
86 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylanb481af72020-07-09 15:22:00 +100087 nullptr, std::move(receiver), metrics_model_name);
88 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
89 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110090 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
91
92 return model_impl;
93}
94
Honglin Yu0ed72352019-08-27 17:42:01 +100095ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
96 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100097 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110098 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +100099 mojo::PendingReceiver<Model> receiver,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000100 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +1000101 : required_inputs_(std::move(required_inputs)),
102 required_outputs_(std::move(required_outputs)),
Andrew Moylan44c352f2020-11-04 15:19:46 +1100103 model_data_(std::move(model_data)),
Michael Martisa967f632018-08-10 10:39:00 +1000104 model_(std::move(model)),
Andrew Moylanb481af72020-07-09 15:22:00 +1000105 receiver_(this, std::move(receiver)),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000106 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +1000107
Andrew Moylanb481af72020-07-09 15:22:00 +1000108void ModelImpl::set_disconnect_handler(base::Closure disconnect_handler) {
109 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Michael Martisa967f632018-08-10 10:39:00 +1000110}
111
112int ModelImpl::num_graph_executors_for_testing() const {
113 return graph_executors_.size();
114}
115
Andrew Moylanb481af72020-07-09 15:22:00 +1000116void ModelImpl::CreateGraphExecutor(
117 mojo::PendingReceiver<GraphExecutor> receiver,
118 CreateGraphExecutorCallback callback) {
Alan Green3117b522020-05-20 09:50:27 +1000119 auto options = GraphExecutorOptions::New(/*use_nnapi=*/false);
Andrew Moylanb481af72020-07-09 15:22:00 +1000120 CreateGraphExecutorWithOptions(std::move(options), std::move(receiver),
Alan Green55e16542020-05-11 14:06:46 +1000121 std::move(callback));
122}
123
124void ModelImpl::CreateGraphExecutorWithOptions(
125 GraphExecutorOptionsPtr options,
Andrew Moylanb481af72020-07-09 15:22:00 +1000126 mojo::PendingReceiver<GraphExecutor> receiver,
Alan Green55e16542020-05-11 14:06:46 +1000127 CreateGraphExecutorCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000128 DCHECK(!metrics_model_name_.empty());
129
charleszhao5a7050e2020-07-14 15:21:41 +1000130 RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100131 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +1000132
Michael Martisa967f632018-08-10 10:39:00 +1000133 if (model_ == nullptr) {
134 LOG(ERROR) << "Null model provided.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900135 std::move(callback).Run(
136 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100137 request_metrics.RecordRequestEvent(
138 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000139 return;
140 }
141
142 // Instantiate interpreter.
143 tflite::ops::builtin::BuiltinOpResolver resolver;
144 std::unique_ptr<tflite::Interpreter> interpreter;
145 const TfLiteStatus resolve_status =
146 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
147 if (resolve_status != kTfLiteOk || !interpreter) {
148 LOG(ERROR) << "Could not resolve model ops.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900149 std::move(callback).Run(
150 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100151 request_metrics.RecordRequestEvent(
152 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000153 return;
154 }
155
Alan Green55e16542020-05-11 14:06:46 +1000156 // If requested, load and apply NNAPI
157 if (options->use_nnapi) {
Alan Green3117b522020-05-20 09:50:27 +1000158 TfLiteDelegate* delegate = tflite::NnApiDelegate();
Alan Green55e16542020-05-11 14:06:46 +1000159 if (!delegate) {
160 LOG(ERROR) << "NNAPI requested but not available.";
161 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
162 request_metrics.RecordRequestEvent(
163 CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
164 return;
165 }
166 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
167 LOG(ERROR) << "Could not use NNAPI delegate.";
168 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_USE_ERROR);
169 request_metrics.RecordRequestEvent(
170 CreateGraphExecutorResult::NNAPI_USE_ERROR);
171 return;
172 }
173 }
174
Michael Martisa967f632018-08-10 10:39:00 +1000175 // Allocate memory for tensors.
176 if (interpreter->AllocateTensors() != kTfLiteOk) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900177 std::move(callback).Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100178 request_metrics.RecordRequestEvent(
179 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000180 return;
181 }
182
183 // Add graph executor and schedule its deletion on pipe closure.
184 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Andrew Moylanb481af72020-07-09 15:22:00 +1000185 std::move(interpreter), std::move(receiver),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000186 metrics_model_name_);
Andrew Moylanb481af72020-07-09 15:22:00 +1000187 graph_executors_.front().set_disconnect_handler(
Michael Martisa967f632018-08-10 10:39:00 +1000188 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
189 graph_executors_.begin()));
190
Qijiang Fan5d381a02020-04-19 23:42:37 +0900191 std::move(callback).Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100192 request_metrics.FinishRecordingPerformanceMetrics();
193 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000194}
195
196void ModelImpl::EraseGraphExecutor(
197 const std::list<GraphExecutorImpl>::const_iterator it) {
198 graph_executors_.erase(it);
199}
200
201} // namespace ml