blob: df096f46132ee7108a416bb1e33711c63c94c6df [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
7#include <utility>
8#include <vector>
9
10#include <base/macros.h>
11#include <base/run_loop.h>
12#include <gmock/gmock.h>
13#include <gtest/gtest.h>
14#include <mojo/public/cpp/bindings/interface_request.h>
15#include <tensorflow/contrib/lite/model.h>
16
17#include "ml/model_impl.h"
18#include "ml/tensor_view.h"
19#include "ml/test_utils.h"
20#include "mojom/graph_executor.mojom.h"
21#include "mojom/model.mojom.h"
22
23namespace ml {
24namespace {
25
26using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
27using ::chromeos::machine_learning::mojom::ExecuteResult;
28using ::chromeos::machine_learning::mojom::GraphExecutorPtr;
29using ::chromeos::machine_learning::mojom::Model;
30using ::chromeos::machine_learning::mojom::ModelPtr;
31using ::chromeos::machine_learning::mojom::ModelRequest;
32using ::chromeos::machine_learning::mojom::TensorPtr;
33using ::testing::ElementsAre;
34
35class ModelImplTest : public testing::Test {
36 protected:
37 // Metadata for the example model:
38 // A simple model that adds up two tensors. Inputs and outputs are 1x1 float
39 // tensors.
40 const std::string model_path_ =
41 getenv("SRC") + std::string("/testdata/add.tflite");
42 const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
43 const std::map<std::string, int> model_outputs_ = {{"z", 0}};
44};
45
46// Test loading an invalid model.
47TEST_F(ModelImplTest, TestBadModel) {
48 // Pass nullptr instead of a valid model.
49 ModelPtr model_ptr;
50 const ModelImpl model_impl(model_inputs_, model_outputs_, nullptr /*model*/,
51 mojo::GetProxy(&model_ptr));
52 ASSERT_TRUE(model_ptr.is_bound());
53
54 // Ensure that creating a graph executor fails.
55 bool callback_done = false;
56 GraphExecutorPtr graph_executor_ptr;
57 model_ptr->CreateGraphExecutor(
58 mojo::GetProxy(&graph_executor_ptr),
59 [&callback_done](const CreateGraphExecutorResult result) {
60 EXPECT_EQ(result,
61 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
62 callback_done = true;
63 });
64
65 base::RunLoop().RunUntilIdle();
66 ASSERT_TRUE(callback_done);
67}
68
69// Test loading the valid example model.
70TEST_F(ModelImplTest, TestExampleModel) {
71 // Read the example TF model from disk.
72 std::unique_ptr<tflite::FlatBufferModel> model =
73 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
74 ASSERT_NE(model.get(), nullptr);
75
76 // Create model object.
77 ModelPtr model_ptr;
78 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
79 mojo::GetProxy(&model_ptr));
80 ASSERT_TRUE(model_ptr.is_bound());
81
82 // Create a graph executor.
83 bool cge_callback_done = false;
84 GraphExecutorPtr graph_executor_ptr;
85 model_ptr->CreateGraphExecutor(
86 mojo::GetProxy(&graph_executor_ptr),
87 [&cge_callback_done](const CreateGraphExecutorResult result) {
88 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
89 cge_callback_done = true;
90 });
91
92 base::RunLoop().RunUntilIdle();
93 ASSERT_TRUE(cge_callback_done);
94
95 // Construct input/output for graph execution.
96 mojo::Map<mojo::String, TensorPtr> inputs;
97 inputs.insert("x", NewTensor<double>({1}, {0.5}));
98 inputs.insert("y", NewTensor<double>({1}, {0.25}));
99 mojo::Array<mojo::String> outputs({"z"});
100
101 // Execute graph.
102 bool exe_callback_done = false;
103 graph_executor_ptr->Execute(
104 std::move(inputs), std::move(outputs),
105 [&exe_callback_done](const ExecuteResult result,
106 const mojo::Array<TensorPtr> outputs) {
107 // Check that the inference succeeded and gives the expected number of
108 // outputs.
109 EXPECT_EQ(result, ExecuteResult::OK);
110 ASSERT_EQ(outputs.size(), 1);
111
112 // Check that the output tensor has the right type and format.
113 const TensorView<double> out_tensor(outputs[0]);
114 EXPECT_TRUE(out_tensor.IsValidType());
115 EXPECT_TRUE(out_tensor.IsValidFormat());
116
117 // Check the output tensor has the expected shape and values.
118 EXPECT_THAT(out_tensor.GetShape().storage(), ElementsAre(1));
119 EXPECT_THAT(out_tensor.GetValues().storage(), ElementsAre(0.75));
120
121 exe_callback_done = true;
122 });
123
124 base::RunLoop().RunUntilIdle();
125 ASSERT_TRUE(exe_callback_done);
126}
127
128TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
129 // Read the example TF model from disk.
130 std::unique_ptr<tflite::FlatBufferModel> model =
131 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
132 ASSERT_NE(model.get(), nullptr);
133
134 // Create model object.
135 ModelPtr model_ptr;
136 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
137 mojo::GetProxy(&model_ptr));
138 ASSERT_TRUE(model_ptr.is_bound());
139
140 // Create one graph executor.
141 bool cge1_callback_done = false;
142 GraphExecutorPtr graph_executor_1_ptr;
143 model_ptr->CreateGraphExecutor(
144 mojo::GetProxy(&graph_executor_1_ptr),
145 [&cge1_callback_done](const CreateGraphExecutorResult result) {
146 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
147 cge1_callback_done = true;
148 });
149
150 base::RunLoop().RunUntilIdle();
151 ASSERT_TRUE(cge1_callback_done);
152 ASSERT_TRUE(graph_executor_1_ptr.is_bound());
153 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
154
155 // Create another graph executor.
156 bool cge2_callback_done = false;
157 GraphExecutorPtr graph_executor_2_ptr;
158 model_ptr->CreateGraphExecutor(
159 mojo::GetProxy(&graph_executor_2_ptr),
160 [&cge2_callback_done](const CreateGraphExecutorResult result) {
161 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
162 cge2_callback_done = true;
163 });
164
165 base::RunLoop().RunUntilIdle();
166 ASSERT_TRUE(cge2_callback_done);
167 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
168 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 2);
169
170 // Destroy one graph executor.
171 graph_executor_1_ptr.reset();
172 base::RunLoop().RunUntilIdle();
173 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
174 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
175
176 // Destroy the other graph executor.
177 graph_executor_2_ptr.reset();
178 base::RunLoop().RunUntilIdle();
179 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 0);
180}
181
182} // namespace
183} // namespace ml