blob: 0dc25dfee041341a46fecedb63a989841c5700ad [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "ml/model_impl.h"
6
7#include <utility>
8
9#include <base/bind.h>
10#include <base/bind_helpers.h>
11#include <tensorflow/contrib/lite/context.h>
12#include <tensorflow/contrib/lite/interpreter.h>
13#include <tensorflow/contrib/lite/kernels/register.h>
14
15namespace ml {
16
17using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
18using ::chromeos::machine_learning::mojom::GraphExecutor;
19using ::chromeos::machine_learning::mojom::GraphExecutorRequest;
20using ::chromeos::machine_learning::mojom::ModelRequest;
21
22ModelImpl::ModelImpl(const std::map<std::string, int>& required_inputs,
23 const std::map<std::string, int>& required_outputs,
24 std::unique_ptr<tflite::FlatBufferModel> model,
25 ModelRequest request)
26 : required_inputs_(required_inputs),
27 required_outputs_(required_outputs),
28 model_(std::move(model)),
29 binding_(this, std::move(request)) {}
30
31void ModelImpl::set_connection_error_handler(
32 base::Closure connection_error_handler) {
33 binding_.set_connection_error_handler(std::move(connection_error_handler));
34}
35
36int ModelImpl::num_graph_executors_for_testing() const {
37 return graph_executors_.size();
38}
39
40void ModelImpl::CreateGraphExecutor(
41 GraphExecutorRequest request, const CreateGraphExecutorCallback& callback) {
42 if (model_ == nullptr) {
43 LOG(ERROR) << "Null model provided.";
44 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
45 return;
46 }
47
48 // Instantiate interpreter.
49 tflite::ops::builtin::BuiltinOpResolver resolver;
50 std::unique_ptr<tflite::Interpreter> interpreter;
51 const TfLiteStatus resolve_status =
52 tflite::InterpreterBuilder(*model_, resolver)(&interpreter);
53 if (resolve_status != kTfLiteOk || !interpreter) {
54 LOG(ERROR) << "Could not resolve model ops.";
55 callback.Run(CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
56 return;
57 }
58
59 // Allocate memory for tensors.
60 if (interpreter->AllocateTensors() != kTfLiteOk) {
61 callback.Run(CreateGraphExecutorResult::MEMORY_ALLOCATION_ERROR);
62 return;
63 }
64
65 // Add graph executor and schedule its deletion on pipe closure.
66 graph_executors_.emplace_front(required_inputs_, required_outputs_,
67 std::move(interpreter), std::move(request));
68 graph_executors_.front().set_connection_error_handler(
69 base::Bind(&ModelImpl::EraseGraphExecutor, base::Unretained(this),
70 graph_executors_.begin()));
71
72 callback.Run(CreateGraphExecutorResult::OK);
73}
74
75void ModelImpl::EraseGraphExecutor(
76 const std::list<GraphExecutorImpl>::const_iterator it) {
77 graph_executors_.erase(it);
78}
79
80} // namespace ml