alanlxl | 30f15bd | 2020-08-11 21:26:12 +1000 | [diff] [blame^] | 1 | // Copyright 2020 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "federated/utils.h" |
| 6 | |
| 7 | #include <memory> |
| 8 | |
| 9 | #include <gmock/gmock.h> |
| 10 | #include <gtest/gtest.h> |
| 11 | |
| 12 | #include "chrome/knowledge/federated/example.pb.h" |
| 13 | #include "chrome/knowledge/federated/feature.pb.h" |
| 14 | #include "federated/mojom/example.mojom.h" |
| 15 | #include "federated/test_utils.h" |
| 16 | |
| 17 | namespace federated { |
| 18 | namespace { |
| 19 | using chromeos::federated::mojom::Example; |
| 20 | using chromeos::federated::mojom::ExamplePtr; |
| 21 | using chromeos::federated::mojom::Features; |
| 22 | using testing::ElementsAre; |
| 23 | |
| 24 | TEST(UtilsTest, ConvertToTensorFlowExampleProto) { |
| 25 | auto example = CreateExamplePtr(); |
| 26 | |
| 27 | tensorflow::Example tf_example_converted = |
| 28 | ConvertToTensorFlowExampleProto(example); |
| 29 | const auto& tf_feature_map = tf_example_converted.features().feature(); |
| 30 | |
| 31 | EXPECT_EQ(tf_feature_map.size(), 4); |
| 32 | |
| 33 | EXPECT_TRUE(tf_feature_map.contains("int_feature1")); |
| 34 | const auto& int_feature1 = tf_feature_map.at("int_feature1"); |
| 35 | EXPECT_TRUE(int_feature1.has_int64_list() && !int_feature1.has_float_list() && |
| 36 | !int_feature1.has_bytes_list()); |
| 37 | EXPECT_THAT(int_feature1.int64_list().value(), ElementsAre(1, 2, 3, 4, 5)); |
| 38 | |
| 39 | EXPECT_TRUE(tf_feature_map.contains("int_feature2")); |
| 40 | const auto& int_feature2 = tf_feature_map.at("int_feature2"); |
| 41 | EXPECT_TRUE(int_feature2.has_int64_list() && !int_feature2.has_float_list() && |
| 42 | !int_feature2.has_bytes_list()); |
| 43 | EXPECT_THAT(int_feature2.int64_list().value(), |
| 44 | ElementsAre(10, 20, 30, 40, 50)); |
| 45 | |
| 46 | EXPECT_TRUE(tf_feature_map.contains("float_feature1")); |
| 47 | const auto& float_feature = tf_feature_map.at("float_feature1"); |
| 48 | EXPECT_TRUE(!float_feature.has_int64_list() && |
| 49 | float_feature.has_float_list() && |
| 50 | !float_feature.has_bytes_list()); |
| 51 | EXPECT_THAT(float_feature.float_list().value(), |
| 52 | ElementsAre(1.1, 2.1, 3.1, 4.1, 5.1)); |
| 53 | |
| 54 | EXPECT_TRUE(tf_feature_map.contains("string_feature1")); |
| 55 | const auto& string_feature = tf_feature_map.at("string_feature1"); |
| 56 | EXPECT_TRUE(!string_feature.has_int64_list() && |
| 57 | !string_feature.has_float_list() && |
| 58 | string_feature.has_bytes_list()); |
| 59 | EXPECT_THAT(string_feature.bytes_list().value(), |
| 60 | ElementsAre("abc", "123", "xyz")); |
| 61 | } |
| 62 | |
| 63 | } // namespace |
| 64 | } // namespace federated |