blob: de71d72e467fdab2609c569e4df58d43bff7a713 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100013#include <tensorflow/lite/model.h>
Michael Martisa74af932018-08-13 16:52:36 +100014
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
Honglin Yu0ed72352019-08-27 17:42:01 +100022using ::chromeos::machine_learning::mojom::BuiltinModelId;
23using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
24using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100025using ::chromeos::machine_learning::mojom::LoadModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100026using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100027
28constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Honglin Yua81145a2019-09-23 15:20:13 +100029// Base name for UMA metrics related to model loading (either |LoadBuiltinModel|
30// or |LoadFlatBufferModel|) requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100031constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100032
Michael Martisa74af932018-08-13 16:52:36 +100033} // namespace
34
Andrew Moylanff6be512018-07-03 11:05:01 +100035MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100036 mojo::ScopedMessagePipeHandle pipe,
37 base::Closure connection_error_handler,
38 const std::string& model_dir)
Honglin Yua81145a2019-09-23 15:20:13 +100039 : builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100040 model_dir_(model_dir),
hscham68867652020-01-06 11:40:47 +090041 binding_(this,
42 mojo::InterfaceRequest<
43 chromeos::machine_learning::mojom::MachineLearningService>(
44 std::move(pipe))) {
Andrew Moylanff6be512018-07-03 11:05:01 +100045 binding_.set_connection_error_handler(std::move(connection_error_handler));
46}
47
Michael Martisa74af932018-08-13 16:52:36 +100048MachineLearningServiceImpl::MachineLearningServiceImpl(
49 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
50 : MachineLearningServiceImpl(std::move(pipe),
51 std::move(connection_error_handler),
52 kSystemModelDir) {}
53
Honglin Yu0ed72352019-08-27 17:42:01 +100054void MachineLearningServiceImpl::LoadBuiltinModel(
55 BuiltinModelSpecPtr spec,
56 ModelRequest request,
57 const LoadBuiltinModelCallback& callback) {
58 // Unsupported models do not have metadata entries.
59 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
60 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100061 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
62 << spec->id << ".";
Honglin Yu0ed72352019-08-27 17:42:01 +100063 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
64 RecordModelSpecificationErrorEvent();
65 return;
66 }
67
68 const BuiltinModelMetadata& metadata = metadata_lookup->second;
69
70 DCHECK(!metadata.metrics_model_name.empty());
71
72 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
73 kMetricsRequestName);
74 request_metrics.StartRecordingPerformanceMetrics();
75
76 // Attempt to load model.
77 const std::string model_path = model_dir_ + metadata.model_file;
78 std::unique_ptr<tflite::FlatBufferModel> model =
79 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
80 if (model == nullptr) {
81 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
82 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
83 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
84 return;
85 }
86
Honglin Yuc0cef102020-01-17 15:26:01 +110087 ModelImpl::Create(metadata.required_inputs, metadata.required_outputs,
88 std::move(model), std::move(request),
89 metadata.metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +100090
Honglin Yu0ed72352019-08-27 17:42:01 +100091 callback.Run(LoadModelResult::OK);
92
93 request_metrics.FinishRecordingPerformanceMetrics();
94 request_metrics.RecordRequestEvent(LoadModelResult::OK);
95}
96
97void MachineLearningServiceImpl::LoadFlatBufferModel(
98 FlatBufferModelSpecPtr spec,
99 ModelRequest request,
100 const LoadFlatBufferModelCallback& callback) {
101 DCHECK(!spec->metrics_model_name.empty());
102
103 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
104 kMetricsRequestName);
105 request_metrics.StartRecordingPerformanceMetrics();
106
107 // Take the ownership of the content of |model_string| because |ModelImpl| has
108 // to hold the memory.
109 auto model_string_impl =
110 std::make_unique<std::string>(std::move(spec->model_string));
111
112 std::unique_ptr<tflite::FlatBufferModel> model =
113 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
114 model_string_impl->length());
115 if (model == nullptr) {
116 LOG(ERROR) << "Failed to load model string of metric name: "
117 << spec->metrics_model_name << "'.";
118 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
119 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
120 return;
121 }
122
Honglin Yuc0cef102020-01-17 15:26:01 +1100123 ModelImpl::Create(
Honglin Yu0ed72352019-08-27 17:42:01 +1000124 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
125 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
126 std::move(model), std::move(model_string_impl), std::move(request),
127 spec->metrics_model_name);
128
Honglin Yu0ed72352019-08-27 17:42:01 +1000129 callback.Run(LoadModelResult::OK);
130
131 request_metrics.FinishRecordingPerformanceMetrics();
132 request_metrics.RecordRequestEvent(LoadModelResult::OK);
133}
134
Andrew Moylanff6be512018-07-03 11:05:01 +1000135} // namespace ml