blob: e5528d88e73be62e276ce449976c7c477e72efe8 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100015
Honglin Yuc0cef102020-01-17 15:26:01 +110016namespace {
17
18// Callback for self-owned ModelImpl's to delete themselves upon connection
19// error.
20void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
21 delete model_impl;
22}
23
24} // namespace
25
Michael Martisa967f632018-08-10 10:39:00 +100026namespace ml {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::GraphExecutor;
30using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
31using ::chromeos::machine_learning::mojom::ModelRequest;
32
alanlxlcb1f8562018-11-01 15:16:11 +110033// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100034constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110035
Honglin Yuc0cef102020-01-17 15:26:01 +110036ModelImpl* ModelImpl::Create(
37 std::map<std::string, int> required_inputs,
38 std::map<std::string, int> required_outputs,
39 std::unique_ptr<tflite::FlatBufferModel> model,
40 std::unique_ptr<std::string> model_string,
41 chromeos::machine_learning::mojom::ModelRequest request,
42 const std::string& metrics_model_name) {
43 auto model_impl = new ModelImpl(
44 std::move(required_inputs), std::move(required_outputs), std::move(model),
45 std::move(model_string), std::move(request), metrics_model_name);
46 // Use a connection error handler to strongly bind |model_impl| to |request|.
47 model_impl->set_connection_error_handler(
48 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
49
50 return model_impl;
51}
52
53ModelImpl* ModelImpl::Create(
54 std::map<std::string, int> required_inputs,
55 std::map<std::string, int> required_outputs,
56 std::unique_ptr<tflite::FlatBufferModel> model,
57 chromeos::machine_learning::mojom::ModelRequest request,
58 const std::string& metrics_model_name) {
59 auto model_impl = new ModelImpl(
60 std::move(required_inputs), std::move(required_outputs), std::move(model),
61 nullptr, std::move(request), metrics_model_name);
62 // Use a connection error handler to strongly bind |model_impl| to |request|.
63 model_impl->set_connection_error_handler(
64 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
65
66 return model_impl;
67}
68
Honglin Yu0ed72352019-08-27 17:42:01 +100069ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
70 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100071 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu0ed72352019-08-27 17:42:01 +100072 std::unique_ptr<std::string> model_string,
Honglin Yu6adafcd2019-07-22 13:48:11 +100073 ModelRequest request,
74 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +100075 : required_inputs_(std::move(required_inputs)),
76 required_outputs_(std::move(required_outputs)),
77 model_string_(std::move(model_string)),
Michael Martisa967f632018-08-10 10:39:00 +100078 model_(std::move(model)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100079 binding_(this, std::move(request)),
80 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100081
82void ModelImpl::set_connection_error_handler(
83 base::Closure connection_error_handler) {
84 binding_.set_connection_error_handler(std::move(connection_error_handler));
85}
86
87int ModelImpl::num_graph_executors_for_testing() const {
88 return graph_executors_.size();
89}
90
Qijiang Fan5d381a02020-04-19 23:42:37 +090091void ModelImpl::CreateGraphExecutor(GraphExecutorRequest request,
92 CreateGraphExecutorCallback callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +100093 DCHECK(!metrics_model_name_.empty());
94
95 RequestMetrics<CreateGraphExecutorResult> request_metrics(
96 metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +110097 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +100098
Michael Martisa967f632018-08-10 10:39:00 +100099 if (model_ == nullptr) {
100 LOG(ERROR) << "Null model provided.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900101 std::move(callback).Run(
102 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100103 request_metrics.RecordRequestEvent(
104 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000105 return;
106 }
107
108 // Instantiate interpreter.
109 tflite::ops::builtin::BuiltinOpResolver resolver;
110 std::unique_ptr<tflite::Interpreter> interpreter;
111 const TfLiteStatus resolve_status =
112 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
113 if (resolve_status != kTfLiteOk || !interpreter) {
114 LOG(ERROR) << "Could not resolve model ops.";
Qijiang Fan5d381a02020-04-19 23:42:37 +0900115 std::move(callback).Run(
116 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100117 request_metrics.RecordRequestEvent(
118 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000119 return;
120 }
121
122 // Allocate memory for tensors.
123 if (interpreter->AllocateTensors() != kTfLiteOk) {
Qijiang Fan5d381a02020-04-19 23:42:37 +0900124 std::move(callback).Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100125 request_metrics.RecordRequestEvent(
126 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000127 return;
128 }
129
130 // Add graph executor and schedule its deletion on pipe closure.
131 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000132 std::move(interpreter), std::move(request),
133 metrics_model_name_);
Michael Martisa967f632018-08-10 10:39:00 +1000134 graph_executors_.front().set_connection_error_handler(
135 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
136 graph_executors_.begin()));
137
Qijiang Fan5d381a02020-04-19 23:42:37 +0900138 std::move(callback).Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100139 request_metrics.FinishRecordingPerformanceMetrics();
140 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000141}
142
143void ModelImpl::EraseGraphExecutor(
144 const std::list<GraphExecutorImpl>::const_iterator it) {
145 graph_executors_.erase(it);
146}
147
148} // namespace ml