blob: 8247db1e762fcfe7f31dd1ffaacb5131c50f750a [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Andrew Moylanff6be512018-07-03 11:05:01 +10007
Michael Martisa74af932018-08-13 16:52:36 +10008#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10009#include <utility>
10
Michael Martisa74af932018-08-13 16:52:36 +100011#include <base/bind.h>
12#include <base/bind_helpers.h>
13#include <tensorflow/contrib/lite/model.h>
14
15#include "ml/model_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090016#include "ml/mojom/model.mojom.h"
Michael Martisa74af932018-08-13 16:52:36 +100017
Andrew Moylanff6be512018-07-03 11:05:01 +100018namespace ml {
19
Michael Martisa74af932018-08-13 16:52:36 +100020namespace {
21
22using ::chromeos::machine_learning::mojom::LoadModelResult;
23using ::chromeos::machine_learning::mojom::ModelId;
24using ::chromeos::machine_learning::mojom::ModelRequest;
25using ::chromeos::machine_learning::mojom::ModelSpecPtr;
26
27constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
alanlxlcb1f8562018-11-01 15:16:11 +110028// Base name for UMA metrics related to LoadModel requests
29constexpr char kMetricsNameBase[] = "LoadModelResult";
Michael Martisa74af932018-08-13 16:52:36 +100030
31// To avoid passing a lambda as a base::Closure.
32void DeleteModelImpl(const ModelImpl* const model_impl) {
33 delete model_impl;
34}
35
36} // namespace
37
Andrew Moylanff6be512018-07-03 11:05:01 +100038MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100039 mojo::ScopedMessagePipeHandle pipe,
40 base::Closure connection_error_handler,
41 const std::string& model_dir)
42 : model_metadata_(GetModelMetadata()),
43 model_dir_(model_dir),
44 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100045 binding_.set_connection_error_handler(std::move(connection_error_handler));
46}
47
Michael Martisa74af932018-08-13 16:52:36 +100048MachineLearningServiceImpl::MachineLearningServiceImpl(
49 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
50 : MachineLearningServiceImpl(std::move(pipe),
51 std::move(connection_error_handler),
52 kSystemModelDir) {}
53
54void MachineLearningServiceImpl::LoadModel(ModelSpecPtr spec,
55 ModelRequest request,
56 const LoadModelCallback& callback) {
alanlxlcb1f8562018-11-01 15:16:11 +110057 RequestMetrics<LoadModelResult> request_metrics(kMetricsNameBase);
58 request_metrics.StartRecordingPerformanceMetrics();
Prashant Malani6f2de072018-11-19 13:45:38 -080059 if (spec->id <= ModelId::UNKNOWN || spec->id > ModelId::kMax) {
Michael Martisa74af932018-08-13 16:52:36 +100060 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110061 request_metrics.RecordRequestEvent(LoadModelResult::MODEL_SPEC_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100062 return;
63 }
64
65 // Shouldn't happen (as we maintain a metadata entry for every valid model),
66 // but can't hurt to be defensive.
67 const auto metadata_lookup = model_metadata_.find(spec->id);
68 if (metadata_lookup == model_metadata_.end()) {
69 LOG(ERROR) << "No metadata present for model ID " << spec->id << ".";
70 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110071 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100072 return;
73 }
74 const ModelMetadata& metadata = metadata_lookup->second;
75
76 // Attempt to load model.
77 const std::string model_path = model_dir_ + metadata.model_file;
78 std::unique_ptr<tflite::FlatBufferModel> model =
79 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
80 if (model == nullptr) {
81 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
82 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110083 request_metrics.RecordRequestEvent(LoadModelResult::LOAD_MODEL_ERROR);
Michael Martisa74af932018-08-13 16:52:36 +100084 return;
85 }
86
87 // Use a connection error handler to strongly bind |model_impl| to |request|.
88 ModelImpl* const model_impl =
89 new ModelImpl(metadata.required_inputs, metadata.required_outputs,
90 std::move(model), std::move(request));
91 model_impl->set_connection_error_handler(
92 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
93 callback.Run(LoadModelResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +110094 request_metrics.FinishRecordingPerformanceMetrics();
95 request_metrics.RecordRequestEvent(LoadModelResult::OK);
Andrew Moylanff6be512018-07-03 11:05:01 +100096}
97
98} // namespace ml