blob: 64ab4e84cd50f2fecb73c16efa990798d577657b [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/graph_executor_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martis26abcd82018-08-08 10:57:25 +10007
8#include <set>
9#include <utility>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090010#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100011
Qijiang Fan713061e2021-03-08 15:45:12 +090012#include <base/check.h>
Qijiang Fane19d67d2020-04-01 08:18:39 +090013#include <base/stl_util.h>
14
Hidehiko Abeaa488c32018-08-31 23:49:41 +090015#include "ml/mojom/tensor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100016#include "ml/tensor_view.h"
Michael Martis26abcd82018-08-08 10:57:25 +100017
18namespace ml {
19
20namespace {
21
22using ::chromeos::machine_learning::mojom::ExecuteResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100023using ::chromeos::machine_learning::mojom::GraphExecutor;
Michael Martis26abcd82018-08-08 10:57:25 +100024using ::chromeos::machine_learning::mojom::Int64List;
25using ::chromeos::machine_learning::mojom::Tensor;
26using ::chromeos::machine_learning::mojom::TensorPtr;
27using ::chromeos::machine_learning::mojom::ValueList;
28
alanlxlcb1f8562018-11-01 15:16:11 +110029// Base name for UMA metrics related to graph execution
Honglin Yu6adafcd2019-07-22 13:48:11 +100030constexpr char kMetricsRequestName[] = "ExecuteResult";
alanlxlcb1f8562018-11-01 15:16:11 +110031
Andrew Moylan79b34a42020-07-08 11:13:11 +100032// Verifies `tensor` is valid (i.e. is of type `TensorType` and of the correct
33// shape for this input) and copies its data into the graph `interpreter` at
34// position `index`.
Michael Martis26abcd82018-08-08 10:57:25 +100035template <typename TensorType, typename MemoryType>
36ExecuteResult PopulateInput(const TensorPtr& tensor,
37 const int index,
38 tflite::Interpreter* const interpreter) {
39 const TensorView<TensorType> tensor_view(tensor);
40
41 if (!tensor_view.IsValidType())
42 return ExecuteResult::INPUT_TYPE_ERROR;
43
44 if (!tensor_view.IsValidFormat())
45 return ExecuteResult::INPUT_FORMAT_ERROR;
46
47 // Check that given input shape matches that expected by TF lite.
48
49 const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090050 const std::vector<int64_t>& actual_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +100051
52 bool shape_matches = expected_dims.size == actual_dims.size();
53 for (int i = 0; shape_matches && i < expected_dims.size; ++i) {
54 shape_matches = expected_dims.data[i] == actual_dims[i];
55 }
56
57 if (!shape_matches)
58 return ExecuteResult::INPUT_SHAPE_ERROR;
59
60 MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090061 const std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +100062 for (int i = 0; i < tensor_values.size(); ++i) {
63 input_memory[i] = tensor_values[i];
64 }
65
66 return ExecuteResult::OK;
67}
68
69ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) {
70 return ExecuteResult::EXECUTION_ERROR;
71}
72
73// A table of functions to validate / populate data for model nodes expecting
74// input of each TF lite type.
75//
76// This table is indexed by TfLiteType, the possible values of which can be
Michael Martis8783c8e2019-06-26 17:30:54 +100077// found at <tensorflow/lite/context.h>. We make the following
Michael Martis26abcd82018-08-08 10:57:25 +100078// assumptions about index values:
79// 1) They will remain consistent across TF lite releases, and
80// 2) They will always start from (close to) 0 and be (mostly) consecutive.
81//
82// Since TfLiteType is part of the stable C API for TF lite, these assumptions
83// seem fair.
84constexpr decltype(&InvalidInput) kPopulateInputFns[] = {
85 &InvalidInput, // kTfLiteNoType
86 &PopulateInput<double, float>, // kTfLiteFloat32
87 &PopulateInput<int64_t, int32_t>, // kTfLiteInt32
88 &PopulateInput<int64_t, uint8_t>, // kTfLiteUInt8
89 &PopulateInput<int64_t, int64_t>, // kTfLiteInt64
90 &InvalidInput, // kTfLiteString
91 &PopulateInput<int64_t, bool>, // kTfLiteBool
92};
93
Andrew Moylan79b34a42020-07-08 11:13:11 +100094// Copies data from position `index` in the graph `interpreter` into the given
Michael Martis26abcd82018-08-08 10:57:25 +100095// tensor object.
96template <typename TensorType, typename MemoryType>
97ExecuteResult PopulateOutput(const int index,
98 const tflite::Interpreter& interpreter,
99 const TensorPtr& tensor) {
100 TensorView<TensorType> tensor_view(tensor);
101 tensor_view.Allocate();
102
103 // Empty output is not valid.
104 const TfLiteIntArray& dims = *interpreter.tensor(index)->dims;
105 if (dims.size == 0)
106 return ExecuteResult::EXECUTION_ERROR;
107
108 // Copy across size information and calculate the number of elements being
109 // output.
110 int64_t num_entries = 1;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900111 std::vector<int64_t>& tensor_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +1000112 tensor_dims.resize(dims.size);
113 for (int i = 0; i < dims.size; ++i) {
114 const int64_t dim_length = dims.data[i];
115
116 if (dim_length <= 0)
117 return ExecuteResult::EXECUTION_ERROR;
118
119 tensor_dims[i] = dim_length;
120 num_entries *= dim_length;
121 }
122
123 // Populate tensor values.
124 const MemoryType* const output_memory =
125 interpreter.typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900126 std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +1000127 tensor_values.resize(num_entries);
128 for (int i = 0; i < num_entries; ++i) {
129 tensor_values[i] = output_memory[i];
130 }
131
132 return ExecuteResult::OK;
133}
134
135ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) {
136 return ExecuteResult::EXECUTION_ERROR;
137}
138
139// A table of functions to populate data for tensors from output of each TF lite
140// type.
141//
142// This table is indexed by TfLiteType, the possible values of which can be
Michael Martis8783c8e2019-06-26 17:30:54 +1000143// found at <tensorflow/lite/context.h>. See the caveats discussed in
Andrew Moylan79b34a42020-07-08 11:13:11 +1000144// the comment above `kPopulateInputFns`.
Michael Martis26abcd82018-08-08 10:57:25 +1000145constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = {
146 &InvalidOutput, // kTfLiteNoType
147 &PopulateOutput<double, float>, // kTfLiteFloat32
148 &PopulateOutput<int64_t, int32_t>, // kTfLiteInt32
149 &PopulateOutput<int64_t, uint8_t>, // kTfLiteUInt8
150 &PopulateOutput<int64_t, int64_t>, // kTfLiteInt64
151 &InvalidOutput, // kTfLiteString
152 &PopulateOutput<int64_t, bool>, // kTfLiteBool
153};
154
Michael Martis26abcd82018-08-08 10:57:25 +1000155} // namespace
156
157GraphExecutorImpl::GraphExecutorImpl(
158 const std::map<std::string, int>& required_inputs,
159 const std::map<std::string, int>& required_outputs,
160 std::unique_ptr<tflite::Interpreter> interpreter,
Andrew Moylanb481af72020-07-09 15:22:00 +1000161 mojo::PendingReceiver<GraphExecutor> receiver,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000162 const std::string& metrics_model_name)
Michael Martis26abcd82018-08-08 10:57:25 +1000163 : required_inputs_(required_inputs),
164 required_outputs_(required_outputs),
165 interpreter_(std::move(interpreter)),
Andrew Moylanb481af72020-07-09 15:22:00 +1000166 receiver_(this, std::move(receiver)),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000167 metrics_model_name_(metrics_model_name) {}
Michael Martis26abcd82018-08-08 10:57:25 +1000168
Andrew Moylanb481af72020-07-09 15:22:00 +1000169void GraphExecutorImpl::set_disconnect_handler(
170 base::Closure disconnect_handler) {
171 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Michael Martis26abcd82018-08-08 10:57:25 +1000172}
173
Qijiang Fan5d381a02020-04-19 23:42:37 +0900174void GraphExecutorImpl::Execute(base::flat_map<std::string, TensorPtr> tensors,
175 const std::vector<std::string>& outputs,
176 ExecuteCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000177 DCHECK(!metrics_model_name_.empty());
alanlxlcb1f8562018-11-01 15:16:11 +1100178
charleszhao5a7050e2020-07-14 15:21:41 +1000179 RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100180 request_metrics.StartRecordingPerformanceMetrics();
181
Michael Martis26abcd82018-08-08 10:57:25 +1000182 // Validate input and output names (before executing graph, for efficiency).
183
184 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900185 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000186
187 const auto name_lookup = required_inputs_.find(cur_input_name);
188 if (name_lookup == required_inputs_.end() ||
189 name_lookup->second >= interpreter_->tensors_size()) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900190 std::move(callback).Run(ExecuteResult::UNKNOWN_INPUT_ERROR,
191 base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100192 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_INPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000193 return;
194 }
195 }
196 if (tensors.size() != required_inputs_.size()) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900197 std::move(callback).Run(ExecuteResult::INPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100198 request_metrics.RecordRequestEvent(ExecuteResult::INPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000199 return;
200 }
201
202 std::set<std::string> seen_outputs;
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900203 for (const auto& cur_output_name : outputs) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900204 const auto name_lookup = required_outputs_.find(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000205 if (name_lookup == required_outputs_.end() ||
206 name_lookup->second >= interpreter_->tensors_size()) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900207 std::move(callback).Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR,
208 base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100209 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000210 return;
211 }
212
213 // Specifying the same output twice is an error.
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900214 const auto insert_result = seen_outputs.insert(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000215 if (!insert_result.second) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900216 std::move(callback).Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR,
217 base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100218 request_metrics.RecordRequestEvent(ExecuteResult::DUPLICATE_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000219 return;
220 }
221 }
222 if (outputs.size() != required_outputs_.size()) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900223 std::move(callback).Run(ExecuteResult::OUTPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100224 request_metrics.RecordRequestEvent(ExecuteResult::OUTPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000225 return;
226 }
227
228 // Copy input data into the interpreter.
229 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900230 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000231 const TensorPtr& cur_input = kv.second;
232
233 // Always valid, by the input name check at the start of this function.
234 const int cur_input_id = required_inputs_.find(cur_input_name)->second;
235
236 // Check that the current input node is a supported type.
237 const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type;
Qijiang Fane19d67d2020-04-01 08:18:39 +0900238 if (cur_input_type >= base::size(kPopulateInputFns)) {
Michael Martis26abcd82018-08-08 10:57:25 +1000239 LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id
240 << " of type " << cur_input_type << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900241 std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100242 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000243 return;
244 }
245
246 // Attempt to copy input data into the current input node.
247 const ExecuteResult populate_input_result =
248 (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id,
249 interpreter_.get());
250 if (populate_input_result != ExecuteResult::OK) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900251 std::move(callback).Run(populate_input_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100252 request_metrics.RecordRequestEvent(populate_input_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000253 return;
254 }
255 }
256
257 // Execute graph.
258 if (interpreter_->Invoke() != kTfLiteOk) {
259 LOG(ERROR) << "TF lite graph execution failed unexpectedly.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900260 std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100261 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000262 return;
263 }
264
265 // Extract output.
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900266 std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
267 for (const auto& cur_output_name : outputs) {
Michael Martis26abcd82018-08-08 10:57:25 +1000268 output_tensors.push_back(Tensor::New());
269
270 // Always valid, by the output name check at the start of this function.
Tom Hughes1d1c1922020-08-27 16:16:53 -0700271 const int cur_output_id = required_outputs_.find(cur_output_name)->second;
Michael Martis26abcd82018-08-08 10:57:25 +1000272
273 // Check that the current output node is a supported type.
274 const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type;
Qijiang Fane19d67d2020-04-01 08:18:39 +0900275 if (cur_output_type >= base::size(kPopulateOutputFns)) {
Michael Martis26abcd82018-08-08 10:57:25 +1000276 LOG(ERROR) << "TF lite graph contains invalid output node "
277 << cur_output_id << " of type " << cur_output_type << ".";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900278 std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100279 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000280 return;
281 }
282
283 // Attempt to extract data from the current output node.
284 const ExecuteResult populate_output_result =
285 (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_,
286 *--output_tensors.end());
287 if (populate_output_result != ExecuteResult::OK) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900288 std::move(callback).Run(populate_output_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100289 request_metrics.RecordRequestEvent(populate_output_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000290 return;
291 }
292 }
293
Qijiang Fan5d381a02020-04-19 23:42:37 +0900294 std::move(callback).Run(ExecuteResult::OK, std::move(output_tensors));
alanlxlcb1f8562018-11-01 15:16:11 +1100295 request_metrics.FinishRecordingPerformanceMetrics();
296 request_metrics.RecordRequestEvent(ExecuteResult::OK);
Michael Martis26abcd82018-08-08 10:57:25 +1000297}
298
299} // namespace ml