blob: 709af65dcd4088c433fe85d3f0d7d87a6bc7b1cb [file] [log] [blame]
Alan Greenb9d0c832020-04-30 08:29:50 +10001// Copyright 2020 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4//
5// A simplified interface to the ML service. Used to implement the ml_cmdline
6// tool.
7#include "ml/simple.h"
8
9#include <string>
10#include <utility>
11#include <vector>
12
13#include <base/bind.h>
14#include <base/run_loop.h>
15#include <mojo/public/cpp/bindings/interface_request.h>
16
17#include "ml/machine_learning_service_impl.h"
18#include "ml/mojom/graph_executor.mojom.h"
19#include "ml/mojom/machine_learning_service.mojom.h"
20#include "ml/mojom/model.mojom.h"
21#include "ml/tensor_view.h"
22
23using ::chromeos::machine_learning::mojom::BuiltinModelId;
24using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
25using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
26using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
27using ::chromeos::machine_learning::mojom::ExecuteResult;
28using ::chromeos::machine_learning::mojom::GraphExecutorPtr;
29using ::chromeos::machine_learning::mojom::LoadModelResult;
30using ::chromeos::machine_learning::mojom::MachineLearningServicePtr;
31using ::chromeos::machine_learning::mojom::ModelPtr;
32using ::chromeos::machine_learning::mojom::TensorPtr;
33
34namespace ml {
35namespace simple {
36namespace {
37
38// Creates a 1-D tensor containing a single value
39TensorPtr NewSingleValueTensor(const double value) {
40 auto tensor(chromeos::machine_learning::mojom::Tensor::New());
41 TensorView<double> tensor_view(tensor);
42 tensor_view.Allocate();
43 tensor_view.GetShape() = {1};
44 tensor_view.GetValues() = {value};
45 return tensor;
46}
47
48} // namespace
49
50AddResult Add(const double x, const double y) {
51 AddResult result = {"Not completed.", -1.0};
52
53 // Create ML Service
54 MachineLearningServicePtr ml_service;
55 const MachineLearningServiceImpl ml_service_impl(
56 mojo::MakeRequest(&ml_service).PassMessagePipe(), base::Closure());
57
58 // Load model.
59 BuiltinModelSpecPtr spec = BuiltinModelSpec::New();
60 spec->id = BuiltinModelId::TEST_MODEL;
61 ModelPtr model;
62 bool model_load_ok = false;
63 ml_service->LoadBuiltinModel(
64 std::move(spec), mojo::MakeRequest(&model),
65 base::Bind(
66 [](bool* const model_load_ok, const LoadModelResult result) {
67 *model_load_ok = result == LoadModelResult::OK;
68 },
69 &model_load_ok));
70 base::RunLoop().RunUntilIdle();
71 if (!model_load_ok) {
72 result.status = "Failed to load model.";
73 return result;
74 }
75
76 // Get graph executor for model.
77 GraphExecutorPtr graph_executor;
78 bool graph_executor_ok = false;
79 model->CreateGraphExecutor(mojo::MakeRequest(&graph_executor),
80 base::Bind(
81 [](bool* const graph_executor_ok,
82 const CreateGraphExecutorResult result) {
83 *graph_executor_ok =
84 result == CreateGraphExecutorResult::OK;
85 },
86 &graph_executor_ok));
87 base::RunLoop().RunUntilIdle();
88 if (!model_load_ok) {
89 result.status = "Failed to get graph executor";
90 return result;
91 }
92
93 // Construct input to graph executor and perform inference
94 base::flat_map<std::string, TensorPtr> inputs;
95 inputs.emplace("x", NewSingleValueTensor(x));
96 inputs.emplace("y", NewSingleValueTensor(y));
97 std::vector<std::string> outputs({"z"});
98 bool inference_ok = false;
99 graph_executor->Execute(
100 std::move(inputs), std::move(outputs),
101 base::Bind(
102 [](bool* const inference_ok, double* const sum,
103 const ExecuteResult execute_result,
104 base::Optional<std::vector<TensorPtr>> outputs) {
105 // Check that the inference succeeded and gave the expected number
106 // of outputs.
107 *inference_ok = execute_result == ExecuteResult::OK &&
108 outputs.has_value() && outputs->size() == 1;
109 if (!*inference_ok) {
110 return;
111 }
112
113 // Get value from output
114 const TensorView<double> out_tensor((*outputs)[0]);
115 *sum = out_tensor.GetValues()[0];
116 },
117 &inference_ok, &result.sum));
118 base::RunLoop().RunUntilIdle();
119 if (!inference_ok) {
120 result.status = "Inference failed.";
121 return result;
122 }
123
124 result.status = "Inference succeeded.";
125 return result;
126}
127
128} // namespace simple
129} // namespace ml