blob: 2a98ad197c676e938138b7ae462d4437f81e1bcb [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_TEST_UTILS_H_
6#define ML_TEST_UTILS_H_
7
Michael Martisc22823d2018-09-14 12:54:04 +10008#include <string>
Michael Martisa967f632018-08-10 10:39:00 +10009#include <vector>
10
Michael Martisc22823d2018-09-14 12:54:04 +100011#include "ml/tensor_view.h"
Michael Martisa967f632018-08-10 10:39:00 +100012#include "mojom/tensor.mojom.h"
13
14namespace ml {
15
16// Create a tensor with the given shape and values. Does no validity checking
17// (by design, as we sometimes need to pass bad tensors to test error handling).
18template <typename T>
19chromeos::machine_learning::mojom::TensorPtr NewTensor(
20 const std::vector<int64_t>& shape, const std::vector<T>& values) {
21 auto tensor(chromeos::machine_learning::mojom::Tensor::New());
22 TensorView<T> tensor_view(tensor);
23 tensor_view.Allocate();
24
25 mojo::Array<int64_t>& tensor_shape = tensor_view.GetShape();
26 for (const int64_t dim : shape) {
27 tensor_shape.push_back(dim);
28 }
29
30 mojo::Array<T>& tensor_values = tensor_view.GetValues();
31 for (const T& value : values) {
32 tensor_values.push_back(value);
33 }
34
35 return tensor;
36}
37
Michael Martisc22823d2018-09-14 12:54:04 +100038// Return the model directory for tests (or die if it cannot be obtained).
39std::string GetTestModelDir();
40
Michael Martisa967f632018-08-10 10:39:00 +100041} // namespace ml
42
43#endif // ML_TEST_UTILS_H_