blob: e2d19b431e69c1b942a2c33eaa75b8c6dc65e66c [file] [log] [blame]
tommic06b1332016-05-14 11:31:40 -07001/*
2 * Copyright 2016 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/base/task_queue.h"
12
13#include <string.h>
14#include <unordered_map>
15
16#include "webrtc/base/checks.h"
17#include "webrtc/base/logging.h"
18
19namespace rtc {
20namespace {
21#define WM_RUN_TASK WM_USER + 1
22#define WM_QUEUE_DELAYED_TASK WM_USER + 2
23
24DWORD g_queue_ptr_tls = 0;
25
26BOOL CALLBACK InitializeTls(PINIT_ONCE init_once, void* param, void** context) {
27 g_queue_ptr_tls = TlsAlloc();
28 return TRUE;
29}
30
31DWORD GetQueuePtrTls() {
32 static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
33 InitOnceExecuteOnce(&init_once, InitializeTls, nullptr, nullptr);
34 return g_queue_ptr_tls;
35}
36
37struct ThreadStartupData {
38 Event* started;
39 void* thread_context;
40};
41
42void CALLBACK InitializeQueueThread(ULONG_PTR param) {
43 MSG msg;
44 PeekMessage(&msg, NULL, WM_USER, WM_USER, PM_NOREMOVE);
45 ThreadStartupData* data = reinterpret_cast<ThreadStartupData*>(param);
46 TlsSetValue(GetQueuePtrTls(), data->thread_context);
47 data->started->Set();
48}
49} // namespace
50
51TaskQueue::TaskQueue(const char* queue_name)
52 : thread_(&TaskQueue::ThreadMain, this, queue_name) {
53 RTC_DCHECK(queue_name);
54 thread_.Start();
55 Event event(false, false);
56 ThreadStartupData startup = {&event, this};
57 RTC_CHECK(thread_.QueueAPC(&InitializeQueueThread,
58 reinterpret_cast<ULONG_PTR>(&startup)));
59 event.Wait(Event::kForever);
60}
61
62TaskQueue::~TaskQueue() {
63 RTC_DCHECK(!IsCurrent());
64 while (!PostThreadMessage(thread_.GetThreadRef(), WM_QUIT, 0, 0)) {
tommie1104112016-06-14 14:37:54 -070065 DWORD last_error = ::GetLastError();
66 if (last_error == ERROR_SUCCESS) {
67 // TODO(tommi): Figure out what's going on on the Win10 build bot when
68 // we get this error.
69 break;
70 }
71 RTC_CHECK_EQ(static_cast<DWORD>(ERROR_NOT_ENOUGH_QUOTA), last_error);
tommic06b1332016-05-14 11:31:40 -070072 Sleep(1);
73 }
74 thread_.Stop();
75}
76
77// static
78TaskQueue* TaskQueue::Current() {
79 return static_cast<TaskQueue*>(TlsGetValue(GetQueuePtrTls()));
80}
81
82// static
83bool TaskQueue::IsCurrent(const char* queue_name) {
84 TaskQueue* current = Current();
85 return current && current->thread_.name().compare(queue_name) == 0;
86}
87
88bool TaskQueue::IsCurrent() const {
89 return IsThreadRefEqual(thread_.GetThreadRef(), CurrentThreadRef());
90}
91
92void TaskQueue::PostTask(std::unique_ptr<QueuedTask> task) {
93 if (PostThreadMessage(thread_.GetThreadRef(), WM_RUN_TASK, 0,
94 reinterpret_cast<LPARAM>(task.get()))) {
95 task.release();
96 }
97}
98
99void TaskQueue::PostDelayedTask(std::unique_ptr<QueuedTask> task,
100 uint32_t milliseconds) {
101 WPARAM wparam;
102#if defined(_WIN64)
103 // GetTickCount() returns a fairly coarse tick count (resolution or about 8ms)
104 // so this compensation isn't that accurate, but since we have unused 32 bits
105 // on Win64, we might as well use them.
106 wparam = (static_cast<WPARAM>(::GetTickCount()) << 32) | milliseconds;
107#else
108 wparam = milliseconds;
109#endif
110 if (PostThreadMessage(thread_.GetThreadRef(), WM_QUEUE_DELAYED_TASK, wparam,
111 reinterpret_cast<LPARAM>(task.get()))) {
112 task.release();
113 }
114}
115
116void TaskQueue::PostTaskAndReply(std::unique_ptr<QueuedTask> task,
117 std::unique_ptr<QueuedTask> reply,
118 TaskQueue* reply_queue) {
119 QueuedTask* task_ptr = task.release();
120 QueuedTask* reply_task_ptr = reply.release();
121 DWORD reply_thread_id = reply_queue->thread_.GetThreadRef();
122 PostTask([task_ptr, reply_task_ptr, reply_thread_id]() {
123 if (task_ptr->Run())
124 delete task_ptr;
125 // If the thread's message queue is full, we can't queue the task and will
126 // have to drop it (i.e. delete).
127 if (!PostThreadMessage(reply_thread_id, WM_RUN_TASK, 0,
128 reinterpret_cast<LPARAM>(reply_task_ptr))) {
129 delete reply_task_ptr;
130 }
131 });
132}
133
134void TaskQueue::PostTaskAndReply(std::unique_ptr<QueuedTask> task,
135 std::unique_ptr<QueuedTask> reply) {
136 return PostTaskAndReply(std::move(task), std::move(reply), Current());
137}
138
139// static
140bool TaskQueue::ThreadMain(void* context) {
141 std::unordered_map<UINT_PTR, std::unique_ptr<QueuedTask>> delayed_tasks;
142
143 BOOL ret;
144 MSG msg;
145
146 while ((ret = GetMessage(&msg, nullptr, 0, 0)) != 0 && ret != -1) {
147 if (!msg.hwnd) {
148 switch (msg.message) {
149 case WM_RUN_TASK: {
150 QueuedTask* task = reinterpret_cast<QueuedTask*>(msg.lParam);
151 if (task->Run())
152 delete task;
153 break;
154 }
155 case WM_QUEUE_DELAYED_TASK: {
156 QueuedTask* task = reinterpret_cast<QueuedTask*>(msg.lParam);
157 uint32_t milliseconds = msg.wParam & 0xFFFFFFFF;
158#if defined(_WIN64)
159 // Subtract the time it took to queue the timer.
160 const DWORD now = GetTickCount();
161 DWORD post_time = now - (msg.wParam >> 32);
162 milliseconds =
163 post_time > milliseconds ? 0 : milliseconds - post_time;
164#endif
165 UINT_PTR timer_id = SetTimer(nullptr, 0, milliseconds, nullptr);
166 delayed_tasks.insert(std::make_pair(timer_id, task));
167 break;
168 }
169 case WM_TIMER: {
170 KillTimer(nullptr, msg.wParam);
171 auto found = delayed_tasks.find(msg.wParam);
172 RTC_DCHECK(found != delayed_tasks.end());
173 if (!found->second->Run())
174 found->second.release();
175 delayed_tasks.erase(found);
176 break;
177 }
178 default:
179 RTC_NOTREACHED();
180 break;
181 }
182 } else {
183 TranslateMessage(&msg);
184 DispatchMessage(&msg);
185 }
186 }
187
188 return false;
189}
190} // namespace rtc