blob: d52793168b76b6ab5f89c652a98cff9f4d7e22e7 [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_GRAPH_EXECUTOR_IMPL_H_
6#define ML_GRAPH_EXECUTOR_IMPL_H_
7
8#include <map>
9#include <memory>
10#include <string>
11
12#include <base/callback_forward.h>
13#include <base/macros.h>
Michael Martis26abcd82018-08-08 10:57:25 +100014#include <mojo/public/cpp/bindings/binding.h>
Michael Martis26abcd82018-08-08 10:57:25 +100015#include <tensorflow/contrib/lite/model.h>
16
Hidehiko Abeaa488c32018-08-31 23:49:41 +090017#include "ml/mojom/graph_executor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100018
19namespace ml {
20
21// Allows execution of TensorFlow lite graphs using input / output specified
22// with Mojo types.
23//
24// Holds as little state as possible (with the remainder living in the parent
25// Model object and shared between all sibling GraphExecutors). Hence, a
26// GraphExecutor becomes invalid when its parent Model object is destroyed.
27//
28// A given GraphExecutorImpl may not be used concurrently from different
29// sequences.
30class GraphExecutorImpl
31 : public chromeos::machine_learning::mojom::GraphExecutor {
32 public:
33 // Creates an instance bound to |request|.
34 //
35 // The |required_inputs| and |required_outputs| arguments specify a mapping
36 // from required input / output tensor names to their indices in the TF lite
37 // graph, and must outlive this object.
38 //
39 // As is standard, |interpreter| must outlive the model with which it was
40 // constructed.
41 GraphExecutorImpl(
42 const std::map<std::string, int>& required_inputs,
43 const std::map<std::string, int>& required_outputs,
44 std::unique_ptr<tflite::Interpreter> interpreter,
45 chromeos::machine_learning::mojom::GraphExecutorRequest request);
46
47 void set_connection_error_handler(base::Closure connection_error_handler);
48
49 private:
50 // chromeos::machine_learning::mojom::GraphExecutor:
51 void Execute(mojo::Map<mojo::String,
52 chromeos::machine_learning::mojom::TensorPtr> inputs,
53 mojo::Array<mojo::String> output_names,
54 const ExecuteCallback& callback);
55
56 const std::map<std::string, int>& required_inputs_;
57 const std::map<std::string, int>& required_outputs_;
58
59 const std::unique_ptr<tflite::Interpreter> interpreter_;
60
61 mojo::Binding<chromeos::machine_learning::mojom::GraphExecutor> binding_;
62
63 DISALLOW_COPY_AND_ASSIGN(GraphExecutorImpl);
64};
65
66} // namespace ml
67
68#endif // ML_GRAPH_EXECUTOR_IMPL_H_