blob: 090932f6bd625da9c8a1c53581a49774a56003d6 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
Honglin Yuc5100022020-07-09 11:54:27 +100011#include <lang_id/lang-id-wrapper.h>
Honglin Yu568fc9a2020-06-05 11:57:21 +100012#include <utils/utf8/unicodetext.h>
Honglin Yuf33dce32019-12-05 15:10:39 +110013
14#include "ml/mojom/text_classifier.mojom.h"
15#include "ml/request_metrics.h"
16
17namespace ml {
18
19namespace {
20
21using ::chromeos::machine_learning::mojom::CodepointSpan;
Honglin Yuf33dce32019-12-05 15:10:39 +110022using ::chromeos::machine_learning::mojom::TextAnnotation;
23using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
Andrew Moylanb481af72020-07-09 15:22:00 +100025using ::chromeos::machine_learning::mojom::TextClassifier;
Honglin Yuf33dce32019-12-05 15:10:39 +110026using ::chromeos::machine_learning::mojom::TextEntity;
27using ::chromeos::machine_learning::mojom::TextEntityData;
28using ::chromeos::machine_learning::mojom::TextEntityPtr;
Honglin Yu273b3422020-07-13 12:20:36 +100029using ::chromeos::machine_learning::mojom::TextLanguage;
30using ::chromeos::machine_learning::mojom::TextLanguagePtr;
Honglin Yuf33dce32019-12-05 15:10:39 +110031using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
32
Honglin Yu3f99ff12020-10-15 00:40:11 +110033constexpr char kTextClassifierModelFilePath[] =
34 "/opt/google/chrome/ml_models/mlservice-model-text_classifier_en-v711.fb";
35
36constexpr char kLanguageIdentificationModelFilePath[] =
37 "/opt/google/chrome/ml_models/"
38 "mlservice-model-language_identification-20190924.smfb";
39
Honglin Yuf33dce32019-12-05 15:10:39 +110040// To avoid passing a lambda as a base::Closure.
41void DeleteTextClassifierImpl(
42 const TextClassifierImpl* const text_classifier_impl) {
43 delete text_classifier_impl;
44}
45
46} // namespace
47
48bool TextClassifierImpl::Create(
Andrew Moylanb481af72020-07-09 15:22:00 +100049 mojo::PendingReceiver<TextClassifier> receiver) {
Honglin Yu3f99ff12020-10-15 00:40:11 +110050 // Attempt to load model.
51 auto annotator_model_mmap = std::make_unique<libtextclassifier3::ScopedMmap>(
52 kTextClassifierModelFilePath);
53 if (!annotator_model_mmap->handle().ok()) {
54 LOG(ERROR) << "Failed to load the text classifier model file.";
55 return false;
56 }
57
Honglin Yuc5100022020-07-09 11:54:27 +100058 auto text_classifier_impl = new TextClassifierImpl(
Honglin Yu3f99ff12020-10-15 00:40:11 +110059 &annotator_model_mmap, kLanguageIdentificationModelFilePath,
60 std::move(receiver));
Honglin Yuc5100022020-07-09 11:54:27 +100061 if (text_classifier_impl->annotator_ == nullptr ||
62 text_classifier_impl->language_identifier_ == nullptr) {
Honglin Yuf33dce32019-12-05 15:10:39 +110063 // Fails to create annotator, return nullptr.
64 delete text_classifier_impl;
65 return false;
66 }
67
Andrew Moylanb481af72020-07-09 15:22:00 +100068 // Use a disconnection handler to strongly bind `text_classifier_impl` to
69 // `receiver`.
70 text_classifier_impl->SetDisconnectionHandler(base::Bind(
Honglin Yuf33dce32019-12-05 15:10:39 +110071 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
72 return true;
73}
74
75TextClassifierImpl::TextClassifierImpl(
Honglin Yuc5100022020-07-09 11:54:27 +100076 std::unique_ptr<libtextclassifier3::ScopedMmap>* annotator_model_mmap,
77 const std::string& langid_model_path,
Andrew Moylanb481af72020-07-09 15:22:00 +100078 mojo::PendingReceiver<TextClassifier> receiver)
Honglin Yuf33dce32019-12-05 15:10:39 +110079 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
Honglin Yuc5100022020-07-09 11:54:27 +100080 annotator_model_mmap, nullptr, nullptr)),
81 language_identifier_(
82 libtextclassifier3::langid::LoadFromPath(langid_model_path)),
Andrew Moylanb481af72020-07-09 15:22:00 +100083 receiver_(this, std::move(receiver)) {}
Honglin Yuf33dce32019-12-05 15:10:39 +110084
Andrew Moylanb481af72020-07-09 15:22:00 +100085void TextClassifierImpl::SetDisconnectionHandler(
86 base::Closure disconnect_handler) {
87 receiver_.set_disconnect_handler(std::move(disconnect_handler));
Honglin Yuf33dce32019-12-05 15:10:39 +110088}
89
90void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
91 AnnotateCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +100092 RequestMetrics request_metrics("TextClassifier", "Annotate");
Honglin Yuf33dce32019-12-05 15:10:39 +110093 request_metrics.StartRecordingPerformanceMetrics();
94
95 // Parse and set up the options.
96 libtextclassifier3::AnnotationOptions option;
97 if (request->default_locales) {
98 option.locales = request->default_locales.value();
99 }
100 if (request->reference_time) {
101 option.reference_time_ms_utc =
102 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
103 }
104 if (request->reference_timezone) {
105 option.reference_timezone = request->reference_timezone.value();
106 }
107 if (request->enabled_entities) {
108 option.entity_types.insert(request->enabled_entities.value().begin(),
109 request->enabled_entities.value().end());
110 }
111 option.detected_text_language_tags =
112 request->detected_text_language_tags.value_or("en");
113 option.annotation_usecase =
114 static_cast<libtextclassifier3::AnnotationUsecase>(
115 request->annotation_usecase);
116
117 // Do the annotation.
118 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
119 annotator_->Annotate(request->text, option);
120
121 // Parse the result.
122 std::vector<TextAnnotationPtr> annotations;
123 for (const auto& annotated_result : annotated_spans) {
124 DCHECK(annotated_result.span.second >= annotated_result.span.first);
125 std::vector<TextEntityPtr> entities;
126 for (const auto& classification : annotated_result.classification) {
127 // First, get entity data.
128 auto entity_data = TextEntityData::New();
129 if (classification.collection == "number") {
130 entity_data->set_numeric_value(classification.numeric_double_value);
131 } else {
132 // For the other types, just encode the substring into string_value.
133 // TODO(honglinyu): add data extraction for more types when needed
134 // and available.
Honglin Yu568fc9a2020-06-05 11:57:21 +1000135 // Note that the returned indices by annotator is unicode codepoints.
136 entity_data->set_string_value(
137 libtextclassifier3::UTF8ToUnicodeText(request->text, false)
138 .UTF8Substring(annotated_result.span.first,
139 annotated_result.span.second));
Honglin Yuf33dce32019-12-05 15:10:39 +1100140 }
141
142 // Second, create the entity.
143 entities.emplace_back(TextEntity::New(classification.collection,
144 classification.score,
145 std::move(entity_data)));
146 }
147 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
148 annotated_result.span.second,
149 std::move(entities)));
150 }
151
152 std::move(callback).Run(std::move(annotations));
153
154 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100155}
156
157void TextClassifierImpl::SuggestSelection(
158 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000159 RequestMetrics request_metrics("TextClassifier", "SuggestSelection");
Honglin Yuf33dce32019-12-05 15:10:39 +1100160 request_metrics.StartRecordingPerformanceMetrics();
161
162 libtextclassifier3::SelectionOptions option;
163 if (request->default_locales) {
164 option.locales = request->default_locales.value();
165 }
166 option.detected_text_language_tags =
167 request->detected_text_language_tags.value_or("en");
168 option.annotation_usecase =
169 static_cast<libtextclassifier3::AnnotationUsecase>(
170 request->annotation_usecase);
171
172 libtextclassifier3::CodepointSpan user_selection;
173 user_selection.first = request->user_selection->start_offset;
174 user_selection.second = request->user_selection->end_offset;
175
176 const libtextclassifier3::CodepointSpan suggested_span =
177 annotator_->SuggestSelection(request->text, user_selection, option);
178 auto result_span = CodepointSpan::New();
179 result_span->start_offset = suggested_span.first;
180 result_span->end_offset = suggested_span.second;
181
182 std::move(callback).Run(std::move(result_span));
183
184 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuf33dce32019-12-05 15:10:39 +1100185}
186
Honglin Yuc5100022020-07-09 11:54:27 +1000187void TextClassifierImpl::FindLanguages(const std::string& text,
188 FindLanguagesCallback callback) {
charleszhao5a7050e2020-07-14 15:21:41 +1000189 RequestMetrics request_metrics("TextClassifier", "FindLanguages");
Honglin Yuc5100022020-07-09 11:54:27 +1000190 request_metrics.StartRecordingPerformanceMetrics();
191
192 const std::vector<std::pair<std::string, float>> languages =
193 libtextclassifier3::langid::GetPredictions(language_identifier_.get(),
194 text);
195
196 std::vector<TextLanguagePtr> langid_result;
197 for (const auto& lang : languages) {
198 langid_result.emplace_back(TextLanguage::New(lang.first, lang.second));
199 }
200
201 std::move(callback).Run(std::move(langid_result));
202
203 request_metrics.FinishRecordingPerformanceMetrics();
Honglin Yuc5100022020-07-09 11:54:27 +1000204}
205
Honglin Yuf33dce32019-12-05 15:10:39 +1100206} // namespace ml