blob: 0c707837792fe1984d724f03bb47f6da6ca938a6 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
Alan Green55e16542020-05-11 14:06:46 +100013#include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100014#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100016
Honglin Yuc0cef102020-01-17 15:26:01 +110017namespace {
18
Andrew Moylanb481af72020-07-09 15:22:00 +100019// Callback for self-owned ModelImpl's to delete themselves upon disconnection.
Honglin Yuc0cef102020-01-17 15:26:01 +110020void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
21 delete model_impl;
22}
23
24} // namespace
25
Michael Martisa967f632018-08-10 10:39:00 +100026namespace ml {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::GraphExecutor;
Alan Green55e16542020-05-11 14:06:46 +100030using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
31using ::chromeos::machine_learning::mojom::GraphExecutorOptionsPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100032using ::chromeos::machine_learning::mojom::GraphExecutor;
33using ::chromeos::machine_learning::mojom::Model;
Michael Martisa967f632018-08-10 10:39:00 +100034
alanlxlcb1f8562018-11-01 15:16:11 +110035// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100036constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110037
Andrew Moylanb481af72020-07-09 15:22:00 +100038ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
39 std::map<std::string, int> required_outputs,
40 std::unique_ptr<tflite::FlatBufferModel> model,
41 std::unique_ptr<std::string> model_string,
42 mojo::PendingReceiver<Model> receiver,
43 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110044 auto model_impl = new ModelImpl(
45 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylanb481af72020-07-09 15:22:00 +100046 std::move(model_string), std::move(receiver), metrics_model_name);
47 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
48 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110049 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
50
51 return model_impl;
52}
53
Andrew Moylanb481af72020-07-09 15:22:00 +100054ModelImpl* ModelImpl::Create(std::map<std::string, int> required_inputs,
55 std::map<std::string, int> required_outputs,
56 std::unique_ptr<tflite::FlatBufferModel> model,
57 mojo::PendingReceiver<Model> receiver,
58 const std::string& metrics_model_name) {
Honglin Yuc0cef102020-01-17 15:26:01 +110059 auto model_impl = new ModelImpl(
60 std::move(required_inputs), std::move(required_outputs), std::move(model),
Andrew Moylanb481af72020-07-09 15:22:00 +100061 nullptr, std::move(receiver), metrics_model_name);
62 // Use a disconnection handler to strongly bind `model_impl` to `receiver`.
63 model_impl->set_disconnect_handler(
Honglin Yuc0cef102020-01-17 15:26:01 +110064 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
65
66 return model_impl;
67}
68
Honglin Yu0ed72352019-08-27 17:42:01 +100069ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
70 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100071 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu0ed72352019-08-27 17:42:01 +100072 std::unique_ptr<std::string> model_string,
Andrew Moylanb481af72020-07-09 15:22:00 +100073 mojo::PendingReceiver<Model> receiver,
Honglin Yu6adafcd2019-07-22 13:48:11 +100074 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +100075 : required_inputs_(std::move(required_inputs)),
76 required_outputs_(std::move(required_outputs)),
77 model_string_(std::move(model_string)),
Michael Martisa967f632018-08-10 10:39:00 +100078 model_(std::move(model)),
Andrew Moylanb481af72020-07-09 15:22:00 +100079 receiver_(this, std::move(receiver)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100080 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100081
Andrew Moylanb481af72020-07-09 15:22:00 +100082void ModelImpl::set_disconnect_handler(base::Closure disconnect_handler) {
83 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Michael Martisa967f632018-08-10 10:39:00 +100084}
85
86int ModelImpl::num_graph_executors_for_testing() const {
87 return graph_executors_.size();
88}
89
Andrew Moylanb481af72020-07-09 15:22:00 +100090void ModelImpl::CreateGraphExecutor(
91 mojo::PendingReceiver<GraphExecutor> receiver,
92 CreateGraphExecutorCallback callback) {
Alan Green3117b522020-05-20 09:50:27 +100093 auto options = GraphExecutorOptions::New(/*use_nnapi=*/false);
Andrew Moylanb481af72020-07-09 15:22:00 +100094 CreateGraphExecutorWithOptions(std::move(options), std::move(receiver),
Alan Green55e16542020-05-11 14:06:46 +100095 std::move(callback));
96}
97
98void ModelImpl::CreateGraphExecutorWithOptions(
99 GraphExecutorOptionsPtr options,
Andrew Moylanb481af72020-07-09 15:22:00 +1000100 mojo::PendingReceiver<GraphExecutor> receiver,
Alan Green55e16542020-05-11 14:06:46 +1000101 CreateGraphExecutorCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000102 DCHECK(!metrics_model_name_.empty());
103
charleszhao5a7050e2020-07-14 15:21:41 +1000104 RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100105 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +1000106
Michael Martisa967f632018-08-10 10:39:00 +1000107 if (model_ == nullptr) {
108 LOG(ERROR) << "Null model provided.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900109 std::move(callback).Run(
110 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100111 request_metrics.RecordRequestEvent(
112 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000113 return;
114 }
115
116 // Instantiate interpreter.
117 tflite::ops::builtin::BuiltinOpResolver resolver;
118 std::unique_ptr<tflite::Interpreter> interpreter;
119 const TfLiteStatus resolve_status =
120 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
121 if (resolve_status != kTfLiteOk || !interpreter) {
122 LOG(ERROR) << "Could not resolve model ops.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900123 std::move(callback).Run(
124 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100125 request_metrics.RecordRequestEvent(
126 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000127 return;
128 }
129
Alan Green55e16542020-05-11 14:06:46 +1000130 // If requested, load and apply NNAPI
131 if (options->use_nnapi) {
Alan Green3117b522020-05-20 09:50:27 +1000132 TfLiteDelegate* delegate = tflite::NnApiDelegate();
Alan Green55e16542020-05-11 14:06:46 +1000133 if (!delegate) {
134 LOG(ERROR) << "NNAPI requested but not available.";
135 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
136 request_metrics.RecordRequestEvent(
137 CreateGraphExecutorResult::NNAPI_UNAVAILABLE);
138 return;
139 }
140 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
141 LOG(ERROR) << "Could not use NNAPI delegate.";
142 std::move(callback).Run(CreateGraphExecutorResult::NNAPI_USE_ERROR);
143 request_metrics.RecordRequestEvent(
144 CreateGraphExecutorResult::NNAPI_USE_ERROR);
145 return;
146 }
147 }
148
Michael Martisa967f632018-08-10 10:39:00 +1000149 // Allocate memory for tensors.
150 if (interpreter->AllocateTensors() != kTfLiteOk) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900151 std::move(callback).Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100152 request_metrics.RecordRequestEvent(
153 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000154 return;
155 }
156
157 // Add graph executor and schedule its deletion on pipe closure.
158 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Andrew Moylanb481af72020-07-09 15:22:00 +1000159 std::move(interpreter), std::move(receiver),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000160 metrics_model_name_);
Andrew Moylanb481af72020-07-09 15:22:00 +1000161 graph_executors_.front().set_disconnect_handler(
Michael Martisa967f632018-08-10 10:39:00 +1000162 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
163 graph_executors_.begin()));
164
Qijiang Fan5d381a02020-04-19 23:42:37 +0900165 std::move(callback).Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100166 request_metrics.FinishRecordingPerformanceMetrics();
167 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000168}
169
170void ModelImpl::EraseGraphExecutor(
171 const std::list<GraphExecutorImpl>::const_iterator it) {
172 graph_executors_.erase(it);
173}
174
175} // namespace ml