blob: 4b83016f6b01e67195132193ac99770a3296d413 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_MODEL_IMPL_H_
6#define ML_MODEL_IMPL_H_
7
8#include <list>
9#include <map>
10#include <memory>
11#include <string>
12
13#include <base/macros.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100014#include <mojo/public/cpp/bindings/pending_receiver.h>
15#include <mojo/public/cpp/bindings/receiver.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100016#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100017
18#include "ml/graph_executor_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090019#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100020
21namespace ml {
22
23// Holds a TensorFlow lite graph and produces GraphExecutors that may run the
24// graph.
25//
26// All GraphExecutors created by a ModelImpl reference its model definition (and
27// hence may not outlive the ModelImpl). Multiple such GraphExecutors may be
28// used concurrently from different sequences.
29class ModelImpl : public chromeos::machine_learning::mojom::Model {
30 public:
Andrew Moylanb481af72020-07-09 15:22:00 +100031 // Creates an instance bound to `receiver`.
Michael Martisa967f632018-08-10 10:39:00 +100032 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100033 // The `required_inputs` and `required_outputs` arguments specify a mapping
Michael Martisa967f632018-08-10 10:39:00 +100034 // from required input / output tensor names to their indices in the TF lite
35 // graph, and must outlive this object.
Andrew Moylan79b34a42020-07-08 11:13:11 +100036 // `model_string` is optional string data that this class can take ownership
37 // of (presumably the backing data for `model`) and that is guaranteed to be
38 // destroyed *after* `model`. This is required by function
39 // `tflite::FlatBufferModel::BuildFromBuffer`.
Honglin Yuc0cef102020-01-17 15:26:01 +110040 //
41 // The RAM of the returned model is not owned by the caller. The model object
Andrew Moylanb481af72020-07-09 15:22:00 +100042 // will be deleted when the corresponding mojo connection is closed.
Honglin Yuc0cef102020-01-17 15:26:01 +110043 static ModelImpl* Create(
44 std::map<std::string, int> required_inputs,
45 std::map<std::string, int> required_outputs,
46 std::unique_ptr<tflite::FlatBufferModel> model,
47 std::unique_ptr<std::string> model_string,
Andrew Moylanb481af72020-07-09 15:22:00 +100048 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
Honglin Yuc0cef102020-01-17 15:26:01 +110049 const std::string& metrics_model_name);
50
Andrew Moylan79b34a42020-07-08 11:13:11 +100051 // Use when constructed from file where no need to pass the `model_string`.
Honglin Yuc0cef102020-01-17 15:26:01 +110052 // The RAM of the returned model is not owned by the caller. The model object
Andrew Moylanb481af72020-07-09 15:22:00 +100053 // will be deleted when the corresponding mojo connection is closed.
Honglin Yuc0cef102020-01-17 15:26:01 +110054 static ModelImpl* Create(
55 std::map<std::string, int> required_inputs,
56 std::map<std::string, int> required_outputs,
57 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylanb481af72020-07-09 15:22:00 +100058 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
Honglin Yuc0cef102020-01-17 15:26:01 +110059 const std::string& metrics_model_name);
60
61 int num_graph_executors_for_testing() const;
62
63 private:
Andrew Moylan79b34a42020-07-08 11:13:11 +100064 // Constructor is private, call `Create` to create objects.
Andrew Moylanb481af72020-07-09 15:22:00 +100065 ModelImpl(
66 std::map<std::string, int> required_inputs,
67 std::map<std::string, int> required_outputs,
68 std::unique_ptr<tflite::FlatBufferModel> model,
69 std::unique_ptr<std::string> model_string,
70 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
71 const std::string& metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +100072
Andrew Moylanb481af72020-07-09 15:22:00 +100073 void set_disconnect_handler(base::Closure disconnect_handler);
Michael Martisa967f632018-08-10 10:39:00 +100074
Michael Martisa967f632018-08-10 10:39:00 +100075 // chromeos::machine_learning::mojom::Model:
76 void CreateGraphExecutor(
Andrew Moylanb481af72020-07-09 15:22:00 +100077 mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
78 receiver,
Qijiang Fan5d381a02020-04-19 23:42:37 +090079 CreateGraphExecutorCallback callback) override;
Alan Green55e16542020-05-11 14:06:46 +100080 void CreateGraphExecutorWithOptions(
81 chromeos::machine_learning::mojom::GraphExecutorOptionsPtr options,
Andrew Moylanb481af72020-07-09 15:22:00 +100082 mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
83 receiver,
Alan Green55e16542020-05-11 14:06:46 +100084 CreateGraphExecutorCallback callback) override;
Michael Martisa967f632018-08-10 10:39:00 +100085
86 // Remove a graph executor from our hosted set.
87 void EraseGraphExecutor(std::list<GraphExecutorImpl>::const_iterator it);
88
Honglin Yu0ed72352019-08-27 17:42:01 +100089 const std::map<std::string, int> required_inputs_;
90 const std::map<std::string, int> required_outputs_;
91
Andrew Moylan79b34a42020-07-08 11:13:11 +100092 // Must be above `model_`.
Honglin Yu0ed72352019-08-27 17:42:01 +100093 const std::unique_ptr<std::string> model_string_;
Michael Martisa967f632018-08-10 10:39:00 +100094
95 const std::unique_ptr<tflite::FlatBufferModel> model_;
96
Andrew Moylanb481af72020-07-09 15:22:00 +100097 mojo::Receiver<chromeos::machine_learning::mojom::Model> receiver_;
Michael Martisa967f632018-08-10 10:39:00 +100098
Andrew Moylanb481af72020-07-09 15:22:00 +100099 // Emulate a strongly bound receiver set: hold a set of GraphExecutors,
100 // specific elements of which are erased on connection closure.
Michael Martisa967f632018-08-10 10:39:00 +1000101 //
102 // That is, when a pipe to a GraphExecutorImpl closes, that object is removed
Andrew Moylanb481af72020-07-09 15:22:00 +1000103 // from this set (by its binding disconnection handler). Further, when a
104 // ModelImpl is destroyed, its entire collection of GraphExecutorImpls is also
Michael Martisa967f632018-08-10 10:39:00 +1000105 // destroyed.
106 std::list<GraphExecutorImpl> graph_executors_;
107
Honglin Yu6adafcd2019-07-22 13:48:11 +1000108 // Model name as it should appear in UMA histogram names.
109 const std::string metrics_model_name_;
110
Michael Martisa967f632018-08-10 10:39:00 +1000111 DISALLOW_COPY_AND_ASSIGN(ModelImpl);
112};
113
114} // namespace ml
115
116#endif // ML_MODEL_IMPL_H_