blob: 401dd9477a7d3aa28793204648fe2542a411a95b [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_MODEL_IMPL_H_
6#define ML_MODEL_IMPL_H_
7
8#include <list>
9#include <map>
10#include <memory>
11#include <string>
12
13#include <base/macros.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100014#include <mojo/public/cpp/bindings/pending_receiver.h>
15#include <mojo/public/cpp/bindings/receiver.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100016#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100017
18#include "ml/graph_executor_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090019#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100020
21namespace ml {
22
Andrew Moylan44c352f2020-11-04 15:19:46 +110023// Holds 4-byte aligned char[] data suitable for a flatbuffer model.
24class AlignedModelData {
25 public:
26 // Constructs from a std::string. If its .c_str() is not 4-byte aligned, an
27 // aligned copy is made.
28 explicit AlignedModelData(std::string model_str);
29
30 ~AlignedModelData();
31
32 AlignedModelData(const AlignedModelData&) = delete;
33 AlignedModelData& operator=(const AlignedModelData&) = delete;
34
35 // The start of the model data. The result will be 4-byte aligned.
36 const char* data() const;
37 // The length of the buffer starting at `data()`.
38 size_t size() const;
39
40 private:
41 // Original std::string containing model data. May be empty.
42 std::unique_ptr<std::string> original_model_str_;
43 // Aligned copy of the original std::string. May be empty.
44 std::unique_ptr<char[]> aligned_copy_;
45 size_t aligned_copy_size_;
46};
47
Michael Martisa967f632018-08-10 10:39:00 +100048// Holds a TensorFlow lite graph and produces GraphExecutors that may run the
49// graph.
50//
51// All GraphExecutors created by a ModelImpl reference its model definition (and
52// hence may not outlive the ModelImpl). Multiple such GraphExecutors may be
53// used concurrently from different sequences.
54class ModelImpl : public chromeos::machine_learning::mojom::Model {
55 public:
Andrew Moylanb481af72020-07-09 15:22:00 +100056 // Creates an instance bound to `receiver`.
Michael Martisa967f632018-08-10 10:39:00 +100057 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100058 // The `required_inputs` and `required_outputs` arguments specify a mapping
Michael Martisa967f632018-08-10 10:39:00 +100059 // from required input / output tensor names to their indices in the TF lite
60 // graph, and must outlive this object.
Andrew Moylan44c352f2020-11-04 15:19:46 +110061 // `model_data` is backing data for `model` which this class will take
62 // ownership of. It will be destroyed *after* `model`.
Honglin Yuc0cef102020-01-17 15:26:01 +110063 //
64 // The RAM of the returned model is not owned by the caller. The model object
Andrew Moylanb481af72020-07-09 15:22:00 +100065 // will be deleted when the corresponding mojo connection is closed.
Honglin Yuc0cef102020-01-17 15:26:01 +110066 static ModelImpl* Create(
67 std::map<std::string, int> required_inputs,
68 std::map<std::string, int> required_outputs,
69 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110070 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +100071 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
Honglin Yuc0cef102020-01-17 15:26:01 +110072 const std::string& metrics_model_name);
73
Andrew Moylan79b34a42020-07-08 11:13:11 +100074 // Use when constructed from file where no need to pass the `model_string`.
Honglin Yuc0cef102020-01-17 15:26:01 +110075 // The RAM of the returned model is not owned by the caller. The model object
Andrew Moylanb481af72020-07-09 15:22:00 +100076 // will be deleted when the corresponding mojo connection is closed.
Honglin Yuc0cef102020-01-17 15:26:01 +110077 static ModelImpl* Create(
78 std::map<std::string, int> required_inputs,
79 std::map<std::string, int> required_outputs,
80 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylanb481af72020-07-09 15:22:00 +100081 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
Honglin Yuc0cef102020-01-17 15:26:01 +110082 const std::string& metrics_model_name);
83
84 int num_graph_executors_for_testing() const;
85
86 private:
Andrew Moylan79b34a42020-07-08 11:13:11 +100087 // Constructor is private, call `Create` to create objects.
Andrew Moylanb481af72020-07-09 15:22:00 +100088 ModelImpl(
89 std::map<std::string, int> required_inputs,
90 std::map<std::string, int> required_outputs,
91 std::unique_ptr<tflite::FlatBufferModel> model,
Andrew Moylan44c352f2020-11-04 15:19:46 +110092 std::unique_ptr<AlignedModelData> model_data,
Andrew Moylanb481af72020-07-09 15:22:00 +100093 mojo::PendingReceiver<chromeos::machine_learning::mojom::Model> receiver,
94 const std::string& metrics_model_name);
Honglin Yu0ed72352019-08-27 17:42:01 +100095
Andrew Moylanb481af72020-07-09 15:22:00 +100096 void set_disconnect_handler(base::Closure disconnect_handler);
Michael Martisa967f632018-08-10 10:39:00 +100097
Michael Martisa967f632018-08-10 10:39:00 +100098 // chromeos::machine_learning::mojom::Model:
99 void CreateGraphExecutor(
Andrew Moylanb481af72020-07-09 15:22:00 +1000100 mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
101 receiver,
Qijiang Fan5d381a02020-04-19 23:42:37 +0900102 CreateGraphExecutorCallback callback) override;
Alan Green55e16542020-05-11 14:06:46 +1000103 void CreateGraphExecutorWithOptions(
104 chromeos::machine_learning::mojom::GraphExecutorOptionsPtr options,
Andrew Moylanb481af72020-07-09 15:22:00 +1000105 mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
106 receiver,
Alan Green55e16542020-05-11 14:06:46 +1000107 CreateGraphExecutorCallback callback) override;
Michael Martisa967f632018-08-10 10:39:00 +1000108
109 // Remove a graph executor from our hosted set.
110 void EraseGraphExecutor(std::list<GraphExecutorImpl>::const_iterator it);
111
Honglin Yu0ed72352019-08-27 17:42:01 +1000112 const std::map<std::string, int> required_inputs_;
113 const std::map<std::string, int> required_outputs_;
114
Andrew Moylan79b34a42020-07-08 11:13:11 +1000115 // Must be above `model_`.
Andrew Moylan44c352f2020-11-04 15:19:46 +1100116 const std::unique_ptr<AlignedModelData> model_data_;
Michael Martisa967f632018-08-10 10:39:00 +1000117
118 const std::unique_ptr<tflite::FlatBufferModel> model_;
119
Andrew Moylanb481af72020-07-09 15:22:00 +1000120 mojo::Receiver<chromeos::machine_learning::mojom::Model> receiver_;
Michael Martisa967f632018-08-10 10:39:00 +1000121
Andrew Moylanb481af72020-07-09 15:22:00 +1000122 // Emulate a strongly bound receiver set: hold a set of GraphExecutors,
123 // specific elements of which are erased on connection closure.
Michael Martisa967f632018-08-10 10:39:00 +1000124 //
125 // That is, when a pipe to a GraphExecutorImpl closes, that object is removed
Andrew Moylanb481af72020-07-09 15:22:00 +1000126 // from this set (by its binding disconnection handler). Further, when a
127 // ModelImpl is destroyed, its entire collection of GraphExecutorImpls is also
Michael Martisa967f632018-08-10 10:39:00 +1000128 // destroyed.
129 std::list<GraphExecutorImpl> graph_executors_;
130
Honglin Yu6adafcd2019-07-22 13:48:11 +1000131 // Model name as it should appear in UMA histogram names.
132 const std::string metrics_model_name_;
133
Michael Martisa967f632018-08-10 10:39:00 +1000134 DISALLOW_COPY_AND_ASSIGN(ModelImpl);
135};
136
137} // namespace ml
138
139#endif // ML_MODEL_IMPL_H_