Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 1 | // Copyright 2020 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | // |
| 5 | // A simplified interface to the ML service. Used to implement the ml_cmdline |
| 6 | // tool. |
| 7 | #include "ml/simple.h" |
| 8 | |
| 9 | #include <string> |
| 10 | #include <utility> |
| 11 | #include <vector> |
| 12 | |
| 13 | #include <base/bind.h> |
| 14 | #include <base/run_loop.h> |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 15 | #include <mojo/public/cpp/bindings/remote.h> |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 16 | |
| 17 | #include "ml/machine_learning_service_impl.h" |
| 18 | #include "ml/mojom/graph_executor.mojom.h" |
| 19 | #include "ml/mojom/machine_learning_service.mojom.h" |
| 20 | #include "ml/mojom/model.mojom.h" |
| 21 | #include "ml/tensor_view.h" |
| 22 | |
| 23 | using ::chromeos::machine_learning::mojom::BuiltinModelId; |
| 24 | using ::chromeos::machine_learning::mojom::BuiltinModelSpec; |
| 25 | using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr; |
| 26 | using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult; |
| 27 | using ::chromeos::machine_learning::mojom::ExecuteResult; |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 28 | using ::chromeos::machine_learning::mojom::GraphExecutor; |
Alan Green | 55e1654 | 2020-05-11 14:06:46 +1000 | [diff] [blame] | 29 | using ::chromeos::machine_learning::mojom::GraphExecutorOptions; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 30 | using ::chromeos::machine_learning::mojom::LoadModelResult; |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 31 | using ::chromeos::machine_learning::mojom::MachineLearningService; |
| 32 | using ::chromeos::machine_learning::mojom::Model; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 33 | using ::chromeos::machine_learning::mojom::TensorPtr; |
| 34 | |
| 35 | namespace ml { |
| 36 | namespace simple { |
| 37 | namespace { |
| 38 | |
| 39 | // Creates a 1-D tensor containing a single value |
| 40 | TensorPtr NewSingleValueTensor(const double value) { |
| 41 | auto tensor(chromeos::machine_learning::mojom::Tensor::New()); |
| 42 | TensorView<double> tensor_view(tensor); |
| 43 | tensor_view.Allocate(); |
| 44 | tensor_view.GetShape() = {1}; |
| 45 | tensor_view.GetValues() = {value}; |
| 46 | return tensor; |
| 47 | } |
| 48 | |
| 49 | } // namespace |
| 50 | |
Alan Green | 3117b52 | 2020-05-20 09:50:27 +1000 | [diff] [blame] | 51 | AddResult Add(const double x, const double y, const bool use_nnapi) { |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 52 | AddResult result = {"Not completed.", -1.0}; |
| 53 | |
| 54 | // Create ML Service |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 55 | mojo::Remote<MachineLearningService> ml_service; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 56 | const MachineLearningServiceImpl ml_service_impl( |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 57 | ml_service.BindNewPipeAndPassReceiver().PassPipe(), |
| 58 | base::Closure()); |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 59 | |
| 60 | // Load model. |
| 61 | BuiltinModelSpecPtr spec = BuiltinModelSpec::New(); |
| 62 | spec->id = BuiltinModelId::TEST_MODEL; |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 63 | mojo::Remote<Model> model; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 64 | bool model_load_ok = false; |
| 65 | ml_service->LoadBuiltinModel( |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 66 | std::move(spec), model.BindNewPipeAndPassReceiver(), |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 67 | base::Bind( |
| 68 | [](bool* const model_load_ok, const LoadModelResult result) { |
| 69 | *model_load_ok = result == LoadModelResult::OK; |
| 70 | }, |
| 71 | &model_load_ok)); |
| 72 | base::RunLoop().RunUntilIdle(); |
| 73 | if (!model_load_ok) { |
| 74 | result.status = "Failed to load model."; |
| 75 | return result; |
| 76 | } |
| 77 | |
| 78 | // Get graph executor for model. |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 79 | mojo::Remote<GraphExecutor> graph_executor; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 80 | bool graph_executor_ok = false; |
Alan Green | 55e1654 | 2020-05-11 14:06:46 +1000 | [diff] [blame] | 81 | auto options = GraphExecutorOptions::New(use_nnapi); |
| 82 | model->CreateGraphExecutorWithOptions( |
Andrew Moylan | b481af7 | 2020-07-09 15:22:00 +1000 | [diff] [blame^] | 83 | std::move(options), graph_executor.BindNewPipeAndPassReceiver(), |
Alan Green | 55e1654 | 2020-05-11 14:06:46 +1000 | [diff] [blame] | 84 | base::Bind( |
| 85 | [](bool* const graph_executor_ok, |
| 86 | const CreateGraphExecutorResult result) { |
| 87 | *graph_executor_ok = result == CreateGraphExecutorResult::OK; |
| 88 | }, |
| 89 | &graph_executor_ok)); |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 90 | base::RunLoop().RunUntilIdle(); |
| 91 | if (!model_load_ok) { |
| 92 | result.status = "Failed to get graph executor"; |
| 93 | return result; |
| 94 | } |
| 95 | |
| 96 | // Construct input to graph executor and perform inference |
| 97 | base::flat_map<std::string, TensorPtr> inputs; |
| 98 | inputs.emplace("x", NewSingleValueTensor(x)); |
| 99 | inputs.emplace("y", NewSingleValueTensor(y)); |
| 100 | std::vector<std::string> outputs({"z"}); |
| 101 | bool inference_ok = false; |
| 102 | graph_executor->Execute( |
| 103 | std::move(inputs), std::move(outputs), |
| 104 | base::Bind( |
| 105 | [](bool* const inference_ok, double* const sum, |
| 106 | const ExecuteResult execute_result, |
| 107 | base::Optional<std::vector<TensorPtr>> outputs) { |
| 108 | // Check that the inference succeeded and gave the expected number |
| 109 | // of outputs. |
| 110 | *inference_ok = execute_result == ExecuteResult::OK && |
| 111 | outputs.has_value() && outputs->size() == 1; |
| 112 | if (!*inference_ok) { |
| 113 | return; |
| 114 | } |
| 115 | |
| 116 | // Get value from output |
| 117 | const TensorView<double> out_tensor((*outputs)[0]); |
| 118 | *sum = out_tensor.GetValues()[0]; |
| 119 | }, |
| 120 | &inference_ok, &result.sum)); |
| 121 | base::RunLoop().RunUntilIdle(); |
| 122 | if (!inference_ok) { |
| 123 | result.status = "Inference failed."; |
| 124 | return result; |
| 125 | } |
| 126 | |
Alan Green | c5bcbcd | 2020-05-07 11:44:26 +1000 | [diff] [blame] | 127 | result.status = "OK"; |
Alan Green | b9d0c83 | 2020-04-30 08:29:50 +1000 | [diff] [blame] | 128 | return result; |
| 129 | } |
| 130 | |
| 131 | } // namespace simple |
| 132 | } // namespace ml |