Michael Martis | 26abcd8 | 2018-08-08 10:57:25 +1000 | [diff] [blame] | 1 | // Copyright 2018 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | // Implementations of specializations of TensorView<> for all supported tensor |
| 6 | // data types |
| 7 | |
| 8 | #include "ml/tensor_view.h" |
| 9 | |
| 10 | namespace ml { |
| 11 | |
| 12 | using ::chromeos::machine_learning::mojom::FloatList; |
| 13 | using ::chromeos::machine_learning::mojom::Int64List; |
| 14 | using ::chromeos::machine_learning::mojom::ValueList; |
| 15 | |
| 16 | template <> |
Hidehiko Abe | 8ab64a6 | 2018-09-19 00:04:39 +0900 | [diff] [blame] | 17 | std::vector<int64_t>& TensorView<int64_t>::GetValues() { |
Michael Martis | 26abcd8 | 2018-08-08 10:57:25 +1000 | [diff] [blame] | 18 | return tensor_->data->get_int64_list()->value; |
| 19 | } |
| 20 | |
| 21 | template <> |
| 22 | bool TensorView<int64_t>::IsValidType() const { |
| 23 | return tensor_->data->which() == ValueList::Tag::INT64_LIST; |
| 24 | } |
| 25 | |
| 26 | template <> |
| 27 | void TensorView<int64_t>::AllocateValues() { |
| 28 | tensor_->data->set_int64_list(Int64List::New()); |
Andrew Moylan | 79b34a4 | 2020-07-08 11:13:11 +1000 | [diff] [blame] | 29 | // TODO(hidehiko): assigning std::vector<>() to `value` is unneeded |
Hidehiko Abe | 8ab64a6 | 2018-09-19 00:04:39 +0900 | [diff] [blame] | 30 | // on libmojo uprev. Remove them after the uprev. |
| 31 | tensor_->data->get_int64_list()->value = std::vector<int64_t>(); |
Michael Martis | 26abcd8 | 2018-08-08 10:57:25 +1000 | [diff] [blame] | 32 | } |
| 33 | |
| 34 | template <> |
Hidehiko Abe | 8ab64a6 | 2018-09-19 00:04:39 +0900 | [diff] [blame] | 35 | std::vector<double>& TensorView<double>::GetValues() { |
Michael Martis | 26abcd8 | 2018-08-08 10:57:25 +1000 | [diff] [blame] | 36 | return tensor_->data->get_float_list()->value; |
| 37 | } |
| 38 | |
| 39 | template <> |
| 40 | bool TensorView<double>::IsValidType() const { |
| 41 | return tensor_->data->which() == ValueList::Tag::FLOAT_LIST; |
| 42 | } |
| 43 | |
| 44 | template <> |
| 45 | void TensorView<double>::AllocateValues() { |
| 46 | tensor_->data->set_float_list(FloatList::New()); |
Andrew Moylan | 79b34a4 | 2020-07-08 11:13:11 +1000 | [diff] [blame] | 47 | // TODO(hidehiko): assigning std::vector<>() to `value` is unneeded |
Hidehiko Abe | 8ab64a6 | 2018-09-19 00:04:39 +0900 | [diff] [blame] | 48 | // on libmojo uprev. Remove them after the uprev. |
| 49 | tensor_->data->get_float_list()->value = std::vector<double>(); |
Michael Martis | 26abcd8 | 2018-08-08 10:57:25 +1000 | [diff] [blame] | 50 | } |
| 51 | |
| 52 | } // namespace ml |