blob: 6d8a62b82c493fd442773e11c6d20cf1a62ffe02 [file] [log] [blame]
Honglin Yuf33dce32019-12-05 15:10:39 +11001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/text_classifier_impl.h"
6
7#include <utility>
8#include <vector>
9
10#include <base/logging.h>
11
12#include "ml/mojom/text_classifier.mojom.h"
13#include "ml/request_metrics.h"
14
15namespace ml {
16
17namespace {
18
19using ::chromeos::machine_learning::mojom::CodepointSpan;
20using ::chromeos::machine_learning::mojom::SuggestSelectionResult;
21using ::chromeos::machine_learning::mojom::TextAnnotation;
22using ::chromeos::machine_learning::mojom::TextAnnotationPtr;
23using ::chromeos::machine_learning::mojom::TextAnnotationRequestPtr;
24using ::chromeos::machine_learning::mojom::TextAnnotationResult;
25using ::chromeos::machine_learning::mojom::TextClassifierRequest;
26using ::chromeos::machine_learning::mojom::TextEntity;
27using ::chromeos::machine_learning::mojom::TextEntityData;
28using ::chromeos::machine_learning::mojom::TextEntityPtr;
29using ::chromeos::machine_learning::mojom::TextSuggestSelectionRequestPtr;
30
31// To avoid passing a lambda as a base::Closure.
32void DeleteTextClassifierImpl(
33 const TextClassifierImpl* const text_classifier_impl) {
34 delete text_classifier_impl;
35}
36
37} // namespace
38
39bool TextClassifierImpl::Create(
40 std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
41 TextClassifierRequest request) {
42 auto text_classifier_impl = new TextClassifierImpl(mmap, std::move(request));
43 if (text_classifier_impl->annotator_ == nullptr) {
44 // Fails to create annotator, return nullptr.
45 delete text_classifier_impl;
46 return false;
47 }
48
49 // Use a connection error handler to strongly bind |text_classifier_impl| to
50 // |request|.
51 text_classifier_impl->SetConnectionErrorHandler(base::Bind(
52 &DeleteTextClassifierImpl, base::Unretained(text_classifier_impl)));
53 return true;
54}
55
56TextClassifierImpl::TextClassifierImpl(
57 std::unique_ptr<libtextclassifier3::ScopedMmap>* mmap,
58 TextClassifierRequest request)
59 : annotator_(libtextclassifier3::Annotator::FromScopedMmap(
60 mmap, nullptr, nullptr)),
61 binding_(this, std::move(request)) {}
62
63void TextClassifierImpl::SetConnectionErrorHandler(
64 base::Closure connection_error_handler) {
65 binding_.set_connection_error_handler(std::move(connection_error_handler));
66}
67
68void TextClassifierImpl::Annotate(TextAnnotationRequestPtr request,
69 AnnotateCallback callback) {
70 RequestMetrics<TextAnnotationResult> request_metrics("TextClassifier",
71 "Annotate");
72 request_metrics.StartRecordingPerformanceMetrics();
73
74 // Parse and set up the options.
75 libtextclassifier3::AnnotationOptions option;
76 if (request->default_locales) {
77 option.locales = request->default_locales.value();
78 }
79 if (request->reference_time) {
80 option.reference_time_ms_utc =
81 request->reference_time->ToTimeT() * base::Time::kMillisecondsPerSecond;
82 }
83 if (request->reference_timezone) {
84 option.reference_timezone = request->reference_timezone.value();
85 }
86 if (request->enabled_entities) {
87 option.entity_types.insert(request->enabled_entities.value().begin(),
88 request->enabled_entities.value().end());
89 }
90 option.detected_text_language_tags =
91 request->detected_text_language_tags.value_or("en");
92 option.annotation_usecase =
93 static_cast<libtextclassifier3::AnnotationUsecase>(
94 request->annotation_usecase);
95
96 // Do the annotation.
97 const std::vector<libtextclassifier3::AnnotatedSpan> annotated_spans =
98 annotator_->Annotate(request->text, option);
99
100 // Parse the result.
101 std::vector<TextAnnotationPtr> annotations;
102 for (const auto& annotated_result : annotated_spans) {
103 DCHECK(annotated_result.span.second >= annotated_result.span.first);
104 std::vector<TextEntityPtr> entities;
105 for (const auto& classification : annotated_result.classification) {
106 // First, get entity data.
107 auto entity_data = TextEntityData::New();
108 if (classification.collection == "number") {
109 entity_data->set_numeric_value(classification.numeric_double_value);
110 } else {
111 // For the other types, just encode the substring into string_value.
112 // TODO(honglinyu): add data extraction for more types when needed
113 // and available.
114 entity_data->set_string_value(request->text.substr(
115 annotated_result.span.first,
116 annotated_result.span.second - annotated_result.span.first));
117 }
118
119 // Second, create the entity.
120 entities.emplace_back(TextEntity::New(classification.collection,
121 classification.score,
122 std::move(entity_data)));
123 }
124 annotations.emplace_back(TextAnnotation::New(annotated_result.span.first,
125 annotated_result.span.second,
126 std::move(entities)));
127 }
128
129 std::move(callback).Run(std::move(annotations));
130
131 request_metrics.FinishRecordingPerformanceMetrics();
132 request_metrics.RecordRequestEvent(TextAnnotationResult::OK);
133}
134
135void TextClassifierImpl::SuggestSelection(
136 TextSuggestSelectionRequestPtr request, SuggestSelectionCallback callback) {
137 RequestMetrics<SuggestSelectionResult> request_metrics("TextClassifier",
138 "SuggestSelection");
139 request_metrics.StartRecordingPerformanceMetrics();
140
141 libtextclassifier3::SelectionOptions option;
142 if (request->default_locales) {
143 option.locales = request->default_locales.value();
144 }
145 option.detected_text_language_tags =
146 request->detected_text_language_tags.value_or("en");
147 option.annotation_usecase =
148 static_cast<libtextclassifier3::AnnotationUsecase>(
149 request->annotation_usecase);
150
151 libtextclassifier3::CodepointSpan user_selection;
152 user_selection.first = request->user_selection->start_offset;
153 user_selection.second = request->user_selection->end_offset;
154
155 const libtextclassifier3::CodepointSpan suggested_span =
156 annotator_->SuggestSelection(request->text, user_selection, option);
157 auto result_span = CodepointSpan::New();
158 result_span->start_offset = suggested_span.first;
159 result_span->end_offset = suggested_span.second;
160
161 std::move(callback).Run(std::move(result_span));
162
163 request_metrics.FinishRecordingPerformanceMetrics();
164 request_metrics.RecordRequestEvent(SuggestSelectionResult::OK);
165}
166
167} // namespace ml