blob: 4a00b28fa5b8fd9bbc9930d3406f820b3f10eb1c [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100013#include <tensorflow/lite/model.h>
Michael Martisa74af932018-08-13 16:52:36 +100014
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
Honglin Yu0ed72352019-08-27 17:42:01 +100022using ::chromeos::machine_learning::mojom::BuiltinModelId;
23using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
24using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
Michael Martisa74af932018-08-13 16:52:36 +100025using ::chromeos::machine_learning::mojom::LoadModelResult;
Michael Martisa74af932018-08-13 16:52:36 +100026using ::chromeos::machine_learning::mojom::ModelRequest;
Michael Martisa74af932018-08-13 16:52:36 +100027
28constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
Honglin Yua81145a2019-09-23 15:20:13 +100029// Base name for UMA metrics related to model loading (either |LoadBuiltinModel|
30// or |LoadFlatBufferModel|) requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100031constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100032
33// To avoid passing a lambda as a base::Closure.
34void DeleteModelImpl(const ModelImpl* const model_impl) {
35 delete model_impl;
36}
37
38} // namespace
39
Andrew Moylanff6be512018-07-03 11:05:01 +100040MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100041 mojo::ScopedMessagePipeHandle pipe,
42 base::Closure connection_error_handler,
43 const std::string& model_dir)
Honglin Yua81145a2019-09-23 15:20:13 +100044 : builtin_model_metadata_(GetBuiltinModelMetadata()),
Michael Martisa74af932018-08-13 16:52:36 +100045 model_dir_(model_dir),
46 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100047 binding_.set_connection_error_handler(std::move(connection_error_handler));
48}
49
Michael Martisa74af932018-08-13 16:52:36 +100050MachineLearningServiceImpl::MachineLearningServiceImpl(
51 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
52 : MachineLearningServiceImpl(std::move(pipe),
53 std::move(connection_error_handler),
54 kSystemModelDir) {}
55
Honglin Yu0ed72352019-08-27 17:42:01 +100056void MachineLearningServiceImpl::LoadBuiltinModel(
57 BuiltinModelSpecPtr spec,
58 ModelRequest request,
59 const LoadBuiltinModelCallback& callback) {
60 // Unsupported models do not have metadata entries.
61 const auto metadata_lookup = builtin_model_metadata_.find(spec->id);
62 if (metadata_lookup == builtin_model_metadata_.end()) {
Honglin Yua81145a2019-09-23 15:20:13 +100063 LOG(WARNING) << "LoadBuiltinModel requested for unsupported model ID "
64 << spec->id << ".";
Honglin Yu0ed72352019-08-27 17:42:01 +100065 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
66 RecordModelSpecificationErrorEvent();
67 return;
68 }
69
70 const BuiltinModelMetadata& metadata = metadata_lookup->second;
71
72 DCHECK(!metadata.metrics_model_name.empty());
73
74 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
75 kMetricsRequestName);
76 request_metrics.StartRecordingPerformanceMetrics();
77
78 // Attempt to load model.
79 const std::string model_path = model_dir_ + metadata.model_file;
80 std::unique_ptr<tflite::FlatBufferModel> model =
81 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
82 if (model == nullptr) {
83 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
84 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
85 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
86 return;
87 }
88
89 // Use a connection error handler to strongly bind |model_impl| to |request|.
90 ModelImpl* const model_impl = new ModelImpl(
91 metadata.required_inputs, metadata.required_outputs, std::move(model),
92 std::move(request), metadata.metrics_model_name);
93
94 model_impl->set_connection_error_handler(
95 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
96 callback.Run(LoadModelResult::OK);
97
98 request_metrics.FinishRecordingPerformanceMetrics();
99 request_metrics.RecordRequestEvent(LoadModelResult::OK);
100}
101
102void MachineLearningServiceImpl::LoadFlatBufferModel(
103 FlatBufferModelSpecPtr spec,
104 ModelRequest request,
105 const LoadFlatBufferModelCallback& callback) {
106 DCHECK(!spec->metrics_model_name.empty());
107
108 RequestMetrics<LoadModelResult> request_metrics(spec->metrics_model_name,
109 kMetricsRequestName);
110 request_metrics.StartRecordingPerformanceMetrics();
111
112 // Take the ownership of the content of |model_string| because |ModelImpl| has
113 // to hold the memory.
114 auto model_string_impl =
115 std::make_unique<std::string>(std::move(spec->model_string));
116
117 std::unique_ptr<tflite::FlatBufferModel> model =
118 tflite::FlatBufferModel::BuildFromBuffer(model_string_impl->c_str(),
119 model_string_impl->length());
120 if (model == nullptr) {
121 LOG(ERROR) << "Failed to load model string of metric name: "
122 << spec->metrics_model_name << "'.";
123 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
124 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
125 return;
126 }
127
128 // Use a connection error handler to strongly bind |model_impl| to |request|.
129 ModelImpl* model_impl = new ModelImpl(
130 std::map<std::string, int>(spec->inputs.begin(), spec->inputs.end()),
131 std::map<std::string, int>(spec->outputs.begin(), spec->outputs.end()),
132 std::move(model), std::move(model_string_impl), std::move(request),
133 spec->metrics_model_name);
134
135 model_impl->set_connection_error_handler(
136 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
137 callback.Run(LoadModelResult::OK);
138
139 request_metrics.FinishRecordingPerformanceMetrics();
140 request_metrics.RecordRequestEvent(LoadModelResult::OK);
141}
142
Andrew Moylanff6be512018-07-03 11:05:01 +1000143} // namespace ml