blob: 0ffc0fa20e5d820569cea1a116ec0677a94362bb [file] [log] [blame]
Frank Tang3e05d9d2021-11-08 14:04:04 -08001// © 2021 and later: Unicode, Inc. and others.
2// License & terms of use: http://www.unicode.org/copyright.html
3
4#include "unicode/utypes.h"
5
6#if !UCONFIG_NO_BREAK_ITERATION
7
8#include "lstmbetst.h"
9#include "lstmbe.h"
10
11#include <algorithm>
12#include <sstream>
13#include <vector>
14
15#include "charstr.h"
16
17//---------------------------------------------
18// runIndexedTest
19//---------------------------------------------
20
21
22void LSTMBETest::runIndexedTest( int32_t index, UBool exec, const char* &name, char* params )
23{
24 fTestParams = params;
25
26 TESTCASE_AUTO_BEGIN;
27
28 TESTCASE_AUTO(TestThaiGraphclust);
29 TESTCASE_AUTO(TestThaiCodepoints);
30 TESTCASE_AUTO(TestBurmeseGraphclust);
31 TESTCASE_AUTO(TestThaiGraphclustWithLargeMemory);
32 TESTCASE_AUTO(TestThaiCodepointsWithLargeMemory);
33
34 TESTCASE_AUTO_END;
35}
36
37
38//--------------------------------------------------------------------------------------
39//
40// LSTMBETest constructor and destructor
41//
42//--------------------------------------------------------------------------------------
43
44LSTMBETest::LSTMBETest() {
45 fTestParams = NULL;
46}
47
48
49LSTMBETest::~LSTMBETest() {
50}
51
52UScriptCode getScriptFromModelName(const std::string& modelName) {
53 if (modelName.find("Thai") == 0) {
54 return USCRIPT_THAI;
55 } else if (modelName.find("Burmese") == 0) {
56 return USCRIPT_MYANMAR;
57 }
58 // Add for other script codes.
59 UPRV_UNREACHABLE_EXIT;
60}
61
62// Read file generated by
63// https://github.com/unicode-org/lstm_word_segmentation/blob/master/segment_text.py
64// as test cases and compare the Output.
65// Format of the file
66// Model:\t[Model Name (such as 'Thai_graphclust_model4_heavy')]
67// Embedding:\t[Embedding type (such as 'grapheme_clusters_tf')]
68// Input:\t[source text]
69// Output:\t[expected output separated by | ]
70// Input: ...
71// Output: ...
72// The test will ensure the Input contains only the characters can be handled by
73// the model. Since by default the LSTM models are not included, all the tested
74// models need to be included under source/test/testdata.
75
76void LSTMBETest::runTestFromFile(const char* filename) {
77 UErrorCode status = U_ZERO_ERROR;
78 LocalPointer<const LanguageBreakEngine> engine;
79 // Open and read the test data file.
80 const char *testDataDirectory = IntlTest::getSourceTestData(status);
81 CharString testFileName(testDataDirectory, -1, status);
82 testFileName.append(filename, -1, status);
83
84 int len;
85 UChar *testFile = ReadAndConvertFile(testFileName.data(), len, "UTF-8", status);
86 if (U_FAILURE(status)) {
87 errln("%s:%d Error %s opening test file %s", __FILE__, __LINE__, u_errorName(status), filename);
88 return;
89 }
90
91 // Put the test data into a UnicodeString
92 UnicodeString testString(FALSE, testFile, len);
93
94 int32_t start = 0;
95
96 UnicodeString line;
97 int32_t end;
98 std::string actual_sep_str;
99 int32_t caseNum = 0;
100 // Iterate through all the lines in the test file.
101 do {
102 int32_t cr = testString.indexOf(u'\r', start);
103 int32_t lf = testString.indexOf(u'\n', start);
104 end = cr >= 0 ? (lf >= 0 ? std::min(cr, lf) : cr) : lf;
105 line = testString.tempSubString(start, end < 0 ? INT32_MAX : end - start);
106 if (line.length() > 0) {
107 // Separate each line to key and value by TAB.
108 int32_t tab = line.indexOf(u'\t');
109 UnicodeString key = line.tempSubString(0, tab);
110 const UnicodeString value = line.tempSubString(tab+1);
111
112 if (key == "Model:") {
113 std::string modelName;
114 value.toUTF8String<std::string>(modelName);
115 engine.adoptInstead(createEngineFromTestData(modelName.c_str(), getScriptFromModelName(modelName), status));
116 if (U_FAILURE(status)) {
117 dataerrln("Could not CreateLSTMBreakEngine for " + line + UnicodeString(u_errorName(status)));
118 return;
119 }
120 } else if (key == "Input:") {
121 // First, we ensure all the char in the Input lines are accepted
122 // by the engine before we test them.
123 caseNum++;
124 bool canHandleAllChars = true;
125 for (int32_t i = 0; i < value.length(); i++) {
126 if (!engine->handles(value.charAt(i))) {
127 errln(UnicodeString("Test Case#") + caseNum + " contains char '" +
128 UnicodeString(value.charAt(i)) +
129 "' cannot be handled by the engine in offset " + i + "\n" + line);
130 canHandleAllChars = false;
131 break;
132 }
133 }
134 if (! canHandleAllChars) {
135 return;
136 }
137
138 // If the engine can handle all the chars in the Input line, we
139 // then find the break points by calling the engine.
140 std::stringstream ss;
141
142 // Construct the UText which is expected by the the engine as
143 // input from the UnicodeString.
144 UText ut = UTEXT_INITIALIZER;
145 utext_openConstUnicodeString(&ut, &value, &status);
146 if (U_FAILURE(status)) {
147 dataerrln("Could not utext_openConstUnicodeString for " + value + UnicodeString(u_errorName(status)));
148 return;
149 }
150
151 UVector32 actual(status);
152 if (U_FAILURE(status)) {
153 dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
154 return;
155 }
Frank Tangd2858cb2022-04-08 20:34:12 -0700156 engine->findBreaks(&ut, 0, value.length(), actual, false, status);
Frank Tang3e05d9d2021-11-08 14:04:04 -0800157 if (U_FAILURE(status)) {
158 dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status));
159 return;
160 }
161 utext_close(&ut);
162 for (int32_t i = 0; i < actual.size(); i++) {
163 ss << actual.elementAti(i) << ", ";
164 }
165 ss << value.length();
166 // Turn the break points into a string for easy comparison
167 // output.
168 actual_sep_str = "{" + ss.str() + "}";
169 } else if (key == "Output:" && !actual_sep_str.empty()) {
170 std::string d;
171 int32_t sep;
172 int32_t start = 0;
173 int32_t curr = 0;
174 std::stringstream ss;
175 while ((sep = value.indexOf(u'|', start)) >= 0) {
176 int32_t len = sep - start;
177 if (len > 0) {
178 if (curr > 0) {
179 ss << ", ";
180 }
181 curr += len;
182 ss << curr;
183 }
184 start = sep + 1;
185 }
186 // Turn the break points into a string for easy comparison
187 // output.
188 std::string expected = "{" + ss.str() + "}";
189 std::string utf8;
190
191 assertEquals((value + " Test Case#" + caseNum).toUTF8String<std::string>(utf8).c_str(),
192 expected.c_str(), actual_sep_str.c_str());
193 actual_sep_str.clear();
194 }
195 }
196 start = std::max(cr, lf) + 1;
197 } while (end >= 0);
198
199 delete [] testFile;
200}
201
202void LSTMBETest::TestThaiGraphclust() {
203 runTestFromFile("Thai_graphclust_model4_heavy_Test.txt");
204}
205
206void LSTMBETest::TestThaiCodepoints() {
207 runTestFromFile("Thai_codepoints_exclusive_model5_heavy_Test.txt");
208}
209
210void LSTMBETest::TestBurmeseGraphclust() {
211 runTestFromFile("Burmese_graphclust_model5_heavy_Test.txt");
212}
213
214const LanguageBreakEngine* LSTMBETest::createEngineFromTestData(
215 const char* model, UScriptCode script, UErrorCode& status) {
216 const char* testdatapath=loadTestData(status);
217 if(U_FAILURE(status))
218 {
219 dataerrln("Could not load testdata.dat " + UnicodeString(testdatapath) + ", " +
220 UnicodeString(u_errorName(status)));
221 return nullptr;
222 }
223
224 LocalUResourceBundlePointer rb(
225 ures_openDirect(testdatapath, model, &status));
226 if (U_FAILURE(status)) {
227 dataerrln("Could not open " + UnicodeString(model) + " under " + UnicodeString(testdatapath) + ", " +
228 UnicodeString(u_errorName(status)));
229 return nullptr;
230 }
231
232 const LSTMData* data = CreateLSTMData(rb.orphan(), status);
233 if (U_FAILURE(status)) {
234 dataerrln("Could not CreateLSTMData " + UnicodeString(model) + " under " + UnicodeString(testdatapath) + ", " +
235 UnicodeString(u_errorName(status)));
236 return nullptr;
237 }
238 if (data == nullptr) {
239 return nullptr;
240 }
241
242 LocalPointer<const LanguageBreakEngine> engine(CreateLSTMBreakEngine(script, data, status));
243 if (U_FAILURE(status) || engine.getAlias() == nullptr) {
244 dataerrln("Could not CreateLSTMBreakEngine " + UnicodeString(testdatapath) + ", " +
245 UnicodeString(u_errorName(status)));
246 DeleteLSTMData(data);
247 return nullptr;
248 }
249 return engine.orphan();
250}
251
252
253void LSTMBETest::TestThaiGraphclustWithLargeMemory() {
254 runTestWithLargeMemory("Thai_graphclust_model4_heavy", USCRIPT_THAI);
255
256}
257
258void LSTMBETest::TestThaiCodepointsWithLargeMemory() {
259 runTestWithLargeMemory("Thai_codepoints_exclusive_model5_heavy", USCRIPT_THAI);
260}
261
262constexpr int32_t MEMORY_TEST_THESHOLD_SHORT = 2 * 1024; // 2 K Unicode Chars.
263constexpr int32_t MEMORY_TEST_THESHOLD = 32 * 1024; // 32 K Unicode Chars.
264
265// Test with very long unicode string.
266void LSTMBETest::runTestWithLargeMemory( const char* model, UScriptCode script) {
267 UErrorCode status = U_ZERO_ERROR;
268 int32_t test_threshold = quick ? MEMORY_TEST_THESHOLD_SHORT : MEMORY_TEST_THESHOLD;
269 LocalPointer<const LanguageBreakEngine> engine(
270 createEngineFromTestData(model, script, status));
271 if (U_FAILURE(status)) {
272 dataerrln("Could not CreateLSTMBreakEngine for " + UnicodeString(model) + UnicodeString(u_errorName(status)));
273 return;
274 }
275 UnicodeString text(u"อ"); // start with a single Thai char.
276 UVector32 actual(status);
277 if (U_FAILURE(status)) {
278 dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
279 return;
280 }
281 while (U_SUCCESS(status) && text.length() <= test_threshold) {
282 // Construct the UText which is expected by the the engine as
283 // input from the UnicodeString.
284 UText ut = UTEXT_INITIALIZER;
285 utext_openConstUnicodeString(&ut, &text, &status);
286 if (U_FAILURE(status)) {
287 dataerrln("Could not utext_openConstUnicodeString for " + text + UnicodeString(u_errorName(status)));
288 return;
289 }
290
Frank Tangd2858cb2022-04-08 20:34:12 -0700291 engine->findBreaks(&ut, 0, text.length(), actual, false, status);
Frank Tang3e05d9d2021-11-08 14:04:04 -0800292 utext_close(&ut);
293 text += text;
294 }
295}
296#endif // #if !UCONFIG_NO_BREAK_ITERATION