blob: d595050cc2b93656df4985b29681beb8fd60f7be [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuf33dce32019-12-05 15:10:39 +110022using ::chromeos::machine_learning::mojom::TextAnnotation;
23using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100025using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110026using ::chromeos::machine_learning::mojom::TextEntity;
27using ::chromeos::machine_learning::mojom::TextEntityData;
28using ::chromeos::machine_learning::mojom::TextEntityPtr;
Honglin Yu273b3422020-07-13 12:20:36 +100029using ::chromeos::machine_learning::mojom::TextLanguage;
30using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110031using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
32
33// To avoid passing a lambda as a base::Closure.
34void DeleteTextClassifierImpl(
35 const TextClassifierImpl* const text_classifier_impl) {
36 delete text_classifier_impl;
37}
38
39} // namespace
40
41bool TextClassifierImpl::Create(
Honglin Yuc5100022020-07-09 11:54:27 +100042 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
43 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100044 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yuc5100022020-07-09 11:54:27 +100045 auto text_classifier_impl = new TextClassifierImpl(
Andrew Moylanb481af72020-07-09 15:22:00 +100046 annotator_model_mmap, langid_model_path, std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100047 if (text_classifier_impl->annotator_ == nullptr ||
48 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110049 // Fails to create annotator, return nullptr.
50 delete text_classifier_impl;
51 return false;
52 }
53
Andrew Moylanb481af72020-07-09 15:22:00 +100054 // Use a disconnection handler to strongly bind `text_classifier_impl` to
55 // `receiver`.
56 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110057 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
58 return true;
59}
60
61TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100062 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
63 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100064 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110065 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100066 annotator_model_mmap, nullptr, nullptr)),
67 language_identifier_(
68 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100069 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110070
Andrew Moylanb481af72020-07-09 15:22:00 +100071void TextClassifierImpl::SetDisconnectionHandler(
72 base::Closure disconnect_handler) {
73 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110074}
75
76void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
77 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100078 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110079 request_metrics.StartRecordingPerformanceMetrics();
80
81 // Parse and set up the options.
82 libtextclassifier3::AnnotationOptions option;
83 if (request->default_locales) {
84 option.locales = request->default_locales.value();
85 }
86 if (request->reference_time) {
87 option.reference_time_ms_utc =
88 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
89 }
90 if (request->reference_timezone) {
91 option.reference_timezone = request->reference_timezone.value();
92 }
93 if (request->enabled_entities) {
94 option.entity_types.insert(request->enabled_entities.value().begin(),
95 request->enabled_entities.value().end());
96 }
97 option.detected_text_language_tags =
98 request->detected_text_language_tags.value_or("en");
99 option.annotation_usecase =
100 static_cast<libtextclassifier3::AnnotationUsecase>(
101 request->annotation_usecase);
102
103 // Do the annotation.
104 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
105 annotator_->Annotate(request->text, option);
106
107 // Parse the result.
108 std::vector<TextAnnotationPtr> annotations;
109 for (const auto& annotated_result : annotated_spans) {
110 DCHECK(annotated_result.span.second >= annotated_result.span.first);
111 std::vector<TextEntityPtr> entities;
112 for (const auto& classification : annotated_result.classification) {
113 // First, get entity data.
114 auto entity_data = TextEntityData::New();
115 if (classification.collection == "number") {
116 entity_data->set_numeric_value(classification.numeric_double_value);
117 } else {
118 // For the other types, just encode the substring into string_value.
119 // TODO(honglinyu): add data extraction for more types when needed
120 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000121 // Note that the returned indices by annotator is unicode codepoints.
122 entity_data->set_string_value(
123 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
124 .UTF8Substring(annotated_result.span.first,
125 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100126 }
127
128 // Second, create the entity.
129 entities.emplace_back(TextEntity::New(classification.collection,
130 classification.score,
131 std::move(entity_data)));
132 }
133 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
134 annotated_result.span.second,
135 std::move(entities)));
136 }
137
138 std::move(callback).Run(std::move(annotations));
139
140 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100141}
142
143void TextClassifierImpl::SuggestSelection(
144 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000145 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100146 request_metrics.StartRecordingPerformanceMetrics();
147
148 libtextclassifier3::SelectionOptions option;
149 if (request->default_locales) {
150 option.locales = request->default_locales.value();
151 }
152 option.detected_text_language_tags =
153 request->detected_text_language_tags.value_or("en");
154 option.annotation_usecase =
155 static_cast<libtextclassifier3::AnnotationUsecase>(
156 request->annotation_usecase);
157
158 libtextclassifier3::CodepointSpan user_selection;
159 user_selection.first = request->user_selection->start_offset;
160 user_selection.second = request->user_selection->end_offset;
161
162 const libtextclassifier3::CodepointSpan suggested_span =
163 annotator_->SuggestSelection(request->text, user_selection, option);
164 auto result_span = CodepointSpan::New();
165 result_span->start_offset = suggested_span.first;
166 result_span->end_offset = suggested_span.second;
167
168 std::move(callback).Run(std::move(result_span));
169
170 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100171}
172
Honglin Yuc5100022020-07-09 11:54:27 +1000173void TextClassifierImpl::FindLanguages(const std::string& text,
174 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000175 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000176 request_metrics.StartRecordingPerformanceMetrics();
177
178 const std::vector<std::pair<std::string, float>> languages =
179 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
180 text);
181
182 std::vector<TextLanguagePtr> langid_result;
183 for (const auto& lang : languages) {
184 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
185 }
186
187 std::move(callback).Run(std::move(langid_result));
188
189 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuc5100022020-07-09 11:54:27 +1000190}
191
Honglin Yuf33dce32019-12-05 15:10:39 +1100192} // namespace ml