blob: 46f93957c98deb13f7cbbfb52263826bba72fc82 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
Qijiang Fan713061e2021-03-08 15:45:12 +090010#include <base/check.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110011#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100012#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100013#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110014
15#include "ml/mojom/text_classifier.mojom.h"
16#include "ml/request_metrics.h"
17
18namespace ml {
19
20namespace {
21
22using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuf33dce32019-12-05 15:10:39 +110023using ::chromeos::machine_learning::mojom::TextAnnotation;
24using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
25using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100026using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110027using ::chromeos::machine_learning::mojom::TextEntity;
28using ::chromeos::machine_learning::mojom::TextEntityData;
29using ::chromeos::machine_learning::mojom::TextEntityPtr;
Honglin Yu273b3422020-07-13 12:20:36 +100030using ::chromeos::machine_learning::mojom::TextLanguage;
31using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110032using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
33
Honglin Yu3f99ff12020-10-15 00:40:11 +110034constexpr char kTextClassifierModelFilePath[] =
Honglin Yu30c65612020-10-23 10:29:01 +110035 "/opt/google/chrome/ml_models/"
36 "mlservice-model-text_classifier_en-v711_vocab-v1.fb";
Honglin Yu3f99ff12020-10-15 00:40:11 +110037
38constexpr char kLanguageIdentificationModelFilePath[] =
39 "/opt/google/chrome/ml_models/"
40 "mlservice-model-language_identification-20190924.smfb";
41
Honglin Yuf33dce32019-12-05 15:10:39 +110042// To avoid passing a lambda as a base::Closure.
43void DeleteTextClassifierImpl(
44 const TextClassifierImpl* const text_classifier_impl) {
45 delete text_classifier_impl;
46}
47
48} // namespace
49
50bool TextClassifierImpl::Create(
Andrew Moylanb481af72020-07-09 15:22:00 +100051 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yu3f99ff12020-10-15 00:40:11 +110052 // Attempt to load model.
53 auto annotator_model_mmap = std::make_unique<libtextclassifier3::ScopedMmap>(
54 kTextClassifierModelFilePath);
55 if (!annotator_model_mmap->handle().ok()) {
56 LOG(ERROR) << "Failed to load the text classifier model file.";
57 return false;
58 }
59
Honglin Yuc5100022020-07-09 11:54:27 +100060 auto text_classifier_impl = new TextClassifierImpl(
Honglin Yu3f99ff12020-10-15 00:40:11 +110061 &annotator_model_mmap, kLanguageIdentificationModelFilePath,
62 std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100063 if (text_classifier_impl->annotator_ == nullptr ||
64 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110065 // Fails to create annotator, return nullptr.
66 delete text_classifier_impl;
67 return false;
68 }
69
Andrew Moylanb481af72020-07-09 15:22:00 +100070 // Use a disconnection handler to strongly bind `text_classifier_impl` to
71 // `receiver`.
72 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110073 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
74 return true;
75}
76
77TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100078 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
79 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100080 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110081 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100082 annotator_model_mmap, nullptr, nullptr)),
83 language_identifier_(
84 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100085 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110086
Andrew Moylanb481af72020-07-09 15:22:00 +100087void TextClassifierImpl::SetDisconnectionHandler(
88 base::Closure disconnect_handler) {
89 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110090}
91
92void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
93 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100094 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110095 request_metrics.StartRecordingPerformanceMetrics();
96
97 // Parse and set up the options.
98 libtextclassifier3::AnnotationOptions option;
99 if (request->default_locales) {
100 option.locales = request->default_locales.value();
101 }
102 if (request->reference_time) {
103 option.reference_time_ms_utc =
104 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
105 }
106 if (request->reference_timezone) {
107 option.reference_timezone = request->reference_timezone.value();
108 }
109 if (request->enabled_entities) {
110 option.entity_types.insert(request->enabled_entities.value().begin(),
111 request->enabled_entities.value().end());
112 }
113 option.detected_text_language_tags =
114 request->detected_text_language_tags.value_or("en");
115 option.annotation_usecase =
116 static_cast<libtextclassifier3::AnnotationUsecase>(
117 request->annotation_usecase);
118
Honglin Yu6f9185b2020-11-06 10:59:23 +1100119 // Uses the vocab based model.
120 option.use_vocab_annotator = true;
121
Honglin Yuf33dce32019-12-05 15:10:39 +1100122 // Do the annotation.
123 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
124 annotator_->Annotate(request->text, option);
125
126 // Parse the result.
127 std::vector<TextAnnotationPtr> annotations;
128 for (const auto& annotated_result : annotated_spans) {
129 DCHECK(annotated_result.span.second >= annotated_result.span.first);
130 std::vector<TextEntityPtr> entities;
131 for (const auto& classification : annotated_result.classification) {
132 // First, get entity data.
133 auto entity_data = TextEntityData::New();
134 if (classification.collection == "number") {
135 entity_data->set_numeric_value(classification.numeric_double_value);
136 } else {
137 // For the other types, just encode the substring into string_value.
138 // TODO(honglinyu): add data extraction for more types when needed
139 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000140 // Note that the returned indices by annotator is unicode codepoints.
141 entity_data->set_string_value(
142 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
143 .UTF8Substring(annotated_result.span.first,
144 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100145 }
146
147 // Second, create the entity.
148 entities.emplace_back(TextEntity::New(classification.collection,
149 classification.score,
150 std::move(entity_data)));
151 }
152 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
153 annotated_result.span.second,
154 std::move(entities)));
155 }
156
157 std::move(callback).Run(std::move(annotations));
158
159 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100160}
161
162void TextClassifierImpl::SuggestSelection(
163 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000164 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100165 request_metrics.StartRecordingPerformanceMetrics();
166
167 libtextclassifier3::SelectionOptions option;
168 if (request->default_locales) {
169 option.locales = request->default_locales.value();
170 }
171 option.detected_text_language_tags =
172 request->detected_text_language_tags.value_or("en");
173 option.annotation_usecase =
174 static_cast<libtextclassifier3::AnnotationUsecase>(
175 request->annotation_usecase);
176
177 libtextclassifier3::CodepointSpan user_selection;
178 user_selection.first = request->user_selection->start_offset;
179 user_selection.second = request->user_selection->end_offset;
180
181 const libtextclassifier3::CodepointSpan suggested_span =
182 annotator_->SuggestSelection(request->text, user_selection, option);
183 auto result_span = CodepointSpan::New();
184 result_span->start_offset = suggested_span.first;
185 result_span->end_offset = suggested_span.second;
186
187 std::move(callback).Run(std::move(result_span));
188
189 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100190}
191
Honglin Yuc5100022020-07-09 11:54:27 +1000192void TextClassifierImpl::FindLanguages(const std::string& text,
193 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000194 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000195 request_metrics.StartRecordingPerformanceMetrics();
196
197 const std::vector<std::pair<std::string, float>> languages =
198 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
199 text);
200
201 std::vector<TextLanguagePtr> langid_result;
202 for (const auto& lang : languages) {
203 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
204 }
205
206 std::move(callback).Run(std::move(langid_result));
207
208 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuc5100022020-07-09 11:54:27 +1000209}
210
Honglin Yuf33dce32019-12-05 15:10:39 +1100211} // namespace ml