blob: 4d1be4f27610f57c0b901028e05a085b094a1c06 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_MODEL_IMPL_H_
6#define ML_MODEL_IMPL_H_
7
8#include <list>
9#include <map>
10#include <memory>
11#include <string>
12
13#include <base/macros.h>
14#include <mojo/public/cpp/bindings/binding.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100016
17#include "ml/graph_executor_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090018#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100019
20namespace ml {
21
22// Holds a TensorFlow lite graph and produces GraphExecutors that may run the
23// graph.
24//
25// All GraphExecutors created by a ModelImpl reference its model definition (and
26// hence may not outlive the ModelImpl). Multiple such GraphExecutors may be
27// used concurrently from different sequences.
28class ModelImpl : public chromeos::machine_learning::mojom::Model {
29 public:
Andrew Moylan79b34a42020-07-08 11:13:11 +100030 // Creates an instance bound to `request`.
Michael Martisa967f632018-08-10 10:39:00 +100031 //
Andrew Moylan79b34a42020-07-08 11:13:11 +100032 // The `required_inputs` and `required_outputs` arguments specify a mapping
Michael Martisa967f632018-08-10 10:39:00 +100033 // from required input / output tensor names to their indices in the TF lite
34 // graph, and must outlive this object.
Andrew Moylan79b34a42020-07-08 11:13:11 +100035 // `model_string` is optional string data that this class can take ownership
36 // of (presumably the backing data for `model`) and that is guaranteed to be
37 // destroyed *after* `model`. This is required by function
38 // `tflite::FlatBufferModel::BuildFromBuffer`.
Honglin Yuc0cef102020-01-17 15:26:01 +110039 //
40 // The RAM of the returned model is not owned by the caller. The model object
41 // will be deleted when the corresponding mojo connection meets error.
42 static ModelImpl* Create(
43 std::map<std::string, int> required_inputs,
44 std::map<std::string, int> required_outputs,
45 std::unique_ptr<tflite::FlatBufferModel> model,
46 std::unique_ptr<std::string> model_string,
47 chromeos::machine_learning::mojom::ModelRequest request,
48 const std::string& metrics_model_name);
49
Andrew Moylan79b34a42020-07-08 11:13:11 +100050 // Use when constructed from file where no need to pass the `model_string`.
Honglin Yuc0cef102020-01-17 15:26:01 +110051 // The RAM of the returned model is not owned by the caller. The model object
52 // will be deleted when the corresponding mojo connection meets error.
53 static ModelImpl* Create(
54 std::map<std::string, int> required_inputs,
55 std::map<std::string, int> required_outputs,
56 std::unique_ptr<tflite::FlatBufferModel> model,
57 chromeos::machine_learning::mojom::ModelRequest request,
58 const std::string& metrics_model_name);
59
60 int num_graph_executors_for_testing() const;
61
62 private:
Andrew Moylan79b34a42020-07-08 11:13:11 +100063 // Constructor is private, call `Create` to create objects.
Honglin Yu0ed72352019-08-27 17:42:01 +100064 ModelImpl(std::map<std::string, int> required_inputs,
65 std::map<std::string, int> required_outputs,
66 std::unique_ptr<tflite::FlatBufferModel> model,
67 std::unique_ptr<std::string> model_string,
68 chromeos::machine_learning::mojom::ModelRequest request,
69 const std::string& metrics_model_name);
70
Michael Martisa967f632018-08-10 10:39:00 +100071 void set_connection_error_handler(base::Closure connection_error_handler);
72
Michael Martisa967f632018-08-10 10:39:00 +100073 // chromeos::machine_learning::mojom::Model:
74 void CreateGraphExecutor(
75 chromeos::machine_learning::mojom::GraphExecutorRequest request,
Qijiang Fan5d381a02020-04-19 23:42:37 +090076 CreateGraphExecutorCallback callback) override;
Alan Green55e16542020-05-11 14:06:46 +100077 void CreateGraphExecutorWithOptions(
78 chromeos::machine_learning::mojom::GraphExecutorOptionsPtr options,
79 chromeos::machine_learning::mojom::GraphExecutorRequest request,
80 CreateGraphExecutorCallback callback) override;
Michael Martisa967f632018-08-10 10:39:00 +100081
82 // Remove a graph executor from our hosted set.
83 void EraseGraphExecutor(std::list<GraphExecutorImpl>::const_iterator it);
84
Honglin Yu0ed72352019-08-27 17:42:01 +100085 const std::map<std::string, int> required_inputs_;
86 const std::map<std::string, int> required_outputs_;
87
Andrew Moylan79b34a42020-07-08 11:13:11 +100088 // Must be above `model_`.
Honglin Yu0ed72352019-08-27 17:42:01 +100089 const std::unique_ptr<std::string> model_string_;
Michael Martisa967f632018-08-10 10:39:00 +100090
91 const std::unique_ptr<tflite::FlatBufferModel> model_;
92
93 mojo::Binding<chromeos::machine_learning::mojom::Model> binding_;
94
95 // Emulate a strong binding set: hold a set of GraphExecutors, specific
96 // elements of which are erased on connection error.
97 //
98 // That is, when a pipe to a GraphExecutorImpl closes, that object is removed
99 // from this set (by its binding connection error handler). Further, when a
100 // ModelImpl is destoyed, its entire collection of GraphExecutorImpls is also
101 // destroyed.
102 std::list<GraphExecutorImpl> graph_executors_;
103
Honglin Yu6adafcd2019-07-22 13:48:11 +1000104 // Model name as it should appear in UMA histogram names.
105 const std::string metrics_model_name_;
106
Michael Martisa967f632018-08-10 10:39:00 +1000107 DISALLOW_COPY_AND_ASSIGN(ModelImpl);
108};
109
110} // namespace ml
111
112#endif // ML_MODEL_IMPL_H_