blob: e8f4c02a7a1a580914bc5501748e5413de616fc4 [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
qiankun.miao53f2d662016-09-02 17:44:08 -07007#pragma warning(disable : 4467)
shaochuan80f1fba2016-09-01 20:44:51 -07008
shaochuan4eff30e2016-09-09 01:24:14 -07009#include <initguid.h> // Required by <devpkey.h>
10
11#include <cfgmgr32.h>
shaochuane58f9c72016-08-30 22:27:08 -070012#include <comdef.h>
shaochuan4eff30e2016-09-09 01:24:14 -070013#include <devpkey.h>
shaochuane58f9c72016-08-30 22:27:08 -070014#include <robuffer.h>
15#include <windows.devices.enumeration.h>
16#include <windows.devices.midi.h>
17#include <wrl/event.h>
18
19#include <iomanip>
20#include <unordered_map>
21#include <unordered_set>
22
23#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070024#include "base/lazy_instance.h"
25#include "base/scoped_generic.h"
shaochuan17bc4a02016-09-06 01:42:12 -070026#include "base/strings/string_util.h"
shaochuane58f9c72016-08-30 22:27:08 -070027#include "base/strings/utf_string_conversions.h"
28#include "base/threading/thread_checker.h"
29#include "base/threading/thread_task_runner_handle.h"
30#include "base/timer/timer.h"
31#include "base/win/scoped_comptr.h"
shaochuane58f9c72016-08-30 22:27:08 -070032#include "media/midi/midi_scheduler.h"
33
shaochuane58f9c72016-08-30 22:27:08 -070034namespace midi {
35namespace {
36
37namespace WRL = Microsoft::WRL;
38
39using namespace ABI::Windows::Devices::Enumeration;
40using namespace ABI::Windows::Devices::Midi;
41using namespace ABI::Windows::Foundation;
42using namespace ABI::Windows::Storage::Streams;
43
44using base::win::ScopedComPtr;
45
46// Helpers for printing HRESULTs.
47struct PrintHr {
48 PrintHr(HRESULT hr) : hr(hr) {}
49 HRESULT hr;
50};
51
52std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
53 std::ios_base::fmtflags ff = os.flags();
54 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
55 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
56 os.flags(ff);
57 return os;
58}
59
shaochuan9ff63b82016-09-01 01:58:44 -070060// Provides access to functions in combase.dll which may not be available on
61// Windows 7. Loads functions dynamically at runtime to prevent library
62// dependencies. Use this class through the global LazyInstance
63// |g_combase_functions|.
64class CombaseFunctions {
65 public:
66 CombaseFunctions() = default;
67
68 ~CombaseFunctions() {
69 if (combase_dll_)
70 ::FreeLibrary(combase_dll_);
71 }
72
73 bool LoadFunctions() {
74 combase_dll_ = ::LoadLibrary(L"combase.dll");
75 if (!combase_dll_)
76 return false;
77
78 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
79 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
80 if (!get_factory_func_)
81 return false;
82
83 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
84 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
85 if (!create_string_func_)
86 return false;
87
88 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
89 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
90 if (!delete_string_func_)
91 return false;
92
93 get_string_raw_buffer_func_ =
94 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
95 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
96 if (!get_string_raw_buffer_func_)
97 return false;
98
99 return true;
100 }
101
102 HRESULT RoGetActivationFactory(HSTRING class_id,
103 const IID& iid,
104 void** out_factory) {
105 DCHECK(get_factory_func_);
106 return get_factory_func_(class_id, iid, out_factory);
107 }
108
109 HRESULT WindowsCreateString(const base::char16* src,
110 uint32_t len,
111 HSTRING* out_hstr) {
112 DCHECK(create_string_func_);
113 return create_string_func_(src, len, out_hstr);
114 }
115
116 HRESULT WindowsDeleteString(HSTRING hstr) {
117 DCHECK(delete_string_func_);
118 return delete_string_func_(hstr);
119 }
120
121 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
122 uint32_t* out_len) {
123 DCHECK(get_string_raw_buffer_func_);
124 return get_string_raw_buffer_func_(hstr, out_len);
125 }
126
127 private:
128 HMODULE combase_dll_ = nullptr;
129
130 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
131 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
132 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
133 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
134};
135
136base::LazyInstance<CombaseFunctions> g_combase_functions =
137 LAZY_INSTANCE_INITIALIZER;
138
139// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
140// WindowsCreateString().
141class ScopedHStringTraits {
142 public:
143 static HSTRING InvalidValue() { return nullptr; }
144
145 static void Free(HSTRING hstr) {
146 g_combase_functions.Get().WindowsDeleteString(hstr);
147 }
148};
149
150class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
151 public:
152 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
153 HSTRING hstr;
154 HRESULT hr = g_combase_functions.Get().WindowsCreateString(
155 str, static_cast<uint32_t>(wcslen(str)), &hstr);
156 if (FAILED(hr))
157 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
158 else
159 reset(hstr);
160 }
161};
162
shaochuane58f9c72016-08-30 22:27:08 -0700163// Factory functions that activate and create WinRT components. The caller takes
164// ownership of the returning ComPtr.
165template <typename InterfaceType, base::char16 const* runtime_class_id>
166ScopedComPtr<InterfaceType> WrlStaticsFactory() {
167 ScopedComPtr<InterfaceType> com_ptr;
168
shaochuan9ff63b82016-09-01 01:58:44 -0700169 ScopedHString class_id_hstring(runtime_class_id);
170 if (!class_id_hstring.is_valid()) {
171 com_ptr = nullptr;
172 return com_ptr;
173 }
174
175 HRESULT hr = g_combase_functions.Get().RoGetActivationFactory(
176 class_id_hstring.get(), __uuidof(InterfaceType), com_ptr.ReceiveVoid());
shaochuane58f9c72016-08-30 22:27:08 -0700177 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700178 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700179 com_ptr = nullptr;
180 }
181
182 return com_ptr;
183}
184
shaochuan80f1fba2016-09-01 20:44:51 -0700185std::string HStringToString(HSTRING hstr) {
shaochuane58f9c72016-08-30 22:27:08 -0700186 // Note: empty HSTRINGs are represent as nullptr, and instantiating
187 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700188 const base::char16* buffer =
shaochuan80f1fba2016-09-01 20:44:51 -0700189 g_combase_functions.Get().WindowsGetStringRawBuffer(hstr, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700190 if (buffer)
191 return base::WideToUTF8(buffer);
192 return std::string();
193}
194
195template <typename T>
196std::string GetIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700197 HSTRING result;
198 HRESULT hr = obj->get_Id(&result);
199 if (FAILED(hr)) {
200 VLOG(1) << "get_Id failed: " << PrintHr(hr);
201 return std::string();
202 }
203 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700204}
205
206template <typename T>
207std::string GetDeviceIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700208 HSTRING result;
209 HRESULT hr = obj->get_DeviceId(&result);
210 if (FAILED(hr)) {
211 VLOG(1) << "get_DeviceId failed: " << PrintHr(hr);
212 return std::string();
213 }
214 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700215}
216
217std::string GetNameString(IDeviceInformation* info) {
shaochuan80f1fba2016-09-01 20:44:51 -0700218 HSTRING result;
219 HRESULT hr = info->get_Name(&result);
220 if (FAILED(hr)) {
221 VLOG(1) << "get_Name failed: " << PrintHr(hr);
222 return std::string();
223 }
224 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700225}
226
227HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
228 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
229
230 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
231 if (FAILED(hr)) {
232 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
233 return hr;
234 }
235
236 // Lifetime of the pointing buffer is controlled by the buffer object.
237 hr = buffer_byte_access->Buffer(out);
238 if (FAILED(hr)) {
239 VLOG(1) << "Buffer failed: " << PrintHr(hr);
240 return hr;
241 }
242
243 return S_OK;
244}
245
shaochuan110262b2016-08-31 02:15:16 -0700246// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
247// instance.
248bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
249 auto midi_synthesizer_statics =
250 WrlStaticsFactory<IMidiSynthesizerStatics,
251 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
252 boolean result = FALSE;
253 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
254 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
255 return result != FALSE;
256}
257
shaochuan4eff30e2016-09-09 01:24:14 -0700258void GetDevPropString(DEVINST handle,
259 const DEVPROPKEY* devprop_key,
260 std::string* out) {
261 DEVPROPTYPE devprop_type;
262 unsigned long buffer_size = 0;
shaochuan17bc4a02016-09-06 01:42:12 -0700263
shaochuan4eff30e2016-09-09 01:24:14 -0700264 // Retrieve |buffer_size| and allocate buffer later for receiving data.
265 CONFIGRET cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type,
266 nullptr, &buffer_size, 0);
267 if (cr != CR_BUFFER_SMALL) {
268 // Here we print error codes in hex instead of using PrintHr() with
269 // HRESULT_FROM_WIN32() and CM_MapCrToWin32Err(), since only a minor set of
270 // CONFIGRET values are mapped to Win32 errors. Same for following VLOG()s.
271 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
272 return;
shaochuan17bc4a02016-09-06 01:42:12 -0700273 }
shaochuan4eff30e2016-09-09 01:24:14 -0700274 if (devprop_type != DEVPROP_TYPE_STRING) {
275 VLOG(1) << "CM_Get_DevNode_Property returns wrong data type, "
276 << "expected DEVPROP_TYPE_STRING";
277 return;
278 }
shaochuan17bc4a02016-09-06 01:42:12 -0700279
shaochuan4eff30e2016-09-09 01:24:14 -0700280 std::unique_ptr<uint8_t[]> buffer(new uint8_t[buffer_size]);
281
282 // Receive property data.
283 cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type, buffer.get(),
284 &buffer_size, 0);
285 if (cr != CR_SUCCESS)
286 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
287 else
288 *out = base::WideToUTF8(reinterpret_cast<base::char16*>(buffer.get()));
289}
shaochuan17bc4a02016-09-06 01:42:12 -0700290
291// Retrieves manufacturer (provider) and version information of underlying
shaochuan4eff30e2016-09-09 01:24:14 -0700292// device driver through PnP Configuration Manager, given device (interface) ID
293// provided by WinRT. |out_manufacturer| and |out_driver_version| won't be
294// modified if retrieval fails.
shaochuan17bc4a02016-09-06 01:42:12 -0700295//
296// Device instance ID is extracted from device (interface) ID provided by WinRT
297// APIs, for example from the following interface ID:
298// \\?\SWD#MMDEVAPI#MIDII_60F39FCA.P_0002#{504be32c-ccf6-4d2c-b73f-6f8b3747e22b}
299// we extract the device instance ID: SWD\MMDEVAPI\MIDII_60F39FCA.P_0002
shaochuan4eff30e2016-09-09 01:24:14 -0700300//
301// However the extracted device instance ID represent a "software device"
302// provided by Microsoft, which is an interface on top of the hardware for each
303// input/output port. Therefore we further locate its parent device, which is
304// the actual hardware device, for driver information.
shaochuan17bc4a02016-09-06 01:42:12 -0700305void GetDriverInfoFromDeviceId(const std::string& dev_id,
306 std::string* out_manufacturer,
307 std::string* out_driver_version) {
308 base::string16 dev_instance_id =
309 base::UTF8ToWide(dev_id.substr(4, dev_id.size() - 43));
310 base::ReplaceChars(dev_instance_id, L"#", L"\\", &dev_instance_id);
311
shaochuan4eff30e2016-09-09 01:24:14 -0700312 DEVINST dev_instance_handle;
313 CONFIGRET cr = CM_Locate_DevNode(&dev_instance_handle, &dev_instance_id[0],
314 CM_LOCATE_DEVNODE_NORMAL);
315 if (cr != CR_SUCCESS) {
316 VLOG(1) << "CM_Locate_DevNode failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700317 return;
318 }
319
shaochuan4eff30e2016-09-09 01:24:14 -0700320 DEVINST parent_handle;
321 cr = CM_Get_Parent(&parent_handle, dev_instance_handle, 0);
322 if (cr != CR_SUCCESS) {
323 VLOG(1) << "CM_Get_Parent failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700324 return;
325 }
326
shaochuan4eff30e2016-09-09 01:24:14 -0700327 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverProvider,
328 out_manufacturer);
329 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverVersion,
330 out_driver_version);
shaochuan17bc4a02016-09-06 01:42:12 -0700331}
332
shaochuane58f9c72016-08-30 22:27:08 -0700333// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
334const int64_t kInvalidTokenValue = 0;
335
336template <typename InterfaceType>
337struct MidiPort {
338 MidiPort() = default;
339
340 uint32_t index;
341 ScopedComPtr<InterfaceType> handle;
342 EventRegistrationToken token_MessageReceived;
343
344 private:
345 DISALLOW_COPY_AND_ASSIGN(MidiPort);
346};
347
348} // namespace
349
350template <typename InterfaceType,
351 typename RuntimeType,
352 typename StaticsInterfaceType,
353 base::char16 const* runtime_class_id>
354class MidiManagerWinrt::MidiPortManager {
355 public:
356 // MidiPortManager instances should be constructed on the COM thread.
357 MidiPortManager(MidiManagerWinrt* midi_manager)
358 : midi_manager_(midi_manager),
359 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
360
361 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
362
363 bool StartWatcher() {
364 DCHECK(thread_checker_.CalledOnValidThread());
365
366 HRESULT hr;
367
368 midi_port_statics_ =
369 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
370 if (!midi_port_statics_)
371 return false;
372
373 HSTRING device_selector = nullptr;
374 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
375 if (FAILED(hr)) {
376 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
377 return false;
378 }
379
380 auto dev_info_statics = WrlStaticsFactory<
381 IDeviceInformationStatics,
382 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
383 if (!dev_info_statics)
384 return false;
385
386 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
387 watcher_.Receive());
388 if (FAILED(hr)) {
389 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
390 return false;
391 }
392
393 // Register callbacks to WinRT that post state-modifying jobs back to COM
394 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
395 // posting jobs. Note that WinRT callback arguments should not be passed
396 // outside the callback since the pointers may be unavailable afterwards.
397 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
398 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
399
400 hr = watcher_->add_Added(
401 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
402 [weak_ptr, task_runner](IDeviceWatcher* watcher,
403 IDeviceInformation* info) {
shaochuanc2894522016-09-20 01:10:50 -0700404 if (!info) {
405 VLOG(1) << "DeviceWatcher.Added callback provides null "
406 "pointer, ignoring";
407 return S_OK;
408 }
409
shaochuan110262b2016-08-31 02:15:16 -0700410 // Disable Microsoft GS Wavetable Synth due to security reasons.
411 // http://crbug.com/499279
412 if (IsMicrosoftSynthesizer(info))
413 return S_OK;
414
shaochuane58f9c72016-08-30 22:27:08 -0700415 std::string dev_id = GetIdString(info),
416 dev_name = GetNameString(info);
417
418 task_runner->PostTask(
419 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
420 dev_id, dev_name));
421
422 return S_OK;
423 })
424 .Get(),
425 &token_Added_);
426 if (FAILED(hr)) {
427 VLOG(1) << "add_Added failed: " << PrintHr(hr);
428 return false;
429 }
430
431 hr = watcher_->add_EnumerationCompleted(
432 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
433 [weak_ptr, task_runner](IDeviceWatcher* watcher,
434 IInspectable* insp) {
435 task_runner->PostTask(
436 FROM_HERE,
437 base::Bind(&MidiPortManager::OnEnumerationCompleted,
438 weak_ptr));
439
440 return S_OK;
441 })
442 .Get(),
443 &token_EnumerationCompleted_);
444 if (FAILED(hr)) {
445 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
446 return false;
447 }
448
449 hr = watcher_->add_Removed(
450 WRL::Callback<
451 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
452 [weak_ptr, task_runner](IDeviceWatcher* watcher,
453 IDeviceInformationUpdate* update) {
shaochuanc2894522016-09-20 01:10:50 -0700454 if (!update) {
455 VLOG(1) << "DeviceWatcher.Removed callback provides null "
456 "pointer, ignoring";
457 return S_OK;
458 }
459
shaochuane58f9c72016-08-30 22:27:08 -0700460 std::string dev_id = GetIdString(update);
461
462 task_runner->PostTask(
463 FROM_HERE,
464 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
465
466 return S_OK;
467 })
468 .Get(),
469 &token_Removed_);
470 if (FAILED(hr)) {
471 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
472 return false;
473 }
474
475 hr = watcher_->add_Stopped(
476 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
477 [](IDeviceWatcher* watcher, IInspectable* insp) {
478 // Placeholder, does nothing for now.
479 return S_OK;
480 })
481 .Get(),
482 &token_Stopped_);
483 if (FAILED(hr)) {
484 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
485 return false;
486 }
487
488 hr = watcher_->add_Updated(
489 WRL::Callback<
490 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
491 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
492 // TODO(shaochuan): Check for fields to be updated here.
493 return S_OK;
494 })
495 .Get(),
496 &token_Updated_);
497 if (FAILED(hr)) {
498 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
499 return false;
500 }
501
502 hr = watcher_->Start();
503 if (FAILED(hr)) {
504 VLOG(1) << "Start failed: " << PrintHr(hr);
505 return false;
506 }
507
508 is_initialized_ = true;
509 return true;
510 }
511
512 void StopWatcher() {
513 DCHECK(thread_checker_.CalledOnValidThread());
514
515 HRESULT hr;
516
517 for (const auto& entry : ports_)
518 RemovePortEventHandlers(entry.second.get());
519
520 if (token_Added_.value != kInvalidTokenValue) {
521 hr = watcher_->remove_Added(token_Added_);
522 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
523 token_Added_.value = kInvalidTokenValue;
524 }
525 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
526 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
527 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
528 << PrintHr(hr);
529 token_EnumerationCompleted_.value = kInvalidTokenValue;
530 }
531 if (token_Removed_.value != kInvalidTokenValue) {
532 hr = watcher_->remove_Removed(token_Removed_);
533 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
534 token_Removed_.value = kInvalidTokenValue;
535 }
536 if (token_Stopped_.value != kInvalidTokenValue) {
537 hr = watcher_->remove_Stopped(token_Stopped_);
538 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
539 token_Stopped_.value = kInvalidTokenValue;
540 }
541 if (token_Updated_.value != kInvalidTokenValue) {
542 hr = watcher_->remove_Updated(token_Updated_);
543 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
544 token_Updated_.value = kInvalidTokenValue;
545 }
546
547 if (is_initialized_) {
548 hr = watcher_->Stop();
549 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
550 is_initialized_ = false;
551 }
552 }
553
554 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
555 DCHECK(thread_checker_.CalledOnValidThread());
556 CHECK(is_initialized_);
557
558 auto it = ports_.find(dev_id);
559 if (it == ports_.end())
560 return nullptr;
561 return it->second.get();
562 }
563
564 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
565 DCHECK(thread_checker_.CalledOnValidThread());
566 CHECK(is_initialized_);
567
568 return GetPortByDeviceId(port_ids_[port_index]);
569 }
570
571 protected:
572 // Points to the MidiManagerWinrt instance, which is expected to outlive the
573 // MidiPortManager instance.
574 MidiManagerWinrt* midi_manager_;
575
576 // Task runner of the COM thread.
577 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
578
579 // Ensures all methods are called on the COM thread.
580 base::ThreadChecker thread_checker_;
581
582 private:
583 // DeviceWatcher callbacks:
584 void OnAdded(std::string dev_id, std::string dev_name) {
585 DCHECK(thread_checker_.CalledOnValidThread());
586 CHECK(is_initialized_);
587
shaochuane58f9c72016-08-30 22:27:08 -0700588 port_names_[dev_id] = dev_name;
589
shaochuan9ff63b82016-09-01 01:58:44 -0700590 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
591 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700592 return;
shaochuane58f9c72016-08-30 22:27:08 -0700593
594 IAsyncOperation<RuntimeType*>* async_op;
595
shaochuan9ff63b82016-09-01 01:58:44 -0700596 HRESULT hr =
597 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700598 if (FAILED(hr)) {
599 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
600 return;
601 }
602
603 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
604 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
605
606 hr = async_op->put_Completed(
607 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
608 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
609 AsyncStatus status) {
shaochuane58f9c72016-08-30 22:27:08 -0700610 // A reference to |async_op| is kept in |async_ops_|, safe to pass
611 // outside.
612 task_runner->PostTask(
613 FROM_HERE,
614 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
shaochuanc2894522016-09-20 01:10:50 -0700615 weak_ptr, async_op));
shaochuane58f9c72016-08-30 22:27:08 -0700616
617 return S_OK;
618 })
619 .Get());
620 if (FAILED(hr)) {
621 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
622 return;
623 }
624
625 // Keep a reference to incompleted |async_op| for releasing later.
626 async_ops_.insert(async_op);
627 }
628
629 void OnEnumerationCompleted() {
630 DCHECK(thread_checker_.CalledOnValidThread());
631 CHECK(is_initialized_);
632
633 if (async_ops_.empty())
634 midi_manager_->OnPortManagerReady();
635 else
636 enumeration_completed_not_ready_ = true;
637 }
638
639 void OnRemoved(std::string dev_id) {
640 DCHECK(thread_checker_.CalledOnValidThread());
641 CHECK(is_initialized_);
642
shaochuan110262b2016-08-31 02:15:16 -0700643 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
644 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700645 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
646 if (!port) {
647 VLOG(1) << "Removing non-existent port " << dev_id;
648 return;
649 }
650
651 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
652
653 RemovePortEventHandlers(port);
654 port->handle = nullptr;
655 }
656
shaochuanc2894522016-09-20 01:10:50 -0700657 void OnCompletedGetPortFromIdAsync(IAsyncOperation<RuntimeType*>* async_op) {
shaochuane58f9c72016-08-30 22:27:08 -0700658 DCHECK(thread_checker_.CalledOnValidThread());
659 CHECK(is_initialized_);
660
shaochuanc2894522016-09-20 01:10:50 -0700661 InterfaceType* handle = nullptr;
662 HRESULT hr = async_op->GetResults(&handle);
663 if (FAILED(hr)) {
664 VLOG(1) << "GetResults failed: " << PrintHr(hr);
665 return;
666 }
667
668 // Manually release COM interface to completed |async_op|.
669 auto it = async_ops_.find(async_op);
670 CHECK(it != async_ops_.end());
671 (*it)->Release();
672 async_ops_.erase(it);
673
674 if (!handle) {
675 VLOG(1) << "Midi{In,Out}Port.FromIdAsync callback provides null pointer, "
676 "ignoring";
677 return;
678 }
679
shaochuane58f9c72016-08-30 22:27:08 -0700680 EventRegistrationToken token = {kInvalidTokenValue};
681 if (!RegisterOnMessageReceived(handle, &token))
682 return;
683
684 std::string dev_id = GetDeviceIdString(handle);
685
686 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
687
688 if (port == nullptr) {
shaochuan17bc4a02016-09-06 01:42:12 -0700689 std::string manufacturer = "Unknown", driver_version = "Unknown";
690 GetDriverInfoFromDeviceId(dev_id, &manufacturer, &driver_version);
691
692 AddPort(MidiPortInfo(dev_id, manufacturer, port_names_[dev_id],
693 driver_version, MIDI_PORT_OPENED));
shaochuane58f9c72016-08-30 22:27:08 -0700694
695 port = new MidiPort<InterfaceType>;
696 port->index = static_cast<uint32_t>(port_ids_.size());
697
698 ports_[dev_id].reset(port);
699 port_ids_.push_back(dev_id);
700 } else {
701 SetPortState(port->index, MIDI_PORT_CONNECTED);
702 }
703
704 port->handle = handle;
705 port->token_MessageReceived = token;
706
shaochuane58f9c72016-08-30 22:27:08 -0700707 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
708 midi_manager_->OnPortManagerReady();
709 enumeration_completed_not_ready_ = false;
710 }
711 }
712
713 // Overrided by MidiInPortManager to listen to input ports.
714 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
715 EventRegistrationToken* p_token) {
716 return true;
717 }
718
719 // Overrided by MidiInPortManager to remove MessageReceived event handler.
720 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
721
722 // Calls midi_manager_->Add{Input,Output}Port.
723 virtual void AddPort(MidiPortInfo info) = 0;
724
725 // Calls midi_manager_->Set{Input,Output}PortState.
726 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
727
728 // WeakPtrFactory has to be declared in derived class, use this method to
729 // retrieve upcasted WeakPtr for posting tasks.
730 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
731
732 // Midi{In,Out}PortStatics instance.
733 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
734
735 // DeviceWatcher instance and event registration tokens for unsubscribing
736 // events in destructor.
737 ScopedComPtr<IDeviceWatcher> watcher_;
738 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
739 token_EnumerationCompleted_ = {kInvalidTokenValue},
740 token_Removed_ = {kInvalidTokenValue},
741 token_Stopped_ = {kInvalidTokenValue},
742 token_Updated_ = {kInvalidTokenValue};
743
744 // All manipulations to these fields should be done on COM thread.
745 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
746 ports_;
747 std::vector<std::string> port_ids_;
748 std::unordered_map<std::string, std::string> port_names_;
749
750 // Keeps AsyncOperation references before the operation completes. Note that
751 // raw pointers are used here and the COM interfaces should be released
752 // manually.
753 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
754
755 // Set when device enumeration is completed but OnPortManagerReady() is not
756 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
757 // In such cases, OnPortManagerReady() will be called in
758 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
759 bool enumeration_completed_not_ready_ = false;
760
761 // Set if the instance is initialized without error. Should be checked in all
762 // methods on COM thread except StartWatcher().
763 bool is_initialized_ = false;
764};
765
766class MidiManagerWinrt::MidiInPortManager final
767 : public MidiPortManager<IMidiInPort,
768 MidiInPort,
769 IMidiInPortStatics,
770 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
771 public:
772 MidiInPortManager(MidiManagerWinrt* midi_manager)
773 : MidiPortManager(midi_manager), weak_factory_(this) {}
774
775 private:
776 // MidiPortManager overrides:
777 bool RegisterOnMessageReceived(IMidiInPort* handle,
778 EventRegistrationToken* p_token) override {
779 DCHECK(thread_checker_.CalledOnValidThread());
780
781 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
782 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
783
784 HRESULT hr = handle->add_MessageReceived(
785 WRL::Callback<
786 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
787 [weak_ptr, task_runner](IMidiInPort* handle,
788 IMidiMessageReceivedEventArgs* args) {
789 const base::TimeTicks now = base::TimeTicks::Now();
790
791 std::string dev_id = GetDeviceIdString(handle);
792
793 ScopedComPtr<IMidiMessage> message;
794 HRESULT hr = args->get_Message(message.Receive());
795 if (FAILED(hr)) {
796 VLOG(1) << "get_Message failed: " << PrintHr(hr);
797 return hr;
798 }
799
800 ScopedComPtr<IBuffer> buffer;
801 hr = message->get_RawData(buffer.Receive());
802 if (FAILED(hr)) {
803 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
804 return hr;
805 }
806
807 uint8_t* p_buffer_data = nullptr;
808 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
809 if (FAILED(hr))
810 return hr;
811
812 uint32_t data_length = 0;
813 hr = buffer->get_Length(&data_length);
814 if (FAILED(hr)) {
815 VLOG(1) << "get_Length failed: " << PrintHr(hr);
816 return hr;
817 }
818
819 std::vector<uint8_t> data(p_buffer_data,
820 p_buffer_data + data_length);
821
822 task_runner->PostTask(
823 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
824 weak_ptr, dev_id, data, now));
825
826 return S_OK;
827 })
828 .Get(),
829 p_token);
830 if (FAILED(hr)) {
831 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
832 return false;
833 }
834
835 return true;
836 }
837
838 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
839 if (!(port->handle &&
840 port->token_MessageReceived.value != kInvalidTokenValue))
841 return;
842
843 HRESULT hr =
844 port->handle->remove_MessageReceived(port->token_MessageReceived);
845 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
846 port->token_MessageReceived.value = kInvalidTokenValue;
847 }
848
849 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
850
851 void SetPortState(uint32_t port_index, MidiPortState state) final {
852 midi_manager_->SetInputPortState(port_index, state);
853 }
854
855 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
856 DCHECK(thread_checker_.CalledOnValidThread());
857
858 return weak_factory_.GetWeakPtr();
859 }
860
861 // Callback on receiving MIDI input message.
862 void OnMessageReceived(std::string dev_id,
863 std::vector<uint8_t> data,
864 base::TimeTicks time) {
865 DCHECK(thread_checker_.CalledOnValidThread());
866
867 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
868 CHECK(port);
869
870 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
871 }
872
873 // Last member to ensure destructed first.
874 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
875
876 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
877};
878
879class MidiManagerWinrt::MidiOutPortManager final
880 : public MidiPortManager<IMidiOutPort,
881 IMidiOutPort,
882 IMidiOutPortStatics,
883 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
884 public:
885 MidiOutPortManager(MidiManagerWinrt* midi_manager)
886 : MidiPortManager(midi_manager), weak_factory_(this) {}
887
888 private:
889 // MidiPortManager overrides:
890 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
891
892 void SetPortState(uint32_t port_index, MidiPortState state) final {
893 midi_manager_->SetOutputPortState(port_index, state);
894 }
895
896 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
897 DCHECK(thread_checker_.CalledOnValidThread());
898
899 return weak_factory_.GetWeakPtr();
900 }
901
902 // Last member to ensure destructed first.
903 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
904
905 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
906};
907
908MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
909
910MidiManagerWinrt::~MidiManagerWinrt() {
911 base::AutoLock auto_lock(lazy_init_member_lock_);
912
913 CHECK(!com_thread_checker_);
914 CHECK(!port_manager_in_);
915 CHECK(!port_manager_out_);
916 CHECK(!scheduler_);
917}
918
919void MidiManagerWinrt::StartInitialization() {
shaochuane58f9c72016-08-30 22:27:08 -0700920 com_thread_.init_com_with_mta(true);
921 com_thread_.Start();
922
923 com_thread_.task_runner()->PostTask(
924 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
925 base::Unretained(this)));
926}
927
928void MidiManagerWinrt::Finalize() {
929 com_thread_.task_runner()->PostTask(
930 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
931 base::Unretained(this)));
932
933 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
934 // will be ignored.
935 com_thread_.Stop();
936}
937
938void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
939 uint32_t port_index,
940 const std::vector<uint8_t>& data,
941 double timestamp) {
942 CHECK(scheduler_);
943
944 scheduler_->PostSendDataTask(
945 client, data.size(), timestamp,
946 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
947 port_index, data));
948}
949
950void MidiManagerWinrt::InitializeOnComThread() {
951 base::AutoLock auto_lock(lazy_init_member_lock_);
952
953 com_thread_checker_.reset(new base::ThreadChecker);
954
shaochuan9ff63b82016-09-01 01:58:44 -0700955 if (!g_combase_functions.Get().LoadFunctions()) {
956 VLOG(1) << "Failed loading functions from combase.dll: "
957 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
958 CompleteInitialization(Result::INITIALIZATION_ERROR);
959 return;
960 }
961
shaochuane58f9c72016-08-30 22:27:08 -0700962 port_manager_in_.reset(new MidiInPortManager(this));
963 port_manager_out_.reset(new MidiOutPortManager(this));
964
965 scheduler_.reset(new MidiScheduler(this));
966
967 if (!(port_manager_in_->StartWatcher() &&
968 port_manager_out_->StartWatcher())) {
969 port_manager_in_->StopWatcher();
970 port_manager_out_->StopWatcher();
971 CompleteInitialization(Result::INITIALIZATION_ERROR);
972 }
973}
974
975void MidiManagerWinrt::FinalizeOnComThread() {
976 base::AutoLock auto_lock(lazy_init_member_lock_);
977
978 DCHECK(com_thread_checker_->CalledOnValidThread());
979
980 scheduler_.reset();
981
shaochuan9ff63b82016-09-01 01:58:44 -0700982 if (port_manager_in_) {
983 port_manager_in_->StopWatcher();
984 port_manager_in_.reset();
985 }
986
987 if (port_manager_out_) {
988 port_manager_out_->StopWatcher();
989 port_manager_out_.reset();
990 }
shaochuane58f9c72016-08-30 22:27:08 -0700991
992 com_thread_checker_.reset();
993}
994
995void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
996 const std::vector<uint8_t>& data) {
997 DCHECK(com_thread_checker_->CalledOnValidThread());
998
999 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
1000 if (!(port && port->handle)) {
1001 VLOG(1) << "Port not available: " << port_index;
1002 return;
1003 }
1004
1005 auto buffer_factory =
1006 WrlStaticsFactory<IBufferFactory,
1007 RuntimeClass_Windows_Storage_Streams_Buffer>();
1008 if (!buffer_factory)
1009 return;
1010
1011 ScopedComPtr<IBuffer> buffer;
1012 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
1013 buffer.Receive());
1014 if (FAILED(hr)) {
1015 VLOG(1) << "Create failed: " << PrintHr(hr);
1016 return;
1017 }
1018
1019 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
1020 if (FAILED(hr)) {
1021 VLOG(1) << "put_Length failed: " << PrintHr(hr);
1022 return;
1023 }
1024
1025 uint8_t* p_buffer_data = nullptr;
1026 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
1027 if (FAILED(hr))
1028 return;
1029
1030 std::copy(data.begin(), data.end(), p_buffer_data);
1031
1032 hr = port->handle->SendBuffer(buffer.get());
1033 if (FAILED(hr)) {
1034 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
1035 return;
1036 }
1037}
1038
1039void MidiManagerWinrt::OnPortManagerReady() {
1040 DCHECK(com_thread_checker_->CalledOnValidThread());
1041 DCHECK(port_manager_ready_count_ < 2);
1042
1043 if (++port_manager_ready_count_ == 2)
1044 CompleteInitialization(Result::OK);
1045}
1046
shaochuane58f9c72016-08-30 22:27:08 -07001047} // namespace midi