blob: 1e5db6b76559873007d80867fd6ec34691771bd3 [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Implementations of specializations of TensorView<> for all supported tensor
6// data types
7
8#include "ml/tensor_view.h"
9
10namespace ml {
11
12using ::chromeos::machine_learning::mojom::FloatList;
13using ::chromeos::machine_learning::mojom::Int64List;
14using ::chromeos::machine_learning::mojom::ValueList;
15
16template <>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090017std::vector<int64_t>& TensorView<int64_t>::GetValues() {
Michael Martis26abcd82018-08-08 10:57:25 +100018 return tensor_->data->get_int64_list()->value;
19}
20
21template <>
22bool TensorView<int64_t>::IsValidType() const {
23 return tensor_->data->which() == ValueList::Tag::INT64_LIST;
24}
25
26template <>
27void TensorView<int64_t>::AllocateValues() {
28 tensor_->data->set_int64_list(Int64List::New());
Andrew Moylan79b34a42020-07-08 11:13:11 +100029 // TODO(hidehiko): assigning std::vector<>() to `value` is unneeded
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090030 // on libmojo uprev. Remove them after the uprev.
31 tensor_->data->get_int64_list()->value = std::vector<int64_t>();
Michael Martis26abcd82018-08-08 10:57:25 +100032}
33
34template <>
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090035std::vector<double>& TensorView<double>::GetValues() {
Michael Martis26abcd82018-08-08 10:57:25 +100036 return tensor_->data->get_float_list()->value;
37}
38
39template <>
40bool TensorView<double>::IsValidType() const {
41 return tensor_->data->which() == ValueList::Tag::FLOAT_LIST;
42}
43
44template <>
45void TensorView<double>::AllocateValues() {
46 tensor_->data->set_float_list(FloatList::New());
Andrew Moylan79b34a42020-07-08 11:13:11 +100047 // TODO(hidehiko): assigning std::vector<>() to `value` is unneeded
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090048 // on libmojo uprev. Remove them after the uprev.
49 tensor_->data->get_float_list()->value = std::vector<double>();
Michael Martis26abcd82018-08-08 10:57:25 +100050}
51
52} // namespace ml