ml: implement text suggester mojo api
Implement wrapping mojo api for the text suggestions shared lib.
BUG=chromium:1146266
TEST=cros_run_unit_tests --board=${BOARD} --packages chromeos-base/ml
Change-Id: I64ce411656d70fa6d0e379c2acc6fdd41f9dd0b6
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/2626551
Tested-by: Curtis McMullan <curtismcmullan@chromium.org>
Commit-Queue: Darren Shen <shend@chromium.org>
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
diff --git a/ml/machine_learning_service_impl.cc b/ml/machine_learning_service_impl.cc
index 4ff3a32..36b7a1a 100644
--- a/ml/machine_learning_service_impl.cc
+++ b/ml/machine_learning_service_impl.cc
@@ -29,6 +29,8 @@
#include "ml/request_metrics.h"
#include "ml/soda_recognizer_impl.h"
#include "ml/text_classifier_impl.h"
+#include "ml/text_suggester_impl.h"
+#include "ml/text_suggestions.h"
namespace ml {
@@ -419,4 +421,47 @@
request_metrics.RecordRequestEvent(LoadModelResult::OK);
}
+void MachineLearningServiceImpl::LoadTextSuggester(
+ mojo::PendingReceiver<chromeos::machine_learning::mojom::TextSuggester>
+ receiver,
+ LoadTextSuggesterCallback callback) {
+ RequestMetrics request_metrics("TextSuggester", kMetricsRequestName);
+ request_metrics.StartRecordingPerformanceMetrics();
+
+ // Load TextSuggestions library.
+ auto* const text_suggestions = ml::TextSuggestions::GetInstance();
+
+ if (text_suggestions->GetStatus() ==
+ ml::TextSuggestions::Status::kNotSupported) {
+ LOG(ERROR) << "Initialize ml::TextSuggestions with error "
+ << static_cast<int>(text_suggestions->GetStatus());
+
+ std::move(callback).Run(LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
+ request_metrics.RecordRequestEvent(
+ LoadModelResult::FEATURE_NOT_SUPPORTED_ERROR);
+ return;
+ }
+
+ if (text_suggestions->GetStatus() != ml::TextSuggestions::Status::kOk) {
+ LOG(ERROR) << "Initialize ml::TextSuggestions with error "
+ << static_cast<int>(text_suggestions->GetStatus());
+
+ std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
+ request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
+ return;
+ }
+
+ // Create TextSuggester.
+ if (!TextSuggesterImpl::Create(std::move(receiver))) {
+ std::move(callback).Run(LoadModelResult::LOAD_MODEL_ERROR);
+ request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
+ return;
+ }
+
+ std::move(callback).Run(LoadModelResult::OK);
+
+ request_metrics.FinishRecordingPerformanceMetrics();
+ request_metrics.RecordRequestEvent(LoadModelResult::OK);
+}
+
} // namespace ml