blob: cbe3ce164b78379b9714a4819a0e2ad5f6dc7935 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_TEST_UTILS_H_
6#define ML_TEST_UTILS_H_
7
8#include <vector>
9
10#include "mojom/tensor.mojom.h"
11
12namespace ml {
13
14// Create a tensor with the given shape and values. Does no validity checking
15// (by design, as we sometimes need to pass bad tensors to test error handling).
16template <typename T>
17chromeos::machine_learning::mojom::TensorPtr NewTensor(
18 const std::vector<int64_t>& shape, const std::vector<T>& values) {
19 auto tensor(chromeos::machine_learning::mojom::Tensor::New());
20 TensorView<T> tensor_view(tensor);
21 tensor_view.Allocate();
22
23 mojo::Array<int64_t>& tensor_shape = tensor_view.GetShape();
24 for (const int64_t dim : shape) {
25 tensor_shape.push_back(dim);
26 }
27
28 mojo::Array<T>& tensor_values = tensor_view.GetValues();
29 for (const T& value : values) {
30 tensor_values.push_back(value);
31 }
32
33 return tensor;
34}
35
36} // namespace ml
37
38#endif // ML_TEST_UTILS_H_