blob: 6aa0bb5fa5e777f1f4a6570b71f97f114ae1b4d6 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +09007#include <unordered_map>
Michael Martisa967f632018-08-10 10:39:00 +10008#include <utility>
9#include <vector>
10
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090011#include <base/bind.h>
Michael Martisa967f632018-08-10 10:39:00 +100012#include <base/macros.h>
13#include <base/run_loop.h>
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090014#include <brillo/bind_lambda.h>
Michael Martisa967f632018-08-10 10:39:00 +100015#include <gmock/gmock.h>
16#include <gtest/gtest.h>
17#include <mojo/public/cpp/bindings/interface_request.h>
18#include <tensorflow/contrib/lite/model.h>
19
20#include "ml/model_impl.h"
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090021#include "ml/mojom/graph_executor.mojom.h"
22#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100023#include "ml/tensor_view.h"
24#include "ml/test_utils.h"
Michael Martisa967f632018-08-10 10:39:00 +100025
26namespace ml {
27namespace {
28
29using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
30using ::chromeos::machine_learning::mojom::ExecuteResult;
31using ::chromeos::machine_learning::mojom::GraphExecutorPtr;
32using ::chromeos::machine_learning::mojom::Model;
33using ::chromeos::machine_learning::mojom::ModelPtr;
34using ::chromeos::machine_learning::mojom::ModelRequest;
35using ::chromeos::machine_learning::mojom::TensorPtr;
36using ::testing::ElementsAre;
37
38class ModelImplTest : public testing::Test {
39 protected:
40 // Metadata for the example model:
41 // A simple model that adds up two tensors. Inputs and outputs are 1x1 float
42 // tensors.
43 const std::string model_path_ =
Michael Martisc22823d2018-09-14 12:54:04 +100044 GetTestModelDir() + "mlservice-model-test_add-20180914.tflite";
Michael Martisa967f632018-08-10 10:39:00 +100045 const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
46 const std::map<std::string, int> model_outputs_ = {{"z", 0}};
47};
48
49// Test loading an invalid model.
50TEST_F(ModelImplTest, TestBadModel) {
51 // Pass nullptr instead of a valid model.
52 ModelPtr model_ptr;
53 const ModelImpl model_impl(model_inputs_, model_outputs_, nullptr /*model*/,
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090054 mojo::MakeRequest(&model_ptr));
Michael Martisa967f632018-08-10 10:39:00 +100055 ASSERT_TRUE(model_ptr.is_bound());
56
57 // Ensure that creating a graph executor fails.
58 bool callback_done = false;
59 GraphExecutorPtr graph_executor_ptr;
60 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090061 mojo::MakeRequest(&graph_executor_ptr),
62 base::Bind(
63 [](bool* callback_done, const CreateGraphExecutorResult result) {
64 EXPECT_EQ(result,
65 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
66 *callback_done = true;
67 },
68 &callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100069
70 base::RunLoop().RunUntilIdle();
71 ASSERT_TRUE(callback_done);
72}
73
74// Test loading the valid example model.
75TEST_F(ModelImplTest, TestExampleModel) {
76 // Read the example TF model from disk.
77 std::unique_ptr<tflite::FlatBufferModel> model =
78 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
79 ASSERT_NE(model.get(), nullptr);
80
81 // Create model object.
82 ModelPtr model_ptr;
83 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090084 mojo::MakeRequest(&model_ptr));
Michael Martisa967f632018-08-10 10:39:00 +100085 ASSERT_TRUE(model_ptr.is_bound());
86
87 // Create a graph executor.
88 bool cge_callback_done = false;
89 GraphExecutorPtr graph_executor_ptr;
90 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090091 mojo::MakeRequest(&graph_executor_ptr),
92 base::Bind(
93 [](bool* cge_callback_done, const CreateGraphExecutorResult result) {
94 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
95 *cge_callback_done = true;
96 },
97 &cge_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100098
99 base::RunLoop().RunUntilIdle();
100 ASSERT_TRUE(cge_callback_done);
101
102 // Construct input/output for graph execution.
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900103 std::unordered_map<std::string, TensorPtr> inputs;
104 inputs.emplace("x", NewTensor<double>({1}, {0.5}));
105 inputs.emplace("y", NewTensor<double>({1}, {0.25}));
106 std::vector<std::string> outputs({"z"});
Michael Martisa967f632018-08-10 10:39:00 +1000107
108 // Execute graph.
109 bool exe_callback_done = false;
110 graph_executor_ptr->Execute(
111 std::move(inputs), std::move(outputs),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900112 base::Bind(
113 [](bool* exe_callback_done, const ExecuteResult result,
114 base::Optional<std::vector<TensorPtr>> outputs) {
115 // Check that the inference succeeded and gives the expected number
116 // of outputs.
117 EXPECT_EQ(result, ExecuteResult::OK);
118 ASSERT_TRUE(outputs.has_value());
119 ASSERT_EQ(outputs->size(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000120
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900121 // Check that the output tensor has the right type and format.
122 const TensorView<double> out_tensor((*outputs)[0]);
123 EXPECT_TRUE(out_tensor.IsValidType());
124 EXPECT_TRUE(out_tensor.IsValidFormat());
Michael Martisa967f632018-08-10 10:39:00 +1000125
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900126 // Check the output tensor has the expected shape and values.
127 EXPECT_THAT(out_tensor.GetShape(), ElementsAre(1));
128 EXPECT_THAT(out_tensor.GetValues(), ElementsAre(0.75));
Michael Martisa967f632018-08-10 10:39:00 +1000129
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900130 *exe_callback_done = true;
131 },
132 &exe_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000133
134 base::RunLoop().RunUntilIdle();
135 ASSERT_TRUE(exe_callback_done);
136}
137
138TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
139 // Read the example TF model from disk.
140 std::unique_ptr<tflite::FlatBufferModel> model =
141 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
142 ASSERT_NE(model.get(), nullptr);
143
144 // Create model object.
145 ModelPtr model_ptr;
146 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900147 mojo::MakeRequest(&model_ptr));
Michael Martisa967f632018-08-10 10:39:00 +1000148 ASSERT_TRUE(model_ptr.is_bound());
149
150 // Create one graph executor.
151 bool cge1_callback_done = false;
152 GraphExecutorPtr graph_executor_1_ptr;
153 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900154 mojo::MakeRequest(&graph_executor_1_ptr),
155 base::Bind(
156 [](bool* cge1_callback_done, const CreateGraphExecutorResult result) {
157 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
158 *cge1_callback_done = true;
159 },
160 &cge1_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000161
162 base::RunLoop().RunUntilIdle();
163 ASSERT_TRUE(cge1_callback_done);
164 ASSERT_TRUE(graph_executor_1_ptr.is_bound());
165 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
166
167 // Create another graph executor.
168 bool cge2_callback_done = false;
169 GraphExecutorPtr graph_executor_2_ptr;
170 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900171 mojo::MakeRequest(&graph_executor_2_ptr),
172 base::Bind(
173 [](bool* cge2_callback_done, const CreateGraphExecutorResult result) {
174 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
175 *cge2_callback_done = true;
176 },
177 &cge2_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000178
179 base::RunLoop().RunUntilIdle();
180 ASSERT_TRUE(cge2_callback_done);
181 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
182 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 2);
183
184 // Destroy one graph executor.
185 graph_executor_1_ptr.reset();
186 base::RunLoop().RunUntilIdle();
187 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
188 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
189
190 // Destroy the other graph executor.
191 graph_executor_2_ptr.reset();
192 base::RunLoop().RunUntilIdle();
193 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 0);
194}
195
196} // namespace
197} // namespace ml