blob: c95340b3f7f2fa4e256eec7feaa4927b4aa5f8f2 [file] [log] [blame]
Alan Greenb9d0c832020-04-30 08:29:50 +10001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4//
5// A simplified interface to the ML service. Used to implement the ml_cmdline
6// tool.
7#include "ml/simple.h"
8
9#include <string>
10#include <utility>
11#include <vector>
12
13#include <base/bind.h>
14#include <base/run_loop.h>
Andrew Moylanb481af72020-07-09 15:22:00 +100015#include <mojo/public/cpp/bindings/remote.h>
Alan Greenb9d0c832020-04-30 08:29:50 +100016
17#include "ml/machine_learning_service_impl.h"
18#include "ml/mojom/graph_executor.mojom.h"
19#include "ml/mojom/machine_learning_service.mojom.h"
20#include "ml/mojom/model.mojom.h"
21#include "ml/tensor_view.h"
22
23using ::chromeos::machine_learning::mojom::BuiltinModelId;
24using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
25using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
26using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
27using ::chromeos::machine_learning::mojom::ExecuteResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100028using ::chromeos::machine_learning::mojom::GraphExecutor;
Alan Green55e16542020-05-11 14:06:46 +100029using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
Alan Greenb9d0c832020-04-30 08:29:50 +100030using ::chromeos::machine_learning::mojom::LoadModelResult;
Andrew Moylanb481af72020-07-09 15:22:00 +100031using ::chromeos::machine_learning::mojom::MachineLearningService;
32using ::chromeos::machine_learning::mojom::Model;
Alan Greenb9d0c832020-04-30 08:29:50 +100033using ::chromeos::machine_learning::mojom::TensorPtr;
34
35namespace ml {
36namespace simple {
37namespace {
38
39// Creates a 1-D tensor containing a single value
40TensorPtr NewSingleValueTensor(const double value) {
41 auto tensor(chromeos::machine_learning::mojom::Tensor::New());
42 TensorView<double> tensor_view(tensor);
43 tensor_view.Allocate();
44 tensor_view.GetShape() = {1};
45 tensor_view.GetValues() = {value};
46 return tensor;
47}
48
49} // namespace
50
Alan Green3117b522020-05-20 09:50:27 +100051AddResult Add(const double x, const double y, const bool use_nnapi) {
Alan Greenb9d0c832020-04-30 08:29:50 +100052 AddResult result = {"Not completed.", -1.0};
53
54 // Create ML Service
Andrew Moylanb481af72020-07-09 15:22:00 +100055 mojo::Remote<MachineLearningService> ml_service;
Alan Greenb9d0c832020-04-30 08:29:50 +100056 const MachineLearningServiceImpl ml_service_impl(
Andrew Moylanb481af72020-07-09 15:22:00 +100057 ml_service.BindNewPipeAndPassReceiver().PassPipe(),
58 base::Closure());
Alan Greenb9d0c832020-04-30 08:29:50 +100059
60 // Load model.
61 BuiltinModelSpecPtr spec = BuiltinModelSpec::New();
62 spec->id = BuiltinModelId::TEST_MODEL;
Andrew Moylanb481af72020-07-09 15:22:00 +100063 mojo::Remote<Model> model;
Alan Greenb9d0c832020-04-30 08:29:50 +100064 bool model_load_ok = false;
65 ml_service->LoadBuiltinModel(
Andrew Moylanb481af72020-07-09 15:22:00 +100066 std::move(spec), model.BindNewPipeAndPassReceiver(),
Alan Greenb9d0c832020-04-30 08:29:50 +100067 base::Bind(
68 [](bool* const model_load_ok, const LoadModelResult result) {
69 *model_load_ok = result == LoadModelResult::OK;
70 },
71 &model_load_ok));
72 base::RunLoop().RunUntilIdle();
73 if (!model_load_ok) {
74 result.status = "Failed to load model.";
75 return result;
76 }
77
78 // Get graph executor for model.
Andrew Moylanb481af72020-07-09 15:22:00 +100079 mojo::Remote<GraphExecutor> graph_executor;
Alan Greenb9d0c832020-04-30 08:29:50 +100080 bool graph_executor_ok = false;
Alan Green55e16542020-05-11 14:06:46 +100081 auto options = GraphExecutorOptions::New(use_nnapi);
82 model->CreateGraphExecutorWithOptions(
Andrew Moylanb481af72020-07-09 15:22:00 +100083 std::move(options), graph_executor.BindNewPipeAndPassReceiver(),
Alan Green55e16542020-05-11 14:06:46 +100084 base::Bind(
85 [](bool* const graph_executor_ok,
86 const CreateGraphExecutorResult result) {
87 *graph_executor_ok = result == CreateGraphExecutorResult::OK;
88 },
89 &graph_executor_ok));
Alan Greenb9d0c832020-04-30 08:29:50 +100090 base::RunLoop().RunUntilIdle();
91 if (!model_load_ok) {
92 result.status = "Failed to get graph executor";
93 return result;
94 }
95
96 // Construct input to graph executor and perform inference
97 base::flat_map<std::string, TensorPtr> inputs;
98 inputs.emplace("x", NewSingleValueTensor(x));
99 inputs.emplace("y", NewSingleValueTensor(y));
100 std::vector<std::string> outputs({"z"});
101 bool inference_ok = false;
102 graph_executor->Execute(
103 std::move(inputs), std::move(outputs),
104 base::Bind(
105 [](bool* const inference_ok, double* const sum,
106 const ExecuteResult execute_result,
107 base::Optional<std::vector<TensorPtr>> outputs) {
108 // Check that the inference succeeded and gave the expected number
109 // of outputs.
110 *inference_ok = execute_result == ExecuteResult::OK &&
111 outputs.has_value() && outputs->size() == 1;
112 if (!*inference_ok) {
113 return;
114 }
115
116 // Get value from output
117 const TensorView<double> out_tensor((*outputs)[0]);
118 *sum = out_tensor.GetValues()[0];
119 },
120 &inference_ok, &result.sum));
121 base::RunLoop().RunUntilIdle();
122 if (!inference_ok) {
123 result.status = "Inference failed.";
124 return result;
125 }
126
Alan Greenc5bcbcd2020-05-07 11:44:26 +1000127 result.status = "OK";
Alan Greenb9d0c832020-04-30 08:29:50 +1000128 return result;
129}
130
131} // namespace simple
132} // namespace ml