ml: Add mojo API of TextClassifier/Annotator.
This CL implements the mojo API of text classifier.
BUG=chromium:1020419
TEST=pass the existing unit tests.
TEST=on device (eve), 1+2=3 works.
TEST=on device (eve), Annotate() and SuggestSelection() work.
Change-Id: Iadfe5747f3b2d1f8c8a8d9f0b975e2e8cda50726
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/1951350
Tested-by: Honglin Yu <honglinyu@chromium.org>
Commit-Queue: Honglin Yu <honglinyu@chromium.org>
Reviewed-by: Sam McNally <sammc@chromium.org>
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
diff --git a/ml/text_classifier_impl.cc b/ml/text_classifier_impl.cc
new file mode 100644
index 0000000..6d8a62b
--- /dev/null
+++ b/ml/text_classifier_impl.cc
@@ -0,0 +1,167 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "ml/text_classifier_impl.h"
+
+#include <utility>
+#include <vector>
+
+#include <base/logging.h>
+
+#include "ml/mojom/text_classifier.mojom.h"
+#include "ml/request_metrics.h"
+
+namespace ml {
+
+namespace {
+
+using ::chromeos::machine_learning::mojom::CodepointSpan;
+using ::chromeos::machine_learning::mojom::SuggestSelectionResult;
+using ::chromeos::machine_learning::mojom::TextAnnotation;
+using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
+using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
+using ::chromeos::machine_learning::mojom::TextAnnotationResult;
+using ::chromeos::machine_learning::mojom::TextClassifierRequest;
+using ::chromeos::machine_learning::mojom::TextEntity;
+using ::chromeos::machine_learning::mojom::TextEntityData;
+using ::chromeos::machine_learning::mojom::TextEntityPtr;
+using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
+
+// To avoid passing a lambda as a base::Closure.
+void DeleteTextClassifierImpl(
+ const TextClassifierImpl* const text_classifier_impl) {
+ delete text_classifier_impl;
+}
+
+} // namespace
+
+bool TextClassifierImpl::Create(
+ std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
+ TextClassifierRequest request) {
+ auto text_classifier_impl = new TextClassifierImpl(mmap, std::move(request));
+ if (text_classifier_impl->annotator_ == nullptr) {
+ // Fails to create annotator, return nullptr.
+ delete text_classifier_impl;
+ return false;
+ }
+
+ // Use a connection error handler to strongly bind |text_classifier_impl| to
+ // |request|.
+ text_classifier_impl->SetConnectionErrorHandler(base::Bind(
+ &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
+ return true;
+}
+
+TextClassifierImpl::TextClassifierImpl(
+ std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
+ TextClassifierRequest request)
+ : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
+ mmap, nullptr, nullptr)),
+ binding_(this, std::move(request)) {}
+
+void TextClassifierImpl::SetConnectionErrorHandler(
+ base::Closure connection_error_handler) {
+ binding_.set_connection_error_handler(std::move(connection_error_handler));
+}
+
+void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
+ AnnotateCallback callback) {
+ RequestMetrics<TextAnnotationResult> request_metrics("TextClassifier",
+ "Annotate");
+ request_metrics.StartRecordingPerformanceMetrics();
+
+ // Parse and set up the options.
+ libtextclassifier3::AnnotationOptions option;
+ if (request->default_locales) {
+ option.locales = request->default_locales.value();
+ }
+ if (request->reference_time) {
+ option.reference_time_ms_utc =
+ request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
+ }
+ if (request->reference_timezone) {
+ option.reference_timezone = request->reference_timezone.value();
+ }
+ if (request->enabled_entities) {
+ option.entity_types.insert(request->enabled_entities.value().begin(),
+ request->enabled_entities.value().end());
+ }
+ option.detected_text_language_tags =
+ request->detected_text_language_tags.value_or("en");
+ option.annotation_usecase =
+ static_cast<libtextclassifier3::AnnotationUsecase>(
+ request->annotation_usecase);
+
+ // Do the annotation.
+ const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
+ annotator_->Annotate(request->text, option);
+
+ // Parse the result.
+ std::vector<TextAnnotationPtr> annotations;
+ for (const auto& annotated_result : annotated_spans) {
+ DCHECK(annotated_result.span.second >= annotated_result.span.first);
+ std::vector<TextEntityPtr> entities;
+ for (const auto& classification : annotated_result.classification) {
+ // First, get entity data.
+ auto entity_data = TextEntityData::New();
+ if (classification.collection == "number") {
+ entity_data->set_numeric_value(classification.numeric_double_value);
+ } else {
+ // For the other types, just encode the substring into string_value.
+ // TODO(honglinyu): add data extraction for more types when needed
+ // and available.
+ entity_data->set_string_value(request->text.substr(
+ annotated_result.span.first,
+ annotated_result.span.second - annotated_result.span.first));
+ }
+
+ // Second, create the entity.
+ entities.emplace_back(TextEntity::New(classification.collection,
+ classification.score,
+ std::move(entity_data)));
+ }
+ annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
+ annotated_result.span.second,
+ std::move(entities)));
+ }
+
+ std::move(callback).Run(std::move(annotations));
+
+ request_metrics.FinishRecordingPerformanceMetrics();
+ request_metrics.RecordRequestEvent(TextAnnotationResult::OK);
+}
+
+void TextClassifierImpl::SuggestSelection(
+ TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
+ RequestMetrics<SuggestSelectionResult> request_metrics("TextClassifier",
+ "SuggestSelection");
+ request_metrics.StartRecordingPerformanceMetrics();
+
+ libtextclassifier3::SelectionOptions option;
+ if (request->default_locales) {
+ option.locales = request->default_locales.value();
+ }
+ option.detected_text_language_tags =
+ request->detected_text_language_tags.value_or("en");
+ option.annotation_usecase =
+ static_cast<libtextclassifier3::AnnotationUsecase>(
+ request->annotation_usecase);
+
+ libtextclassifier3::CodepointSpan user_selection;
+ user_selection.first = request->user_selection->start_offset;
+ user_selection.second = request->user_selection->end_offset;
+
+ const libtextclassifier3::CodepointSpan suggested_span =
+ annotator_->SuggestSelection(request->text, user_selection, option);
+ auto result_span = CodepointSpan::New();
+ result_span->start_offset = suggested_span.first;
+ result_span->end_offset = suggested_span.second;
+
+ std::move(callback).Run(std::move(result_span));
+
+ request_metrics.FinishRecordingPerformanceMetrics();
+ request_metrics.RecordRequestEvent(SuggestSelectionResult::OK);
+}
+
+} // namespace ml