blob: a5249a0b70b582d41ff5793a349a61a5c9b0c9a7 [file] [log] [blame]
Michael Martisa967f632018-08-10 10:39:00 +10001// Copyright 2018 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef ML_MODEL_IMPL_H_
6#define ML_MODEL_IMPL_H_
7
8#include <list>
9#include <map>
10#include <memory>
11#include <string>
12
13#include <base/macros.h>
14#include <mojo/public/cpp/bindings/binding.h>
Michael Martis8783c8e2019-06-26 17:30:54 +100015#include <tensorflow/lite/model.h>
Michael Martisa967f632018-08-10 10:39:00 +100016
17#include "ml/graph_executor_impl.h"
Hidehiko Abeaa488c32018-08-31 23:49:41 +090018#include "ml/mojom/model.mojom.h"
Michael Martisa967f632018-08-10 10:39:00 +100019
20namespace ml {
21
22// Holds a TensorFlow lite graph and produces GraphExecutors that may run the
23// graph.
24//
25// All GraphExecutors created by a ModelImpl reference its model definition (and
26// hence may not outlive the ModelImpl). Multiple such GraphExecutors may be
27// used concurrently from different sequences.
28class ModelImpl : public chromeos::machine_learning::mojom::Model {
29 public:
30 // Creates an instance bound to |request|.
31 //
32 // The |required_inputs| and |required_outputs| arguments specify a mapping
33 // from required input / output tensor names to their indices in the TF lite
34 // graph, and must outlive this object.
35 ModelImpl(const std::map<std::string, int>& required_inputs,
36 const std::map<std::string, int>& required_outputs,
37 std::unique_ptr<tflite::FlatBufferModel> model,
Honglin Yu6adafcd2019-07-22 13:48:11 +100038 chromeos::machine_learning::mojom::ModelRequest request,
39 const std::string& metrics_model_name);
Michael Martisa967f632018-08-10 10:39:00 +100040
41 void set_connection_error_handler(base::Closure connection_error_handler);
42
43 int num_graph_executors_for_testing() const;
44
45 private:
46 // chromeos::machine_learning::mojom::Model:
47 void CreateGraphExecutor(
48 chromeos::machine_learning::mojom::GraphExecutorRequest request,
49 const CreateGraphExecutorCallback& callback) override;
50
51 // Remove a graph executor from our hosted set.
52 void EraseGraphExecutor(std::list<GraphExecutorImpl>::const_iterator it);
53
54 const std::map<std::string, int>& required_inputs_;
55 const std::map<std::string, int>& required_outputs_;
56
57 const std::unique_ptr<tflite::FlatBufferModel> model_;
58
59 mojo::Binding<chromeos::machine_learning::mojom::Model> binding_;
60
61 // Emulate a strong binding set: hold a set of GraphExecutors, specific
62 // elements of which are erased on connection error.
63 //
64 // That is, when a pipe to a GraphExecutorImpl closes, that object is removed
65 // from this set (by its binding connection error handler). Further, when a
66 // ModelImpl is destoyed, its entire collection of GraphExecutorImpls is also
67 // destroyed.
68 std::list<GraphExecutorImpl> graph_executors_;
69
Honglin Yu6adafcd2019-07-22 13:48:11 +100070 // Model name as it should appear in UMA histogram names.
71 const std::string metrics_model_name_;
72
Michael Martisa967f632018-08-10 10:39:00 +100073 DISALLOW_COPY_AND_ASSIGN(ModelImpl);
74};
75
76} // namespace ml
77
78#endif // ML_MODEL_IMPL_H_