blob: 8eb09e3712dc074691b78366168b9d6348f7c5c5 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuf33dce32019-12-05 15:10:39 +110022using ::chromeos::machine_learning::mojom::TextAnnotation;
23using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100025using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110026using ::chromeos::machine_learning::mojom::TextEntity;
27using ::chromeos::machine_learning::mojom::TextEntityData;
28using ::chromeos::machine_learning::mojom::TextEntityPtr;
Honglin Yu273b3422020-07-13 12:20:36 +100029using ::chromeos::machine_learning::mojom::TextLanguage;
30using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110031using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
32
Honglin Yu3f99ff12020-10-15 00:40:11 +110033constexpr char kTextClassifierModelFilePath[] =
Honglin Yu30c65612020-10-23 10:29:01 +110034 "/opt/google/chrome/ml_models/"
35 "mlservice-model-text_classifier_en-v711_vocab-v1.fb";
Honglin Yu3f99ff12020-10-15 00:40:11 +110036
37constexpr char kLanguageIdentificationModelFilePath[] =
38 "/opt/google/chrome/ml_models/"
39 "mlservice-model-language_identification-20190924.smfb";
40
Honglin Yuf33dce32019-12-05 15:10:39 +110041// To avoid passing a lambda as a base::Closure.
42void DeleteTextClassifierImpl(
43 const TextClassifierImpl* const text_classifier_impl) {
44 delete text_classifier_impl;
45}
46
47} // namespace
48
49bool TextClassifierImpl::Create(
Andrew Moylanb481af72020-07-09 15:22:00 +100050 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yu3f99ff12020-10-15 00:40:11 +110051 // Attempt to load model.
52 auto annotator_model_mmap = std::make_unique<libtextclassifier3::ScopedMmap>(
53 kTextClassifierModelFilePath);
54 if (!annotator_model_mmap->handle().ok()) {
55 LOG(ERROR) << "Failed to load the text classifier model file.";
56 return false;
57 }
58
Honglin Yuc5100022020-07-09 11:54:27 +100059 auto text_classifier_impl = new TextClassifierImpl(
Honglin Yu3f99ff12020-10-15 00:40:11 +110060 &annotator_model_mmap, kLanguageIdentificationModelFilePath,
61 std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100062 if (text_classifier_impl->annotator_ == nullptr ||
63 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110064 // Fails to create annotator, return nullptr.
65 delete text_classifier_impl;
66 return false;
67 }
68
Andrew Moylanb481af72020-07-09 15:22:00 +100069 // Use a disconnection handler to strongly bind `text_classifier_impl` to
70 // `receiver`.
71 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110072 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
73 return true;
74}
75
76TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100077 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
78 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100079 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110080 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100081 annotator_model_mmap, nullptr, nullptr)),
82 language_identifier_(
83 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100084 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110085
Andrew Moylanb481af72020-07-09 15:22:00 +100086void TextClassifierImpl::SetDisconnectionHandler(
87 base::Closure disconnect_handler) {
88 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110089}
90
91void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
92 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100093 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110094 request_metrics.StartRecordingPerformanceMetrics();
95
96 // Parse and set up the options.
97 libtextclassifier3::AnnotationOptions option;
98 if (request->default_locales) {
99 option.locales = request->default_locales.value();
100 }
101 if (request->reference_time) {
102 option.reference_time_ms_utc =
103 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
104 }
105 if (request->reference_timezone) {
106 option.reference_timezone = request->reference_timezone.value();
107 }
108 if (request->enabled_entities) {
109 option.entity_types.insert(request->enabled_entities.value().begin(),
110 request->enabled_entities.value().end());
111 }
112 option.detected_text_language_tags =
113 request->detected_text_language_tags.value_or("en");
114 option.annotation_usecase =
115 static_cast<libtextclassifier3::AnnotationUsecase>(
116 request->annotation_usecase);
117
118 // Do the annotation.
119 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
120 annotator_->Annotate(request->text, option);
121
122 // Parse the result.
123 std::vector<TextAnnotationPtr> annotations;
124 for (const auto& annotated_result : annotated_spans) {
125 DCHECK(annotated_result.span.second >= annotated_result.span.first);
126 std::vector<TextEntityPtr> entities;
127 for (const auto& classification : annotated_result.classification) {
128 // First, get entity data.
129 auto entity_data = TextEntityData::New();
130 if (classification.collection == "number") {
131 entity_data->set_numeric_value(classification.numeric_double_value);
132 } else {
133 // For the other types, just encode the substring into string_value.
134 // TODO(honglinyu): add data extraction for more types when needed
135 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000136 // Note that the returned indices by annotator is unicode codepoints.
137 entity_data->set_string_value(
138 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
139 .UTF8Substring(annotated_result.span.first,
140 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100141 }
142
143 // Second, create the entity.
144 entities.emplace_back(TextEntity::New(classification.collection,
145 classification.score,
146 std::move(entity_data)));
147 }
148 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
149 annotated_result.span.second,
150 std::move(entities)));
151 }
152
153 std::move(callback).Run(std::move(annotations));
154
155 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100156}
157
158void TextClassifierImpl::SuggestSelection(
159 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000160 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100161 request_metrics.StartRecordingPerformanceMetrics();
162
163 libtextclassifier3::SelectionOptions option;
164 if (request->default_locales) {
165 option.locales = request->default_locales.value();
166 }
167 option.detected_text_language_tags =
168 request->detected_text_language_tags.value_or("en");
169 option.annotation_usecase =
170 static_cast<libtextclassifier3::AnnotationUsecase>(
171 request->annotation_usecase);
172
173 libtextclassifier3::CodepointSpan user_selection;
174 user_selection.first = request->user_selection->start_offset;
175 user_selection.second = request->user_selection->end_offset;
176
177 const libtextclassifier3::CodepointSpan suggested_span =
178 annotator_->SuggestSelection(request->text, user_selection, option);
179 auto result_span = CodepointSpan::New();
180 result_span->start_offset = suggested_span.first;
181 result_span->end_offset = suggested_span.second;
182
183 std::move(callback).Run(std::move(result_span));
184
185 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100186}
187
Honglin Yuc5100022020-07-09 11:54:27 +1000188void TextClassifierImpl::FindLanguages(const std::string& text,
189 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000190 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000191 request_metrics.StartRecordingPerformanceMetrics();
192
193 const std::vector<std::pair<std::string, float>> languages =
194 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
195 text);
196
197 std::vector<TextLanguagePtr> langid_result;
198 for (const auto& lang : languages) {
199 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
200 }
201
202 std::move(callback).Run(std::move(langid_result));
203
204 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuc5100022020-07-09 11:54:27 +1000205}
206
Honglin Yuf33dce32019-12-05 15:10:39 +1100207} // namespace ml