blob: 93bd081f85e61bbe8d1f324ab66228a5be6910b8 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuc5100022020-07-09 11:54:27 +100022using ::chromeos::machine_learning::mojom::FindLanguagesResult;
23using ::chromeos::machine_learning::mojom::TextLanguage;
24using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110025using ::chromeos::machine_learning::mojom::SuggestSelectionResult;
26using ::chromeos::machine_learning::mojom::TextAnnotation;
27using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
28using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
29using ::chromeos::machine_learning::mojom::TextAnnotationResult;
30using ::chromeos::machine_learning::mojom::TextClassifierRequest;
31using ::chromeos::machine_learning::mojom::TextEntity;
32using ::chromeos::machine_learning::mojom::TextEntityData;
33using ::chromeos::machine_learning::mojom::TextEntityPtr;
34using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
35
36// To avoid passing a lambda as a base::Closure.
37void DeleteTextClassifierImpl(
38 const TextClassifierImpl* const text_classifier_impl) {
39 delete text_classifier_impl;
40}
41
42} // namespace
43
44bool TextClassifierImpl::Create(
Honglin Yuc5100022020-07-09 11:54:27 +100045 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
46 const std::string& langid_model_path,
Honglin Yuf33dce32019-12-05 15:10:39 +110047 TextClassifierRequest request) {
Honglin Yuc5100022020-07-09 11:54:27 +100048 auto text_classifier_impl = new TextClassifierImpl(
49 annotator_model_mmap, langid_model_path, std::move(request));
50 if (text_classifier_impl->annotator_ == nullptr ||
51 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110052 // Fails to create annotator, return nullptr.
53 delete text_classifier_impl;
54 return false;
55 }
56
Andrew Moylan79b34a42020-07-08 11:13:11 +100057 // Use a connection error handler to strongly bind `text_classifier_impl` to
58 // `request`.
Honglin Yuf33dce32019-12-05 15:10:39 +110059 text_classifier_impl->SetConnectionErrorHandler(base::Bind(
60 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
61 return true;
62}
63
64TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100065 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
66 const std::string& langid_model_path,
Honglin Yuf33dce32019-12-05 15:10:39 +110067 TextClassifierRequest request)
68 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100069 annotator_model_mmap, nullptr, nullptr)),
70 language_identifier_(
71 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Honglin Yuf33dce32019-12-05 15:10:39 +110072 binding_(this, std::move(request)) {}
73
74void TextClassifierImpl::SetConnectionErrorHandler(
75 base::Closure connection_error_handler) {
76 binding_.set_connection_error_handler(std::move(connection_error_handler));
77}
78
79void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
80 AnnotateCallback callback) {
81 RequestMetrics<TextAnnotationResult> request_metrics("TextClassifier",
82 "Annotate");
83 request_metrics.StartRecordingPerformanceMetrics();
84
85 // Parse and set up the options.
86 libtextclassifier3::AnnotationOptions option;
87 if (request->default_locales) {
88 option.locales = request->default_locales.value();
89 }
90 if (request->reference_time) {
91 option.reference_time_ms_utc =
92 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
93 }
94 if (request->reference_timezone) {
95 option.reference_timezone = request->reference_timezone.value();
96 }
97 if (request->enabled_entities) {
98 option.entity_types.insert(request->enabled_entities.value().begin(),
99 request->enabled_entities.value().end());
100 }
101 option.detected_text_language_tags =
102 request->detected_text_language_tags.value_or("en");
103 option.annotation_usecase =
104 static_cast<libtextclassifier3::AnnotationUsecase>(
105 request->annotation_usecase);
106
107 // Do the annotation.
108 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
109 annotator_->Annotate(request->text, option);
110
111 // Parse the result.
112 std::vector<TextAnnotationPtr> annotations;
113 for (const auto& annotated_result : annotated_spans) {
114 DCHECK(annotated_result.span.second >= annotated_result.span.first);
115 std::vector<TextEntityPtr> entities;
116 for (const auto& classification : annotated_result.classification) {
117 // First, get entity data.
118 auto entity_data = TextEntityData::New();
119 if (classification.collection == "number") {
120 entity_data->set_numeric_value(classification.numeric_double_value);
121 } else {
122 // For the other types, just encode the substring into string_value.
123 // TODO(honglinyu): add data extraction for more types when needed
124 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000125 // Note that the returned indices by annotator is unicode codepoints.
126 entity_data->set_string_value(
127 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
128 .UTF8Substring(annotated_result.span.first,
129 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100130 }
131
132 // Second, create the entity.
133 entities.emplace_back(TextEntity::New(classification.collection,
134 classification.score,
135 std::move(entity_data)));
136 }
137 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
138 annotated_result.span.second,
139 std::move(entities)));
140 }
141
142 std::move(callback).Run(std::move(annotations));
143
144 request_metrics.FinishRecordingPerformanceMetrics();
145 request_metrics.RecordRequestEvent(TextAnnotationResult::OK);
146}
147
148void TextClassifierImpl::SuggestSelection(
149 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
150 RequestMetrics<SuggestSelectionResult> request_metrics("TextClassifier",
151 "SuggestSelection");
152 request_metrics.StartRecordingPerformanceMetrics();
153
154 libtextclassifier3::SelectionOptions option;
155 if (request->default_locales) {
156 option.locales = request->default_locales.value();
157 }
158 option.detected_text_language_tags =
159 request->detected_text_language_tags.value_or("en");
160 option.annotation_usecase =
161 static_cast<libtextclassifier3::AnnotationUsecase>(
162 request->annotation_usecase);
163
164 libtextclassifier3::CodepointSpan user_selection;
165 user_selection.first = request->user_selection->start_offset;
166 user_selection.second = request->user_selection->end_offset;
167
168 const libtextclassifier3::CodepointSpan suggested_span =
169 annotator_->SuggestSelection(request->text, user_selection, option);
170 auto result_span = CodepointSpan::New();
171 result_span->start_offset = suggested_span.first;
172 result_span->end_offset = suggested_span.second;
173
174 std::move(callback).Run(std::move(result_span));
175
176 request_metrics.FinishRecordingPerformanceMetrics();
177 request_metrics.RecordRequestEvent(SuggestSelectionResult::OK);
178}
179
Honglin Yuc5100022020-07-09 11:54:27 +1000180void TextClassifierImpl::FindLanguages(const std::string& text,
181 FindLanguagesCallback callback) {
182 RequestMetrics<FindLanguagesResult> request_metrics("TextClassifier",
183 "FindLanguages");
184 request_metrics.StartRecordingPerformanceMetrics();
185
186 const std::vector<std::pair<std::string, float>> languages =
187 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
188 text);
189
190 std::vector<TextLanguagePtr> langid_result;
191 for (const auto& lang : languages) {
192 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
193 }
194
195 std::move(callback).Run(std::move(langid_result));
196
197 request_metrics.FinishRecordingPerformanceMetrics();
198 request_metrics.RecordRequestEvent(FindLanguagesResult::OK);
199}
200
Honglin Yuf33dce32019-12-05 15:10:39 +1100201} // namespace ml