blob: c5ae65fc19951ddc7df54cc94232448fad73b053 [file] [log] [blame]
alanlxl30f15bd2020-08-11 21:26:12 +10001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "federated/utils.h"
6
7#include <memory>
8
9#include <gmock/gmock.h>
10#include <gtest/gtest.h>
11
12#include "chrome/knowledge/federated/example.pb.h"
13#include "chrome/knowledge/federated/feature.pb.h"
14#include "federated/mojom/example.mojom.h"
15#include "federated/test_utils.h"
16
17namespace federated {
18namespace {
19using chromeos::federated::mojom::Example;
20using chromeos::federated::mojom::ExamplePtr;
21using chromeos::federated::mojom::Features;
22using testing::ElementsAre;
23
24TEST(UtilsTest, ConvertToTensorFlowExampleProto) {
25 auto example = CreateExamplePtr();
26
27 tensorflow::Example tf_example_converted =
28 ConvertToTensorFlowExampleProto(example);
29 const auto& tf_feature_map = tf_example_converted.features().feature();
30
31 EXPECT_EQ(tf_feature_map.size(), 4);
32
33 EXPECT_TRUE(tf_feature_map.contains("int_feature1"));
34 const auto& int_feature1 = tf_feature_map.at("int_feature1");
35 EXPECT_TRUE(int_feature1.has_int64_list() && !int_feature1.has_float_list() &&
36 !int_feature1.has_bytes_list());
37 EXPECT_THAT(int_feature1.int64_list().value(), ElementsAre(1, 2, 3, 4, 5));
38
39 EXPECT_TRUE(tf_feature_map.contains("int_feature2"));
40 const auto& int_feature2 = tf_feature_map.at("int_feature2");
41 EXPECT_TRUE(int_feature2.has_int64_list() && !int_feature2.has_float_list() &&
42 !int_feature2.has_bytes_list());
43 EXPECT_THAT(int_feature2.int64_list().value(),
44 ElementsAre(10, 20, 30, 40, 50));
45
46 EXPECT_TRUE(tf_feature_map.contains("float_feature1"));
47 const auto& float_feature = tf_feature_map.at("float_feature1");
48 EXPECT_TRUE(!float_feature.has_int64_list() &&
49 float_feature.has_float_list() &&
50 !float_feature.has_bytes_list());
51 EXPECT_THAT(float_feature.float_list().value(),
52 ElementsAre(1.1, 2.1, 3.1, 4.1, 5.1));
53
54 EXPECT_TRUE(tf_feature_map.contains("string_feature1"));
55 const auto& string_feature = tf_feature_map.at("string_feature1");
56 EXPECT_TRUE(!string_feature.has_int64_list() &&
57 !string_feature.has_float_list() &&
58 string_feature.has_bytes_list());
59 EXPECT_THAT(string_feature.bytes_list().value(),
60 ElementsAre("abc", "123", "xyz"));
61}
62
63} // namespace
64} // namespace federated