blob: e697d48c8dd2d53815c38998fd38aa8534be223f [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/graph_executor_impl.h"
6
7#include <set>
8#include <utility>
9
Hidehiko Abeaa488c32018-08-31 23:49:41 +090010#include "ml/mojom/tensor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100011#include "ml/tensor_view.h"
Michael Martis26abcd82018-08-08 10:57:25 +100012
13namespace ml {
14
15namespace {
16
17using ::chromeos::machine_learning::mojom::ExecuteResult;
18using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
19using ::chromeos::machine_learning::mojom::Int64List;
20using ::chromeos::machine_learning::mojom::Tensor;
21using ::chromeos::machine_learning::mojom::TensorPtr;
22using ::chromeos::machine_learning::mojom::ValueList;
23
24// Verifies |tensor| is valid (i.e. is of type |TensorType| and of the correct
25// shape for this input) and copies its data into the graph |interpreter| at
26// position |index|.
27template <typename TensorType, typename MemoryType>
28ExecuteResult PopulateInput(const TensorPtr& tensor,
29 const int index,
30 tflite::Interpreter* const interpreter) {
31 const TensorView<TensorType> tensor_view(tensor);
32
33 if (!tensor_view.IsValidType())
34 return ExecuteResult::INPUT_TYPE_ERROR;
35
36 if (!tensor_view.IsValidFormat())
37 return ExecuteResult::INPUT_FORMAT_ERROR;
38
39 // Check that given input shape matches that expected by TF lite.
40
41 const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims;
42 const mojo::Array<int64_t>& actual_dims = tensor_view.GetShape();
43
44 bool shape_matches = expected_dims.size == actual_dims.size();
45 for (int i = 0; shape_matches && i < expected_dims.size; ++i) {
46 shape_matches = expected_dims.data[i] == actual_dims[i];
47 }
48
49 if (!shape_matches)
50 return ExecuteResult::INPUT_SHAPE_ERROR;
51
52 MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index);
53 const mojo::Array<TensorType>& tensor_values = tensor_view.GetValues();
54 for (int i = 0; i < tensor_values.size(); ++i) {
55 input_memory[i] = tensor_values[i];
56 }
57
58 return ExecuteResult::OK;
59}
60
61ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) {
62 return ExecuteResult::EXECUTION_ERROR;
63}
64
65// A table of functions to validate / populate data for model nodes expecting
66// input of each TF lite type.
67//
68// This table is indexed by TfLiteType, the possible values of which can be
69// found at <tensorflow/contrib/lite/context.h>. We make the following
70// assumptions about index values:
71// 1) They will remain consistent across TF lite releases, and
72// 2) They will always start from (close to) 0 and be (mostly) consecutive.
73//
74// Since TfLiteType is part of the stable C API for TF lite, these assumptions
75// seem fair.
76constexpr decltype(&InvalidInput) kPopulateInputFns[] = {
77 &InvalidInput, // kTfLiteNoType
78 &PopulateInput<double, float>, // kTfLiteFloat32
79 &PopulateInput<int64_t, int32_t>, // kTfLiteInt32
80 &PopulateInput<int64_t, uint8_t>, // kTfLiteUInt8
81 &PopulateInput<int64_t, int64_t>, // kTfLiteInt64
82 &InvalidInput, // kTfLiteString
83 &PopulateInput<int64_t, bool>, // kTfLiteBool
84};
85
86// Copies data from position |index| in the graph |interpreter| into the given
87// tensor object.
88template <typename TensorType, typename MemoryType>
89ExecuteResult PopulateOutput(const int index,
90 const tflite::Interpreter& interpreter,
91 const TensorPtr& tensor) {
92 TensorView<TensorType> tensor_view(tensor);
93 tensor_view.Allocate();
94
95 // Empty output is not valid.
96 const TfLiteIntArray& dims = *interpreter.tensor(index)->dims;
97 if (dims.size == 0)
98 return ExecuteResult::EXECUTION_ERROR;
99
100 // Copy across size information and calculate the number of elements being
101 // output.
102 int64_t num_entries = 1;
103 mojo::Array<int64_t>& tensor_dims = tensor_view.GetShape();
104 tensor_dims.resize(dims.size);
105 for (int i = 0; i < dims.size; ++i) {
106 const int64_t dim_length = dims.data[i];
107
108 if (dim_length <= 0)
109 return ExecuteResult::EXECUTION_ERROR;
110
111 tensor_dims[i] = dim_length;
112 num_entries *= dim_length;
113 }
114
115 // Populate tensor values.
116 const MemoryType* const output_memory =
117 interpreter.typed_tensor<MemoryType>(index);
118 mojo::Array<TensorType>& tensor_values = tensor_view.GetValues();
119 tensor_values.resize(num_entries);
120 for (int i = 0; i < num_entries; ++i) {
121 tensor_values[i] = output_memory[i];
122 }
123
124 return ExecuteResult::OK;
125}
126
127ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) {
128 return ExecuteResult::EXECUTION_ERROR;
129}
130
131// A table of functions to populate data for tensors from output of each TF lite
132// type.
133//
134// This table is indexed by TfLiteType, the possible values of which can be
135// found at <tensorflow/contrib/lite/context.h>. See the caveats discussed in
136// the comment above |kPopulateInputFns|.
137constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = {
138 &InvalidOutput, // kTfLiteNoType
139 &PopulateOutput<double, float>, // kTfLiteFloat32
140 &PopulateOutput<int64_t, int32_t>, // kTfLiteInt32
141 &PopulateOutput<int64_t, uint8_t>, // kTfLiteUInt8
142 &PopulateOutput<int64_t, int64_t>, // kTfLiteInt64
143 &InvalidOutput, // kTfLiteString
144 &PopulateOutput<int64_t, bool>, // kTfLiteBool
145};
146
147// For making callback invocations nicer.
148mojo::Array<TensorPtr> NullArray() {
149 return mojo::Array<TensorPtr>(nullptr);
150}
151
152} // namespace
153
154GraphExecutorImpl::GraphExecutorImpl(
155 const std::map<std::string, int>& required_inputs,
156 const std::map<std::string, int>& required_outputs,
157 std::unique_ptr<tflite::Interpreter> interpreter,
158 GraphExecutorRequest request)
159 : required_inputs_(required_inputs),
160 required_outputs_(required_outputs),
161 interpreter_(std::move(interpreter)),
162 binding_(this, std::move(request)) {}
163
164void GraphExecutorImpl::set_connection_error_handler(
165 base::Closure connection_error_handler) {
166 binding_.set_connection_error_handler(std::move(connection_error_handler));
167}
168
169void GraphExecutorImpl::Execute(
170 const mojo::Map<mojo::String, TensorPtr> tensors,
171 const mojo::Array<mojo::String> outputs,
172 const ExecuteCallback& callback) {
173 // Validate input and output names (before executing graph, for efficiency).
174
175 for (const auto& kv : tensors) {
176 const std::string& cur_input_name = kv.first.get();
177
178 const auto name_lookup = required_inputs_.find(cur_input_name);
179 if (name_lookup == required_inputs_.end() ||
180 name_lookup->second >= interpreter_->tensors_size()) {
181 callback.Run(ExecuteResult::UNKNOWN_INPUT_ERROR, NullArray());
182 return;
183 }
184 }
185 if (tensors.size() != required_inputs_.size()) {
186 callback.Run(ExecuteResult::INPUT_MISSING_ERROR, NullArray());
187 return;
188 }
189
190 std::set<std::string> seen_outputs;
191 for (const mojo::String& cur_output_name : outputs) {
192 const auto name_lookup = required_outputs_.find(cur_output_name.get());
193 if (name_lookup == required_outputs_.end() ||
194 name_lookup->second >= interpreter_->tensors_size()) {
195 callback.Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR, NullArray());
196 return;
197 }
198
199 // Specifying the same output twice is an error.
200 const auto insert_result = seen_outputs.insert(cur_output_name.get());
201 if (!insert_result.second) {
202 callback.Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR, NullArray());
203 return;
204 }
205 }
206 if (outputs.size() != required_outputs_.size()) {
207 callback.Run(ExecuteResult::OUTPUT_MISSING_ERROR, NullArray());
208 return;
209 }
210
211 // Copy input data into the interpreter.
212 for (const auto& kv : tensors) {
213 const std::string& cur_input_name = kv.first.get();
214 const TensorPtr& cur_input = kv.second;
215
216 // Always valid, by the input name check at the start of this function.
217 const int cur_input_id = required_inputs_.find(cur_input_name)->second;
218
219 // Check that the current input node is a supported type.
220 const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type;
221 if (cur_input_type >= arraysize(kPopulateInputFns)) {
222 LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id
223 << " of type " << cur_input_type << ".";
224 callback.Run(ExecuteResult::EXECUTION_ERROR, NullArray());
225 return;
226 }
227
228 // Attempt to copy input data into the current input node.
229 const ExecuteResult populate_input_result =
230 (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id,
231 interpreter_.get());
232 if (populate_input_result != ExecuteResult::OK) {
233 callback.Run(populate_input_result, NullArray());
234 return;
235 }
236 }
237
238 // Execute graph.
239 if (interpreter_->Invoke() != kTfLiteOk) {
240 LOG(ERROR) << "TF lite graph execution failed unexpectedly.";
241 callback.Run(ExecuteResult::EXECUTION_ERROR, NullArray());
242 return;
243 }
244
245 // Extract output.
246 mojo::Array<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
247 for (const mojo::String& cur_output_name : outputs) {
248 output_tensors.push_back(Tensor::New());
249
250 // Always valid, by the output name check at the start of this function.
251 const int cur_output_id =
252 required_outputs_.find(cur_output_name.get())->second;
253
254 // Check that the current output node is a supported type.
255 const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type;
256 if (cur_output_type >= arraysize(kPopulateOutputFns)) {
257 LOG(ERROR) << "TF lite graph contains invalid output node "
258 << cur_output_id << " of type " << cur_output_type << ".";
259 callback.Run(ExecuteResult::EXECUTION_ERROR, NullArray());
260 return;
261 }
262
263 // Attempt to extract data from the current output node.
264 const ExecuteResult populate_output_result =
265 (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_,
266 *--output_tensors.end());
267 if (populate_output_result != ExecuteResult::OK) {
268 callback.Run(populate_output_result, NullArray());
269 return;
270 }
271 }
272
273 callback.Run(ExecuteResult::OK, std::move(output_tensors));
274}
275
276} // namespace ml