blob: 36bf24b188b78b4948ddbaf9fcd96b3c3c8bb1a6 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
7#include <utility>
8#include <vector>
9
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090010#include <base/bind.h>
hscham3d0632f2019-12-11 15:58:57 +090011#include <base/containers/flat_map.h>
Michael Martisa967f632018-08-10 10:39:00 +100012#include <base/macros.h>
13#include <base/run_loop.h>
14#include <gmock/gmock.h>
15#include <gtest/gtest.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100016#include <mojo/public/cpp/bindings/remote.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100017#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100018
19#include "ml/model_impl.h"
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090020#include "ml/mojom/graph_executor.mojom.h"
21#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100022#include "ml/tensor_view.h"
23#include "ml/test_utils.h"
Michael Martisa967f632018-08-10 10:39:00 +100024
25namespace ml {
26namespace {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::ExecuteResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100030using ::chromeos::machine_learning::mojom::GraphExecutor;
Michael Martisa967f632018-08-10 10:39:00 +100031using ::chromeos::machine_learning::mojom::Model;
Michael Martisa967f632018-08-10 10:39:00 +100032using ::chromeos::machine_learning::mojom::TensorPtr;
33using ::testing::ElementsAre;
34
35class ModelImplTest : public testing::Test {
36 protected:
37 // Metadata for the example model:
38 // A simple model that adds up two tensors. Inputs and outputs are 1x1 float
39 // tensors.
40 const std::string model_path_ =
Michael Martisc22823d2018-09-14 12:54:04 +100041 GetTestModelDir() + "mlservice-model-test_add-20180914.tflite";
Michael Martisa967f632018-08-10 10:39:00 +100042 const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
43 const std::map<std::string, int> model_outputs_ = {{"z", 0}};
44};
45
46// Test loading an invalid model.
47TEST_F(ModelImplTest, TestBadModel) {
48 // Pass nullptr instead of a valid model.
Andrew Moylanb481af72020-07-09 15:22:00 +100049 mojo::Remote<Model> model;
Honglin Yuc0cef102020-01-17 15:26:01 +110050 ModelImpl::Create(model_inputs_, model_outputs_, nullptr /*model*/,
Andrew Moylanb481af72020-07-09 15:22:00 +100051 model.BindNewPipeAndPassReceiver(), "TestModel");
52 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +100053
54 // Ensure that creating a graph executor fails.
55 bool callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +100056 mojo::Remote<GraphExecutor> graph_executor;
57 model->CreateGraphExecutor(
58 graph_executor.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090059 base::Bind(
60 [](bool* callback_done, const CreateGraphExecutorResult result) {
61 EXPECT_EQ(result,
62 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
63 *callback_done = true;
64 },
65 &callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100066
67 base::RunLoop().RunUntilIdle();
68 ASSERT_TRUE(callback_done);
69}
70
71// Test loading the valid example model.
72TEST_F(ModelImplTest, TestExampleModel) {
73 // Read the example TF model from disk.
Andrew Moylanb481af72020-07-09 15:22:00 +100074 std::unique_ptr<tflite::FlatBufferModel> tflite_model =
Michael Martisa967f632018-08-10 10:39:00 +100075 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
Andrew Moylanb481af72020-07-09 15:22:00 +100076 ASSERT_NE(tflite_model.get(), nullptr);
Michael Martisa967f632018-08-10 10:39:00 +100077
78 // Create model object.
Andrew Moylanb481af72020-07-09 15:22:00 +100079 mojo::Remote<Model> model;
80 ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
81 model.BindNewPipeAndPassReceiver(), "TestModel");
82 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +100083
84 // Create a graph executor.
85 bool cge_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +100086 mojo::Remote<GraphExecutor> graph_executor;
87 model->CreateGraphExecutor(
88 graph_executor.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090089 base::Bind(
90 [](bool* cge_callback_done, const CreateGraphExecutorResult result) {
91 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
92 *cge_callback_done = true;
93 },
94 &cge_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100095
96 base::RunLoop().RunUntilIdle();
97 ASSERT_TRUE(cge_callback_done);
98
99 // Construct input/output for graph execution.
hscham3d0632f2019-12-11 15:58:57 +0900100 base::flat_map<std::string, TensorPtr> inputs;
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900101 inputs.emplace("x", NewTensor<double>({1}, {0.5}));
102 inputs.emplace("y", NewTensor<double>({1}, {0.25}));
103 std::vector<std::string> outputs({"z"});
Michael Martisa967f632018-08-10 10:39:00 +1000104
105 // Execute graph.
106 bool exe_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000107 graph_executor->Execute(
Michael Martisa967f632018-08-10 10:39:00 +1000108 std::move(inputs), std::move(outputs),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900109 base::Bind(
110 [](bool* exe_callback_done, const ExecuteResult result,
111 base::Optional<std::vector<TensorPtr>> outputs) {
112 // Check that the inference succeeded and gives the expected number
113 // of outputs.
114 EXPECT_EQ(result, ExecuteResult::OK);
115 ASSERT_TRUE(outputs.has_value());
116 ASSERT_EQ(outputs->size(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000117
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900118 // Check that the output tensor has the right type and format.
119 const TensorView<double> out_tensor((*outputs)[0]);
120 EXPECT_TRUE(out_tensor.IsValidType());
121 EXPECT_TRUE(out_tensor.IsValidFormat());
Michael Martisa967f632018-08-10 10:39:00 +1000122
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900123 // Check the output tensor has the expected shape and values.
124 EXPECT_THAT(out_tensor.GetShape(), ElementsAre(1));
125 EXPECT_THAT(out_tensor.GetValues(), ElementsAre(0.75));
Michael Martisa967f632018-08-10 10:39:00 +1000126
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900127 *exe_callback_done = true;
128 },
129 &exe_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000130
131 base::RunLoop().RunUntilIdle();
132 ASSERT_TRUE(exe_callback_done);
133}
134
135TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
136 // Read the example TF model from disk.
Andrew Moylanb481af72020-07-09 15:22:00 +1000137 std::unique_ptr<tflite::FlatBufferModel> tflite_model =
Michael Martisa967f632018-08-10 10:39:00 +1000138 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
Andrew Moylanb481af72020-07-09 15:22:00 +1000139 ASSERT_NE(tflite_model.get(), nullptr);
Michael Martisa967f632018-08-10 10:39:00 +1000140
141 // Create model object.
Andrew Moylanb481af72020-07-09 15:22:00 +1000142 mojo::Remote<Model> model;
Honglin Yuc0cef102020-01-17 15:26:01 +1100143 const ModelImpl* model_impl =
Andrew Moylanb481af72020-07-09 15:22:00 +1000144 ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
145 model.BindNewPipeAndPassReceiver(), "TestModel");
146 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +1000147
148 // Create one graph executor.
149 bool cge1_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000150 mojo::Remote<GraphExecutor> graph_executor_1;
151 model->CreateGraphExecutor(
152 graph_executor_1.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900153 base::Bind(
154 [](bool* cge1_callback_done, const CreateGraphExecutorResult result) {
155 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
156 *cge1_callback_done = true;
157 },
158 &cge1_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000159
160 base::RunLoop().RunUntilIdle();
161 ASSERT_TRUE(cge1_callback_done);
Andrew Moylanb481af72020-07-09 15:22:00 +1000162 ASSERT_TRUE(graph_executor_1.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100163 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000164
165 // Create another graph executor.
166 bool cge2_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000167 mojo::Remote<GraphExecutor> graph_executor_2;
168 model->CreateGraphExecutor(
169 graph_executor_2.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900170 base::Bind(
171 [](bool* cge2_callback_done, const CreateGraphExecutorResult result) {
172 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
173 *cge2_callback_done = true;
174 },
175 &cge2_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000176
177 base::RunLoop().RunUntilIdle();
178 ASSERT_TRUE(cge2_callback_done);
Andrew Moylanb481af72020-07-09 15:22:00 +1000179 ASSERT_TRUE(graph_executor_2.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100180 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 2);
Michael Martisa967f632018-08-10 10:39:00 +1000181
182 // Destroy one graph executor.
Andrew Moylanb481af72020-07-09 15:22:00 +1000183 graph_executor_1.reset();
Michael Martisa967f632018-08-10 10:39:00 +1000184 base::RunLoop().RunUntilIdle();
Andrew Moylanb481af72020-07-09 15:22:00 +1000185 ASSERT_TRUE(graph_executor_2.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100186 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000187
188 // Destroy the other graph executor.
Andrew Moylanb481af72020-07-09 15:22:00 +1000189 graph_executor_2.reset();
Michael Martisa967f632018-08-10 10:39:00 +1000190 base::RunLoop().RunUntilIdle();
Honglin Yuc0cef102020-01-17 15:26:01 +1100191 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 0);
Michael Martisa967f632018-08-10 10:39:00 +1000192}
193
194} // namespace
195} // namespace ml