ml: Add mojo API of TextClassifier/Annotator.
This CL implements the mojo API of text classifier.
BUG=chromium:1020419
TEST=pass the existing unit tests.
TEST=on device (eve), 1+2=3 works.
TEST=on device (eve), Annotate() and SuggestSelection() work.
Change-Id: Iadfe5747f3b2d1f8c8a8d9f0b975e2e8cda50726
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/1951350
Tested-by: Honglin Yu <honglinyu@chromium.org>
Commit-Queue: Honglin Yu <honglinyu@chromium.org>
Reviewed-by: Sam McNally <sammc@chromium.org>
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
diff --git a/ml/machine_learning_service_impl.cc b/ml/machine_learning_service_impl.cc
index 62b2f81..efcf64c 100644
--- a/ml/machine_learning_service_impl.cc
+++ b/ml/machine_learning_service_impl.cc
@@ -10,10 +10,16 @@
#include <base/bind.h>
#include <base/bind_helpers.h>
+#include <base/files/file.h>
+#include <base/files/file_util.h>
#include <tensorflow/lite/model.h>
+#include <unicode/putil.h>
+#include <unicode/udata.h>
+#include <utils/memory/mmap.h>
#include "ml/model_impl.h"
#include "ml/mojom/model.mojom.h"
+#include "ml/text_classifier_impl.h"
namespace ml {
@@ -30,18 +36,25 @@
// or |LoadFlatBufferModel|) requests
constexpr char kMetricsRequestName[] = "LoadModelResult";
+constexpr char kTextClassifierModelFile[] =
+ "mlservice-model-text_classifier_en-v706.fb";
+
+constexpr char kIcuDataFilePath[] = "/opt/google/chrome/icudtl.dat";
+
} // namespace
MachineLearningServiceImpl::MachineLearningServiceImpl(
mojo::ScopedMessagePipeHandle pipe,
base::Closure connection_error_handler,
const std::string& model_dir)
- : builtin_model_metadata_(GetBuiltinModelMetadata()),
+ : icu_data_(nullptr),
+ text_classifier_model_filename_(kTextClassifierModelFile),
+ builtin_model_metadata_(GetBuiltinModelMetadata()),
model_dir_(model_dir),
binding_(this,
mojo::InterfaceRequest<
chromeos::machine_learning::mojom::MachineLearningService>(
- std::move(pipe))) {
+ std::move(pipe))) {
binding_.set_connection_error_handler(std::move(connection_error_handler));
}
@@ -51,6 +64,11 @@
std::move(connection_error_handler),
kSystemModelDir) {}
+void MachineLearningServiceImpl::SetTextClassifierModelFilenameForTesting(
+ const std::string& filename) {
+ text_classifier_model_filename_ = filename;
+}
+
void MachineLearningServiceImpl::LoadBuiltinModel(
BuiltinModelSpecPtr spec,
ModelRequest request,
@@ -132,4 +150,59 @@
request_metrics.RecordRequestEvent(LoadModelResult::OK);
}
+void MachineLearningServiceImpl::LoadTextClassifier(
+ chromeos::machine_learning::mojom::TextClassifierRequest request,
+ LoadTextClassifierCallback callback) {
+ RequestMetrics<LoadModelResult> request_metrics("TextClassifier",
+ kMetricsRequestName);
+ request_metrics.StartRecordingPerformanceMetrics();
+
+ // Attempt to load model.
+ std::string model_path = model_dir_ + text_classifier_model_filename_;
+ auto scoped_mmap =
+ std::make_unique<libtextclassifier3::ScopedMmap>(model_path);
+ if (!scoped_mmap->handle().ok()) {
+ LOG(ERROR) << "Failed to load the text classifier model file '"
+ << model_path << "'.";
+ std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
+ request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
+ return;
+ }
+
+ // Create the TextClassifier.
+ if (!TextClassifierImpl::Create(&scoped_mmap, std::move(request))) {
+ LOG(ERROR) << "Failed to create TextClassifierImpl object.";
+ std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
+ request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
+ return;
+ }
+
+ // initialize the icu library.
+ InitIcuIfNeeded();
+
+ std::move(callback).Run(LoadModelResult::OK);
+
+ request_metrics.FinishRecordingPerformanceMetrics();
+ request_metrics.RecordRequestEvent(LoadModelResult::OK);
+}
+
+void MachineLearningServiceImpl::InitIcuIfNeeded() {
+ if (icu_data_ == nullptr) {
+ // Need to load the data file again.
+ int64_t file_size;
+ const base::FilePath icu_data_file_path(kIcuDataFilePath);
+ CHECK(base::GetFileSize(icu_data_file_path, &file_size));
+ icu_data_ = new char[file_size];
+ CHECK(base::ReadFile(icu_data_file_path, icu_data_,
+ static_cast<int>(file_size)) == file_size);
+ // Init the Icu library.
+ UErrorCode err = U_ZERO_ERROR;
+ udata_setCommonData(reinterpret_cast<void*>(icu_data_), &err);
+ DCHECK(err == U_ZERO_ERROR);
+ // Never try to load Icu data from files.
+ udata_setFileAccess(UDATA_ONLY_PACKAGES, &err);
+ }
+}
+
+
} // namespace ml