blob: 3c9ef4b0957c28b1f0aaa7b588ce7e34d98e4531 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuf33dce32019-12-05 15:10:39 +110022using ::chromeos::machine_learning::mojom::TextAnnotation;
23using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100025using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110026using ::chromeos::machine_learning::mojom::TextEntity;
27using ::chromeos::machine_learning::mojom::TextEntityData;
28using ::chromeos::machine_learning::mojom::TextEntityPtr;
Honglin Yu273b3422020-07-13 12:20:36 +100029using ::chromeos::machine_learning::mojom::TextLanguage;
30using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110031using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
32
Honglin Yu3f99ff12020-10-15 00:40:11 +110033constexpr char kTextClassifierModelFilePath[] =
Honglin Yu30c65612020-10-23 10:29:01 +110034 "/opt/google/chrome/ml_models/"
35 "mlservice-model-text_classifier_en-v711_vocab-v1.fb";
Honglin Yu3f99ff12020-10-15 00:40:11 +110036
37constexpr char kLanguageIdentificationModelFilePath[] =
38 "/opt/google/chrome/ml_models/"
39 "mlservice-model-language_identification-20190924.smfb";
40
Honglin Yuf33dce32019-12-05 15:10:39 +110041// To avoid passing a lambda as a base::Closure.
42void DeleteTextClassifierImpl(
43 const TextClassifierImpl* const text_classifier_impl) {
44 delete text_classifier_impl;
45}
46
47} // namespace
48
49bool TextClassifierImpl::Create(
Andrew Moylanb481af72020-07-09 15:22:00 +100050 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yu3f99ff12020-10-15 00:40:11 +110051 // Attempt to load model.
52 auto annotator_model_mmap = std::make_unique<libtextclassifier3::ScopedMmap>(
53 kTextClassifierModelFilePath);
54 if (!annotator_model_mmap->handle().ok()) {
55 LOG(ERROR) << "Failed to load the text classifier model file.";
56 return false;
57 }
58
Honglin Yuc5100022020-07-09 11:54:27 +100059 auto text_classifier_impl = new TextClassifierImpl(
Honglin Yu3f99ff12020-10-15 00:40:11 +110060 &annotator_model_mmap, kLanguageIdentificationModelFilePath,
61 std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100062 if (text_classifier_impl->annotator_ == nullptr ||
63 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110064 // Fails to create annotator, return nullptr.
65 delete text_classifier_impl;
66 return false;
67 }
68
Andrew Moylanb481af72020-07-09 15:22:00 +100069 // Use a disconnection handler to strongly bind `text_classifier_impl` to
70 // `receiver`.
71 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110072 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
73 return true;
74}
75
76TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100077 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
78 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100079 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110080 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100081 annotator_model_mmap, nullptr, nullptr)),
82 language_identifier_(
83 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100084 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110085
Andrew Moylanb481af72020-07-09 15:22:00 +100086void TextClassifierImpl::SetDisconnectionHandler(
87 base::Closure disconnect_handler) {
88 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110089}
90
91void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
92 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100093 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110094 request_metrics.StartRecordingPerformanceMetrics();
95
96 // Parse and set up the options.
97 libtextclassifier3::AnnotationOptions option;
98 if (request->default_locales) {
99 option.locales = request->default_locales.value();
100 }
101 if (request->reference_time) {
102 option.reference_time_ms_utc =
103 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
104 }
105 if (request->reference_timezone) {
106 option.reference_timezone = request->reference_timezone.value();
107 }
108 if (request->enabled_entities) {
109 option.entity_types.insert(request->enabled_entities.value().begin(),
110 request->enabled_entities.value().end());
111 }
112 option.detected_text_language_tags =
113 request->detected_text_language_tags.value_or("en");
114 option.annotation_usecase =
115 static_cast<libtextclassifier3::AnnotationUsecase>(
116 request->annotation_usecase);
117
Honglin Yu6f9185b2020-11-06 10:59:23 +1100118 // Uses the vocab based model.
119 option.use_vocab_annotator = true;
120
Honglin Yuf33dce32019-12-05 15:10:39 +1100121 // Do the annotation.
122 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
123 annotator_->Annotate(request->text, option);
124
125 // Parse the result.
126 std::vector<TextAnnotationPtr> annotations;
127 for (const auto& annotated_result : annotated_spans) {
128 DCHECK(annotated_result.span.second >= annotated_result.span.first);
129 std::vector<TextEntityPtr> entities;
130 for (const auto& classification : annotated_result.classification) {
131 // First, get entity data.
132 auto entity_data = TextEntityData::New();
133 if (classification.collection == "number") {
134 entity_data->set_numeric_value(classification.numeric_double_value);
135 } else {
136 // For the other types, just encode the substring into string_value.
137 // TODO(honglinyu): add data extraction for more types when needed
138 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000139 // Note that the returned indices by annotator is unicode codepoints.
140 entity_data->set_string_value(
141 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
142 .UTF8Substring(annotated_result.span.first,
143 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100144 }
145
146 // Second, create the entity.
147 entities.emplace_back(TextEntity::New(classification.collection,
148 classification.score,
149 std::move(entity_data)));
150 }
151 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
152 annotated_result.span.second,
153 std::move(entities)));
154 }
155
156 std::move(callback).Run(std::move(annotations));
157
158 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100159}
160
161void TextClassifierImpl::SuggestSelection(
162 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000163 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100164 request_metrics.StartRecordingPerformanceMetrics();
165
166 libtextclassifier3::SelectionOptions option;
167 if (request->default_locales) {
168 option.locales = request->default_locales.value();
169 }
170 option.detected_text_language_tags =
171 request->detected_text_language_tags.value_or("en");
172 option.annotation_usecase =
173 static_cast<libtextclassifier3::AnnotationUsecase>(
174 request->annotation_usecase);
175
176 libtextclassifier3::CodepointSpan user_selection;
177 user_selection.first = request->user_selection->start_offset;
178 user_selection.second = request->user_selection->end_offset;
179
180 const libtextclassifier3::CodepointSpan suggested_span =
181 annotator_->SuggestSelection(request->text, user_selection, option);
182 auto result_span = CodepointSpan::New();
183 result_span->start_offset = suggested_span.first;
184 result_span->end_offset = suggested_span.second;
185
186 std::move(callback).Run(std::move(result_span));
187
188 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100189}
190
Honglin Yuc5100022020-07-09 11:54:27 +1000191void TextClassifierImpl::FindLanguages(const std::string& text,
192 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000193 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000194 request_metrics.StartRecordingPerformanceMetrics();
195
196 const std::vector<std::pair<std::string, float>> languages =
197 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
198 text);
199
200 std::vector<TextLanguagePtr> langid_result;
201 for (const auto& lang : languages) {
202 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
203 }
204
205 std::move(callback).Run(std::move(langid_result));
206
207 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuc5100022020-07-09 11:54:27 +1000208}
209
Honglin Yuf33dce32019-12-05 15:10:39 +1100210} // namespace ml