blob: ec073af6f64e76b06f2d820490b6acc1fc22d1de [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuc5100022020-07-09 11:54:27 +100022using ::chromeos::machine_learning::mojom::FindLanguagesResult;
23using ::chromeos::machine_learning::mojom::TextLanguage;
24using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110025using ::chromeos::machine_learning::mojom::SuggestSelectionResult;
26using ::chromeos::machine_learning::mojom::TextAnnotation;
27using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
28using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
29using ::chromeos::machine_learning::mojom::TextAnnotationResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100030using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110031using ::chromeos::machine_learning::mojom::TextEntity;
32using ::chromeos::machine_learning::mojom::TextEntityData;
33using ::chromeos::machine_learning::mojom::TextEntityPtr;
34using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
35
36// To avoid passing a lambda as a base::Closure.
37void DeleteTextClassifierImpl(
38 const TextClassifierImpl* const text_classifier_impl) {
39 delete text_classifier_impl;
40}
41
42} // namespace
43
44bool TextClassifierImpl::Create(
Honglin Yuc5100022020-07-09 11:54:27 +100045 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
46 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100047 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yuc5100022020-07-09 11:54:27 +100048 auto text_classifier_impl = new TextClassifierImpl(
Andrew Moylanb481af72020-07-09 15:22:00 +100049 annotator_model_mmap, langid_model_path, std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100050 if (text_classifier_impl->annotator_ == nullptr ||
51 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110052 // Fails to create annotator, return nullptr.
53 delete text_classifier_impl;
54 return false;
55 }
56
Andrew Moylanb481af72020-07-09 15:22:00 +100057 // Use a disconnection handler to strongly bind `text_classifier_impl` to
58 // `receiver`.
59 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110060 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
61 return true;
62}
63
64TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100065 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
66 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100067 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110068 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100069 annotator_model_mmap, nullptr, nullptr)),
70 language_identifier_(
71 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100072 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110073
Andrew Moylanb481af72020-07-09 15:22:00 +100074void TextClassifierImpl::SetDisconnectionHandler(
75 base::Closure disconnect_handler) {
76 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110077}
78
79void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
80 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100081 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110082 request_metrics.StartRecordingPerformanceMetrics();
83
84 // Parse and set up the options.
85 libtextclassifier3::AnnotationOptions option;
86 if (request->default_locales) {
87 option.locales = request->default_locales.value();
88 }
89 if (request->reference_time) {
90 option.reference_time_ms_utc =
91 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
92 }
93 if (request->reference_timezone) {
94 option.reference_timezone = request->reference_timezone.value();
95 }
96 if (request->enabled_entities) {
97 option.entity_types.insert(request->enabled_entities.value().begin(),
98 request->enabled_entities.value().end());
99 }
100 option.detected_text_language_tags =
101 request->detected_text_language_tags.value_or("en");
102 option.annotation_usecase =
103 static_cast<libtextclassifier3::AnnotationUsecase>(
104 request->annotation_usecase);
105
106 // Do the annotation.
107 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
108 annotator_->Annotate(request->text, option);
109
110 // Parse the result.
111 std::vector<TextAnnotationPtr> annotations;
112 for (const auto& annotated_result : annotated_spans) {
113 DCHECK(annotated_result.span.second >= annotated_result.span.first);
114 std::vector<TextEntityPtr> entities;
115 for (const auto& classification : annotated_result.classification) {
116 // First, get entity data.
117 auto entity_data = TextEntityData::New();
118 if (classification.collection == "number") {
119 entity_data->set_numeric_value(classification.numeric_double_value);
120 } else {
121 // For the other types, just encode the substring into string_value.
122 // TODO(honglinyu): add data extraction for more types when needed
123 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000124 // Note that the returned indices by annotator is unicode codepoints.
125 entity_data->set_string_value(
126 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
127 .UTF8Substring(annotated_result.span.first,
128 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100129 }
130
131 // Second, create the entity.
132 entities.emplace_back(TextEntity::New(classification.collection,
133 classification.score,
134 std::move(entity_data)));
135 }
136 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
137 annotated_result.span.second,
138 std::move(entities)));
139 }
140
141 std::move(callback).Run(std::move(annotations));
142
143 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100144}
145
146void TextClassifierImpl::SuggestSelection(
147 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000148 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100149 request_metrics.StartRecordingPerformanceMetrics();
150
151 libtextclassifier3::SelectionOptions option;
152 if (request->default_locales) {
153 option.locales = request->default_locales.value();
154 }
155 option.detected_text_language_tags =
156 request->detected_text_language_tags.value_or("en");
157 option.annotation_usecase =
158 static_cast<libtextclassifier3::AnnotationUsecase>(
159 request->annotation_usecase);
160
161 libtextclassifier3::CodepointSpan user_selection;
162 user_selection.first = request->user_selection->start_offset;
163 user_selection.second = request->user_selection->end_offset;
164
165 const libtextclassifier3::CodepointSpan suggested_span =
166 annotator_->SuggestSelection(request->text, user_selection, option);
167 auto result_span = CodepointSpan::New();
168 result_span->start_offset = suggested_span.first;
169 result_span->end_offset = suggested_span.second;
170
171 std::move(callback).Run(std::move(result_span));
172
173 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100174}
175
Honglin Yuc5100022020-07-09 11:54:27 +1000176void TextClassifierImpl::FindLanguages(const std::string& text,
177 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000178 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000179 request_metrics.StartRecordingPerformanceMetrics();
180
181 const std::vector<std::pair<std::string, float>> languages =
182 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
183 text);
184
185 std::vector<TextLanguagePtr> langid_result;
186 for (const auto& lang : languages) {
187 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
188 }
189
190 std::move(callback).Run(std::move(langid_result));
191
192 request_metrics.FinishRecordingPerformanceMetrics();
193 request_metrics.RecordRequestEvent(FindLanguagesResult::OK);
194}
195
Honglin Yuf33dce32019-12-05 15:10:39 +1100196} // namespace ml