blob: e1834ea7fcde3fad12da8cc741d85e3f3f1ba2a8 [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/graph_executor_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martis26abcd82018-08-08 10:57:25 +10007
8#include <set>
9#include <utility>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090010#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100011
Hidehiko Abeaa488c32018-08-31 23:49:41 +090012#include "ml/mojom/tensor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100013#include "ml/tensor_view.h"
Michael Martis26abcd82018-08-08 10:57:25 +100014
15namespace ml {
16
17namespace {
18
19using ::chromeos::machine_learning::mojom::ExecuteResult;
20using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
21using ::chromeos::machine_learning::mojom::Int64List;
22using ::chromeos::machine_learning::mojom::Tensor;
23using ::chromeos::machine_learning::mojom::TensorPtr;
24using ::chromeos::machine_learning::mojom::ValueList;
25
alanlxlcb1f8562018-11-01 15:16:11 +110026// Base name for UMA metrics related to graph execution
Honglin Yu6adafcd2019-07-22 13:48:11 +100027constexpr char kMetricsRequestName[] = "ExecuteResult";
alanlxlcb1f8562018-11-01 15:16:11 +110028
Michael Martis26abcd82018-08-08 10:57:25 +100029// Verifies |tensor| is valid (i.e. is of type |TensorType| and of the correct
30// shape for this input) and copies its data into the graph |interpreter| at
31// position |index|.
32template <typename TensorType, typename MemoryType>
33ExecuteResult PopulateInput(const TensorPtr& tensor,
34 const int index,
35 tflite::Interpreter* const interpreter) {
36 const TensorView<TensorType> tensor_view(tensor);
37
38 if (!tensor_view.IsValidType())
39 return ExecuteResult::INPUT_TYPE_ERROR;
40
41 if (!tensor_view.IsValidFormat())
42 return ExecuteResult::INPUT_FORMAT_ERROR;
43
44 // Check that given input shape matches that expected by TF lite.
45
46 const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090047 const std::vector<int64_t>& actual_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +100048
49 bool shape_matches = expected_dims.size == actual_dims.size();
50 for (int i = 0; shape_matches && i < expected_dims.size; ++i) {
51 shape_matches = expected_dims.data[i] == actual_dims[i];
52 }
53
54 if (!shape_matches)
55 return ExecuteResult::INPUT_SHAPE_ERROR;
56
57 MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090058 const std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +100059 for (int i = 0; i < tensor_values.size(); ++i) {
60 input_memory[i] = tensor_values[i];
61 }
62
63 return ExecuteResult::OK;
64}
65
66ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) {
67 return ExecuteResult::EXECUTION_ERROR;
68}
69
70// A table of functions to validate / populate data for model nodes expecting
71// input of each TF lite type.
72//
73// This table is indexed by TfLiteType, the possible values of which can be
74// found at <tensorflow/contrib/lite/context.h>. We make the following
75// assumptions about index values:
76// 1) They will remain consistent across TF lite releases, and
77// 2) They will always start from (close to) 0 and be (mostly) consecutive.
78//
79// Since TfLiteType is part of the stable C API for TF lite, these assumptions
80// seem fair.
81constexpr decltype(&InvalidInput) kPopulateInputFns[] = {
82 &InvalidInput, // kTfLiteNoType
83 &PopulateInput<double, float>, // kTfLiteFloat32
84 &PopulateInput<int64_t, int32_t>, // kTfLiteInt32
85 &PopulateInput<int64_t, uint8_t>, // kTfLiteUInt8
86 &PopulateInput<int64_t, int64_t>, // kTfLiteInt64
87 &InvalidInput, // kTfLiteString
88 &PopulateInput<int64_t, bool>, // kTfLiteBool
89};
90
91// Copies data from position |index| in the graph |interpreter| into the given
92// tensor object.
93template <typename TensorType, typename MemoryType>
94ExecuteResult PopulateOutput(const int index,
95 const tflite::Interpreter& interpreter,
96 const TensorPtr& tensor) {
97 TensorView<TensorType> tensor_view(tensor);
98 tensor_view.Allocate();
99
100 // Empty output is not valid.
101 const TfLiteIntArray& dims = *interpreter.tensor(index)->dims;
102 if (dims.size == 0)
103 return ExecuteResult::EXECUTION_ERROR;
104
105 // Copy across size information and calculate the number of elements being
106 // output.
107 int64_t num_entries = 1;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900108 std::vector<int64_t>& tensor_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +1000109 tensor_dims.resize(dims.size);
110 for (int i = 0; i < dims.size; ++i) {
111 const int64_t dim_length = dims.data[i];
112
113 if (dim_length <= 0)
114 return ExecuteResult::EXECUTION_ERROR;
115
116 tensor_dims[i] = dim_length;
117 num_entries *= dim_length;
118 }
119
120 // Populate tensor values.
121 const MemoryType* const output_memory =
122 interpreter.typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900123 std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +1000124 tensor_values.resize(num_entries);
125 for (int i = 0; i < num_entries; ++i) {
126 tensor_values[i] = output_memory[i];
127 }
128
129 return ExecuteResult::OK;
130}
131
132ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) {
133 return ExecuteResult::EXECUTION_ERROR;
134}
135
136// A table of functions to populate data for tensors from output of each TF lite
137// type.
138//
139// This table is indexed by TfLiteType, the possible values of which can be
140// found at <tensorflow/contrib/lite/context.h>. See the caveats discussed in
141// the comment above |kPopulateInputFns|.
142constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = {
143 &InvalidOutput, // kTfLiteNoType
144 &PopulateOutput<double, float>, // kTfLiteFloat32
145 &PopulateOutput<int64_t, int32_t>, // kTfLiteInt32
146 &PopulateOutput<int64_t, uint8_t>, // kTfLiteUInt8
147 &PopulateOutput<int64_t, int64_t>, // kTfLiteInt64
148 &InvalidOutput, // kTfLiteString
149 &PopulateOutput<int64_t, bool>, // kTfLiteBool
150};
151
Michael Martis26abcd82018-08-08 10:57:25 +1000152} // namespace
153
154GraphExecutorImpl::GraphExecutorImpl(
155 const std::map<std::string, int>& required_inputs,
156 const std::map<std::string, int>& required_outputs,
157 std::unique_ptr<tflite::Interpreter> interpreter,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000158 GraphExecutorRequest request,
159 const std::string& metrics_model_name)
Michael Martis26abcd82018-08-08 10:57:25 +1000160 : required_inputs_(required_inputs),
161 required_outputs_(required_outputs),
162 interpreter_(std::move(interpreter)),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000163 binding_(this, std::move(request)),
164 metrics_model_name_(metrics_model_name) {}
Michael Martis26abcd82018-08-08 10:57:25 +1000165
166void GraphExecutorImpl::set_connection_error_handler(
167 base::Closure connection_error_handler) {
168 binding_.set_connection_error_handler(std::move(connection_error_handler));
169}
170
171void GraphExecutorImpl::Execute(
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900172 std::unordered_map<std::string, TensorPtr> tensors,
173 const std::vector<std::string>& outputs,
Michael Martis26abcd82018-08-08 10:57:25 +1000174 const ExecuteCallback& callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +1000175 DCHECK(!metrics_model_name_.empty());
alanlxlcb1f8562018-11-01 15:16:11 +1100176
Honglin Yu6adafcd2019-07-22 13:48:11 +1000177 RequestMetrics<ExecuteResult> request_metrics(metrics_model_name_,
178 kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +1100179 request_metrics.StartRecordingPerformanceMetrics();
180
Michael Martis26abcd82018-08-08 10:57:25 +1000181 // Validate input and output names (before executing graph, for efficiency).
182
183 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900184 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000185
186 const auto name_lookup = required_inputs_.find(cur_input_name);
187 if (name_lookup == required_inputs_.end() ||
188 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900189 callback.Run(ExecuteResult::UNKNOWN_INPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100190 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_INPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000191 return;
192 }
193 }
194 if (tensors.size() != required_inputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900195 callback.Run(ExecuteResult::INPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100196 request_metrics.RecordRequestEvent(ExecuteResult::INPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000197 return;
198 }
199
200 std::set<std::string> seen_outputs;
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900201 for (const auto& cur_output_name : outputs) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900202 const auto name_lookup = required_outputs_.find(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000203 if (name_lookup == required_outputs_.end() ||
204 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900205 callback.Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100206 request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000207 return;
208 }
209
210 // Specifying the same output twice is an error.
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900211 const auto insert_result = seen_outputs.insert(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000212 if (!insert_result.second) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900213 callback.Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100214 request_metrics.RecordRequestEvent(ExecuteResult::DUPLICATE_OUTPUT_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000215 return;
216 }
217 }
218 if (outputs.size() != required_outputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900219 callback.Run(ExecuteResult::OUTPUT_MISSING_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100220 request_metrics.RecordRequestEvent(ExecuteResult::OUTPUT_MISSING_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000221 return;
222 }
223
224 // Copy input data into the interpreter.
225 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900226 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000227 const TensorPtr& cur_input = kv.second;
228
229 // Always valid, by the input name check at the start of this function.
230 const int cur_input_id = required_inputs_.find(cur_input_name)->second;
231
232 // Check that the current input node is a supported type.
233 const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type;
234 if (cur_input_type >= arraysize(kPopulateInputFns)) {
235 LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id
236 << " of type " << cur_input_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900237 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100238 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000239 return;
240 }
241
242 // Attempt to copy input data into the current input node.
243 const ExecuteResult populate_input_result =
244 (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id,
245 interpreter_.get());
246 if (populate_input_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900247 callback.Run(populate_input_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100248 request_metrics.RecordRequestEvent(populate_input_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000249 return;
250 }
251 }
252
253 // Execute graph.
254 if (interpreter_->Invoke() != kTfLiteOk) {
255 LOG(ERROR) << "TF lite graph execution failed unexpectedly.";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900256 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100257 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000258 return;
259 }
260
261 // Extract output.
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900262 std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
263 for (const auto& cur_output_name : outputs) {
Michael Martis26abcd82018-08-08 10:57:25 +1000264 output_tensors.push_back(Tensor::New());
265
266 // Always valid, by the output name check at the start of this function.
267 const int cur_output_id =
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900268 required_outputs_.find(cur_output_name)->second;
Michael Martis26abcd82018-08-08 10:57:25 +1000269
270 // Check that the current output node is a supported type.
271 const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type;
272 if (cur_output_type >= arraysize(kPopulateOutputFns)) {
273 LOG(ERROR) << "TF lite graph contains invalid output node "
274 << cur_output_id << " of type " << cur_output_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900275 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100276 request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR);
Michael Martis26abcd82018-08-08 10:57:25 +1000277 return;
278 }
279
280 // Attempt to extract data from the current output node.
281 const ExecuteResult populate_output_result =
282 (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_,
283 *--output_tensors.end());
284 if (populate_output_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900285 callback.Run(populate_output_result, base::nullopt);
alanlxlcb1f8562018-11-01 15:16:11 +1100286 request_metrics.RecordRequestEvent(populate_output_result);
Michael Martis26abcd82018-08-08 10:57:25 +1000287 return;
288 }
289 }
290
291 callback.Run(ExecuteResult::OK, std::move(output_tensors));
alanlxlcb1f8562018-11-01 15:16:11 +1100292 request_metrics.FinishRecordingPerformanceMetrics();
293 request_metrics.RecordRequestEvent(ExecuteResult::OK);
Michael Martis26abcd82018-08-08 10:57:25 +1000294}
295
296} // namespace ml