blob: b58d9cd532a67e159b2b5054d34f54a85ec64ce3 [file] [log] [blame]
Michael Martis26abcd82018-08-08 10:57:25 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_GRAPH_EXECUTOR_IMPL_H_
6#define ML_GRAPH_EXECUTOR_IMPL_H_
7
8#include <map>
9#include <memory>
10#include <string>
Hidehiko Abe31bb9632018-11-23 02:49:56 +090011#include <vector>
Michael Martis26abcd82018-08-08 10:57:25 +100012
13#include <base/callback_forward.h>
hscham3d0632f2019-12-11 15:58:57 +090014#include <base/containers/flat_map.h>
Michael Martis26abcd82018-08-08 10:57:25 +100015#include <base/macros.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100016#include <mojo/public/cpp/bindings/pending_receiver.h>
17#include <mojo/public/cpp/bindings/receiver.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100018#include <tensorflow/lite/model.h>
Michael Martis26abcd82018-08-08 10:57:25 +100019
Hidehiko Abeaa488c32018-08-31 23:49:41 +090020#include "ml/mojom/graph_executor.mojom.h"
Michael Martis26abcd82018-08-08 10:57:25 +100021
22namespace ml {
23
24// Allows execution of TensorFlow lite graphs using input / output specified
25// with Mojo types.
26//
27// Holds as little state as possible (with the remainder living in the parent
28// Model object and shared between all sibling GraphExecutors). Hence, a
29// GraphExecutor becomes invalid when its parent Model object is destroyed.
30//
31// A given GraphExecutorImpl may not be used concurrently from different
32// sequences.
33class GraphExecutorImpl
34 : public chromeos::machine_learning::mojom::GraphExecutor {
35 public:
Andrew Moylanb481af72020-07-09 15:22:00 +100036 // Creates an instance bound to `receiver`.
Michael Martis26abcd82018-08-08 10:57:25 +100037 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100038 // The `required_inputs` and `required_outputs` arguments specify a mapping
Michael Martis26abcd82018-08-08 10:57:25 +100039 // from required input / output tensor names to their indices in the TF lite
40 // graph, and must outlive this object.
41 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100042 // UMA metrics will be logged with the specified `metrics_model_name`.
Honglin Yu6adafcd2019-07-22 13:48:11 +100043 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100044 // As is standard, `interpreter` must outlive the model with which it was
Michael Martis26abcd82018-08-08 10:57:25 +100045 // constructed.
46 GraphExecutorImpl(
47 const std::map<std::string, int>& required_inputs,
48 const std::map<std::string, int>& required_outputs,
49 std::unique_ptr<tflite::Interpreter> interpreter,
Andrew Moylanb481af72020-07-09 15:22:00 +100050 mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
51 receiver,
Honglin Yu6adafcd2019-07-22 13:48:11 +100052 const std::string& metrics_model_name);
Qijiang Fan6bc59e12020-11-11 02:51:06 +090053 GraphExecutorImpl(const GraphExecutorImpl&) = delete;
54 GraphExecutorImpl& operator=(const GraphExecutorImpl&) = delete;
Michael Martis26abcd82018-08-08 10:57:25 +100055
Andrew Moylanb481af72020-07-09 15:22:00 +100056 void set_disconnect_handler(base::Closure disconnect_handler);
Michael Martis26abcd82018-08-08 10:57:25 +100057
58 private:
59 // chromeos::machine_learning::mojom::GraphExecutor:
Hidehiko Abe31bb9632018-11-23 02:49:56 +090060 void Execute(
Qijiang Fan5d381a02020-04-19 23:42:37 +090061 base::flat_map<std::string, chromeos::machine_learning::mojom::TensorPtr>
62 inputs,
Hidehiko Abe31bb9632018-11-23 02:49:56 +090063 const std::vector<std::string>& output_names,
Qijiang Fan5d381a02020-04-19 23:42:37 +090064 ExecuteCallback callback);
Michael Martis26abcd82018-08-08 10:57:25 +100065
66 const std::map<std::string, int>& required_inputs_;
67 const std::map<std::string, int>& required_outputs_;
68
69 const std::unique_ptr<tflite::Interpreter> interpreter_;
70
Andrew Moylanb481af72020-07-09 15:22:00 +100071 mojo::Receiver<chromeos::machine_learning::mojom::GraphExecutor> receiver_;
Michael Martis26abcd82018-08-08 10:57:25 +100072
Honglin Yu6adafcd2019-07-22 13:48:11 +100073 // Model name as it should appear in UMA histogram names.
74 const std::string metrics_model_name_;
Michael Martis26abcd82018-08-08 10:57:25 +100075};
76
77} // namespace ml
78
79#endif // ML_GRAPH_EXECUTOR_IMPL_H_