blob: d3da04132f23dfec17145c5119dc6584bd4609d3 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100015
Honglin Yuc0cef102020-01-17 15:26:01 +110016namespace {
17
18// Callback for self-owned ModelImpl's to delete themselves upon connection
19// error.
20void DeleteModelImpl(const ml::ModelImpl* const model_impl) {
21 delete model_impl;
22}
23
24} // namespace
25
Michael Martisa967f632018-08-10 10:39:00 +100026namespace ml {
27
28using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
29using ::chromeos::machine_learning::mojom::GraphExecutor;
30using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
31using ::chromeos::machine_learning::mojom::ModelRequest;
32
alanlxlcb1f8562018-11-01 15:16:11 +110033// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100034constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110035
Honglin Yuc0cef102020-01-17 15:26:01 +110036ModelImpl* ModelImpl::Create(
37 std::map<std::string, int> required_inputs,
38 std::map<std::string, int> required_outputs,
39 std::unique_ptr<tflite::FlatBufferModel> model,
40 std::unique_ptr<std::string> model_string,
41 chromeos::machine_learning::mojom::ModelRequest request,
42 const std::string& metrics_model_name) {
43 auto model_impl = new ModelImpl(
44 std::move(required_inputs), std::move(required_outputs), std::move(model),
45 std::move(model_string), std::move(request), metrics_model_name);
46 // Use a connection error handler to strongly bind |model_impl| to |request|.
47 model_impl->set_connection_error_handler(
48 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
49
50 return model_impl;
51}
52
53ModelImpl* ModelImpl::Create(
54 std::map<std::string, int> required_inputs,
55 std::map<std::string, int> required_outputs,
56 std::unique_ptr<tflite::FlatBufferModel> model,
57 chromeos::machine_learning::mojom::ModelRequest request,
58 const std::string& metrics_model_name) {
59 auto model_impl = new ModelImpl(
60 std::move(required_inputs), std::move(required_outputs), std::move(model),
61 nullptr, std::move(request), metrics_model_name);
62 // Use a connection error handler to strongly bind |model_impl| to |request|.
63 model_impl->set_connection_error_handler(
64 base::Bind(&DeleteModelImpl, base::Unretained(model_impl)));
65
66 return model_impl;
67}
68
Honglin Yu0ed72352019-08-27 17:42:01 +100069ModelImpl::ModelImpl(std::map<std::string, int> required_inputs,
70 std::map<std::string, int> required_outputs,
Michael Martisa967f632018-08-10 10:39:00 +100071 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu0ed72352019-08-27 17:42:01 +100072 std::unique_ptr<std::string> model_string,
Honglin Yu6adafcd2019-07-22 13:48:11 +100073 ModelRequest request,
74 const std::string& metrics_model_name)
Honglin Yu0ed72352019-08-27 17:42:01 +100075 : required_inputs_(std::move(required_inputs)),
76 required_outputs_(std::move(required_outputs)),
77 model_string_(std::move(model_string)),
Michael Martisa967f632018-08-10 10:39:00 +100078 model_(std::move(model)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100079 binding_(this, std::move(request)),
80 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100081
82void ModelImpl::set_connection_error_handler(
83 base::Closure connection_error_handler) {
84 binding_.set_connection_error_handler(std::move(connection_error_handler));
85}
86
87int ModelImpl::num_graph_executors_for_testing() const {
88 return graph_executors_.size();
89}
90
91void ModelImpl::CreateGraphExecutor(
92 GraphExecutorRequest request, const CreateGraphExecutorCallback& callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +100093 DCHECK(!metrics_model_name_.empty());
94
95 RequestMetrics<CreateGraphExecutorResult> request_metrics(
96 metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +110097 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +100098
Michael Martisa967f632018-08-10 10:39:00 +100099 if (model_ == nullptr) {
100 LOG(ERROR) << "Null model provided.";
101 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100102 request_metrics.RecordRequestEvent(
103 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000104 return;
105 }
106
107 // Instantiate interpreter.
108 tflite::ops::builtin::BuiltinOpResolver resolver;
109 std::unique_ptr<tflite::Interpreter> interpreter;
110 const TfLiteStatus resolve_status =
111 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
112 if (resolve_status != kTfLiteOk || !interpreter) {
113 LOG(ERROR) << "Could not resolve model ops.";
114 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100115 request_metrics.RecordRequestEvent(
116 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000117 return;
118 }
119
120 // Allocate memory for tensors.
121 if (interpreter->AllocateTensors() != kTfLiteOk) {
122 callback.Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +1100123 request_metrics.RecordRequestEvent(
124 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +1000125 return;
126 }
127
128 // Add graph executor and schedule its deletion on pipe closure.
129 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Honglin Yu6adafcd2019-07-22 13:48:11 +1000130 std::move(interpreter), std::move(request),
131 metrics_model_name_);
Michael Martisa967f632018-08-10 10:39:00 +1000132 graph_executors_.front().set_connection_error_handler(
133 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
134 graph_executors_.begin()));
135
136 callback.Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +1100137 request_metrics.FinishRecordingPerformanceMetrics();
138 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +1000139}
140
141void ModelImpl::EraseGraphExecutor(
142 const std::list<GraphExecutorImpl>::const_iterator it) {
143 graph_executors_.erase(it);
144}
145
146} // namespace ml