blob: e95766f20e057cb491640a885ea2411bf61d770a [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/graph_executor_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martis26abcd82018-08-08 10:57:25 +10007
8#include <set>
9#include <utility>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090010#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100011
Qijiang Fane19d67d2020-04-01 08:18:39 +090012#include <base/stl_util.h>
13
Hidehiko Abeaa488c32018-08-31 23:49:41 +090014#include "ml/mojom/tensor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100015#include "ml/tensor_view.h"
Michael Martis26abcd82018-08-08 10:57:25 +100016
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::ExecuteResult;
22using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
23using ::chromeos::machine_learning::mojom::Int64List;
24using ::chromeos::machine_learning::mojom::Tensor;
25using ::chromeos::machine_learning::mojom::TensorPtr;
26using ::chromeos::machine_learning::mojom::ValueList;
27
alanlxlcb1f8562018-11-01 15:16:11 +110028// Base name for UMA metrics related to graph execution
Honglin Yu6adafcd2019-07-22 13:48:11 +100029constexpr char kMetricsRequestName[] = "ExecuteResult";
alanlxlcb1f8562018-11-01 15:16:11 +110030
Michael Martis26abcd82018-08-08 10:57:25 +100031// Verifies |tensor| is valid (i.e. is of type |TensorType| and of the correct
32// shape for this input) and copies its data into the graph |interpreter| at
33// position |index|.
34template <typename TensorType, typename MemoryType>
35ExecuteResult PopulateInput(const TensorPtr& tensor,
36 const int index,
37 tflite::Interpreter* const interpreter) {
38 const TensorView<TensorType> tensor_view(tensor);
39
40 if (!tensor_view.IsValidType())
41 return ExecuteResult::INPUT_TYPE_ERROR;
42
43 if (!tensor_view.IsValidFormat())
44 return ExecuteResult::INPUT_FORMAT_ERROR;
45
46 // Check that given input shape matches that expected by TF lite.
47
48 const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090049 const std::vector<int64_t>& actual_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +100050
51 bool shape_matches = expected_dims.size == actual_dims.size();
52 for (int i = 0; shape_matches && i < expected_dims.size; ++i) {
53 shape_matches = expected_dims.data[i] == actual_dims[i];
54 }
55
56 if (!shape_matches)
57 return ExecuteResult::INPUT_SHAPE_ERROR;
58
59 MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090060 const std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +100061 for (int i = 0; i < tensor_values.size(); ++i) {
62 input_memory[i] = tensor_values[i];
63 }
64
65 return ExecuteResult::OK;
66}
67
68ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) {
69 return ExecuteResult::EXECUTION_ERROR;
70}
71
72// A table of functions to validate / populate data for model nodes expecting
73// input of each TF lite type.
74//
75// This table is indexed by TfLiteType, the possible values of which can be
Michael Martis8783c8e2019-06-26 17:30:54 +100076// found at <tensorflow/lite/context.h>. We make the following
Michael Martis26abcd82018-08-08 10:57:25 +100077// assumptions about index values:
78// 1) They will remain consistent across TF lite releases, and
79// 2) They will always start from (close to) 0 and be (mostly) consecutive.
80//
81// Since TfLiteType is part of the stable C API for TF lite, these assumptions
82// seem fair.
83constexpr decltype(&InvalidInput) kPopulateInputFns[] = {
84 &InvalidInput, // kTfLiteNoType
85 &PopulateInput<double, float>, // kTfLiteFloat32
86 &PopulateInput<int64_t, int32_t>, // kTfLiteInt32
87 &PopulateInput<int64_t, uint8_t>, // kTfLiteUInt8
88 &PopulateInput<int64_t, int64_t>, // kTfLiteInt64
89 &InvalidInput, // kTfLiteString
90 &PopulateInput<int64_t, bool>, // kTfLiteBool
91};
92
93// Copies data from position |index| in the graph |interpreter| into the given
94// tensor object.
95template <typename TensorType, typename MemoryType>
96ExecuteResult PopulateOutput(const int index,
97 const tflite::Interpreter& interpreter,
98 const TensorPtr& tensor) {
99 TensorView<TensorType> tensor_view(tensor);
100 tensor_view.Allocate();
101
102 // Empty output is not valid.
103 const TfLiteIntArray& dims = *interpreter.tensor(index)->dims;
104 if (dims.size == 0)
105 return ExecuteResult::EXECUTION_ERROR;
106
107 // Copy across size information and calculate the number of elements being
108 // output.
109 int64_t num_entries = 1;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900110 std::vector<int64_t>& tensor_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +1000111 tensor_dims.resize(dims.size);
112 for (int i = 0; i < dims.size; ++i) {
113 const int64_t dim_length = dims.data[i];
114
115 if (dim_length <= 0)
116 return ExecuteResult::EXECUTION_ERROR;
117
118 tensor_dims[i] = dim_length;
119 num_entries *= dim_length;
120 }
121
122 // Populate tensor values.
123 const MemoryType* const output_memory =
124 interpreter.typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900125 std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +1000126 tensor_values.resize(num_entries);
127 for (int i = 0; i < num_entries; ++i) {
128 tensor_values[i] = output_memory[i];
129 }
130
131 return ExecuteResult::OK;
132}
133
134ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) {
135 return ExecuteResult::EXECUTION_ERROR;
136}
137
138// A table of functions to populate data for tensors from output of each TF lite
139// type.
140//
141// This table is indexed by TfLiteType, the possible values of which can be
Michael Martis8783c8e2019-06-26 17:30:54 +1000142// found at <tensorflow/lite/context.h>. See the caveats discussed in
Michael Martis26abcd82018-08-08 10:57:25 +1000143// the comment above |kPopulateInputFns|.
144constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = {
145 &InvalidOutput, // kTfLiteNoType
146 &PopulateOutput<double, float>, // kTfLiteFloat32
147 &PopulateOutput<int64_t, int32_t>, // kTfLiteInt32
148 &PopulateOutput<int64_t, uint8_t>, // kTfLiteUInt8
149 &PopulateOutput<int64_t, int64_t>, // kTfLiteInt64
150 &InvalidOutput, // kTfLiteString
151 &PopulateOutput<int64_t, bool>, // kTfLiteBool
152};
153
Michael Martis26abcd82018-08-08 10:57:25 +1000154} // namespace
155
156GraphExecutorImpl::GraphExecutorImpl(
157 const std::map<std::string, int>& required_inputs,
158 const std::map<std::string, int>& required_outputs,
159 std::unique_ptr<tflite::Interpreter> interpreter,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000160 GraphExecutorRequest request,
161 const std::string& metrics_model_name)
Michael Martis26abcd82018-08-08 10:57:25 +1000162 : required_inputs_(required_inputs),
163 required_outputs_(required_outputs),
164 interpreter_(std::move(interpreter)),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000165 binding_(this, std::move(request)),
166 metrics_model_name_(metrics_model_name) {}
Michael Martis26abcd82018-08-08 10:57:25 +1000167
168void GraphExecutorImpl::set_connection_error_handler(
169 base::Closure connection_error_handler) {
170 binding_.set_connection_error_handler(std::move(connection_error_handler));
171}
172
173void GraphExecutorImpl::Execute(
hscham3d0632f2019-12-11 15:58:57 +0900174 base::flat_map<std::string, TensorPtr> tensors,
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900175 const std::vector<std::string>& outputs,
Michael Martis26abcd82018-08-08 10:57:25 +1000176 const ExecuteCallback& callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000177 DCHECK(!metrics_model_name_.empty());
alanlxlcb1f8562018-11-01 15:16:11 +1100178
Honglin Yu6adafcd2019-07-22 13:48:11 +1000179 RequestMetrics<ExecuteResult> request_metrics(metrics_model_name_,
180 kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100181 request_metrics.StartRecordingPerformanceMetrics();
182
Michael Martis26abcd82018-08-08 10:57:25 +1000183 // Validate input and output names (before executing graph, for efficiency).
184
185 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900186 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000187
188 const auto name_lookup = required_inputs_.find(cur_input_name);
189 if (name_lookup == required_inputs_.end() ||
190 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900191 callback.Run(ExecuteResult::UNKNOWN_INPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100192 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_INPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000193 return;
194 }
195 }
196 if (tensors.size() != required_inputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900197 callback.Run(ExecuteResult::INPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100198 request_metrics.RecordRequestEvent(ExecuteResult::INPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000199 return;
200 }
201
202 std::set<std::string> seen_outputs;
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900203 for (const auto& cur_output_name : outputs) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900204 const auto name_lookup = required_outputs_.find(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000205 if (name_lookup == required_outputs_.end() ||
206 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900207 callback.Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100208 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000209 return;
210 }
211
212 // Specifying the same output twice is an error.
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900213 const auto insert_result = seen_outputs.insert(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000214 if (!insert_result.second) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900215 callback.Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100216 request_metrics.RecordRequestEvent(ExecuteResult::DUPLICATE_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000217 return;
218 }
219 }
220 if (outputs.size() != required_outputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900221 callback.Run(ExecuteResult::OUTPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100222 request_metrics.RecordRequestEvent(ExecuteResult::OUTPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000223 return;
224 }
225
226 // Copy input data into the interpreter.
227 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900228 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000229 const TensorPtr& cur_input = kv.second;
230
231 // Always valid, by the input name check at the start of this function.
232 const int cur_input_id = required_inputs_.find(cur_input_name)->second;
233
234 // Check that the current input node is a supported type.
235 const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type;
Qijiang Fane19d67d2020-04-01 08:18:39 +0900236 if (cur_input_type >= base::size(kPopulateInputFns)) {
Michael Martis26abcd82018-08-08 10:57:25 +1000237 LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id
238 << " of type " << cur_input_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900239 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100240 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000241 return;
242 }
243
244 // Attempt to copy input data into the current input node.
245 const ExecuteResult populate_input_result =
246 (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id,
247 interpreter_.get());
248 if (populate_input_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900249 callback.Run(populate_input_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100250 request_metrics.RecordRequestEvent(populate_input_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000251 return;
252 }
253 }
254
255 // Execute graph.
256 if (interpreter_->Invoke() != kTfLiteOk) {
257 LOG(ERROR) << "TF lite graph execution failed unexpectedly.";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900258 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100259 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000260 return;
261 }
262
263 // Extract output.
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900264 std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
265 for (const auto& cur_output_name : outputs) {
Michael Martis26abcd82018-08-08 10:57:25 +1000266 output_tensors.push_back(Tensor::New());
267
268 // Always valid, by the output name check at the start of this function.
269 const int cur_output_id =
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900270 required_outputs_.find(cur_output_name)->second;
Michael Martis26abcd82018-08-08 10:57:25 +1000271
272 // Check that the current output node is a supported type.
273 const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type;
Qijiang Fane19d67d2020-04-01 08:18:39 +0900274 if (cur_output_type >= base::size(kPopulateOutputFns)) {
Michael Martis26abcd82018-08-08 10:57:25 +1000275 LOG(ERROR) << "TF lite graph contains invalid output node "
276 << cur_output_id << " of type " << cur_output_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900277 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100278 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000279 return;
280 }
281
282 // Attempt to extract data from the current output node.
283 const ExecuteResult populate_output_result =
284 (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_,
285 *--output_tensors.end());
286 if (populate_output_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900287 callback.Run(populate_output_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100288 request_metrics.RecordRequestEvent(populate_output_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000289 return;
290 }
291 }
292
293 callback.Run(ExecuteResult::OK, std::move(output_tensors));
alanlxlcb1f8562018-11-01 15:16:11 +1100294 request_metrics.FinishRecordingPerformanceMetrics();
295 request_metrics.RecordRequestEvent(ExecuteResult::OK);
Michael Martis26abcd82018-08-08 10:57:25 +1000296}
297
298} // namespace ml