blob: 22cdcd0f6c4c2dc3249015bfe32e8768e4247c4c [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
shaochuan80f1fba2016-09-01 20:44:51 -07007// TODO(shaochuan): Remove this once clang supports uuid syntax in <robuffer.h>.
8// https://reviews.llvm.org/D23895
9namespace Windows {
10namespace Storage {
11namespace Streams {
qiankun.miao53f2d662016-09-02 17:44:08 -070012#pragma warning(disable : 4467)
shaochuan80f1fba2016-09-01 20:44:51 -070013struct __declspec(uuid("905a0fef-bc53-11df-8c49-001e4fc686da"))
14 IBufferByteAccess;
15}
16}
17}
18
shaochuane58f9c72016-08-30 22:27:08 -070019#include <comdef.h>
20#include <robuffer.h>
shaochuan17bc4a02016-09-06 01:42:12 -070021#include <setupapi.h>
shaochuane58f9c72016-08-30 22:27:08 -070022#include <windows.devices.enumeration.h>
23#include <windows.devices.midi.h>
24#include <wrl/event.h>
25
26#include <iomanip>
27#include <unordered_map>
28#include <unordered_set>
29
30#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070031#include "base/lazy_instance.h"
32#include "base/scoped_generic.h"
shaochuan17bc4a02016-09-06 01:42:12 -070033#include "base/strings/string_util.h"
shaochuane58f9c72016-08-30 22:27:08 -070034#include "base/strings/utf_string_conversions.h"
35#include "base/threading/thread_checker.h"
36#include "base/threading/thread_task_runner_handle.h"
37#include "base/timer/timer.h"
38#include "base/win/scoped_comptr.h"
shaochuane58f9c72016-08-30 22:27:08 -070039#include "media/midi/midi_scheduler.h"
40
41namespace media {
42namespace midi {
43namespace {
44
45namespace WRL = Microsoft::WRL;
46
47using namespace ABI::Windows::Devices::Enumeration;
48using namespace ABI::Windows::Devices::Midi;
49using namespace ABI::Windows::Foundation;
50using namespace ABI::Windows::Storage::Streams;
51
52using base::win::ScopedComPtr;
53
54// Helpers for printing HRESULTs.
55struct PrintHr {
56 PrintHr(HRESULT hr) : hr(hr) {}
57 HRESULT hr;
58};
59
60std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
61 std::ios_base::fmtflags ff = os.flags();
62 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
63 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
64 os.flags(ff);
65 return os;
66}
67
shaochuan9ff63b82016-09-01 01:58:44 -070068// Provides access to functions in combase.dll which may not be available on
69// Windows 7. Loads functions dynamically at runtime to prevent library
70// dependencies. Use this class through the global LazyInstance
71// |g_combase_functions|.
72class CombaseFunctions {
73 public:
74 CombaseFunctions() = default;
75
76 ~CombaseFunctions() {
77 if (combase_dll_)
78 ::FreeLibrary(combase_dll_);
79 }
80
81 bool LoadFunctions() {
82 combase_dll_ = ::LoadLibrary(L"combase.dll");
83 if (!combase_dll_)
84 return false;
85
86 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
87 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
88 if (!get_factory_func_)
89 return false;
90
91 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
92 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
93 if (!create_string_func_)
94 return false;
95
96 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
97 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
98 if (!delete_string_func_)
99 return false;
100
101 get_string_raw_buffer_func_ =
102 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
103 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
104 if (!get_string_raw_buffer_func_)
105 return false;
106
107 return true;
108 }
109
110 HRESULT RoGetActivationFactory(HSTRING class_id,
111 const IID& iid,
112 void** out_factory) {
113 DCHECK(get_factory_func_);
114 return get_factory_func_(class_id, iid, out_factory);
115 }
116
117 HRESULT WindowsCreateString(const base::char16* src,
118 uint32_t len,
119 HSTRING* out_hstr) {
120 DCHECK(create_string_func_);
121 return create_string_func_(src, len, out_hstr);
122 }
123
124 HRESULT WindowsDeleteString(HSTRING hstr) {
125 DCHECK(delete_string_func_);
126 return delete_string_func_(hstr);
127 }
128
129 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
130 uint32_t* out_len) {
131 DCHECK(get_string_raw_buffer_func_);
132 return get_string_raw_buffer_func_(hstr, out_len);
133 }
134
135 private:
136 HMODULE combase_dll_ = nullptr;
137
138 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
139 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
140 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
141 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
142};
143
144base::LazyInstance<CombaseFunctions> g_combase_functions =
145 LAZY_INSTANCE_INITIALIZER;
146
147// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
148// WindowsCreateString().
149class ScopedHStringTraits {
150 public:
151 static HSTRING InvalidValue() { return nullptr; }
152
153 static void Free(HSTRING hstr) {
154 g_combase_functions.Get().WindowsDeleteString(hstr);
155 }
156};
157
158class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
159 public:
160 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
161 HSTRING hstr;
162 HRESULT hr = g_combase_functions.Get().WindowsCreateString(
163 str, static_cast<uint32_t>(wcslen(str)), &hstr);
164 if (FAILED(hr))
165 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
166 else
167 reset(hstr);
168 }
169};
170
shaochuane58f9c72016-08-30 22:27:08 -0700171// Factory functions that activate and create WinRT components. The caller takes
172// ownership of the returning ComPtr.
173template <typename InterfaceType, base::char16 const* runtime_class_id>
174ScopedComPtr<InterfaceType> WrlStaticsFactory() {
175 ScopedComPtr<InterfaceType> com_ptr;
176
shaochuan9ff63b82016-09-01 01:58:44 -0700177 ScopedHString class_id_hstring(runtime_class_id);
178 if (!class_id_hstring.is_valid()) {
179 com_ptr = nullptr;
180 return com_ptr;
181 }
182
183 HRESULT hr = g_combase_functions.Get().RoGetActivationFactory(
184 class_id_hstring.get(), __uuidof(InterfaceType), com_ptr.ReceiveVoid());
shaochuane58f9c72016-08-30 22:27:08 -0700185 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700186 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700187 com_ptr = nullptr;
188 }
189
190 return com_ptr;
191}
192
shaochuan80f1fba2016-09-01 20:44:51 -0700193std::string HStringToString(HSTRING hstr) {
shaochuane58f9c72016-08-30 22:27:08 -0700194 // Note: empty HSTRINGs are represent as nullptr, and instantiating
195 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700196 const base::char16* buffer =
shaochuan80f1fba2016-09-01 20:44:51 -0700197 g_combase_functions.Get().WindowsGetStringRawBuffer(hstr, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700198 if (buffer)
199 return base::WideToUTF8(buffer);
200 return std::string();
201}
202
203template <typename T>
204std::string GetIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700205 HSTRING result;
206 HRESULT hr = obj->get_Id(&result);
207 if (FAILED(hr)) {
208 VLOG(1) << "get_Id failed: " << PrintHr(hr);
209 return std::string();
210 }
211 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700212}
213
214template <typename T>
215std::string GetDeviceIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700216 HSTRING result;
217 HRESULT hr = obj->get_DeviceId(&result);
218 if (FAILED(hr)) {
219 VLOG(1) << "get_DeviceId failed: " << PrintHr(hr);
220 return std::string();
221 }
222 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700223}
224
225std::string GetNameString(IDeviceInformation* info) {
shaochuan80f1fba2016-09-01 20:44:51 -0700226 HSTRING result;
227 HRESULT hr = info->get_Name(&result);
228 if (FAILED(hr)) {
229 VLOG(1) << "get_Name failed: " << PrintHr(hr);
230 return std::string();
231 }
232 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700233}
234
235HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
236 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
237
238 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
239 if (FAILED(hr)) {
240 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
241 return hr;
242 }
243
244 // Lifetime of the pointing buffer is controlled by the buffer object.
245 hr = buffer_byte_access->Buffer(out);
246 if (FAILED(hr)) {
247 VLOG(1) << "Buffer failed: " << PrintHr(hr);
248 return hr;
249 }
250
251 return S_OK;
252}
253
shaochuan110262b2016-08-31 02:15:16 -0700254// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
255// instance.
256bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
257 auto midi_synthesizer_statics =
258 WrlStaticsFactory<IMidiSynthesizerStatics,
259 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
260 boolean result = FALSE;
261 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
262 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
263 return result != FALSE;
264}
265
shaochuan17bc4a02016-09-06 01:42:12 -0700266class ScopedDeviceInfoListTraits {
267 public:
268 static HDEVINFO InvalidValue() { return INVALID_HANDLE_VALUE; }
269
270 static void Free(HDEVINFO devinfo_set) {
271 SetupDiDestroyDeviceInfoList(devinfo_set);
272 }
273};
274
275using ScopedDeviceInfoList =
276 base::ScopedGeneric<HDEVINFO, ScopedDeviceInfoListTraits>;
277
278// Retrieves manufacturer (provider) and version information of underlying
279// device driver using Setup API, given device (interface) ID provided by WinRT.
280// |out_manufacturer| and |out_driver_version| won't be modified if retrieval
281// fails. Note that SetupDiBuildDriverInfoList() would block for a few hundred
282// milliseconds so consider calling this function in an asynchronous manner.
283//
284// Device instance ID is extracted from device (interface) ID provided by WinRT
285// APIs, for example from the following interface ID:
286// \\?\SWD#MMDEVAPI#MIDII_60F39FCA.P_0002#{504be32c-ccf6-4d2c-b73f-6f8b3747e22b}
287// we extract the device instance ID: SWD\MMDEVAPI\MIDII_60F39FCA.P_0002
288void GetDriverInfoFromDeviceId(const std::string& dev_id,
289 std::string* out_manufacturer,
290 std::string* out_driver_version) {
291 base::string16 dev_instance_id =
292 base::UTF8ToWide(dev_id.substr(4, dev_id.size() - 43));
293 base::ReplaceChars(dev_instance_id, L"#", L"\\", &dev_instance_id);
294
295 ScopedDeviceInfoList devinfo_list(
296 SetupDiCreateDeviceInfoList(nullptr, nullptr));
297 if (!devinfo_list.is_valid()) {
298 VLOG(1) << "SetupDiCreateDeviceInfoList failed: "
299 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
300 return;
301 }
302
303 SP_DEVINFO_DATA devinfo_data = {0};
304 devinfo_data.cbSize = sizeof(SP_DEVINFO_DATA);
305 SP_DRVINFO_DATA drvinfo_data = {0};
306 drvinfo_data.cbSize = sizeof(SP_DRVINFO_DATA);
307
308 if (!SetupDiOpenDeviceInfo(devinfo_list.get(), dev_instance_id.c_str(),
309 nullptr, 0, &devinfo_data)) {
310 VLOG(1) << "SetupDiOpenDeviceInfo failed: "
311 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
312 return;
313 }
314
315 if (!SetupDiBuildDriverInfoList(devinfo_list.get(), &devinfo_data,
316 SPDIT_COMPATDRIVER)) {
317 VLOG(1) << "SetupDiBuildDriverInfoList failed: "
318 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
319 return;
320 }
321
322 // Assume only one entry in driver info list.
323 if (SetupDiEnumDriverInfo(devinfo_list.get(), &devinfo_data,
324 SPDIT_COMPATDRIVER, 0, &drvinfo_data)) {
325 *out_manufacturer = base::WideToUTF8(drvinfo_data.ProviderName);
326
327 std::stringstream ss;
328 ss << (drvinfo_data.DriverVersion >> 48) << "."
329 << (drvinfo_data.DriverVersion >> 32 & 0xffff) << "."
330 << (drvinfo_data.DriverVersion >> 16 & 0xffff) << "."
331 << (drvinfo_data.DriverVersion & 0xffff);
332 *out_driver_version = ss.str();
333 } else {
334 VLOG(1) << "SetupDiEnumDriverInfo failed: "
335 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
336 }
337
338 if (!SetupDiDestroyDriverInfoList(devinfo_list.get(), &devinfo_data,
339 SPDIT_COMPATDRIVER))
340 VLOG(1) << "SetupDiDestroyDriverInfoList failed: "
341 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
342}
343
shaochuane58f9c72016-08-30 22:27:08 -0700344// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
345const int64_t kInvalidTokenValue = 0;
346
347template <typename InterfaceType>
348struct MidiPort {
349 MidiPort() = default;
350
351 uint32_t index;
352 ScopedComPtr<InterfaceType> handle;
353 EventRegistrationToken token_MessageReceived;
354
355 private:
356 DISALLOW_COPY_AND_ASSIGN(MidiPort);
357};
358
359} // namespace
360
361template <typename InterfaceType,
362 typename RuntimeType,
363 typename StaticsInterfaceType,
364 base::char16 const* runtime_class_id>
365class MidiManagerWinrt::MidiPortManager {
366 public:
367 // MidiPortManager instances should be constructed on the COM thread.
368 MidiPortManager(MidiManagerWinrt* midi_manager)
369 : midi_manager_(midi_manager),
370 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
371
372 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
373
374 bool StartWatcher() {
375 DCHECK(thread_checker_.CalledOnValidThread());
376
377 HRESULT hr;
378
379 midi_port_statics_ =
380 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
381 if (!midi_port_statics_)
382 return false;
383
384 HSTRING device_selector = nullptr;
385 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
386 if (FAILED(hr)) {
387 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
388 return false;
389 }
390
391 auto dev_info_statics = WrlStaticsFactory<
392 IDeviceInformationStatics,
393 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
394 if (!dev_info_statics)
395 return false;
396
397 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
398 watcher_.Receive());
399 if (FAILED(hr)) {
400 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
401 return false;
402 }
403
404 // Register callbacks to WinRT that post state-modifying jobs back to COM
405 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
406 // posting jobs. Note that WinRT callback arguments should not be passed
407 // outside the callback since the pointers may be unavailable afterwards.
408 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
409 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
410
411 hr = watcher_->add_Added(
412 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
413 [weak_ptr, task_runner](IDeviceWatcher* watcher,
414 IDeviceInformation* info) {
shaochuan110262b2016-08-31 02:15:16 -0700415 // Disable Microsoft GS Wavetable Synth due to security reasons.
416 // http://crbug.com/499279
417 if (IsMicrosoftSynthesizer(info))
418 return S_OK;
419
shaochuane58f9c72016-08-30 22:27:08 -0700420 std::string dev_id = GetIdString(info),
421 dev_name = GetNameString(info);
422
423 task_runner->PostTask(
424 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
425 dev_id, dev_name));
426
427 return S_OK;
428 })
429 .Get(),
430 &token_Added_);
431 if (FAILED(hr)) {
432 VLOG(1) << "add_Added failed: " << PrintHr(hr);
433 return false;
434 }
435
436 hr = watcher_->add_EnumerationCompleted(
437 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
438 [weak_ptr, task_runner](IDeviceWatcher* watcher,
439 IInspectable* insp) {
440 task_runner->PostTask(
441 FROM_HERE,
442 base::Bind(&MidiPortManager::OnEnumerationCompleted,
443 weak_ptr));
444
445 return S_OK;
446 })
447 .Get(),
448 &token_EnumerationCompleted_);
449 if (FAILED(hr)) {
450 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
451 return false;
452 }
453
454 hr = watcher_->add_Removed(
455 WRL::Callback<
456 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
457 [weak_ptr, task_runner](IDeviceWatcher* watcher,
458 IDeviceInformationUpdate* update) {
459 std::string dev_id = GetIdString(update);
460
461 task_runner->PostTask(
462 FROM_HERE,
463 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
464
465 return S_OK;
466 })
467 .Get(),
468 &token_Removed_);
469 if (FAILED(hr)) {
470 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
471 return false;
472 }
473
474 hr = watcher_->add_Stopped(
475 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
476 [](IDeviceWatcher* watcher, IInspectable* insp) {
477 // Placeholder, does nothing for now.
478 return S_OK;
479 })
480 .Get(),
481 &token_Stopped_);
482 if (FAILED(hr)) {
483 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
484 return false;
485 }
486
487 hr = watcher_->add_Updated(
488 WRL::Callback<
489 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
490 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
491 // TODO(shaochuan): Check for fields to be updated here.
492 return S_OK;
493 })
494 .Get(),
495 &token_Updated_);
496 if (FAILED(hr)) {
497 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
498 return false;
499 }
500
501 hr = watcher_->Start();
502 if (FAILED(hr)) {
503 VLOG(1) << "Start failed: " << PrintHr(hr);
504 return false;
505 }
506
507 is_initialized_ = true;
508 return true;
509 }
510
511 void StopWatcher() {
512 DCHECK(thread_checker_.CalledOnValidThread());
513
514 HRESULT hr;
515
516 for (const auto& entry : ports_)
517 RemovePortEventHandlers(entry.second.get());
518
519 if (token_Added_.value != kInvalidTokenValue) {
520 hr = watcher_->remove_Added(token_Added_);
521 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
522 token_Added_.value = kInvalidTokenValue;
523 }
524 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
525 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
526 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
527 << PrintHr(hr);
528 token_EnumerationCompleted_.value = kInvalidTokenValue;
529 }
530 if (token_Removed_.value != kInvalidTokenValue) {
531 hr = watcher_->remove_Removed(token_Removed_);
532 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
533 token_Removed_.value = kInvalidTokenValue;
534 }
535 if (token_Stopped_.value != kInvalidTokenValue) {
536 hr = watcher_->remove_Stopped(token_Stopped_);
537 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
538 token_Stopped_.value = kInvalidTokenValue;
539 }
540 if (token_Updated_.value != kInvalidTokenValue) {
541 hr = watcher_->remove_Updated(token_Updated_);
542 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
543 token_Updated_.value = kInvalidTokenValue;
544 }
545
546 if (is_initialized_) {
547 hr = watcher_->Stop();
548 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
549 is_initialized_ = false;
550 }
551 }
552
553 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
554 DCHECK(thread_checker_.CalledOnValidThread());
555 CHECK(is_initialized_);
556
557 auto it = ports_.find(dev_id);
558 if (it == ports_.end())
559 return nullptr;
560 return it->second.get();
561 }
562
563 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
564 DCHECK(thread_checker_.CalledOnValidThread());
565 CHECK(is_initialized_);
566
567 return GetPortByDeviceId(port_ids_[port_index]);
568 }
569
570 protected:
571 // Points to the MidiManagerWinrt instance, which is expected to outlive the
572 // MidiPortManager instance.
573 MidiManagerWinrt* midi_manager_;
574
575 // Task runner of the COM thread.
576 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
577
578 // Ensures all methods are called on the COM thread.
579 base::ThreadChecker thread_checker_;
580
581 private:
582 // DeviceWatcher callbacks:
583 void OnAdded(std::string dev_id, std::string dev_name) {
584 DCHECK(thread_checker_.CalledOnValidThread());
585 CHECK(is_initialized_);
586
shaochuane58f9c72016-08-30 22:27:08 -0700587 port_names_[dev_id] = dev_name;
588
shaochuan9ff63b82016-09-01 01:58:44 -0700589 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
590 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700591 return;
shaochuane58f9c72016-08-30 22:27:08 -0700592
593 IAsyncOperation<RuntimeType*>* async_op;
594
shaochuan9ff63b82016-09-01 01:58:44 -0700595 HRESULT hr =
596 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700597 if (FAILED(hr)) {
598 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
599 return;
600 }
601
602 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
603 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
604
605 hr = async_op->put_Completed(
606 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
607 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
608 AsyncStatus status) {
609 InterfaceType* handle;
610 HRESULT hr = async_op->GetResults(&handle);
611 if (FAILED(hr)) {
612 VLOG(1) << "GetResults failed: " << PrintHr(hr);
613 return hr;
614 }
615
616 // A reference to |async_op| is kept in |async_ops_|, safe to pass
617 // outside.
618 task_runner->PostTask(
619 FROM_HERE,
620 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
621 weak_ptr, handle, async_op));
622
623 return S_OK;
624 })
625 .Get());
626 if (FAILED(hr)) {
627 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
628 return;
629 }
630
631 // Keep a reference to incompleted |async_op| for releasing later.
632 async_ops_.insert(async_op);
633 }
634
635 void OnEnumerationCompleted() {
636 DCHECK(thread_checker_.CalledOnValidThread());
637 CHECK(is_initialized_);
638
639 if (async_ops_.empty())
640 midi_manager_->OnPortManagerReady();
641 else
642 enumeration_completed_not_ready_ = true;
643 }
644
645 void OnRemoved(std::string dev_id) {
646 DCHECK(thread_checker_.CalledOnValidThread());
647 CHECK(is_initialized_);
648
shaochuan110262b2016-08-31 02:15:16 -0700649 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
650 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700651 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
652 if (!port) {
653 VLOG(1) << "Removing non-existent port " << dev_id;
654 return;
655 }
656
657 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
658
659 RemovePortEventHandlers(port);
660 port->handle = nullptr;
661 }
662
663 void OnCompletedGetPortFromIdAsync(InterfaceType* handle,
664 IAsyncOperation<RuntimeType*>* async_op) {
665 DCHECK(thread_checker_.CalledOnValidThread());
666 CHECK(is_initialized_);
667
668 EventRegistrationToken token = {kInvalidTokenValue};
669 if (!RegisterOnMessageReceived(handle, &token))
670 return;
671
672 std::string dev_id = GetDeviceIdString(handle);
673
674 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
675
676 if (port == nullptr) {
shaochuan17bc4a02016-09-06 01:42:12 -0700677 std::string manufacturer = "Unknown", driver_version = "Unknown";
678 GetDriverInfoFromDeviceId(dev_id, &manufacturer, &driver_version);
679
680 AddPort(MidiPortInfo(dev_id, manufacturer, port_names_[dev_id],
681 driver_version, MIDI_PORT_OPENED));
shaochuane58f9c72016-08-30 22:27:08 -0700682
683 port = new MidiPort<InterfaceType>;
684 port->index = static_cast<uint32_t>(port_ids_.size());
685
686 ports_[dev_id].reset(port);
687 port_ids_.push_back(dev_id);
688 } else {
689 SetPortState(port->index, MIDI_PORT_CONNECTED);
690 }
691
692 port->handle = handle;
693 port->token_MessageReceived = token;
694
695 // Manually release COM interface to completed |async_op|.
696 auto it = async_ops_.find(async_op);
697 CHECK(it != async_ops_.end());
698 (*it)->Release();
699 async_ops_.erase(it);
700
701 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
702 midi_manager_->OnPortManagerReady();
703 enumeration_completed_not_ready_ = false;
704 }
705 }
706
707 // Overrided by MidiInPortManager to listen to input ports.
708 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
709 EventRegistrationToken* p_token) {
710 return true;
711 }
712
713 // Overrided by MidiInPortManager to remove MessageReceived event handler.
714 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
715
716 // Calls midi_manager_->Add{Input,Output}Port.
717 virtual void AddPort(MidiPortInfo info) = 0;
718
719 // Calls midi_manager_->Set{Input,Output}PortState.
720 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
721
722 // WeakPtrFactory has to be declared in derived class, use this method to
723 // retrieve upcasted WeakPtr for posting tasks.
724 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
725
726 // Midi{In,Out}PortStatics instance.
727 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
728
729 // DeviceWatcher instance and event registration tokens for unsubscribing
730 // events in destructor.
731 ScopedComPtr<IDeviceWatcher> watcher_;
732 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
733 token_EnumerationCompleted_ = {kInvalidTokenValue},
734 token_Removed_ = {kInvalidTokenValue},
735 token_Stopped_ = {kInvalidTokenValue},
736 token_Updated_ = {kInvalidTokenValue};
737
738 // All manipulations to these fields should be done on COM thread.
739 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
740 ports_;
741 std::vector<std::string> port_ids_;
742 std::unordered_map<std::string, std::string> port_names_;
743
744 // Keeps AsyncOperation references before the operation completes. Note that
745 // raw pointers are used here and the COM interfaces should be released
746 // manually.
747 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
748
749 // Set when device enumeration is completed but OnPortManagerReady() is not
750 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
751 // In such cases, OnPortManagerReady() will be called in
752 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
753 bool enumeration_completed_not_ready_ = false;
754
755 // Set if the instance is initialized without error. Should be checked in all
756 // methods on COM thread except StartWatcher().
757 bool is_initialized_ = false;
758};
759
760class MidiManagerWinrt::MidiInPortManager final
761 : public MidiPortManager<IMidiInPort,
762 MidiInPort,
763 IMidiInPortStatics,
764 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
765 public:
766 MidiInPortManager(MidiManagerWinrt* midi_manager)
767 : MidiPortManager(midi_manager), weak_factory_(this) {}
768
769 private:
770 // MidiPortManager overrides:
771 bool RegisterOnMessageReceived(IMidiInPort* handle,
772 EventRegistrationToken* p_token) override {
773 DCHECK(thread_checker_.CalledOnValidThread());
774
775 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
776 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
777
778 HRESULT hr = handle->add_MessageReceived(
779 WRL::Callback<
780 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
781 [weak_ptr, task_runner](IMidiInPort* handle,
782 IMidiMessageReceivedEventArgs* args) {
783 const base::TimeTicks now = base::TimeTicks::Now();
784
785 std::string dev_id = GetDeviceIdString(handle);
786
787 ScopedComPtr<IMidiMessage> message;
788 HRESULT hr = args->get_Message(message.Receive());
789 if (FAILED(hr)) {
790 VLOG(1) << "get_Message failed: " << PrintHr(hr);
791 return hr;
792 }
793
794 ScopedComPtr<IBuffer> buffer;
795 hr = message->get_RawData(buffer.Receive());
796 if (FAILED(hr)) {
797 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
798 return hr;
799 }
800
801 uint8_t* p_buffer_data = nullptr;
802 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
803 if (FAILED(hr))
804 return hr;
805
806 uint32_t data_length = 0;
807 hr = buffer->get_Length(&data_length);
808 if (FAILED(hr)) {
809 VLOG(1) << "get_Length failed: " << PrintHr(hr);
810 return hr;
811 }
812
813 std::vector<uint8_t> data(p_buffer_data,
814 p_buffer_data + data_length);
815
816 task_runner->PostTask(
817 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
818 weak_ptr, dev_id, data, now));
819
820 return S_OK;
821 })
822 .Get(),
823 p_token);
824 if (FAILED(hr)) {
825 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
826 return false;
827 }
828
829 return true;
830 }
831
832 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
833 if (!(port->handle &&
834 port->token_MessageReceived.value != kInvalidTokenValue))
835 return;
836
837 HRESULT hr =
838 port->handle->remove_MessageReceived(port->token_MessageReceived);
839 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
840 port->token_MessageReceived.value = kInvalidTokenValue;
841 }
842
843 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
844
845 void SetPortState(uint32_t port_index, MidiPortState state) final {
846 midi_manager_->SetInputPortState(port_index, state);
847 }
848
849 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
850 DCHECK(thread_checker_.CalledOnValidThread());
851
852 return weak_factory_.GetWeakPtr();
853 }
854
855 // Callback on receiving MIDI input message.
856 void OnMessageReceived(std::string dev_id,
857 std::vector<uint8_t> data,
858 base::TimeTicks time) {
859 DCHECK(thread_checker_.CalledOnValidThread());
860
861 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
862 CHECK(port);
863
864 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
865 }
866
867 // Last member to ensure destructed first.
868 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
869
870 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
871};
872
873class MidiManagerWinrt::MidiOutPortManager final
874 : public MidiPortManager<IMidiOutPort,
875 IMidiOutPort,
876 IMidiOutPortStatics,
877 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
878 public:
879 MidiOutPortManager(MidiManagerWinrt* midi_manager)
880 : MidiPortManager(midi_manager), weak_factory_(this) {}
881
882 private:
883 // MidiPortManager overrides:
884 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
885
886 void SetPortState(uint32_t port_index, MidiPortState state) final {
887 midi_manager_->SetOutputPortState(port_index, state);
888 }
889
890 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
891 DCHECK(thread_checker_.CalledOnValidThread());
892
893 return weak_factory_.GetWeakPtr();
894 }
895
896 // Last member to ensure destructed first.
897 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
898
899 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
900};
901
902MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
903
904MidiManagerWinrt::~MidiManagerWinrt() {
905 base::AutoLock auto_lock(lazy_init_member_lock_);
906
907 CHECK(!com_thread_checker_);
908 CHECK(!port_manager_in_);
909 CHECK(!port_manager_out_);
910 CHECK(!scheduler_);
911}
912
913void MidiManagerWinrt::StartInitialization() {
shaochuane58f9c72016-08-30 22:27:08 -0700914 com_thread_.init_com_with_mta(true);
915 com_thread_.Start();
916
917 com_thread_.task_runner()->PostTask(
918 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
919 base::Unretained(this)));
920}
921
922void MidiManagerWinrt::Finalize() {
923 com_thread_.task_runner()->PostTask(
924 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
925 base::Unretained(this)));
926
927 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
928 // will be ignored.
929 com_thread_.Stop();
930}
931
932void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
933 uint32_t port_index,
934 const std::vector<uint8_t>& data,
935 double timestamp) {
936 CHECK(scheduler_);
937
938 scheduler_->PostSendDataTask(
939 client, data.size(), timestamp,
940 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
941 port_index, data));
942}
943
944void MidiManagerWinrt::InitializeOnComThread() {
945 base::AutoLock auto_lock(lazy_init_member_lock_);
946
947 com_thread_checker_.reset(new base::ThreadChecker);
948
shaochuan9ff63b82016-09-01 01:58:44 -0700949 if (!g_combase_functions.Get().LoadFunctions()) {
950 VLOG(1) << "Failed loading functions from combase.dll: "
951 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
952 CompleteInitialization(Result::INITIALIZATION_ERROR);
953 return;
954 }
955
shaochuane58f9c72016-08-30 22:27:08 -0700956 port_manager_in_.reset(new MidiInPortManager(this));
957 port_manager_out_.reset(new MidiOutPortManager(this));
958
959 scheduler_.reset(new MidiScheduler(this));
960
961 if (!(port_manager_in_->StartWatcher() &&
962 port_manager_out_->StartWatcher())) {
963 port_manager_in_->StopWatcher();
964 port_manager_out_->StopWatcher();
965 CompleteInitialization(Result::INITIALIZATION_ERROR);
966 }
967}
968
969void MidiManagerWinrt::FinalizeOnComThread() {
970 base::AutoLock auto_lock(lazy_init_member_lock_);
971
972 DCHECK(com_thread_checker_->CalledOnValidThread());
973
974 scheduler_.reset();
975
shaochuan9ff63b82016-09-01 01:58:44 -0700976 if (port_manager_in_) {
977 port_manager_in_->StopWatcher();
978 port_manager_in_.reset();
979 }
980
981 if (port_manager_out_) {
982 port_manager_out_->StopWatcher();
983 port_manager_out_.reset();
984 }
shaochuane58f9c72016-08-30 22:27:08 -0700985
986 com_thread_checker_.reset();
987}
988
989void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
990 const std::vector<uint8_t>& data) {
991 DCHECK(com_thread_checker_->CalledOnValidThread());
992
993 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
994 if (!(port && port->handle)) {
995 VLOG(1) << "Port not available: " << port_index;
996 return;
997 }
998
999 auto buffer_factory =
1000 WrlStaticsFactory<IBufferFactory,
1001 RuntimeClass_Windows_Storage_Streams_Buffer>();
1002 if (!buffer_factory)
1003 return;
1004
1005 ScopedComPtr<IBuffer> buffer;
1006 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
1007 buffer.Receive());
1008 if (FAILED(hr)) {
1009 VLOG(1) << "Create failed: " << PrintHr(hr);
1010 return;
1011 }
1012
1013 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
1014 if (FAILED(hr)) {
1015 VLOG(1) << "put_Length failed: " << PrintHr(hr);
1016 return;
1017 }
1018
1019 uint8_t* p_buffer_data = nullptr;
1020 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
1021 if (FAILED(hr))
1022 return;
1023
1024 std::copy(data.begin(), data.end(), p_buffer_data);
1025
1026 hr = port->handle->SendBuffer(buffer.get());
1027 if (FAILED(hr)) {
1028 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
1029 return;
1030 }
1031}
1032
1033void MidiManagerWinrt::OnPortManagerReady() {
1034 DCHECK(com_thread_checker_->CalledOnValidThread());
1035 DCHECK(port_manager_ready_count_ < 2);
1036
1037 if (++port_manager_ready_count_ == 2)
1038 CompleteInitialization(Result::OK);
1039}
1040
shaochuane58f9c72016-08-30 22:27:08 -07001041} // namespace midi
1042} // namespace media