blob: 8c97447e69313dc347d5aac8e47a0cabdfdd6a3c [file] [log] [blame]
Andrew Moylanff6be512018-07-03 11:05:01 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/machine_learning_service_impl.h"
6
Michael Martisa74af932018-08-13 16:52:36 +10007#include <memory>
Andrew Moylanff6be512018-07-03 11:05:01 +10008#include <utility>
9
Michael Martisa74af932018-08-13 16:52:36 +100010#include <base/bind.h>
11#include <base/bind_helpers.h>
12#include <tensorflow/contrib/lite/model.h>
13
14#include "ml/model_impl.h"
15#include "mojom/model.mojom.h"
16
Andrew Moylanff6be512018-07-03 11:05:01 +100017namespace ml {
18
Michael Martisa74af932018-08-13 16:52:36 +100019namespace {
20
21using ::chromeos::machine_learning::mojom::LoadModelResult;
22using ::chromeos::machine_learning::mojom::ModelId;
23using ::chromeos::machine_learning::mojom::ModelRequest;
24using ::chromeos::machine_learning::mojom::ModelSpecPtr;
25
26constexpr char kSystemModelDir[] = "/opt/google/chrome/ml_models/";
27
28// To avoid passing a lambda as a base::Closure.
29void DeleteModelImpl(const ModelImpl* const model_impl) {
30 delete model_impl;
31}
32
33} // namespace
34
Andrew Moylanff6be512018-07-03 11:05:01 +100035MachineLearningServiceImpl::MachineLearningServiceImpl(
Michael Martisa74af932018-08-13 16:52:36 +100036 mojo::ScopedMessagePipeHandle pipe,
37 base::Closure connection_error_handler,
38 const std::string& model_dir)
39 : model_metadata_(GetModelMetadata()),
40 model_dir_(model_dir),
41 binding_(this, std::move(pipe)) {
Andrew Moylanff6be512018-07-03 11:05:01 +100042 binding_.set_connection_error_handler(std::move(connection_error_handler));
43}
44
Michael Martisa74af932018-08-13 16:52:36 +100045MachineLearningServiceImpl::MachineLearningServiceImpl(
46 mojo::ScopedMessagePipeHandle pipe, base::Closure connection_error_handler)
47 : MachineLearningServiceImpl(std::move(pipe),
48 std::move(connection_error_handler),
49 kSystemModelDir) {}
50
51void MachineLearningServiceImpl::LoadModel(ModelSpecPtr spec,
52 ModelRequest request,
53 const LoadModelCallback& callback) {
54 if (spec->id <= ModelId::UNKNOWN || spec->id > ModelId::TEST_MODEL) {
55 callback.Run(LoadModelResult::MODEL_SPEC_ERROR);
56 return;
57 }
58
59 // Shouldn't happen (as we maintain a metadata entry for every valid model),
60 // but can't hurt to be defensive.
61 const auto metadata_lookup = model_metadata_.find(spec->id);
62 if (metadata_lookup == model_metadata_.end()) {
63 LOG(ERROR) << "No metadata present for model ID " << spec->id << ".";
64 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
65 return;
66 }
67 const ModelMetadata& metadata = metadata_lookup->second;
68
69 // Attempt to load model.
70 const std::string model_path = model_dir_ + metadata.model_file;
71 std::unique_ptr<tflite::FlatBufferModel> model =
72 tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
73 if (model == nullptr) {
74 LOG(ERROR) << "Failed to load model file '" << model_path << "'.";
75 callback.Run(LoadModelResult::LOAD_MODEL_ERROR);
76 return;
77 }
78
79 // Use a connection error handler to strongly bind |model_impl| to |request|.
80 ModelImpl* const model_impl =
81 new ModelImpl(metadata.required_inputs, metadata.required_outputs,
82 std::move(model), std::move(request));
83 model_impl->set_connection_error_handler(
84 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
85 callback.Run(LoadModelResult::OK);
Andrew Moylanff6be512018-07-03 11:05:01 +100086}
87
88} // namespace ml