blob: 4243fccd5ac4fc60ba8676b91c35544ac8a7a77c [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
6
Andrew Moylan44c352f2020-11-04 15:19:46 +11007#include <algorithm>
Michael Martisa967f632018-08-10 10:39:00 +10008#include <utility>
9
10#include <base/bind.h>
hscham4ce3c992021-02-19 16:37:23 +090011#include <base/callback_helpers.h>
Qijiang Fan713061e2021-03-08 15:45:12 +090012#include <base/check.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100013#include <tensorflow/lite/context.h>
Alan Green55e16542020-05-11 14:06:46 +100014#include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/interpreter.h>
16#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100017
Andrew Moylan44c352f2020-11-04 15:19:46 +110018#include "ml/machine_learning_service_impl.h"
19#include "ml/request_metrics.h"
20
Honglin Yuc0cef102020-01-17 15:26:01 +110021namespace {
22
Andrew Moylanb481af72020-07-09 15:22:00 +100023// Callback for self-owned ModelImpl's to delete themselves upon disconnection.
Honglin Yuc0cef102020-01-17 15:26:01 +110024void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
25 delete model_impl;
26}
27
28} // namespace
29
Michael Martisa967f632018-08-10 10:39:00 +100030namespace ml {
31
32using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
33using ::chromeos::machine_learning::mojom::GraphExecutor;
Alan Green55e16542020-05-11 14:06:46 +100034using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
35using ::chromeos::machine_learning::mojom::GraphExecutorOptionsPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100036using ::chromeos::machine_learning::mojom::Model;
Michael Martisa967f632018-08-10 10:39:00 +100037
alanlxlcb1f8562018-11-01 15:16:11 +110038// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100039constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110040
Andrew Moylan44c352f2020-11-04 15:19:46 +110041AlignedModelData::AlignedModelData(std::string model_str) {
42 if (reinterpret_cast<std::uintptr_t>(model_str.c_str()) % 4 == 0) {
43 // `model_str` is aligned. Keep it.
44 original_model_str_ = std::make_unique<std::string>(std::move(model_str));
45 aligned_copy_ = nullptr;
46 aligned_copy_size_ = 0;
47 } else {
48 // `model_str` is unaligned. Discard it and make an aligned copy.
49 aligned_copy_.reset(new char[model_str.size()]);
50 std::copy(model_str.begin(), model_str.end(), aligned_copy_.get());
51 aligned_copy_size_ = model_str.size();
52 }
53}
54
55const char* AlignedModelData::data() const {
56 return aligned_copy_ ? aligned_copy_.get() : original_model_str_->c_str();
57}
58
59size_t AlignedModelData::size() const {
60 return aligned_copy_ ? aligned_copy_size_ : original_model_str_->size();
61}
62
63AlignedModelData::~AlignedModelData() = default;
64
Andrew Moylanb481af72020-07-09 15:22:00 +100065ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
66 std::map<std::string, int> required_outputs,
67 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110068 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +100069 mojo::PendingReceiver<Model> receiver,
70 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110071 auto model_impl = new ModelImpl(
72 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylan44c352f2020-11-04 15:19:46 +110073 std::move(model_data), std::move(receiver), metrics_model_name);
Andrew Moylanb481af72020-07-09 15:22:00 +100074 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
75 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110076 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
77
78 return model_impl;
79}
80
Andrew Moylanb481af72020-07-09 15:22:00 +100081ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
82 std::map<std::string, int> required_outputs,
83 std::unique_ptr<tflite::FlatBufferModel> model,
84 mojo::PendingReceiver<Model> receiver,
85 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110086 auto model_impl = new ModelImpl(
87 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylanb481af72020-07-09 15:22:00 +100088 nullptr, std::move(receiver), metrics_model_name);
89 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
90 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110091 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
92
93 return model_impl;
94}
95
Honglin Yu0ed72352019-08-27 17:42:01 +100096ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
97 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100098 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110099 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +1000100 mojo::PendingReceiver<Model> receiver,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000101 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +1000102 : required_inputs_(std::move(required_inputs)),
103 required_outputs_(std::move(required_outputs)),
Andrew Moylan44c352f2020-11-04 15:19:46 +1100104 model_data_(std::move(model_data)),
Michael Martisa967f632018-08-10 10:39:00 +1000105 model_(std::move(model)),
Andrew Moylanb481af72020-07-09 15:22:00 +1000106 receiver_(this, std::move(receiver)),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000107 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +1000108
Andrew Moylanb481af72020-07-09 15:22:00 +1000109void ModelImpl::set_disconnect_handler(base::Closure disconnect_handler) {
110 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Michael Martisa967f632018-08-10 10:39:00 +1000111}
112
113int ModelImpl::num_graph_executors_for_testing() const {
114 return graph_executors_.size();
115}
116
Andrew Moylanb481af72020-07-09 15:22:00 +1000117void ModelImpl::CreateGraphExecutor(
118 mojo::PendingReceiver<GraphExecutor> receiver,
119 CreateGraphExecutorCallback callback) {
Alan Green3117b522020-05-20 09:50:27 +1000120 auto options = GraphExecutorOptions::New(/*use_nnapi=*/false);
Andrew Moylanb481af72020-07-09 15:22:00 +1000121 CreateGraphExecutorWithOptions(std::move(options), std::move(receiver),
Alan Green55e16542020-05-11 14:06:46 +1000122 std::move(callback));
123}
124
125void ModelImpl::CreateGraphExecutorWithOptions(
126 GraphExecutorOptionsPtr options,
Andrew Moylanb481af72020-07-09 15:22:00 +1000127 mojo::PendingReceiver<GraphExecutor> receiver,
Alan Green55e16542020-05-11 14:06:46 +1000128 CreateGraphExecutorCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000129 DCHECK(!metrics_model_name_.empty());
130
charleszhao5a7050e2020-07-14 15:21:41 +1000131 RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100132 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +1000133
Michael Martisa967f632018-08-10 10:39:00 +1000134 if (model_ == nullptr) {
135 LOG(ERROR) << "Null model provided.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900136 std::move(callback).Run(
137 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100138 request_metrics.RecordRequestEvent(
139 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000140 return;
141 }
142
143 // Instantiate interpreter.
144 tflite::ops::builtin::BuiltinOpResolver resolver;
145 std::unique_ptr<tflite::Interpreter> interpreter;
146 const TfLiteStatus resolve_status =
147 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
148 if (resolve_status != kTfLiteOk || !interpreter) {
149 LOG(ERROR) << "Could not resolve model ops.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900150 std::move(callback).Run(
151 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100152 request_metrics.RecordRequestEvent(
153 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000154 return;
155 }
156
Alan Green55e16542020-05-11 14:06:46 +1000157 // If requested, load and apply NNAPI
158 if (options->use_nnapi) {
Alan Green3117b522020-05-20 09:50:27 +1000159 TfLiteDelegate* delegate = tflite::NnApiDelegate();
Alan Green55e16542020-05-11 14:06:46 +1000160 if (!delegate) {
161 LOG(ERROR) << "NNAPI requested but not available.";
162 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
163 request_metrics.RecordRequestEvent(
164 CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
165 return;
166 }
167 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
168 LOG(ERROR) << "Could not use NNAPI delegate.";
169 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_USE_ERROR);
170 request_metrics.RecordRequestEvent(
171 CreateGraphExecutorResult::NNAPI_USE_ERROR);
172 return;
173 }
174 }
175
Michael Martisa967f632018-08-10 10:39:00 +1000176 // Allocate memory for tensors.
177 if (interpreter->AllocateTensors() != kTfLiteOk) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900178 std::move(callback).Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100179 request_metrics.RecordRequestEvent(
180 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000181 return;
182 }
183
184 // Add graph executor and schedule its deletion on pipe closure.
185 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Andrew Moylanb481af72020-07-09 15:22:00 +1000186 std::move(interpreter), std::move(receiver),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000187 metrics_model_name_);
Andrew Moylanb481af72020-07-09 15:22:00 +1000188 graph_executors_.front().set_disconnect_handler(
Michael Martisa967f632018-08-10 10:39:00 +1000189 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
190 graph_executors_.begin()));
191
Qijiang Fan5d381a02020-04-19 23:42:37 +0900192 std::move(callback).Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100193 request_metrics.FinishRecordingPerformanceMetrics();
194 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000195}
196
197void ModelImpl::EraseGraphExecutor(
198 const std::list<GraphExecutorImpl>::const_iterator it) {
199 graph_executors_.erase(it);
200}
201
202} // namespace ml