blob: 386439f5d43064b689157c65e54174ace0724772 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
Alan Green55e16542020-05-11 14:06:46 +100013#include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100014#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100016
Honglin Yuc0cef102020-01-17 15:26:01 +110017namespace {
18
19// Callback for self-owned ModelImpl's to delete themselves upon connection
20// error.
21void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
22 delete model_impl;
23}
24
25} // namespace
26
Michael Martisa967f632018-08-10 10:39:00 +100027namespace ml {
28
29using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
30using ::chromeos::machine_learning::mojom::GraphExecutor;
Alan Green55e16542020-05-11 14:06:46 +100031using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
32using ::chromeos::machine_learning::mojom::GraphExecutorOptionsPtr;
Michael Martisa967f632018-08-10 10:39:00 +100033using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
34using ::chromeos::machine_learning::mojom::ModelRequest;
35
alanlxlcb1f8562018-11-01 15:16:11 +110036// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100037constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110038
Honglin Yuc0cef102020-01-17 15:26:01 +110039ModelImpl* ModelImpl::Create(
40 std::map<std::string, int> required_inputs,
41 std::map<std::string, int> required_outputs,
42 std::unique_ptr<tflite::FlatBufferModel> model,
43 std::unique_ptr<std::string> model_string,
44 chromeos::machine_learning::mojom::ModelRequest request,
45 const std::string& metrics_model_name) {
46 auto model_impl = new ModelImpl(
47 std::move(required_inputs), std::move(required_outputs), std::move(model),
48 std::move(model_string), std::move(request), metrics_model_name);
Andrew Moylan79b34a42020-07-08 11:13:11 +100049 // Use a connection error handler to strongly bind `model_impl` to `request`.
Honglin Yuc0cef102020-01-17 15:26:01 +110050 model_impl->set_connection_error_handler(
51 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
52
53 return model_impl;
54}
55
56ModelImpl* ModelImpl::Create(
57 std::map<std::string, int> required_inputs,
58 std::map<std::string, int> required_outputs,
59 std::unique_ptr<tflite::FlatBufferModel> model,
60 chromeos::machine_learning::mojom::ModelRequest request,
61 const std::string& metrics_model_name) {
62 auto model_impl = new ModelImpl(
63 std::move(required_inputs), std::move(required_outputs), std::move(model),
64 nullptr, std::move(request), metrics_model_name);
Andrew Moylan79b34a42020-07-08 11:13:11 +100065 // Use a connection error handler to strongly bind `model_impl` to `request`.
Honglin Yuc0cef102020-01-17 15:26:01 +110066 model_impl->set_connection_error_handler(
67 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
68
69 return model_impl;
70}
71
Honglin Yu0ed72352019-08-27 17:42:01 +100072ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
73 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100074 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu0ed72352019-08-27 17:42:01 +100075 std::unique_ptr<std::string> model_string,
Honglin Yu6adafcd2019-07-22 13:48:11 +100076 ModelRequest request,
77 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +100078 : required_inputs_(std::move(required_inputs)),
79 required_outputs_(std::move(required_outputs)),
80 model_string_(std::move(model_string)),
Michael Martisa967f632018-08-10 10:39:00 +100081 model_(std::move(model)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100082 binding_(this, std::move(request)),
83 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100084
85void ModelImpl::set_connection_error_handler(
86 base::Closure connection_error_handler) {
87 binding_.set_connection_error_handler(std::move(connection_error_handler));
88}
89
90int ModelImpl::num_graph_executors_for_testing() const {
91 return graph_executors_.size();
92}
93
Qijiang Fan5d381a02020-04-19 23:42:37 +090094void ModelImpl::CreateGraphExecutor(GraphExecutorRequest request,
95 CreateGraphExecutorCallback callback) {
Alan Green3117b522020-05-20 09:50:27 +100096 auto options = GraphExecutorOptions::New(/*use_nnapi=*/false);
Alan Green55e16542020-05-11 14:06:46 +100097 CreateGraphExecutorWithOptions(std::move(options), std::move(request),
98 std::move(callback));
99}
100
101void ModelImpl::CreateGraphExecutorWithOptions(
102 GraphExecutorOptionsPtr options,
103 GraphExecutorRequest request,
104 CreateGraphExecutorCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000105 DCHECK(!metrics_model_name_.empty());
106
107 RequestMetrics<CreateGraphExecutorResult> request_metrics(
108 metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100109 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +1000110
Michael Martisa967f632018-08-10 10:39:00 +1000111 if (model_ == nullptr) {
112 LOG(ERROR) << "Null model provided.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900113 std::move(callback).Run(
114 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100115 request_metrics.RecordRequestEvent(
116 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000117 return;
118 }
119
120 // Instantiate interpreter.
121 tflite::ops::builtin::BuiltinOpResolver resolver;
122 std::unique_ptr<tflite::Interpreter> interpreter;
123 const TfLiteStatus resolve_status =
124 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
125 if (resolve_status != kTfLiteOk || !interpreter) {
126 LOG(ERROR) << "Could not resolve model ops.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900127 std::move(callback).Run(
128 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100129 request_metrics.RecordRequestEvent(
130 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000131 return;
132 }
133
Alan Green55e16542020-05-11 14:06:46 +1000134 // If requested, load and apply NNAPI
135 if (options->use_nnapi) {
Alan Green3117b522020-05-20 09:50:27 +1000136 TfLiteDelegate* delegate = tflite::NnApiDelegate();
Alan Green55e16542020-05-11 14:06:46 +1000137 if (!delegate) {
138 LOG(ERROR) << "NNAPI requested but not available.";
139 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
140 request_metrics.RecordRequestEvent(
141 CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
142 return;
143 }
144 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
145 LOG(ERROR) << "Could not use NNAPI delegate.";
146 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_USE_ERROR);
147 request_metrics.RecordRequestEvent(
148 CreateGraphExecutorResult::NNAPI_USE_ERROR);
149 return;
150 }
151 }
152
Michael Martisa967f632018-08-10 10:39:00 +1000153 // Allocate memory for tensors.
154 if (interpreter->AllocateTensors() != kTfLiteOk) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900155 std::move(callback).Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100156 request_metrics.RecordRequestEvent(
157 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000158 return;
159 }
160
161 // Add graph executor and schedule its deletion on pipe closure.
162 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000163 std::move(interpreter), std::move(request),
164 metrics_model_name_);
Michael Martisa967f632018-08-10 10:39:00 +1000165 graph_executors_.front().set_connection_error_handler(
166 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
167 graph_executors_.begin()));
168
Qijiang Fan5d381a02020-04-19 23:42:37 +0900169 std::move(callback).Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100170 request_metrics.FinishRecordingPerformanceMetrics();
171 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000172}
173
174void ModelImpl::EraseGraphExecutor(
175 const std::list<GraphExecutorImpl>::const_iterator it) {
176 graph_executors_.erase(it);
177}
178
179} // namespace ml