blob: 6fa87aef091eeb01255ed72d83daa9dc4a7aed38 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100013#include <tensorflow/lite/model.h>
Michael Martisa74af932018-08-13 16:52:36 +100014
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
Honglin Yu0ed72352019-08-27 17:42:01 +100022using ::chromeos::machine_learning::mojom::BuiltinModelId;
23using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
24using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100025using ::chromeos::machine_learning::mojom::LoadModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100026using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100027
28constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Honglin Yua81145a2019-09-23 15:20:13 +100029// Base name for UMA metrics related to model loading (either |LoadBuiltinModel|
30// or |LoadFlatBufferModel|) requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100031constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100032
33// To avoid passing a lambda as a base::Closure.
34void DeleteModelImpl(const ModelImpl* const model_impl) {
35 delete model_impl;
36}
37
38} // namespace
39
Andrew Moylanff6be512018-07-03 11:05:01 +100040MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100041 mojo::ScopedMessagePipeHandle pipe,
42 base::Closure connection_error_handler,
43 const std::string& model_dir)
Honglin Yua81145a2019-09-23 15:20:13 +100044 : builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100045 model_dir_(model_dir),
hscham68867652020-01-06 11:40:47 +090046 binding_(this,
47 mojo::InterfaceRequest<
48 chromeos::machine_learning::mojom::MachineLearningService>(
49 std::move(pipe))) {
Andrew Moylanff6be512018-07-03 11:05:01 +100050 binding_.set_connection_error_handler(std::move(connection_error_handler));
51}
52
Michael Martisa74af932018-08-13 16:52:36 +100053MachineLearningServiceImpl::MachineLearningServiceImpl(
54 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
55 : MachineLearningServiceImpl(std::move(pipe),
56 std::move(connection_error_handler),
57 kSystemModelDir) {}
58
Honglin Yu0ed72352019-08-27 17:42:01 +100059void MachineLearningServiceImpl::LoadBuiltinModel(
60 BuiltinModelSpecPtr spec,
61 ModelRequest request,
62 const LoadBuiltinModelCallback& callback) {
63 // Unsupported models do not have metadata entries.
64 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
65 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100066 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
67 << spec->id << ".";
Honglin Yu0ed72352019-08-27 17:42:01 +100068 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
69 RecordModelSpecificationErrorEvent();
70 return;
71 }
72
73 const BuiltinModelMetadata& metadata = metadata_lookup->second;
74
75 DCHECK(!metadata.metrics_model_name.empty());
76
77 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
78 kMetricsRequestName);
79 request_metrics.StartRecordingPerformanceMetrics();
80
81 // Attempt to load model.
82 const std::string model_path = model_dir_ + metadata.model_file;
83 std::unique_ptr<tflite::FlatBufferModel> model =
84 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
85 if (model == nullptr) {
86 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
87 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
88 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
89 return;
90 }
91
92 // Use a connection error handler to strongly bind |model_impl| to |request|.
93 ModelImpl* const model_impl = new ModelImpl(
94 metadata.required_inputs, metadata.required_outputs, std::move(model),
95 std::move(request), metadata.metrics_model_name);
96
97 model_impl->set_connection_error_handler(
98 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
99 callback.Run(LoadModelResult::OK);
100
101 request_metrics.FinishRecordingPerformanceMetrics();
102 request_metrics.RecordRequestEvent(LoadModelResult::OK);
103}
104
105void MachineLearningServiceImpl::LoadFlatBufferModel(
106 FlatBufferModelSpecPtr spec,
107 ModelRequest request,
108 const LoadFlatBufferModelCallback& callback) {
109 DCHECK(!spec->metrics_model_name.empty());
110
111 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
112 kMetricsRequestName);
113 request_metrics.StartRecordingPerformanceMetrics();
114
115 // Take the ownership of the content of |model_string| because |ModelImpl| has
116 // to hold the memory.
117 auto model_string_impl =
118 std::make_unique<std::string>(std::move(spec->model_string));
119
120 std::unique_ptr<tflite::FlatBufferModel> model =
121 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
122 model_string_impl->length());
123 if (model == nullptr) {
124 LOG(ERROR) << "Failed to load model string of metric name: "
125 << spec->metrics_model_name << "'.";
126 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
127 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
128 return;
129 }
130
131 // Use a connection error handler to strongly bind |model_impl| to |request|.
132 ModelImpl* model_impl = new ModelImpl(
133 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
134 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
135 std::move(model), std::move(model_string_impl), std::move(request),
136 spec->metrics_model_name);
137
138 model_impl->set_connection_error_handler(
139 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
140 callback.Run(LoadModelResult::OK);
141
142 request_metrics.FinishRecordingPerformanceMetrics();
143 request_metrics.RecordRequestEvent(LoadModelResult::OK);
144}
145
Andrew Moylanff6be512018-07-03 11:05:01 +1000146} // namespace ml