blob: 4f57f58be1061aa2a02de7e860288b8686dc7ad5 [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
13#include <tensorflow/contrib/lite/model.h>
14
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
22using ::chromeos::machine_learning::mojom::LoadModelResult;
23using ::chromeos::machine_learning::mojom::ModelId;
24using ::chromeos::machine_learning::mojom::ModelRequest;
25using ::chromeos::machine_learning::mojom::ModelSpecPtr;
26
27constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
alanlxlcb1f8562018-11-01 15:16:11 +110028// Base name for UMA metrics related to LoadModel requests
29constexpr char kMetricsNameBase[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100030
31// To avoid passing a lambda as a base::Closure.
32void DeleteModelImpl(const ModelImpl* const model_impl) {
33 delete model_impl;
34}
35
36} // namespace
37
Andrew Moylanff6be512018-07-03 11:05:01 +100038MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100039 mojo::ScopedMessagePipeHandle pipe,
40 base::Closure connection_error_handler,
41 const std::string& model_dir)
42 : model_metadata_(GetModelMetadata()),
43 model_dir_(model_dir),
44 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100045 binding_.set_connection_error_handler(std::move(connection_error_handler));
46}
47
Michael Martisa74af932018-08-13 16:52:36 +100048MachineLearningServiceImpl::MachineLearningServiceImpl(
49 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
50 : MachineLearningServiceImpl(std::move(pipe),
51 std::move(connection_error_handler),
52 kSystemModelDir) {}
53
54void MachineLearningServiceImpl::LoadModel(ModelSpecPtr spec,
55 ModelRequest request,
56 const LoadModelCallback& callback) {
alanlxlcb1f8562018-11-01 15:16:11 +110057 RequestMetrics<LoadModelResult> request_metrics(kMetricsNameBase);
58 request_metrics.StartRecordingPerformanceMetrics();
Michael Martisa74af932018-08-13 16:52:36 +100059
Andrew Moylan195a6f52019-05-16 20:57:32 +100060 // Unsupported models do not have metadata entries.
Michael Martisa74af932018-08-13 16:52:36 +100061 const auto metadata_lookup = model_metadata_.find(spec->id);
62 if (metadata_lookup == model_metadata_.end()) {
Andrew Moylan195a6f52019-05-16 20:57:32 +100063 LOG(WARNING) << "LoadModel requested for unsupported model ID " << spec->id
64 << ".";
65 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
66 request_metrics.RecordRequestEvent(LoadModelResult::MODEL_SPEC_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100067 return;
68 }
69 const ModelMetadata& metadata = metadata_lookup->second;
70
71 // Attempt to load model.
72 const std::string model_path = model_dir_ + metadata.model_file;
73 std::unique_ptr<tflite::FlatBufferModel> model =
74 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
75 if (model == nullptr) {
76 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
77 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110078 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100079 return;
80 }
81
82 // Use a connection error handler to strongly bind |model_impl| to |request|.
83 ModelImpl* const model_impl =
84 new ModelImpl(metadata.required_inputs, metadata.required_outputs,
85 std::move(model), std::move(request));
86 model_impl->set_connection_error_handler(
87 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
88 callback.Run(LoadModelResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +110089 request_metrics.FinishRecordingPerformanceMetrics();
90 request_metrics.RecordRequestEvent(LoadModelResult::OK);
Andrew Moylanff6be512018-07-03 11:05:01 +100091}
92
93} // namespace ml