blob: 24e5554946f004a5257b3191c04a55981636a430 [file] [log] [blame]
Charles Zhao6dde75b2020-09-15 14:38:28 +10001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <memory>
6#include <string>
7#include <utility>
8
Qijiang Fan713061e2021-03-08 15:45:12 +09009#include <base/check.h>
Charles Zhao6dde75b2020-09-15 14:38:28 +100010#include <base/files/file.h>
11#include <base/files/file_path.h>
12#include <base/files/file_util.h>
13#include <base/files/scoped_temp_dir.h>
14#include <google/protobuf/text_format.h>
15#include <gtest/gtest.h>
16
17#include "ml/benchmark.h"
18#include "ml/benchmark.pb.h"
19#include "ml/mojom/model.mojom.h"
20#include "proto/benchmark_config.pb.h"
21
22namespace ml {
23
24using ::chrome::ml_benchmark::BenchmarkResults;
25using ::chrome::ml_benchmark::BenchmarkReturnStatus;
26using ::chrome::ml_benchmark::CrOSBenchmarkConfig;
27using ::google::protobuf::TextFormat;
28
29// Test model
30constexpr char kSmartDim20181115ModelFile[] =
31 "/opt/google/chrome/ml_models/mlservice-model-test_add-20180914.tflite";
32
33// Test input.
34constexpr char kModelProtoText[] = R"(
35 required_inputs: {
36 key: "x"
Charles Zhao5dbfd062020-10-26 19:53:42 +110037 value: {
38 index: 1
39 dims: [1]
40 }
Charles Zhao6dde75b2020-09-15 14:38:28 +100041 }
42 required_inputs: {
43 key: "y"
Charles Zhao5dbfd062020-10-26 19:53:42 +110044 value: {
45 index: 2
46 dims: [1]
47 }
Charles Zhao6dde75b2020-09-15 14:38:28 +100048 }
49 required_outputs: {
50 key: "z"
Charles Zhao5dbfd062020-10-26 19:53:42 +110051 value: {
52 index: 0
53 dims: [1]
54 }
Charles Zhao6dde75b2020-09-15 14:38:28 +100055 }
56)";
57constexpr char kInputOutputText[] = R"(
58 input: {
59 features: {
60 feature: {
61 key: "x"
62 value: {
63 float_list: { value:[ 0.5 ] }
64 }
65 }
66 feature: {
67 key: "y"
68 value: {
69 float_list: { value:[ 0.25 ] }
70 }
71 }
72 }
73 }
74 expected_output:{
75 features: {
76 feature: {
77 key: "z"
78 value: {
79 float_list: { value: [ 0.75 ] }
80 }
81 }
82 }
83 }
84)";
85
86class MlBenchmarkTest : public ::testing::Test {
87 public:
88 MlBenchmarkTest() {
89 // Set benchmark_config_;
90 CHECK(temp_dir_.CreateUniqueTempDir());
91 const base::FilePath tflite_model_filepath =
92 temp_dir_.GetPath().Append("model.pb");
Charles Zhao5dbfd062020-10-26 19:53:42 +110093 input_output_filepath_ = temp_dir_.GetPath().Append("input_output.pb");
Charles Zhao6dde75b2020-09-15 14:38:28 +100094 TfliteBenchmarkConfig tflite_config;
95 tflite_config.set_tflite_model_filepath(tflite_model_filepath.value());
Charles Zhao5dbfd062020-10-26 19:53:42 +110096 tflite_config.set_input_output_filepath(input_output_filepath_.value());
Charles Zhao6dde75b2020-09-15 14:38:28 +100097 tflite_config.set_num_runs(100);
98 TextFormat::PrintToString(tflite_config,
99 benchmark_config_.mutable_driver_config());
100
101 // Set FlatBufferModelSpecProto;
102 FlatBufferModelSpecProto model_proto;
103 CHECK(TextFormat::ParseFromString(kModelProtoText, &model_proto));
104 base::ReadFileToString(base::FilePath(kSmartDim20181115ModelFile),
105 model_proto.mutable_model_string());
106 const std::string model_content = model_proto.SerializeAsString();
107 base::WriteFile(tflite_model_filepath, model_content.data(),
108 model_content.size());
109
110 // Set ExpectedInputOutput.
Charles Zhao5dbfd062020-10-26 19:53:42 +1100111 SetExpectedValue(0.75f);
112 }
Qijiang Fan6bc59e12020-11-11 02:51:06 +0900113 MlBenchmarkTest(const MlBenchmarkTest&) = delete;
114 MlBenchmarkTest& operator=(const MlBenchmarkTest&) = delete;
115
Charles Zhao5dbfd062020-10-26 19:53:42 +1100116 // Write the output with given expected value.
117 void SetExpectedValue(const float val) {
Charles Zhao6dde75b2020-09-15 14:38:28 +1000118 ExpectedInputOutput input_output;
119 CHECK(TextFormat::ParseFromString(kInputOutputText, &input_output));
Charles Zhao5dbfd062020-10-26 19:53:42 +1100120 (*(*input_output.mutable_expected_output()
121 ->mutable_features()
122 ->mutable_feature())["z"]
123 .mutable_float_list()
124 ->mutable_value())[0] = val;
Charles Zhao6dde75b2020-09-15 14:38:28 +1000125 const std::string input_content = input_output.SerializeAsString();
Charles Zhao5dbfd062020-10-26 19:53:42 +1100126 base::WriteFile(input_output_filepath_, input_content.data(),
Charles Zhao6dde75b2020-09-15 14:38:28 +1000127 input_content.size());
128 }
129
130 protected:
131 // Temporary directory containing a file used for the file mechanism.
132 base::ScopedTempDir temp_dir_;
Charles Zhao5dbfd062020-10-26 19:53:42 +1100133 base::FilePath input_output_filepath_;
Charles Zhao6dde75b2020-09-15 14:38:28 +1000134 CrOSBenchmarkConfig benchmark_config_;
Charles Zhao6dde75b2020-09-15 14:38:28 +1000135};
136
Charles Zhao5dbfd062020-10-26 19:53:42 +1100137TEST_F(MlBenchmarkTest, TfliteModelMatchedValueTest) {
Charles Zhao6dde75b2020-09-15 14:38:28 +1000138 // Step 1 run benchmark_start.
139 const std::string config = benchmark_config_.SerializeAsString();
140 void* results_data = nullptr;
141 int results_size = 0;
142 EXPECT_EQ(benchmark_start(config.c_str(), config.size(), &results_data,
143 &results_size),
144 BenchmarkReturnStatus::OK);
145
146 // Step 2 check results.
147 BenchmarkResults results;
148 CHECK(results.ParseFromArray(results_data, results_size));
149 free_benchmark_results(results_data);
150 EXPECT_EQ(results.status(), BenchmarkReturnStatus::OK);
Michael Pishchaginb652dae2021-03-11 14:11:28 +0000151
152 auto metrics = results.metrics();
153 EXPECT_EQ(metrics[0].name(), "average_error");
154 EXPECT_EQ(metrics[0].units(), chrome::ml_benchmark::Metric::UNITLESS);
155 EXPECT_EQ(metrics[0].direction(),
156 chrome::ml_benchmark::Metric::SMALLER_IS_BETTER);
157 EXPECT_EQ(metrics[0].cardinality(), chrome::ml_benchmark::Metric::SINGLE);
158 EXPECT_NEAR(metrics[0].values()[0], 0.0f, 1e-5);
Charles Zhao6dde75b2020-09-15 14:38:28 +1000159}
160
Charles Zhao5dbfd062020-10-26 19:53:42 +1100161TEST_F(MlBenchmarkTest, TfliteModelUnmachedValueTest) {
162 SetExpectedValue(0.76f);
163 // Step 1 run benchmark_start.
164 const std::string config = benchmark_config_.SerializeAsString();
165 void* results_data = nullptr;
166 int results_size = 0;
167 EXPECT_EQ(benchmark_start(config.c_str(), config.size(), &results_data,
168 &results_size),
169 BenchmarkReturnStatus::OK);
170
171 // Step 2 check results.
172 BenchmarkResults results;
173 CHECK(results.ParseFromArray(results_data, results_size));
174 free_benchmark_results(results_data);
175 EXPECT_EQ(results.status(), BenchmarkReturnStatus::OK);
Michael Pishchaginb652dae2021-03-11 14:11:28 +0000176 auto metrics = results.metrics();
177 EXPECT_EQ(metrics[0].name(), "average_error");
178 EXPECT_EQ(metrics[0].units(), chrome::ml_benchmark::Metric::UNITLESS);
179 EXPECT_EQ(metrics[0].direction(),
180 chrome::ml_benchmark::Metric::SMALLER_IS_BETTER);
181 EXPECT_EQ(metrics[0].cardinality(), chrome::ml_benchmark::Metric::SINGLE);
182 EXPECT_NEAR(metrics[0].values()[0], 0.01f, 1e-5);
Charles Zhao5dbfd062020-10-26 19:53:42 +1100183}
184
Charles Zhao6dde75b2020-09-15 14:38:28 +1000185} // namespace ml