blob: bc0b4496591e90220194d024a417bd24cd39f78c [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
7#include <utility>
8#include <vector>
9
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090010#include <base/bind.h>
hscham3d0632f2019-12-11 15:58:57 +090011#include <base/containers/flat_map.h>
Michael Martisa967f632018-08-10 10:39:00 +100012#include <base/macros.h>
13#include <base/run_loop.h>
14#include <gmock/gmock.h>
15#include <gtest/gtest.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100016#include <mojo/public/cpp/bindings/remote.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100017#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100018
19#include "ml/model_impl.h"
Hidehiko Abe8ab64a62018-09-19 00:04:39 +090020#include "ml/mojom/graph_executor.mojom.h"
21#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100022#include "ml/tensor_view.h"
23#include "ml/test_utils.h"
Michael Martisa967f632018-08-10 10:39:00 +100024
25namespace ml {
26namespace {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::ExecuteResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100030using ::chromeos::machine_learning::mojom::GraphExecutor;
Michael Martisa967f632018-08-10 10:39:00 +100031using ::chromeos::machine_learning::mojom::Model;
Michael Martisa967f632018-08-10 10:39:00 +100032using ::chromeos::machine_learning::mojom::TensorPtr;
33using ::testing::ElementsAre;
Andrew Moylan44c352f2020-11-04 15:19:46 +110034using ::testing::Eq;
Michael Martisa967f632018-08-10 10:39:00 +100035
36class ModelImplTest : public testing::Test {
37 protected:
38 // Metadata for the example model:
39 // A simple model that adds up two tensors. Inputs and outputs are 1x1 float
40 // tensors.
41 const std::string model_path_ =
Michael Martisc22823d2018-09-14 12:54:04 +100042 GetTestModelDir() + "mlservice-model-test_add-20180914.tflite";
Michael Martisa967f632018-08-10 10:39:00 +100043 const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
44 const std::map<std::string, int> model_outputs_ = {{"z", 0}};
45};
46
Andrew Moylan44c352f2020-11-04 15:19:46 +110047// Tests that AlignedModelData ensures that short strings have aligned .c_str().
48TEST(AlignedModelData, MaybeUnalignedInput) {
49 // Short strings can have unaligned .c_str() because they are stored directly
50 // inside the string struct rather than on the heap.
51 const std::string test_str = "short string";
52 std::string maybe_unaligned_str = test_str;
53 // Note: Whether `maybe_unaligned_str` *actually* has unaligned .c_str()
54 // depends on the particular impl of std::string. At the time of writing, it
55 // is indeed unaligned on e.g. amd64-generic.
56 const AlignedModelData aligned_model_data(std::move(maybe_unaligned_str));
57 // The .data() should now be aligned.
58 EXPECT_THAT(reinterpret_cast<std::uintptr_t>(aligned_model_data.data()) % 4,
59 Eq(0));
60 // The contents agree.
61 EXPECT_TRUE(
62 std::equal(test_str.begin(), test_str.end(), aligned_model_data.data()));
63}
64
Michael Martisa967f632018-08-10 10:39:00 +100065// Test loading an invalid model.
66TEST_F(ModelImplTest, TestBadModel) {
67 // Pass nullptr instead of a valid model.
Andrew Moylanb481af72020-07-09 15:22:00 +100068 mojo::Remote<Model> model;
Honglin Yuc0cef102020-01-17 15:26:01 +110069 ModelImpl::Create(model_inputs_, model_outputs_, nullptr /*model*/,
Andrew Moylanb481af72020-07-09 15:22:00 +100070 model.BindNewPipeAndPassReceiver(), "TestModel");
71 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +100072
73 // Ensure that creating a graph executor fails.
74 bool callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +100075 mojo::Remote<GraphExecutor> graph_executor;
76 model->CreateGraphExecutor(
77 graph_executor.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +090078 base::Bind(
79 [](bool* callback_done, const CreateGraphExecutorResult result) {
80 EXPECT_EQ(result,
81 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
82 *callback_done = true;
83 },
84 &callback_done));
Michael Martisa967f632018-08-10 10:39:00 +100085
86 base::RunLoop().RunUntilIdle();
87 ASSERT_TRUE(callback_done);
88}
89
90// Test loading the valid example model.
91TEST_F(ModelImplTest, TestExampleModel) {
92 // Read the example TF model from disk.
Andrew Moylanb481af72020-07-09 15:22:00 +100093 std::unique_ptr<tflite::FlatBufferModel> tflite_model =
Michael Martisa967f632018-08-10 10:39:00 +100094 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
Andrew Moylanb481af72020-07-09 15:22:00 +100095 ASSERT_NE(tflite_model.get(), nullptr);
Michael Martisa967f632018-08-10 10:39:00 +100096
97 // Create model object.
Andrew Moylanb481af72020-07-09 15:22:00 +100098 mojo::Remote<Model> model;
99 ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
100 model.BindNewPipeAndPassReceiver(), "TestModel");
101 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +1000102
103 // Create a graph executor.
104 bool cge_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000105 mojo::Remote<GraphExecutor> graph_executor;
106 model->CreateGraphExecutor(
107 graph_executor.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900108 base::Bind(
109 [](bool* cge_callback_done, const CreateGraphExecutorResult result) {
110 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
111 *cge_callback_done = true;
112 },
113 &cge_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000114
115 base::RunLoop().RunUntilIdle();
116 ASSERT_TRUE(cge_callback_done);
117
118 // Construct input/output for graph execution.
hscham3d0632f2019-12-11 15:58:57 +0900119 base::flat_map<std::string, TensorPtr> inputs;
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900120 inputs.emplace("x", NewTensor<double>({1}, {0.5}));
121 inputs.emplace("y", NewTensor<double>({1}, {0.25}));
122 std::vector<std::string> outputs({"z"});
Michael Martisa967f632018-08-10 10:39:00 +1000123
124 // Execute graph.
125 bool exe_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000126 graph_executor->Execute(
Michael Martisa967f632018-08-10 10:39:00 +1000127 std::move(inputs), std::move(outputs),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900128 base::Bind(
129 [](bool* exe_callback_done, const ExecuteResult result,
130 base::Optional<std::vector<TensorPtr>> outputs) {
131 // Check that the inference succeeded and gives the expected number
132 // of outputs.
133 EXPECT_EQ(result, ExecuteResult::OK);
134 ASSERT_TRUE(outputs.has_value());
135 ASSERT_EQ(outputs->size(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000136
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900137 // Check that the output tensor has the right type and format.
138 const TensorView<double> out_tensor((*outputs)[0]);
139 EXPECT_TRUE(out_tensor.IsValidType());
140 EXPECT_TRUE(out_tensor.IsValidFormat());
Michael Martisa967f632018-08-10 10:39:00 +1000141
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900142 // Check the output tensor has the expected shape and values.
143 EXPECT_THAT(out_tensor.GetShape(), ElementsAre(1));
144 EXPECT_THAT(out_tensor.GetValues(), ElementsAre(0.75));
Michael Martisa967f632018-08-10 10:39:00 +1000145
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900146 *exe_callback_done = true;
147 },
148 &exe_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000149
150 base::RunLoop().RunUntilIdle();
151 ASSERT_TRUE(exe_callback_done);
152}
153
154TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
155 // Read the example TF model from disk.
Andrew Moylanb481af72020-07-09 15:22:00 +1000156 std::unique_ptr<tflite::FlatBufferModel> tflite_model =
Michael Martisa967f632018-08-10 10:39:00 +1000157 tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
Andrew Moylanb481af72020-07-09 15:22:00 +1000158 ASSERT_NE(tflite_model.get(), nullptr);
Michael Martisa967f632018-08-10 10:39:00 +1000159
160 // Create model object.
Andrew Moylanb481af72020-07-09 15:22:00 +1000161 mojo::Remote<Model> model;
Honglin Yuc0cef102020-01-17 15:26:01 +1100162 const ModelImpl* model_impl =
Andrew Moylanb481af72020-07-09 15:22:00 +1000163 ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
164 model.BindNewPipeAndPassReceiver(), "TestModel");
165 ASSERT_TRUE(model.is_bound());
Michael Martisa967f632018-08-10 10:39:00 +1000166
167 // Create one graph executor.
168 bool cge1_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000169 mojo::Remote<GraphExecutor> graph_executor_1;
170 model->CreateGraphExecutor(
171 graph_executor_1.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900172 base::Bind(
173 [](bool* cge1_callback_done, const CreateGraphExecutorResult result) {
174 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
175 *cge1_callback_done = true;
176 },
177 &cge1_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000178
179 base::RunLoop().RunUntilIdle();
180 ASSERT_TRUE(cge1_callback_done);
Andrew Moylanb481af72020-07-09 15:22:00 +1000181 ASSERT_TRUE(graph_executor_1.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100182 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000183
184 // Create another graph executor.
185 bool cge2_callback_done = false;
Andrew Moylanb481af72020-07-09 15:22:00 +1000186 mojo::Remote<GraphExecutor> graph_executor_2;
187 model->CreateGraphExecutor(
188 graph_executor_2.BindNewPipeAndPassReceiver(),
Hidehiko Abe7ac22bb2018-11-08 00:29:25 +0900189 base::Bind(
190 [](bool* cge2_callback_done, const CreateGraphExecutorResult result) {
191 EXPECT_EQ(result, CreateGraphExecutorResult::OK);
192 *cge2_callback_done = true;
193 },
194 &cge2_callback_done));
Michael Martisa967f632018-08-10 10:39:00 +1000195
196 base::RunLoop().RunUntilIdle();
197 ASSERT_TRUE(cge2_callback_done);
Andrew Moylanb481af72020-07-09 15:22:00 +1000198 ASSERT_TRUE(graph_executor_2.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100199 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 2);
Michael Martisa967f632018-08-10 10:39:00 +1000200
201 // Destroy one graph executor.
Andrew Moylanb481af72020-07-09 15:22:00 +1000202 graph_executor_1.reset();
Michael Martisa967f632018-08-10 10:39:00 +1000203 base::RunLoop().RunUntilIdle();
Andrew Moylanb481af72020-07-09 15:22:00 +1000204 ASSERT_TRUE(graph_executor_2.is_bound());
Honglin Yuc0cef102020-01-17 15:26:01 +1100205 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
Michael Martisa967f632018-08-10 10:39:00 +1000206
207 // Destroy the other graph executor.
Andrew Moylanb481af72020-07-09 15:22:00 +1000208 graph_executor_2.reset();
Michael Martisa967f632018-08-10 10:39:00 +1000209 base::RunLoop().RunUntilIdle();
Honglin Yuc0cef102020-01-17 15:26:01 +1100210 ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 0);
Michael Martisa967f632018-08-10 10:39:00 +1000211}
212
213} // namespace
214} // namespace ml