Meng-Huan Yu | 315e6e4 | 2019-05-20 15:27:21 +0800 | [diff] [blame] | 1 | // Copyright 2019 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #ifndef LIBHWSEC_TASK_DISPATCHING_FRAMEWORK_H_ |
| 6 | #define LIBHWSEC_TASK_DISPATCHING_FRAMEWORK_H_ |
| 7 | |
| 8 | #include <memory> |
| 9 | #include <string> |
| 10 | #include <utility> |
| 11 | |
| 12 | #include <base/bind.h> |
Hidehiko Abe | 018e0fa | 2019-06-14 19:53:22 +0900 | [diff] [blame] | 13 | #include <base/location.h> |
Qijiang Fan | 886c469 | 2021-02-19 11:54:10 +0900 | [diff] [blame] | 14 | #include <base/notreached.h> |
Meng-Huan Yu | 315e6e4 | 2019-05-20 15:27:21 +0800 | [diff] [blame] | 15 | #include <base/threading/thread_task_runner_handle.h> |
| 16 | #include <brillo/dbus/dbus_method_response.h> |
| 17 | |
| 18 | namespace hwsec { |
| 19 | |
| 20 | // This class allows DBusMethodResponse to Return(), ReplyWithError() or |
| 21 | // destruct from any thread. However, it should be noted that the creation of |
| 22 | // this class should be on the original dbus thread, and this class does not |
| 23 | // handle the situation whereby Return() or ReplyWithError() is called from two |
| 24 | // different threads. (It is the task of the caller to ensure that each instance |
| 25 | // returns only once.) |
| 26 | template <typename... Types> |
| 27 | class ThreadSafeDBusMethodResponse |
| 28 | : public brillo::dbus_utils::DBusMethodResponse<Types...> { |
| 29 | public: |
| 30 | using BaseClass = brillo::dbus_utils::DBusMethodResponse<Types...>; |
| 31 | using DBusMethodResponse = brillo::dbus_utils::DBusMethodResponse<Types...>; |
| 32 | |
| 33 | ThreadSafeDBusMethodResponse(ThreadSafeDBusMethodResponse&& callback) = |
| 34 | default; |
Hidehiko Abe | 018e0fa | 2019-06-14 19:53:22 +0900 | [diff] [blame] | 35 | explicit ThreadSafeDBusMethodResponse(BaseClass&& original_callback) |
Meng-Huan Yu | 315e6e4 | 2019-05-20 15:27:21 +0800 | [diff] [blame] | 36 | : BaseClass::DBusMethodResponse( |
| 37 | nullptr, |
| 38 | base::Bind([](std::unique_ptr<dbus::Response>) { NOTREACHED(); })), |
| 39 | origin_task_runner_(base::ThreadTaskRunnerHandle::Get()), |
| 40 | origin_thread_id_(base::PlatformThread::CurrentId()), |
| 41 | original_callback_(new BaseClass(std::move(original_callback))) {} |
| 42 | |
| 43 | ~ThreadSafeDBusMethodResponse() override { |
| 44 | // The base class can only be destroyed on the original thread, |
| 45 | // because if this method haven't been sent, then it'll try to send an |
| 46 | // empty response, and that may only happen on the original thread. |
| 47 | // |
| 48 | // If we are not on the original thread, we move out the |
| 49 | // |original_callback_|. The callback will be destruct at original thread, |
| 50 | // and this class is safe to destruct in current thread. |
| 51 | if (!IsOnOriginalThread()) { |
| 52 | origin_task_runner_->PostTask( |
| 53 | FROM_HERE, |
| 54 | base::BindOnce([](const std::unique_ptr<BaseClass>& callback) {}, |
| 55 | std::move(original_callback_))); |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | void Return(const Types&... return_values) override { |
| 60 | if (IsOnOriginalThread()) { |
| 61 | original_callback_->Return(return_values...); |
| 62 | } else { |
| 63 | // We are not on the original thread, so we'll post it back |
| 64 | origin_task_runner_->PostTask( |
| 65 | FROM_HERE, |
| 66 | base::BindOnce(&BaseClass::Return, std::move(original_callback_), |
| 67 | return_values...)); |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | void ReplyWithError(const brillo::Error* error) override { |
| 72 | if (IsOnOriginalThread()) { |
| 73 | original_callback_->ReplyWithError(error); |
| 74 | } else { |
| 75 | // We are not on the original thread, so we'll post it back. |
| 76 | origin_task_runner_->PostTask( |
| 77 | FROM_HERE, base::BindOnce( |
| 78 | [](std::unique_ptr<BaseClass> callback, |
| 79 | std::unique_ptr<brillo::Error> error) { |
| 80 | callback->ReplyWithError(error.get()); |
| 81 | }, |
| 82 | std::move(original_callback_), error->Clone())); |
| 83 | } |
| 84 | } |
| 85 | |
Hidehiko Abe | 018e0fa | 2019-06-14 19:53:22 +0900 | [diff] [blame] | 86 | void ReplyWithError(const base::Location& location, |
Meng-Huan Yu | 315e6e4 | 2019-05-20 15:27:21 +0800 | [diff] [blame] | 87 | const std::string& error_domain, |
| 88 | const std::string& error_code, |
| 89 | const std::string& error_message) override { |
| 90 | if (IsOnOriginalThread()) { |
| 91 | original_callback_->ReplyWithError(location, error_domain, error_code, |
| 92 | error_message); |
| 93 | } else { |
| 94 | // We are not on the original thread, so we'll post it back. |
| 95 | origin_task_runner_->PostTask( |
| 96 | FROM_HERE, |
| 97 | base::BindOnce( |
| 98 | [](std::unique_ptr<BaseClass> original_callback, |
Hidehiko Abe | 018e0fa | 2019-06-14 19:53:22 +0900 | [diff] [blame] | 99 | const base::Location& location, |
Meng-Huan Yu | 315e6e4 | 2019-05-20 15:27:21 +0800 | [diff] [blame] | 100 | const std::string& error_domain, const std::string& error_code, |
| 101 | const std::string& error_message) { |
| 102 | original_callback->ReplyWithError(location, error_domain, |
| 103 | error_code, error_message); |
| 104 | }, |
| 105 | std::move(original_callback_), location, error_domain, error_code, |
| 106 | error_message)); |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | static std::unique_ptr<DBusMethodResponse> MakeThreadSafe( |
| 111 | std::unique_ptr<DBusMethodResponse> response) { |
| 112 | return std::make_unique<ThreadSafeDBusMethodResponse>(std::move(*response)); |
| 113 | } |
| 114 | |
| 115 | private: |
| 116 | bool IsOnOriginalThread() const { |
| 117 | return base::PlatformThread::CurrentId() == origin_thread_id_; |
| 118 | } |
| 119 | |
| 120 | // We record the task runner and thread id from which this object is created |
| 121 | // so that when Reply(), ReplyWithError() is called, we can verify if it's on |
| 122 | // the original thread, if it's not, we can post it. |
| 123 | scoped_refptr<base::SingleThreadTaskRunner> origin_task_runner_; |
| 124 | base::PlatformThreadId origin_thread_id_; |
| 125 | |
| 126 | // The instatnce of base class. It is initialized at constructor. |
| 127 | // Because it should operate on the original thread, we will pass it to the |
| 128 | // original thread when needed, and it will deconstruct at the original thread |
| 129 | // when the task is compelete. By the design of original callback, this class |
| 130 | // is not designed to be called twice, and the caller should handle this. |
| 131 | std::unique_ptr<BaseClass> original_callback_; |
| 132 | }; |
| 133 | |
| 134 | } // namespace hwsec |
| 135 | |
| 136 | #endif // LIBHWSEC_TASK_DISPATCHING_FRAMEWORK_H_ |