blob: 83126b50b37974a6c643d5561849e73b2cd90590 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
alanlxlcb1f8562018-11-01 15:16:11 +11006#include "ml/request_metrics.h"
Michael Martisa967f632018-08-10 10:39:00 +10007
8#include <utility>
9
10#include <base/bind.h>
11#include <base/bind_helpers.h>
12#include <tensorflow/contrib/lite/context.h>
13#include <tensorflow/contrib/lite/interpreter.h>
14#include <tensorflow/contrib/lite/kernels/register.h>
15
16namespace ml {
17
18using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
19using ::chromeos::machine_learning::mojom::GraphExecutor;
20using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
21using ::chromeos::machine_learning::mojom::ModelRequest;
22
alanlxlcb1f8562018-11-01 15:16:11 +110023// Base name for UMA metrics related to CreateGraphExecutor calls
24constexpr char kMetricsNameBase[] = "CreateGraphExecutorResult";
25
Michael Martisa967f632018-08-10 10:39:00 +100026ModelImpl::ModelImpl(const std::map<std::string, int>& required_inputs,
27 const std::map<std::string, int>& required_outputs,
28 std::unique_ptr<tflite::FlatBufferModel> model,
29 ModelRequest request)
30 : required_inputs_(required_inputs),
31 required_outputs_(required_outputs),
32 model_(std::move(model)),
33 binding_(this, std::move(request)) {}
34
35void ModelImpl::set_connection_error_handler(
36 base::Closure connection_error_handler) {
37 binding_.set_connection_error_handler(std::move(connection_error_handler));
38}
39
40int ModelImpl::num_graph_executors_for_testing() const {
41 return graph_executors_.size();
42}
43
44void ModelImpl::CreateGraphExecutor(
45 GraphExecutorRequest request, const CreateGraphExecutorCallback& callback) {
alanlxlcb1f8562018-11-01 15:16:11 +110046 RequestMetrics<CreateGraphExecutorResult> request_metrics(kMetricsNameBase);
47 request_metrics.StartRecordingPerformanceMetrics();
Michael Martisa967f632018-08-10 10:39:00 +100048 if (model_ == nullptr) {
49 LOG(ERROR) << "Null model provided.";
50 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110051 request_metrics.RecordRequestEvent(
52 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100053 return;
54 }
55
56 // Instantiate interpreter.
57 tflite::ops::builtin::BuiltinOpResolver resolver;
58 std::unique_ptr<tflite::Interpreter> interpreter;
59 const TfLiteStatus resolve_status =
60 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
61 if (resolve_status != kTfLiteOk || !interpreter) {
62 LOG(ERROR) << "Could not resolve model ops.";
63 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110064 request_metrics.RecordRequestEvent(
65 CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100066 return;
67 }
68
69 // Allocate memory for tensors.
70 if (interpreter->AllocateTensors() != kTfLiteOk) {
71 callback.Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
alanlxlcb1f8562018-11-01 15:16:11 +110072 request_metrics.RecordRequestEvent(
73 CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
Michael Martisa967f632018-08-10 10:39:00 +100074 return;
75 }
76
77 // Add graph executor and schedule its deletion on pipe closure.
78 graph_executors_.emplace_front(required_inputs_, required_outputs_,
79 std::move(interpreter), std::move(request));
80 graph_executors_.front().set_connection_error_handler(
81 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
82 graph_executors_.begin()));
83
84 callback.Run(CreateGraphExecutorResult::OK);
alanlxlcb1f8562018-11-01 15:16:11 +110085 request_metrics.FinishRecordingPerformanceMetrics();
86 request_metrics.RecordRequestEvent(CreateGraphExecutorResult::OK);
Michael Martisa967f632018-08-10 10:39:00 +100087}
88
89void ModelImpl::EraseGraphExecutor(
90 const std::list<GraphExecutorImpl>::const_iterator it) {
91 graph_executors_.erase(it);
92}
93
94} // namespace ml