blob: 6130f8fe6a6c5c48df73efff5e7e378ec9df5a81 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
13#include <tensorflow/contrib/lite/model.h>
14
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
22using ::chromeos::machine_learning::mojom::LoadModelResult;
23using ::chromeos::machine_learning::mojom::ModelId;
24using ::chromeos::machine_learning::mojom::ModelRequest;
25using ::chromeos::machine_learning::mojom::ModelSpecPtr;
26
27constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
alanlxlcb1f8562018-11-01 15:16:11 +110028// Base name for UMA metrics related to LoadModel requests
Honglin Yu6adafcd2019-07-22 13:48:11 +100029constexpr char kMetricsRequestName[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100030
31// To avoid passing a lambda as a base::Closure.
32void DeleteModelImpl(const ModelImpl* const model_impl) {
33 delete model_impl;
34}
35
36} // namespace
37
Andrew Moylanff6be512018-07-03 11:05:01 +100038MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100039 mojo::ScopedMessagePipeHandle pipe,
40 base::Closure connection_error_handler,
41 const std::string& model_dir)
42 : model_metadata_(GetModelMetadata()),
43 model_dir_(model_dir),
44 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100045 binding_.set_connection_error_handler(std::move(connection_error_handler));
46}
47
Michael Martisa74af932018-08-13 16:52:36 +100048MachineLearningServiceImpl::MachineLearningServiceImpl(
49 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
50 : MachineLearningServiceImpl(std::move(pipe),
51 std::move(connection_error_handler),
52 kSystemModelDir) {}
53
54void MachineLearningServiceImpl::LoadModel(ModelSpecPtr spec,
55 ModelRequest request,
56 const LoadModelCallback& callback) {
Andrew Moylan195a6f52019-05-16 20:57:32 +100057 // Unsupported models do not have metadata entries.
Michael Martisa74af932018-08-13 16:52:36 +100058 const auto metadata_lookup = model_metadata_.find(spec->id);
59 if (metadata_lookup == model_metadata_.end()) {
Andrew Moylan195a6f52019-05-16 20:57:32 +100060 LOG(WARNING) << "LoadModel requested for unsupported model ID " << spec->id
61 << ".";
62 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
Honglin Yu6adafcd2019-07-22 13:48:11 +100063 RecordModelSpecificationErrorEvent();
Michael Martisa74af932018-08-13 16:52:36 +100064 return;
65 }
Honglin Yu6adafcd2019-07-22 13:48:11 +100066
Michael Martisa74af932018-08-13 16:52:36 +100067 const ModelMetadata& metadata = metadata_lookup->second;
68
Honglin Yu6adafcd2019-07-22 13:48:11 +100069 DCHECK(!metadata.metrics_model_name.empty());
70
71 RequestMetrics<LoadModelResult> request_metrics(metadata.metrics_model_name,
72 kMetricsRequestName);
73 request_metrics.StartRecordingPerformanceMetrics();
74
Michael Martisa74af932018-08-13 16:52:36 +100075 // Attempt to load model.
76 const std::string model_path = model_dir_ + metadata.model_file;
77 std::unique_ptr<tflite::FlatBufferModel> model =
78 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
79 if (model == nullptr) {
80 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
81 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110082 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100083 return;
84 }
85
86 // Use a connection error handler to strongly bind |model_impl| to |request|.
Honglin Yu6adafcd2019-07-22 13:48:11 +100087 ModelImpl* const model_impl = new ModelImpl(
88 metadata.required_inputs, metadata.required_outputs, std::move(model),
89 std::move(request), metadata.metrics_model_name);
Michael Martisa74af932018-08-13 16:52:36 +100090 model_impl->set_connection_error_handler(
91 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
92 callback.Run(LoadModelResult::OK);
Honglin Yu6adafcd2019-07-22 13:48:11 +100093
alanlxlcb1f8562018-11-01 15:16:11 +110094 request_metrics.FinishRecordingPerformanceMetrics();
95 request_metrics.RecordRequestEvent(LoadModelResult::OK);
Andrew Moylanff6be512018-07-03 11:05:01 +100096}
97
98} // namespace ml