blob: e8734e1647b75f36c757025083baff413dfe0c48 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_TEST_UTILS_H_
6#define ML_TEST_UTILS_H_
7
Michael Martisc22823d2018-09-14 12:54:04 +10008#include <string>
Michael Martisa967f632018-08-10 10:39:00 +10009#include <vector>
10
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090011#include "ml/mojom/tensor.mojom.h"
Michael Martisc22823d2018-09-14 12:54:04 +100012#include "ml/tensor_view.h"
Michael Martisa967f632018-08-10 10:39:00 +100013
14namespace ml {
15
16// Create a tensor with the given shape and values. Does no validity checking
17// (by design, as we sometimes need to pass bad tensors to test error handling).
18template <typename T>
19chromeos::machine_learning::mojom::TensorPtr NewTensor(
20 const std::vector<int64_t>& shape, const std::vector<T>& values) {
21 auto tensor(chromeos::machine_learning::mojom::Tensor::New());
22 TensorView<T> tensor_view(tensor);
23 tensor_view.Allocate();
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090024 tensor_view.GetShape() = shape;
25 tensor_view.GetValues() = values;
Michael Martisa967f632018-08-10 10:39:00 +100026
27 return tensor;
28}
29
Michael Martisc22823d2018-09-14 12:54:04 +100030// Return the model directory for tests (or die if it cannot be obtained).
31std::string GetTestModelDir();
32
Michael Martisa967f632018-08-10 10:39:00 +100033} // namespace ml
34
35#endif // ML_TEST_UTILS_H_