blob: 1464c697dab673ab9adcbaa30b2164b2c56cf82a [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/graph_executor_impl.h"
6
7#include <set>
8#include <utility>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +09009#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100010
Hidehiko Abeaa488c32018-08-31 23:49:41 +090011#include "ml/mojom/tensor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100012#include "ml/tensor_view.h"
Michael Martis26abcd82018-08-08 10:57:25 +100013
14namespace ml {
15
16namespace {
17
18using ::chromeos::machine_learning::mojom::ExecuteResult;
19using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
20using ::chromeos::machine_learning::mojom::Int64List;
21using ::chromeos::machine_learning::mojom::Tensor;
22using ::chromeos::machine_learning::mojom::TensorPtr;
23using ::chromeos::machine_learning::mojom::ValueList;
24
25// Verifies |tensor| is valid (i.e. is of type |TensorType| and of the correct
26// shape for this input) and copies its data into the graph |interpreter| at
27// position |index|.
28template <typename TensorType, typename MemoryType>
29ExecuteResult PopulateInput(const TensorPtr& tensor,
30 const int index,
31 tflite::Interpreter* const interpreter) {
32 const TensorView<TensorType> tensor_view(tensor);
33
34 if (!tensor_view.IsValidType())
35 return ExecuteResult::INPUT_TYPE_ERROR;
36
37 if (!tensor_view.IsValidFormat())
38 return ExecuteResult::INPUT_FORMAT_ERROR;
39
40 // Check that given input shape matches that expected by TF lite.
41
42 const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090043 const std::vector<int64_t>& actual_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +100044
45 bool shape_matches = expected_dims.size == actual_dims.size();
46 for (int i = 0; shape_matches && i < expected_dims.size; ++i) {
47 shape_matches = expected_dims.data[i] == actual_dims[i];
48 }
49
50 if (!shape_matches)
51 return ExecuteResult::INPUT_SHAPE_ERROR;
52
53 MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090054 const std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +100055 for (int i = 0; i < tensor_values.size(); ++i) {
56 input_memory[i] = tensor_values[i];
57 }
58
59 return ExecuteResult::OK;
60}
61
62ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) {
63 return ExecuteResult::EXECUTION_ERROR;
64}
65
66// A table of functions to validate / populate data for model nodes expecting
67// input of each TF lite type.
68//
69// This table is indexed by TfLiteType, the possible values of which can be
70// found at <tensorflow/contrib/lite/context.h>. We make the following
71// assumptions about index values:
72// 1) They will remain consistent across TF lite releases, and
73// 2) They will always start from (close to) 0 and be (mostly) consecutive.
74//
75// Since TfLiteType is part of the stable C API for TF lite, these assumptions
76// seem fair.
77constexpr decltype(&InvalidInput) kPopulateInputFns[] = {
78 &InvalidInput, // kTfLiteNoType
79 &PopulateInput<double, float>, // kTfLiteFloat32
80 &PopulateInput<int64_t, int32_t>, // kTfLiteInt32
81 &PopulateInput<int64_t, uint8_t>, // kTfLiteUInt8
82 &PopulateInput<int64_t, int64_t>, // kTfLiteInt64
83 &InvalidInput, // kTfLiteString
84 &PopulateInput<int64_t, bool>, // kTfLiteBool
85};
86
87// Copies data from position |index| in the graph |interpreter| into the given
88// tensor object.
89template <typename TensorType, typename MemoryType>
90ExecuteResult PopulateOutput(const int index,
91 const tflite::Interpreter& interpreter,
92 const TensorPtr& tensor) {
93 TensorView<TensorType> tensor_view(tensor);
94 tensor_view.Allocate();
95
96 // Empty output is not valid.
97 const TfLiteIntArray& dims = *interpreter.tensor(index)->dims;
98 if (dims.size == 0)
99 return ExecuteResult::EXECUTION_ERROR;
100
101 // Copy across size information and calculate the number of elements being
102 // output.
103 int64_t num_entries = 1;
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900104 std::vector<int64_t>& tensor_dims = tensor_view.GetShape();
Michael Martis26abcd82018-08-08 10:57:25 +1000105 tensor_dims.resize(dims.size);
106 for (int i = 0; i < dims.size; ++i) {
107 const int64_t dim_length = dims.data[i];
108
109 if (dim_length <= 0)
110 return ExecuteResult::EXECUTION_ERROR;
111
112 tensor_dims[i] = dim_length;
113 num_entries *= dim_length;
114 }
115
116 // Populate tensor values.
117 const MemoryType* const output_memory =
118 interpreter.typed_tensor<MemoryType>(index);
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900119 std::vector<TensorType>& tensor_values = tensor_view.GetValues();
Michael Martis26abcd82018-08-08 10:57:25 +1000120 tensor_values.resize(num_entries);
121 for (int i = 0; i < num_entries; ++i) {
122 tensor_values[i] = output_memory[i];
123 }
124
125 return ExecuteResult::OK;
126}
127
128ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) {
129 return ExecuteResult::EXECUTION_ERROR;
130}
131
132// A table of functions to populate data for tensors from output of each TF lite
133// type.
134//
135// This table is indexed by TfLiteType, the possible values of which can be
136// found at <tensorflow/contrib/lite/context.h>. See the caveats discussed in
137// the comment above |kPopulateInputFns|.
138constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = {
139 &InvalidOutput, // kTfLiteNoType
140 &PopulateOutput<double, float>, // kTfLiteFloat32
141 &PopulateOutput<int64_t, int32_t>, // kTfLiteInt32
142 &PopulateOutput<int64_t, uint8_t>, // kTfLiteUInt8
143 &PopulateOutput<int64_t, int64_t>, // kTfLiteInt64
144 &InvalidOutput, // kTfLiteString
145 &PopulateOutput<int64_t, bool>, // kTfLiteBool
146};
147
Michael Martis26abcd82018-08-08 10:57:25 +1000148} // namespace
149
150GraphExecutorImpl::GraphExecutorImpl(
151 const std::map<std::string, int>& required_inputs,
152 const std::map<std::string, int>& required_outputs,
153 std::unique_ptr<tflite::Interpreter> interpreter,
154 GraphExecutorRequest request)
155 : required_inputs_(required_inputs),
156 required_outputs_(required_outputs),
157 interpreter_(std::move(interpreter)),
158 binding_(this, std::move(request)) {}
159
160void GraphExecutorImpl::set_connection_error_handler(
161 base::Closure connection_error_handler) {
162 binding_.set_connection_error_handler(std::move(connection_error_handler));
163}
164
165void GraphExecutorImpl::Execute(
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900166 std::unordered_map<std::string, TensorPtr> tensors,
167 const std::vector<std::string>& outputs,
Michael Martis26abcd82018-08-08 10:57:25 +1000168 const ExecuteCallback& callback) {
169 // Validate input and output names (before executing graph, for efficiency).
170
171 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900172 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000173
174 const auto name_lookup = required_inputs_.find(cur_input_name);
175 if (name_lookup == required_inputs_.end() ||
176 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900177 callback.Run(ExecuteResult::UNKNOWN_INPUT_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000178 return;
179 }
180 }
181 if (tensors.size() != required_inputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900182 callback.Run(ExecuteResult::INPUT_MISSING_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000183 return;
184 }
185
186 std::set<std::string> seen_outputs;
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900187 for (const auto& cur_output_name : outputs) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900188 const auto name_lookup = required_outputs_.find(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000189 if (name_lookup == required_outputs_.end() ||
190 name_lookup->second >= interpreter_->tensors_size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900191 callback.Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000192 return;
193 }
194
195 // Specifying the same output twice is an error.
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900196 const auto insert_result = seen_outputs.insert(cur_output_name);
Michael Martis26abcd82018-08-08 10:57:25 +1000197 if (!insert_result.second) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900198 callback.Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000199 return;
200 }
201 }
202 if (outputs.size() != required_outputs_.size()) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900203 callback.Run(ExecuteResult::OUTPUT_MISSING_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000204 return;
205 }
206
207 // Copy input data into the interpreter.
208 for (const auto& kv : tensors) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900209 const std::string& cur_input_name = kv.first;
Michael Martis26abcd82018-08-08 10:57:25 +1000210 const TensorPtr& cur_input = kv.second;
211
212 // Always valid, by the input name check at the start of this function.
213 const int cur_input_id = required_inputs_.find(cur_input_name)->second;
214
215 // Check that the current input node is a supported type.
216 const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type;
217 if (cur_input_type >= arraysize(kPopulateInputFns)) {
218 LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id
219 << " of type " << cur_input_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900220 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000221 return;
222 }
223
224 // Attempt to copy input data into the current input node.
225 const ExecuteResult populate_input_result =
226 (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id,
227 interpreter_.get());
228 if (populate_input_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900229 callback.Run(populate_input_result, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000230 return;
231 }
232 }
233
234 // Execute graph.
235 if (interpreter_->Invoke() != kTfLiteOk) {
236 LOG(ERROR) << "TF lite graph execution failed unexpectedly.";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900237 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000238 return;
239 }
240
241 // Extract output.
Hidehiko Abe31bb9632018-11-23 02:49:56 +0900242 std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
243 for (const auto& cur_output_name : outputs) {
Michael Martis26abcd82018-08-08 10:57:25 +1000244 output_tensors.push_back(Tensor::New());
245
246 // Always valid, by the output name check at the start of this function.
247 const int cur_output_id =
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900248 required_outputs_.find(cur_output_name)->second;
Michael Martis26abcd82018-08-08 10:57:25 +1000249
250 // Check that the current output node is a supported type.
251 const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type;
252 if (cur_output_type >= arraysize(kPopulateOutputFns)) {
253 LOG(ERROR) << "TF lite graph contains invalid output node "
254 << cur_output_id << " of type " << cur_output_type << ".";
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900255 callback.Run(ExecuteResult::EXECUTION_ERROR, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000256 return;
257 }
258
259 // Attempt to extract data from the current output node.
260 const ExecuteResult populate_output_result =
261 (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_,
262 *--output_tensors.end());
263 if (populate_output_result != ExecuteResult::OK) {
Hidehiko Abe8ab64a62018-09-19 00:04:39 +0900264 callback.Run(populate_output_result, base::nullopt);
Michael Martis26abcd82018-08-08 10:57:25 +1000265 return;
266 }
267 }
268
269 callback.Run(ExecuteResult::OK, std::move(output_tensors));
270}
271
272} // namespace ml