blob: 1dfcdc71ee39accb221d9739cd48851f2ff08f33 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100015
16namespace ml {
17
18using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
19using ::chromeos::machine_learning::mojom::GraphExecutor;
20using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
21using ::chromeos::machine_learning::mojom::ModelRequest;
22
alanlxlcb1f8562018-11-01 15:16:11 +110023// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100024constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110025
Honglin Yu0ed72352019-08-27 17:42:01 +100026ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
27 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100028 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu0ed72352019-08-27 17:42:01 +100029 std::unique_ptr<std::string> model_string,
Honglin Yu6adafcd2019-07-22 13:48:11 +100030 ModelRequest request,
31 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +100032 : required_inputs_(std::move(required_inputs)),
33 required_outputs_(std::move(required_outputs)),
34 model_string_(std::move(model_string)),
Michael Martisa967f632018-08-10 10:39:00 +100035 model_(std::move(model)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100036 binding_(this, std::move(request)),
37 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100038
Honglin Yu0ed72352019-08-27 17:42:01 +100039ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
40 std::map<std::string, int> required_outputs,
41 std::unique_ptr<tflite::FlatBufferModel> model,
42 ModelRequest request,
43 const std::string& metrics_model_name)
44 : ModelImpl(std::move(required_inputs),
45 std::move(required_outputs),
46 std::move(model),
47 nullptr,
48 std::move(request),
49 metrics_model_name) {}
50
Michael Martisa967f632018-08-10 10:39:00 +100051void ModelImpl::set_connection_error_handler(
52 base::Closure connection_error_handler) {
53 binding_.set_connection_error_handler(std::move(connection_error_handler));
54}
55
56int ModelImpl::num_graph_executors_for_testing() const {
57 return graph_executors_.size();
58}
59
60void ModelImpl::CreateGraphExecutor(
61 GraphExecutorRequest request, const CreateGraphExecutorCallback& callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +100062 DCHECK(!metrics_model_name_.empty());
63
64 RequestMetrics<CreateGraphExecutorResult> request_metrics(
65 metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +110066 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +100067
Michael Martisa967f632018-08-10 10:39:00 +100068 if (model_ == nullptr) {
69 LOG(ERROR) << "Null model provided.";
70 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110071 request_metrics.RecordRequestEvent(
72 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100073 return;
74 }
75
76 // Instantiate interpreter.
77 tflite::ops::builtin::BuiltinOpResolver resolver;
78 std::unique_ptr<tflite::Interpreter> interpreter;
79 const TfLiteStatus resolve_status =
80 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
81 if (resolve_status != kTfLiteOk || !interpreter) {
82 LOG(ERROR) << "Could not resolve model ops.";
83 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110084 request_metrics.RecordRequestEvent(
85 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100086 return;
87 }
88
89 // Allocate memory for tensors.
90 if (interpreter->AllocateTensors() != kTfLiteOk) {
91 callback.Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110092 request_metrics.RecordRequestEvent(
93 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100094 return;
95 }
96
97 // Add graph executor and schedule its deletion on pipe closure.
98 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Honglin Yu6adafcd2019-07-22 13:48:11 +100099 std::move(interpreter), std::move(request),
100 metrics_model_name_);
Michael Martisa967f632018-08-10 10:39:00 +1000101 graph_executors_.front().set_connection_error_handler(
102 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
103 graph_executors_.begin()));
104
105 callback.Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100106 request_metrics.FinishRecordingPerformanceMetrics();
107 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000108}
109
110void ModelImpl::EraseGraphExecutor(
111 const std::list<GraphExecutorImpl>::const_iterator it) {
112 graph_executors_.erase(it);
113}
114
115} // namespace ml