blob: 0969bbdc0831e4549ec53043ac99c93affc7d8d5 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
7#include <utility>
8#include <vector>
9
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090010#include <base/bind.h>
hscham3d0632f2019-12-11 15:58:57 +090011#include <base/containers/flat_map.h>
Michael Martisa967f632018-08-10 10:39:00 +100012#include <base/macros.h>
13#include <base/run_loop.h>
14#include <gmock/gmock.h>
15#include <gtest/gtest.h>
16#include <mojo/public/cpp/bindings/interface_request.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100017#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100018
19#include "ml/model_impl.h"
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090020#include "ml/mojom/graph_executor.mojom.h"
21#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100022#include "ml/tensor_view.h"
23#include "ml/test_utils.h"
Michael Martisa967f632018-08-10 10:39:00 +100024
25namespace ml {
26namespace {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::ExecuteResult;
30using ::chromeos::machine_learning::mojom::GraphExecutorPtr;
31using ::chromeos::machine_learning::mojom::Model;
32using ::chromeos::machine_learning::mojom::ModelPtr;
33using ::chromeos::machine_learning::mojom::ModelRequest;
34using ::chromeos::machine_learning::mojom::TensorPtr;
35using ::testing::ElementsAre;
36
37class ModelImplTest : public testing::Test {
38 protected:
39 // Metadata for the example model:
40 // A simple model that adds up two tensors. Inputs and outputs are 1x1 float
41 // tensors.
42 const std::string model_path_ =
Michael Martisc22823d2018-09-14 12:54:04 +100043 GetTestModelDir() + "mlservice-model-test_add-20180914.tflite";
Michael Martisa967f632018-08-10 10:39:00 +100044 const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
45 const std::map<std::string, int> model_outputs_ = {{"z", 0}};
46};
47
48// Test loading an invalid model.
49TEST_F(ModelImplTest, TestBadModel) {
50 // Pass nullptr instead of a valid model.
51 ModelPtr model_ptr;
52 const ModelImpl model_impl(model_inputs_, model_outputs_, nullptr /*model*/,
Honglin Yu6adafcd2019-07-22 13:48:11 +100053 mojo::MakeRequest(&model_ptr), "TestModel");
Michael Martisa967f632018-08-10 10:39:00 +100054 ASSERT_TRUE(model_ptr.is_bound());
55
56 // Ensure that creating a graph executor fails.
57 bool callback_done = false;
58 GraphExecutorPtr graph_executor_ptr;
59 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090060 mojo::MakeRequest(&graph_executor_ptr),
61 base::Bind(
62 [](bool* callback_done, const CreateGraphExecutorResult result) {
63 EXPECT_EQ(result,
64 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
65 *callback_done = true;
66 },
67 &callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100068
69 base::RunLoop().RunUntilIdle();
70 ASSERT_TRUE(callback_done);
71}
72
73// Test loading the valid example model.
74TEST_F(ModelImplTest, TestExampleModel) {
75 // Read the example TF model from disk.
76 std::unique_ptr<tflite::FlatBufferModel> model =
77 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
78 ASSERT_NE(model.get(), nullptr);
79
80 // Create model object.
81 ModelPtr model_ptr;
82 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
Honglin Yu6adafcd2019-07-22 13:48:11 +100083 mojo::MakeRequest(&model_ptr), "TestModel");
Michael Martisa967f632018-08-10 10:39:00 +100084 ASSERT_TRUE(model_ptr.is_bound());
85
86 // Create a graph executor.
87 bool cge_callback_done = false;
88 GraphExecutorPtr graph_executor_ptr;
89 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090090 mojo::MakeRequest(&graph_executor_ptr),
91 base::Bind(
92 [](bool* cge_callback_done, const CreateGraphExecutorResult result) {
93 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
94 *cge_callback_done = true;
95 },
96 &cge_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100097
98 base::RunLoop().RunUntilIdle();
99 ASSERT_TRUE(cge_callback_done);
100
101 // Construct input/output for graph execution.
hscham3d0632f2019-12-11 15:58:57 +0900102 base::flat_map<std::string, TensorPtr> inputs;
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900103 inputs.emplace("x", NewTensor<double>({1}, {0.5}));
104 inputs.emplace("y", NewTensor<double>({1}, {0.25}));
105 std::vector<std::string> outputs({"z"});
Michael Martisa967f632018-08-10 10:39:00 +1000106
107 // Execute graph.
108 bool exe_callback_done = false;
109 graph_executor_ptr->Execute(
110 std::move(inputs), std::move(outputs),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900111 base::Bind(
112 [](bool* exe_callback_done, const ExecuteResult result,
113 base::Optional<std::vector<TensorPtr>> outputs) {
114 // Check that the inference succeeded and gives the expected number
115 // of outputs.
116 EXPECT_EQ(result, ExecuteResult::OK);
117 ASSERT_TRUE(outputs.has_value());
118 ASSERT_EQ(outputs->size(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000119
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900120 // Check that the output tensor has the right type and format.
121 const TensorView<double> out_tensor((*outputs)[0]);
122 EXPECT_TRUE(out_tensor.IsValidType());
123 EXPECT_TRUE(out_tensor.IsValidFormat());
Michael Martisa967f632018-08-10 10:39:00 +1000124
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900125 // Check the output tensor has the expected shape and values.
126 EXPECT_THAT(out_tensor.GetShape(), ElementsAre(1));
127 EXPECT_THAT(out_tensor.GetValues(), ElementsAre(0.75));
Michael Martisa967f632018-08-10 10:39:00 +1000128
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900129 *exe_callback_done = true;
130 },
131 &exe_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000132
133 base::RunLoop().RunUntilIdle();
134 ASSERT_TRUE(exe_callback_done);
135}
136
137TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
138 // Read the example TF model from disk.
139 std::unique_ptr<tflite::FlatBufferModel> model =
140 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
141 ASSERT_NE(model.get(), nullptr);
142
143 // Create model object.
144 ModelPtr model_ptr;
145 const ModelImpl model_impl(model_inputs_, model_outputs_, std::move(model),
Honglin Yu6adafcd2019-07-22 13:48:11 +1000146 mojo::MakeRequest(&model_ptr), "TestModel");
Michael Martisa967f632018-08-10 10:39:00 +1000147 ASSERT_TRUE(model_ptr.is_bound());
148
149 // Create one graph executor.
150 bool cge1_callback_done = false;
151 GraphExecutorPtr graph_executor_1_ptr;
152 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900153 mojo::MakeRequest(&graph_executor_1_ptr),
154 base::Bind(
155 [](bool* cge1_callback_done, const CreateGraphExecutorResult result) {
156 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
157 *cge1_callback_done = true;
158 },
159 &cge1_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000160
161 base::RunLoop().RunUntilIdle();
162 ASSERT_TRUE(cge1_callback_done);
163 ASSERT_TRUE(graph_executor_1_ptr.is_bound());
164 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
165
166 // Create another graph executor.
167 bool cge2_callback_done = false;
168 GraphExecutorPtr graph_executor_2_ptr;
169 model_ptr->CreateGraphExecutor(
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900170 mojo::MakeRequest(&graph_executor_2_ptr),
171 base::Bind(
172 [](bool* cge2_callback_done, const CreateGraphExecutorResult result) {
173 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
174 *cge2_callback_done = true;
175 },
176 &cge2_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000177
178 base::RunLoop().RunUntilIdle();
179 ASSERT_TRUE(cge2_callback_done);
180 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
181 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 2);
182
183 // Destroy one graph executor.
184 graph_executor_1_ptr.reset();
185 base::RunLoop().RunUntilIdle();
186 ASSERT_TRUE(graph_executor_2_ptr.is_bound());
187 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 1);
188
189 // Destroy the other graph executor.
190 graph_executor_2_ptr.reset();
191 base::RunLoop().RunUntilIdle();
192 ASSERT_EQ(model_impl.num_graph_executors_for_testing(), 0);
193}
194
195} // namespace
196} // namespace ml