blob: 5b5ce0f90adc5a2426c0225a53c91d8ae055b65e [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_GRAPH_EXECUTOR_IMPL_H_
6#define ML_GRAPH_EXECUTOR_IMPL_H_
7
8#include <map>
9#include <memory>
10#include <string>
Hidehiko Abe31bb9632018-11-23 02:49:56 +090011#include <unordered_map>
12#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100013
14#include <base/callback_forward.h>
15#include <base/macros.h>
Michael Martis26abcd82018-08-08 10:57:25 +100016#include <mojo/public/cpp/bindings/binding.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100017#include <tensorflow/lite/model.h>
Michael Martis26abcd82018-08-08 10:57:25 +100018
Hidehiko Abeaa488c32018-08-31 23:49:41 +090019#include "ml/mojom/graph_executor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100020
21namespace ml {
22
23// Allows execution of TensorFlow lite graphs using input / output specified
24// with Mojo types.
25//
26// Holds as little state as possible (with the remainder living in the parent
27// Model object and shared between all sibling GraphExecutors). Hence, a
28// GraphExecutor becomes invalid when its parent Model object is destroyed.
29//
30// A given GraphExecutorImpl may not be used concurrently from different
31// sequences.
32class GraphExecutorImpl
33 : public chromeos::machine_learning::mojom::GraphExecutor {
34 public:
35 // Creates an instance bound to |request|.
36 //
37 // The |required_inputs| and |required_outputs| arguments specify a mapping
38 // from required input / output tensor names to their indices in the TF lite
39 // graph, and must outlive this object.
40 //
Honglin Yu6adafcd2019-07-22 13:48:11 +100041 // UMA metrics will be logged with the specified |metrics_model_name|.
42 //
Michael Martis26abcd82018-08-08 10:57:25 +100043 // As is standard, |interpreter| must outlive the model with which it was
44 // constructed.
45 GraphExecutorImpl(
46 const std::map<std::string, int>& required_inputs,
47 const std::map<std::string, int>& required_outputs,
48 std::unique_ptr<tflite::Interpreter> interpreter,
Honglin Yu6adafcd2019-07-22 13:48:11 +100049 chromeos::machine_learning::mojom::GraphExecutorRequest request,
50 const std::string& metrics_model_name);
Michael Martis26abcd82018-08-08 10:57:25 +100051
52 void set_connection_error_handler(base::Closure connection_error_handler);
53
54 private:
55 // chromeos::machine_learning::mojom::GraphExecutor:
Hidehiko Abe31bb9632018-11-23 02:49:56 +090056 void Execute(
57 std::unordered_map<std::string,
Michael Martis26abcd82018-08-08 10:57:25 +100058 chromeos::machine_learning::mojom::TensorPtr> inputs,
Hidehiko Abe31bb9632018-11-23 02:49:56 +090059 const std::vector<std::string>& output_names,
60 const ExecuteCallback& callback);
Michael Martis26abcd82018-08-08 10:57:25 +100061
62 const std::map<std::string, int>& required_inputs_;
63 const std::map<std::string, int>& required_outputs_;
64
65 const std::unique_ptr<tflite::Interpreter> interpreter_;
66
67 mojo::Binding<chromeos::machine_learning::mojom::GraphExecutor> binding_;
68
Honglin Yu6adafcd2019-07-22 13:48:11 +100069 // Model name as it should appear in UMA histogram names.
70 const std::string metrics_model_name_;
71
Michael Martis26abcd82018-08-08 10:57:25 +100072 DISALLOW_COPY_AND_ASSIGN(GraphExecutorImpl);
73};
74
75} // namespace ml
76
77#endif // ML_GRAPH_EXECUTOR_IMPL_H_