blob: 111d32f10fa70d466c5a9b7a7c4e746f87d4994e [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100011#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110012
13#include "ml/mojom/text_classifier.mojom.h"
14#include "ml/request_metrics.h"
15
16namespace ml {
17
18namespace {
19
20using ::chromeos::machine_learning::mojom::CodepointSpan;
21using ::chromeos::machine_learning::mojom::SuggestSelectionResult;
22using ::chromeos::machine_learning::mojom::TextAnnotation;
23using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
25using ::chromeos::machine_learning::mojom::TextAnnotationResult;
26using ::chromeos::machine_learning::mojom::TextClassifierRequest;
27using ::chromeos::machine_learning::mojom::TextEntity;
28using ::chromeos::machine_learning::mojom::TextEntityData;
29using ::chromeos::machine_learning::mojom::TextEntityPtr;
30using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
31
32// To avoid passing a lambda as a base::Closure.
33void DeleteTextClassifierImpl(
34 const TextClassifierImpl* const text_classifier_impl) {
35 delete text_classifier_impl;
36}
37
38} // namespace
39
40bool TextClassifierImpl::Create(
41 std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
42 TextClassifierRequest request) {
43 auto text_classifier_impl = new TextClassifierImpl(mmap, std::move(request));
44 if (text_classifier_impl->annotator_ == nullptr) {
45 // Fails to create annotator, return nullptr.
46 delete text_classifier_impl;
47 return false;
48 }
49
Andrew Moylan79b34a42020-07-08 11:13:11 +100050 // Use a connection error handler to strongly bind `text_classifier_impl` to
51 // `request`.
Honglin Yuf33dce32019-12-05 15:10:39 +110052 text_classifier_impl->SetConnectionErrorHandler(base::Bind(
53 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
54 return true;
55}
56
57TextClassifierImpl::TextClassifierImpl(
58 std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
59 TextClassifierRequest request)
60 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
61 mmap, nullptr, nullptr)),
62 binding_(this, std::move(request)) {}
63
64void TextClassifierImpl::SetConnectionErrorHandler(
65 base::Closure connection_error_handler) {
66 binding_.set_connection_error_handler(std::move(connection_error_handler));
67}
68
69void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
70 AnnotateCallback callback) {
71 RequestMetrics<TextAnnotationResult> request_metrics("TextClassifier",
72 "Annotate");
73 request_metrics.StartRecordingPerformanceMetrics();
74
75 // Parse and set up the options.
76 libtextclassifier3::AnnotationOptions option;
77 if (request->default_locales) {
78 option.locales = request->default_locales.value();
79 }
80 if (request->reference_time) {
81 option.reference_time_ms_utc =
82 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
83 }
84 if (request->reference_timezone) {
85 option.reference_timezone = request->reference_timezone.value();
86 }
87 if (request->enabled_entities) {
88 option.entity_types.insert(request->enabled_entities.value().begin(),
89 request->enabled_entities.value().end());
90 }
91 option.detected_text_language_tags =
92 request->detected_text_language_tags.value_or("en");
93 option.annotation_usecase =
94 static_cast<libtextclassifier3::AnnotationUsecase>(
95 request->annotation_usecase);
96
97 // Do the annotation.
98 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
99 annotator_->Annotate(request->text, option);
100
101 // Parse the result.
102 std::vector<TextAnnotationPtr> annotations;
103 for (const auto& annotated_result : annotated_spans) {
104 DCHECK(annotated_result.span.second >= annotated_result.span.first);
105 std::vector<TextEntityPtr> entities;
106 for (const auto& classification : annotated_result.classification) {
107 // First, get entity data.
108 auto entity_data = TextEntityData::New();
109 if (classification.collection == "number") {
110 entity_data->set_numeric_value(classification.numeric_double_value);
111 } else {
112 // For the other types, just encode the substring into string_value.
113 // TODO(honglinyu): add data extraction for more types when needed
114 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000115 // Note that the returned indices by annotator is unicode codepoints.
116 entity_data->set_string_value(
117 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
118 .UTF8Substring(annotated_result.span.first,
119 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100120 }
121
122 // Second, create the entity.
123 entities.emplace_back(TextEntity::New(classification.collection,
124 classification.score,
125 std::move(entity_data)));
126 }
127 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
128 annotated_result.span.second,
129 std::move(entities)));
130 }
131
132 std::move(callback).Run(std::move(annotations));
133
134 request_metrics.FinishRecordingPerformanceMetrics();
135 request_metrics.RecordRequestEvent(TextAnnotationResult::OK);
136}
137
138void TextClassifierImpl::SuggestSelection(
139 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
140 RequestMetrics<SuggestSelectionResult> request_metrics("TextClassifier",
141 "SuggestSelection");
142 request_metrics.StartRecordingPerformanceMetrics();
143
144 libtextclassifier3::SelectionOptions option;
145 if (request->default_locales) {
146 option.locales = request->default_locales.value();
147 }
148 option.detected_text_language_tags =
149 request->detected_text_language_tags.value_or("en");
150 option.annotation_usecase =
151 static_cast<libtextclassifier3::AnnotationUsecase>(
152 request->annotation_usecase);
153
154 libtextclassifier3::CodepointSpan user_selection;
155 user_selection.first = request->user_selection->start_offset;
156 user_selection.second = request->user_selection->end_offset;
157
158 const libtextclassifier3::CodepointSpan suggested_span =
159 annotator_->SuggestSelection(request->text, user_selection, option);
160 auto result_span = CodepointSpan::New();
161 result_span->start_offset = suggested_span.first;
162 result_span->end_offset = suggested_span.second;
163
164 std::move(callback).Run(std::move(result_span));
165
166 request_metrics.FinishRecordingPerformanceMetrics();
167 request_metrics.RecordRequestEvent(SuggestSelectionResult::OK);
168}
169
170} // namespace ml