blob: fb8ea545e0c3d58141663022f786a459ed9f1268 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100012#include <tensorflow/lite/context.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
Michael Martisa967f632018-08-10 10:39:00 +100015
16namespace ml {
17
18using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
19using ::chromeos::machine_learning::mojom::GraphExecutor;
20using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
21using ::chromeos::machine_learning::mojom::ModelRequest;
22
alanlxlcb1f8562018-11-01 15:16:11 +110023// Base name for UMA metrics related to CreateGraphExecutor calls
Honglin Yu6adafcd2019-07-22 13:48:11 +100024constexpr char kMetricsRequestName[] = "CreateGraphExecutorResult";
alanlxlcb1f8562018-11-01 15:16:11 +110025
Michael Martisa967f632018-08-10 10:39:00 +100026ModelImpl::ModelImpl(const std::map<std::string, int>& required_inputs,
27 const std::map<std::string, int>& required_outputs,
28 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu6adafcd2019-07-22 13:48:11 +100029 ModelRequest request,
30 const std::string& metrics_model_name)
Michael Martisa967f632018-08-10 10:39:00 +100031 : required_inputs_(required_inputs),
32 required_outputs_(required_outputs),
33 model_(std::move(model)),
Honglin Yu6adafcd2019-07-22 13:48:11 +100034 binding_(this, std::move(request)),
35 metrics_model_name_(metrics_model_name) {}
Michael Martisa967f632018-08-10 10:39:00 +100036
37void ModelImpl::set_connection_error_handler(
38 base::Closure connection_error_handler) {
39 binding_.set_connection_error_handler(std::move(connection_error_handler));
40}
41
42int ModelImpl::num_graph_executors_for_testing() const {
43 return graph_executors_.size();
44}
45
46void ModelImpl::CreateGraphExecutor(
47 GraphExecutorRequest request, const CreateGraphExecutorCallback& callback) {
Honglin Yu6adafcd2019-07-22 13:48:11 +100048 DCHECK(!metrics_model_name_.empty());
49
50 RequestMetrics<CreateGraphExecutorResult> request_metrics(
51 metrics_model_name_, kMetricsRequestName);
alanlxlcb1f8562018-11-01 15:16:11 +110052 request_metrics.StartRecordingPerformanceMetrics();
Honglin Yu6adafcd2019-07-22 13:48:11 +100053
Michael Martisa967f632018-08-10 10:39:00 +100054 if (model_ == nullptr) {
55 LOG(ERROR) << "Null model provided.";
56 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110057 request_metrics.RecordRequestEvent(
58 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100059 return;
60 }
61
62 // Instantiate interpreter.
63 tflite::ops::builtin::BuiltinOpResolver resolver;
64 std::unique_ptr<tflite::Interpreter> interpreter;
65 const TfLiteStatus resolve_status =
66 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
67 if (resolve_status != kTfLiteOk || !interpreter) {
68 LOG(ERROR) << "Could not resolve model ops.";
69 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110070 request_metrics.RecordRequestEvent(
71 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100072 return;
73 }
74
75 // Allocate memory for tensors.
76 if (interpreter->AllocateTensors() != kTfLiteOk) {
77 callback.Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110078 request_metrics.RecordRequestEvent(
79 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100080 return;
81 }
82
83 // Add graph executor and schedule its deletion on pipe closure.
84 graph_executors_.emplace_front(required_inputs_, required_outputs_,
Honglin Yu6adafcd2019-07-22 13:48:11 +100085 std::move(interpreter), std::move(request),
86 metrics_model_name_);
Michael Martisa967f632018-08-10 10:39:00 +100087 graph_executors_.front().set_connection_error_handler(
88 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
89 graph_executors_.begin()));
90
91 callback.Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +110092 request_metrics.FinishRecordingPerformanceMetrics();
93 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +100094}
95
96void ModelImpl::EraseGraphExecutor(
97 const std::list<GraphExecutorImpl>::const_iterator it) {
98 graph_executors_.erase(it);
99}
100
101} // namespace ml