blob: 5a0dedfe07438488891661c72bbfd39ddb8195d0 [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
qiankun.miao53f2d662016-09-02 17:44:08 -07007#pragma warning(disable : 4467)
shaochuan80f1fba2016-09-01 20:44:51 -07008
shaochuan4eff30e2016-09-09 01:24:14 -07009#include <initguid.h> // Required by <devpkey.h>
10
11#include <cfgmgr32.h>
shaochuane58f9c72016-08-30 22:27:08 -070012#include <comdef.h>
shaochuan4eff30e2016-09-09 01:24:14 -070013#include <devpkey.h>
robliao048b1572017-04-21 12:46:39 -070014#include <objbase.h>
shaochuane58f9c72016-08-30 22:27:08 -070015#include <robuffer.h>
16#include <windows.devices.enumeration.h>
17#include <windows.devices.midi.h>
18#include <wrl/event.h>
19
20#include <iomanip>
21#include <unordered_map>
22#include <unordered_set>
23
24#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070025#include "base/scoped_generic.h"
shaochuan17bc4a02016-09-06 01:42:12 -070026#include "base/strings/string_util.h"
shaochuane58f9c72016-08-30 22:27:08 -070027#include "base/strings/utf_string_conversions.h"
28#include "base/threading/thread_checker.h"
29#include "base/threading/thread_task_runner_handle.h"
30#include "base/timer/timer.h"
31#include "base/win/scoped_comptr.h"
shaochuane58f9c72016-08-30 22:27:08 -070032#include "media/midi/midi_scheduler.h"
33
shaochuane58f9c72016-08-30 22:27:08 -070034namespace midi {
35namespace {
36
37namespace WRL = Microsoft::WRL;
38
39using namespace ABI::Windows::Devices::Enumeration;
40using namespace ABI::Windows::Devices::Midi;
41using namespace ABI::Windows::Foundation;
42using namespace ABI::Windows::Storage::Streams;
43
44using base::win::ScopedComPtr;
toyoshimec2570a2016-10-21 02:15:27 -070045using mojom::PortState;
toyoshim2f3a48f2016-10-17 01:54:13 -070046using mojom::Result;
shaochuane58f9c72016-08-30 22:27:08 -070047
48// Helpers for printing HRESULTs.
49struct PrintHr {
50 PrintHr(HRESULT hr) : hr(hr) {}
51 HRESULT hr;
52};
53
54std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
55 std::ios_base::fmtflags ff = os.flags();
56 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
57 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
58 os.flags(ff);
59 return os;
60}
61
shaochuan9ff63b82016-09-01 01:58:44 -070062// Provides access to functions in combase.dll which may not be available on
63// Windows 7. Loads functions dynamically at runtime to prevent library
dalecurtis3f5ce942017-02-10 18:08:18 -080064// dependencies.
shaochuan9ff63b82016-09-01 01:58:44 -070065class CombaseFunctions {
66 public:
67 CombaseFunctions() = default;
68
69 ~CombaseFunctions() {
70 if (combase_dll_)
71 ::FreeLibrary(combase_dll_);
72 }
73
74 bool LoadFunctions() {
75 combase_dll_ = ::LoadLibrary(L"combase.dll");
76 if (!combase_dll_)
77 return false;
78
79 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
80 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
81 if (!get_factory_func_)
82 return false;
83
84 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
85 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
86 if (!create_string_func_)
87 return false;
88
89 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
90 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
91 if (!delete_string_func_)
92 return false;
93
94 get_string_raw_buffer_func_ =
95 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
96 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
97 if (!get_string_raw_buffer_func_)
98 return false;
99
100 return true;
101 }
102
103 HRESULT RoGetActivationFactory(HSTRING class_id,
104 const IID& iid,
105 void** out_factory) {
106 DCHECK(get_factory_func_);
107 return get_factory_func_(class_id, iid, out_factory);
108 }
109
110 HRESULT WindowsCreateString(const base::char16* src,
111 uint32_t len,
112 HSTRING* out_hstr) {
113 DCHECK(create_string_func_);
114 return create_string_func_(src, len, out_hstr);
115 }
116
117 HRESULT WindowsDeleteString(HSTRING hstr) {
118 DCHECK(delete_string_func_);
119 return delete_string_func_(hstr);
120 }
121
122 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
123 uint32_t* out_len) {
124 DCHECK(get_string_raw_buffer_func_);
125 return get_string_raw_buffer_func_(hstr, out_len);
126 }
127
128 private:
129 HMODULE combase_dll_ = nullptr;
130
131 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
132 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
133 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
134 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
135};
136
dalecurtis3f5ce942017-02-10 18:08:18 -0800137CombaseFunctions* GetCombaseFunctions() {
138 static CombaseFunctions* functions = new CombaseFunctions();
139 return functions;
140}
shaochuan9ff63b82016-09-01 01:58:44 -0700141
142// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
143// WindowsCreateString().
144class ScopedHStringTraits {
145 public:
146 static HSTRING InvalidValue() { return nullptr; }
147
148 static void Free(HSTRING hstr) {
dalecurtis3f5ce942017-02-10 18:08:18 -0800149 GetCombaseFunctions()->WindowsDeleteString(hstr);
shaochuan9ff63b82016-09-01 01:58:44 -0700150 }
151};
152
153class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
154 public:
155 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
156 HSTRING hstr;
dalecurtis3f5ce942017-02-10 18:08:18 -0800157 HRESULT hr = GetCombaseFunctions()->WindowsCreateString(
shaochuan9ff63b82016-09-01 01:58:44 -0700158 str, static_cast<uint32_t>(wcslen(str)), &hstr);
159 if (FAILED(hr))
160 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
161 else
162 reset(hstr);
163 }
164};
165
shaochuane58f9c72016-08-30 22:27:08 -0700166// Factory functions that activate and create WinRT components. The caller takes
167// ownership of the returning ComPtr.
168template <typename InterfaceType, base::char16 const* runtime_class_id>
169ScopedComPtr<InterfaceType> WrlStaticsFactory() {
170 ScopedComPtr<InterfaceType> com_ptr;
171
shaochuan9ff63b82016-09-01 01:58:44 -0700172 ScopedHString class_id_hstring(runtime_class_id);
173 if (!class_id_hstring.is_valid()) {
174 com_ptr = nullptr;
175 return com_ptr;
176 }
177
dalecurtis3f5ce942017-02-10 18:08:18 -0800178 HRESULT hr = GetCombaseFunctions()->RoGetActivationFactory(
robliao048b1572017-04-21 12:46:39 -0700179 class_id_hstring.get(), IID_PPV_ARGS(&com_ptr));
shaochuane58f9c72016-08-30 22:27:08 -0700180 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700181 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700182 com_ptr = nullptr;
183 }
184
185 return com_ptr;
186}
187
shaochuan80f1fba2016-09-01 20:44:51 -0700188std::string HStringToString(HSTRING hstr) {
shaochuane58f9c72016-08-30 22:27:08 -0700189 // Note: empty HSTRINGs are represent as nullptr, and instantiating
190 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700191 const base::char16* buffer =
dalecurtis3f5ce942017-02-10 18:08:18 -0800192 GetCombaseFunctions()->WindowsGetStringRawBuffer(hstr, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700193 if (buffer)
194 return base::WideToUTF8(buffer);
195 return std::string();
196}
197
198template <typename T>
199std::string GetIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700200 HSTRING result;
201 HRESULT hr = obj->get_Id(&result);
202 if (FAILED(hr)) {
203 VLOG(1) << "get_Id failed: " << PrintHr(hr);
204 return std::string();
205 }
206 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700207}
208
209template <typename T>
210std::string GetDeviceIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700211 HSTRING result;
212 HRESULT hr = obj->get_DeviceId(&result);
213 if (FAILED(hr)) {
214 VLOG(1) << "get_DeviceId failed: " << PrintHr(hr);
215 return std::string();
216 }
217 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700218}
219
220std::string GetNameString(IDeviceInformation* info) {
shaochuan80f1fba2016-09-01 20:44:51 -0700221 HSTRING result;
222 HRESULT hr = info->get_Name(&result);
223 if (FAILED(hr)) {
224 VLOG(1) << "get_Name failed: " << PrintHr(hr);
225 return std::string();
226 }
227 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700228}
229
230HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
231 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
232
233 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
234 if (FAILED(hr)) {
235 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
236 return hr;
237 }
238
239 // Lifetime of the pointing buffer is controlled by the buffer object.
240 hr = buffer_byte_access->Buffer(out);
241 if (FAILED(hr)) {
242 VLOG(1) << "Buffer failed: " << PrintHr(hr);
243 return hr;
244 }
245
246 return S_OK;
247}
248
shaochuan110262b2016-08-31 02:15:16 -0700249// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
250// instance.
251bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
252 auto midi_synthesizer_statics =
253 WrlStaticsFactory<IMidiSynthesizerStatics,
254 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
255 boolean result = FALSE;
256 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
257 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
258 return result != FALSE;
259}
260
shaochuan4eff30e2016-09-09 01:24:14 -0700261void GetDevPropString(DEVINST handle,
262 const DEVPROPKEY* devprop_key,
263 std::string* out) {
264 DEVPROPTYPE devprop_type;
265 unsigned long buffer_size = 0;
shaochuan17bc4a02016-09-06 01:42:12 -0700266
shaochuan4eff30e2016-09-09 01:24:14 -0700267 // Retrieve |buffer_size| and allocate buffer later for receiving data.
268 CONFIGRET cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type,
269 nullptr, &buffer_size, 0);
270 if (cr != CR_BUFFER_SMALL) {
271 // Here we print error codes in hex instead of using PrintHr() with
272 // HRESULT_FROM_WIN32() and CM_MapCrToWin32Err(), since only a minor set of
273 // CONFIGRET values are mapped to Win32 errors. Same for following VLOG()s.
274 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
275 return;
shaochuan17bc4a02016-09-06 01:42:12 -0700276 }
shaochuan4eff30e2016-09-09 01:24:14 -0700277 if (devprop_type != DEVPROP_TYPE_STRING) {
278 VLOG(1) << "CM_Get_DevNode_Property returns wrong data type, "
279 << "expected DEVPROP_TYPE_STRING";
280 return;
281 }
shaochuan17bc4a02016-09-06 01:42:12 -0700282
shaochuan4eff30e2016-09-09 01:24:14 -0700283 std::unique_ptr<uint8_t[]> buffer(new uint8_t[buffer_size]);
284
285 // Receive property data.
286 cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type, buffer.get(),
287 &buffer_size, 0);
288 if (cr != CR_SUCCESS)
289 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
290 else
291 *out = base::WideToUTF8(reinterpret_cast<base::char16*>(buffer.get()));
292}
shaochuan17bc4a02016-09-06 01:42:12 -0700293
294// Retrieves manufacturer (provider) and version information of underlying
shaochuan4eff30e2016-09-09 01:24:14 -0700295// device driver through PnP Configuration Manager, given device (interface) ID
296// provided by WinRT. |out_manufacturer| and |out_driver_version| won't be
297// modified if retrieval fails.
shaochuan17bc4a02016-09-06 01:42:12 -0700298//
299// Device instance ID is extracted from device (interface) ID provided by WinRT
300// APIs, for example from the following interface ID:
301// \\?\SWD#MMDEVAPI#MIDII_60F39FCA.P_0002#{504be32c-ccf6-4d2c-b73f-6f8b3747e22b}
302// we extract the device instance ID: SWD\MMDEVAPI\MIDII_60F39FCA.P_0002
shaochuan4eff30e2016-09-09 01:24:14 -0700303//
304// However the extracted device instance ID represent a "software device"
305// provided by Microsoft, which is an interface on top of the hardware for each
306// input/output port. Therefore we further locate its parent device, which is
307// the actual hardware device, for driver information.
shaochuan17bc4a02016-09-06 01:42:12 -0700308void GetDriverInfoFromDeviceId(const std::string& dev_id,
309 std::string* out_manufacturer,
310 std::string* out_driver_version) {
311 base::string16 dev_instance_id =
312 base::UTF8ToWide(dev_id.substr(4, dev_id.size() - 43));
313 base::ReplaceChars(dev_instance_id, L"#", L"\\", &dev_instance_id);
314
shaochuan4eff30e2016-09-09 01:24:14 -0700315 DEVINST dev_instance_handle;
316 CONFIGRET cr = CM_Locate_DevNode(&dev_instance_handle, &dev_instance_id[0],
317 CM_LOCATE_DEVNODE_NORMAL);
318 if (cr != CR_SUCCESS) {
319 VLOG(1) << "CM_Locate_DevNode failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700320 return;
321 }
322
shaochuan4eff30e2016-09-09 01:24:14 -0700323 DEVINST parent_handle;
324 cr = CM_Get_Parent(&parent_handle, dev_instance_handle, 0);
325 if (cr != CR_SUCCESS) {
326 VLOG(1) << "CM_Get_Parent failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700327 return;
328 }
329
shaochuan4eff30e2016-09-09 01:24:14 -0700330 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverProvider,
331 out_manufacturer);
332 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverVersion,
333 out_driver_version);
shaochuan17bc4a02016-09-06 01:42:12 -0700334}
335
shaochuane58f9c72016-08-30 22:27:08 -0700336// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
337const int64_t kInvalidTokenValue = 0;
338
339template <typename InterfaceType>
340struct MidiPort {
341 MidiPort() = default;
342
343 uint32_t index;
344 ScopedComPtr<InterfaceType> handle;
345 EventRegistrationToken token_MessageReceived;
346
347 private:
348 DISALLOW_COPY_AND_ASSIGN(MidiPort);
349};
350
351} // namespace
352
353template <typename InterfaceType,
354 typename RuntimeType,
355 typename StaticsInterfaceType,
356 base::char16 const* runtime_class_id>
357class MidiManagerWinrt::MidiPortManager {
358 public:
359 // MidiPortManager instances should be constructed on the COM thread.
360 MidiPortManager(MidiManagerWinrt* midi_manager)
361 : midi_manager_(midi_manager),
362 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
363
364 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
365
366 bool StartWatcher() {
367 DCHECK(thread_checker_.CalledOnValidThread());
368
369 HRESULT hr;
370
371 midi_port_statics_ =
372 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
373 if (!midi_port_statics_)
374 return false;
375
376 HSTRING device_selector = nullptr;
377 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
378 if (FAILED(hr)) {
379 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
380 return false;
381 }
382
383 auto dev_info_statics = WrlStaticsFactory<
384 IDeviceInformationStatics,
385 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
386 if (!dev_info_statics)
387 return false;
388
389 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
390 watcher_.Receive());
391 if (FAILED(hr)) {
392 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
393 return false;
394 }
395
396 // Register callbacks to WinRT that post state-modifying jobs back to COM
397 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
398 // posting jobs. Note that WinRT callback arguments should not be passed
399 // outside the callback since the pointers may be unavailable afterwards.
400 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
401 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
402
403 hr = watcher_->add_Added(
404 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
405 [weak_ptr, task_runner](IDeviceWatcher* watcher,
406 IDeviceInformation* info) {
shaochuanc2894522016-09-20 01:10:50 -0700407 if (!info) {
408 VLOG(1) << "DeviceWatcher.Added callback provides null "
409 "pointer, ignoring";
410 return S_OK;
411 }
412
shaochuan110262b2016-08-31 02:15:16 -0700413 // Disable Microsoft GS Wavetable Synth due to security reasons.
414 // http://crbug.com/499279
415 if (IsMicrosoftSynthesizer(info))
416 return S_OK;
417
shaochuane58f9c72016-08-30 22:27:08 -0700418 std::string dev_id = GetIdString(info),
419 dev_name = GetNameString(info);
420
421 task_runner->PostTask(
422 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
423 dev_id, dev_name));
424
425 return S_OK;
426 })
427 .Get(),
428 &token_Added_);
429 if (FAILED(hr)) {
430 VLOG(1) << "add_Added failed: " << PrintHr(hr);
431 return false;
432 }
433
434 hr = watcher_->add_EnumerationCompleted(
435 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
436 [weak_ptr, task_runner](IDeviceWatcher* watcher,
437 IInspectable* insp) {
438 task_runner->PostTask(
439 FROM_HERE,
440 base::Bind(&MidiPortManager::OnEnumerationCompleted,
441 weak_ptr));
442
443 return S_OK;
444 })
445 .Get(),
446 &token_EnumerationCompleted_);
447 if (FAILED(hr)) {
448 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
449 return false;
450 }
451
452 hr = watcher_->add_Removed(
453 WRL::Callback<
454 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
455 [weak_ptr, task_runner](IDeviceWatcher* watcher,
456 IDeviceInformationUpdate* update) {
shaochuanc2894522016-09-20 01:10:50 -0700457 if (!update) {
458 VLOG(1) << "DeviceWatcher.Removed callback provides null "
459 "pointer, ignoring";
460 return S_OK;
461 }
462
shaochuane58f9c72016-08-30 22:27:08 -0700463 std::string dev_id = GetIdString(update);
464
465 task_runner->PostTask(
466 FROM_HERE,
467 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
468
469 return S_OK;
470 })
471 .Get(),
472 &token_Removed_);
473 if (FAILED(hr)) {
474 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
475 return false;
476 }
477
478 hr = watcher_->add_Stopped(
479 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
480 [](IDeviceWatcher* watcher, IInspectable* insp) {
481 // Placeholder, does nothing for now.
482 return S_OK;
483 })
484 .Get(),
485 &token_Stopped_);
486 if (FAILED(hr)) {
487 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
488 return false;
489 }
490
491 hr = watcher_->add_Updated(
492 WRL::Callback<
493 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
494 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
495 // TODO(shaochuan): Check for fields to be updated here.
496 return S_OK;
497 })
498 .Get(),
499 &token_Updated_);
500 if (FAILED(hr)) {
501 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
502 return false;
503 }
504
505 hr = watcher_->Start();
506 if (FAILED(hr)) {
507 VLOG(1) << "Start failed: " << PrintHr(hr);
508 return false;
509 }
510
511 is_initialized_ = true;
512 return true;
513 }
514
515 void StopWatcher() {
516 DCHECK(thread_checker_.CalledOnValidThread());
517
518 HRESULT hr;
519
520 for (const auto& entry : ports_)
521 RemovePortEventHandlers(entry.second.get());
522
523 if (token_Added_.value != kInvalidTokenValue) {
524 hr = watcher_->remove_Added(token_Added_);
525 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
526 token_Added_.value = kInvalidTokenValue;
527 }
528 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
529 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
530 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
531 << PrintHr(hr);
532 token_EnumerationCompleted_.value = kInvalidTokenValue;
533 }
534 if (token_Removed_.value != kInvalidTokenValue) {
535 hr = watcher_->remove_Removed(token_Removed_);
536 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
537 token_Removed_.value = kInvalidTokenValue;
538 }
539 if (token_Stopped_.value != kInvalidTokenValue) {
540 hr = watcher_->remove_Stopped(token_Stopped_);
541 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
542 token_Stopped_.value = kInvalidTokenValue;
543 }
544 if (token_Updated_.value != kInvalidTokenValue) {
545 hr = watcher_->remove_Updated(token_Updated_);
546 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
547 token_Updated_.value = kInvalidTokenValue;
548 }
549
550 if (is_initialized_) {
551 hr = watcher_->Stop();
552 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
553 is_initialized_ = false;
554 }
555 }
556
557 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
558 DCHECK(thread_checker_.CalledOnValidThread());
559 CHECK(is_initialized_);
560
561 auto it = ports_.find(dev_id);
562 if (it == ports_.end())
563 return nullptr;
564 return it->second.get();
565 }
566
567 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
568 DCHECK(thread_checker_.CalledOnValidThread());
569 CHECK(is_initialized_);
570
571 return GetPortByDeviceId(port_ids_[port_index]);
572 }
573
574 protected:
575 // Points to the MidiManagerWinrt instance, which is expected to outlive the
576 // MidiPortManager instance.
577 MidiManagerWinrt* midi_manager_;
578
579 // Task runner of the COM thread.
580 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
581
582 // Ensures all methods are called on the COM thread.
583 base::ThreadChecker thread_checker_;
584
585 private:
586 // DeviceWatcher callbacks:
587 void OnAdded(std::string dev_id, std::string dev_name) {
588 DCHECK(thread_checker_.CalledOnValidThread());
589 CHECK(is_initialized_);
590
shaochuane58f9c72016-08-30 22:27:08 -0700591 port_names_[dev_id] = dev_name;
592
shaochuan9ff63b82016-09-01 01:58:44 -0700593 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
594 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700595 return;
shaochuane58f9c72016-08-30 22:27:08 -0700596
597 IAsyncOperation<RuntimeType*>* async_op;
598
shaochuan9ff63b82016-09-01 01:58:44 -0700599 HRESULT hr =
600 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700601 if (FAILED(hr)) {
602 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
603 return;
604 }
605
606 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
607 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
608
609 hr = async_op->put_Completed(
610 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
611 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
612 AsyncStatus status) {
shaochuane58f9c72016-08-30 22:27:08 -0700613 // A reference to |async_op| is kept in |async_ops_|, safe to pass
614 // outside.
615 task_runner->PostTask(
616 FROM_HERE,
617 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
shaochuanc2894522016-09-20 01:10:50 -0700618 weak_ptr, async_op));
shaochuane58f9c72016-08-30 22:27:08 -0700619
620 return S_OK;
621 })
622 .Get());
623 if (FAILED(hr)) {
624 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
625 return;
626 }
627
628 // Keep a reference to incompleted |async_op| for releasing later.
629 async_ops_.insert(async_op);
630 }
631
632 void OnEnumerationCompleted() {
633 DCHECK(thread_checker_.CalledOnValidThread());
634 CHECK(is_initialized_);
635
636 if (async_ops_.empty())
637 midi_manager_->OnPortManagerReady();
638 else
639 enumeration_completed_not_ready_ = true;
640 }
641
642 void OnRemoved(std::string dev_id) {
643 DCHECK(thread_checker_.CalledOnValidThread());
644 CHECK(is_initialized_);
645
shaochuan110262b2016-08-31 02:15:16 -0700646 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
647 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700648 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
649 if (!port) {
650 VLOG(1) << "Removing non-existent port " << dev_id;
651 return;
652 }
653
toyoshimec2570a2016-10-21 02:15:27 -0700654 SetPortState(port->index, PortState::DISCONNECTED);
shaochuane58f9c72016-08-30 22:27:08 -0700655
656 RemovePortEventHandlers(port);
657 port->handle = nullptr;
658 }
659
shaochuanc2894522016-09-20 01:10:50 -0700660 void OnCompletedGetPortFromIdAsync(IAsyncOperation<RuntimeType*>* async_op) {
shaochuane58f9c72016-08-30 22:27:08 -0700661 DCHECK(thread_checker_.CalledOnValidThread());
662 CHECK(is_initialized_);
663
shaochuanc2894522016-09-20 01:10:50 -0700664 InterfaceType* handle = nullptr;
665 HRESULT hr = async_op->GetResults(&handle);
666 if (FAILED(hr)) {
667 VLOG(1) << "GetResults failed: " << PrintHr(hr);
668 return;
669 }
670
671 // Manually release COM interface to completed |async_op|.
672 auto it = async_ops_.find(async_op);
673 CHECK(it != async_ops_.end());
674 (*it)->Release();
675 async_ops_.erase(it);
676
677 if (!handle) {
678 VLOG(1) << "Midi{In,Out}Port.FromIdAsync callback provides null pointer, "
679 "ignoring";
680 return;
681 }
682
shaochuane58f9c72016-08-30 22:27:08 -0700683 EventRegistrationToken token = {kInvalidTokenValue};
684 if (!RegisterOnMessageReceived(handle, &token))
685 return;
686
687 std::string dev_id = GetDeviceIdString(handle);
688
689 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
690
691 if (port == nullptr) {
shaochuan17bc4a02016-09-06 01:42:12 -0700692 std::string manufacturer = "Unknown", driver_version = "Unknown";
693 GetDriverInfoFromDeviceId(dev_id, &manufacturer, &driver_version);
694
695 AddPort(MidiPortInfo(dev_id, manufacturer, port_names_[dev_id],
toyoshimec2570a2016-10-21 02:15:27 -0700696 driver_version, PortState::OPENED));
shaochuane58f9c72016-08-30 22:27:08 -0700697
698 port = new MidiPort<InterfaceType>;
699 port->index = static_cast<uint32_t>(port_ids_.size());
700
701 ports_[dev_id].reset(port);
702 port_ids_.push_back(dev_id);
703 } else {
toyoshimec2570a2016-10-21 02:15:27 -0700704 SetPortState(port->index, PortState::CONNECTED);
shaochuane58f9c72016-08-30 22:27:08 -0700705 }
706
707 port->handle = handle;
708 port->token_MessageReceived = token;
709
shaochuane58f9c72016-08-30 22:27:08 -0700710 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
711 midi_manager_->OnPortManagerReady();
712 enumeration_completed_not_ready_ = false;
713 }
714 }
715
716 // Overrided by MidiInPortManager to listen to input ports.
717 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
718 EventRegistrationToken* p_token) {
719 return true;
720 }
721
722 // Overrided by MidiInPortManager to remove MessageReceived event handler.
723 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
724
725 // Calls midi_manager_->Add{Input,Output}Port.
726 virtual void AddPort(MidiPortInfo info) = 0;
727
728 // Calls midi_manager_->Set{Input,Output}PortState.
toyoshimec2570a2016-10-21 02:15:27 -0700729 virtual void SetPortState(uint32_t port_index, PortState state) = 0;
shaochuane58f9c72016-08-30 22:27:08 -0700730
731 // WeakPtrFactory has to be declared in derived class, use this method to
732 // retrieve upcasted WeakPtr for posting tasks.
733 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
734
735 // Midi{In,Out}PortStatics instance.
736 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
737
738 // DeviceWatcher instance and event registration tokens for unsubscribing
739 // events in destructor.
740 ScopedComPtr<IDeviceWatcher> watcher_;
741 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
742 token_EnumerationCompleted_ = {kInvalidTokenValue},
743 token_Removed_ = {kInvalidTokenValue},
744 token_Stopped_ = {kInvalidTokenValue},
745 token_Updated_ = {kInvalidTokenValue};
746
747 // All manipulations to these fields should be done on COM thread.
748 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
749 ports_;
750 std::vector<std::string> port_ids_;
751 std::unordered_map<std::string, std::string> port_names_;
752
753 // Keeps AsyncOperation references before the operation completes. Note that
754 // raw pointers are used here and the COM interfaces should be released
755 // manually.
756 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
757
758 // Set when device enumeration is completed but OnPortManagerReady() is not
759 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
760 // In such cases, OnPortManagerReady() will be called in
761 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
762 bool enumeration_completed_not_ready_ = false;
763
764 // Set if the instance is initialized without error. Should be checked in all
765 // methods on COM thread except StartWatcher().
766 bool is_initialized_ = false;
767};
768
769class MidiManagerWinrt::MidiInPortManager final
770 : public MidiPortManager<IMidiInPort,
771 MidiInPort,
772 IMidiInPortStatics,
773 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
774 public:
775 MidiInPortManager(MidiManagerWinrt* midi_manager)
776 : MidiPortManager(midi_manager), weak_factory_(this) {}
777
778 private:
779 // MidiPortManager overrides:
780 bool RegisterOnMessageReceived(IMidiInPort* handle,
781 EventRegistrationToken* p_token) override {
782 DCHECK(thread_checker_.CalledOnValidThread());
783
784 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
785 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
786
787 HRESULT hr = handle->add_MessageReceived(
788 WRL::Callback<
789 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
790 [weak_ptr, task_runner](IMidiInPort* handle,
791 IMidiMessageReceivedEventArgs* args) {
792 const base::TimeTicks now = base::TimeTicks::Now();
793
794 std::string dev_id = GetDeviceIdString(handle);
795
796 ScopedComPtr<IMidiMessage> message;
797 HRESULT hr = args->get_Message(message.Receive());
798 if (FAILED(hr)) {
799 VLOG(1) << "get_Message failed: " << PrintHr(hr);
800 return hr;
801 }
802
803 ScopedComPtr<IBuffer> buffer;
804 hr = message->get_RawData(buffer.Receive());
805 if (FAILED(hr)) {
806 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
807 return hr;
808 }
809
810 uint8_t* p_buffer_data = nullptr;
robliao3566d1a2017-04-18 17:28:09 -0700811 hr = GetPointerToBufferData(buffer.Get(), &p_buffer_data);
shaochuane58f9c72016-08-30 22:27:08 -0700812 if (FAILED(hr))
813 return hr;
814
815 uint32_t data_length = 0;
816 hr = buffer->get_Length(&data_length);
817 if (FAILED(hr)) {
818 VLOG(1) << "get_Length failed: " << PrintHr(hr);
819 return hr;
820 }
821
822 std::vector<uint8_t> data(p_buffer_data,
823 p_buffer_data + data_length);
824
825 task_runner->PostTask(
826 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
827 weak_ptr, dev_id, data, now));
828
829 return S_OK;
830 })
831 .Get(),
832 p_token);
833 if (FAILED(hr)) {
834 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
835 return false;
836 }
837
838 return true;
839 }
840
841 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
842 if (!(port->handle &&
843 port->token_MessageReceived.value != kInvalidTokenValue))
844 return;
845
846 HRESULT hr =
847 port->handle->remove_MessageReceived(port->token_MessageReceived);
848 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
849 port->token_MessageReceived.value = kInvalidTokenValue;
850 }
851
852 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
853
toyoshimec2570a2016-10-21 02:15:27 -0700854 void SetPortState(uint32_t port_index, PortState state) final {
shaochuane58f9c72016-08-30 22:27:08 -0700855 midi_manager_->SetInputPortState(port_index, state);
856 }
857
858 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
859 DCHECK(thread_checker_.CalledOnValidThread());
860
861 return weak_factory_.GetWeakPtr();
862 }
863
864 // Callback on receiving MIDI input message.
865 void OnMessageReceived(std::string dev_id,
866 std::vector<uint8_t> data,
867 base::TimeTicks time) {
868 DCHECK(thread_checker_.CalledOnValidThread());
869
870 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
871 CHECK(port);
872
873 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
874 }
875
876 // Last member to ensure destructed first.
877 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
878
879 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
880};
881
882class MidiManagerWinrt::MidiOutPortManager final
883 : public MidiPortManager<IMidiOutPort,
884 IMidiOutPort,
885 IMidiOutPortStatics,
886 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
887 public:
888 MidiOutPortManager(MidiManagerWinrt* midi_manager)
889 : MidiPortManager(midi_manager), weak_factory_(this) {}
890
891 private:
892 // MidiPortManager overrides:
893 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
894
toyoshimec2570a2016-10-21 02:15:27 -0700895 void SetPortState(uint32_t port_index, PortState state) final {
shaochuane58f9c72016-08-30 22:27:08 -0700896 midi_manager_->SetOutputPortState(port_index, state);
897 }
898
899 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
900 DCHECK(thread_checker_.CalledOnValidThread());
901
902 return weak_factory_.GetWeakPtr();
903 }
904
905 // Last member to ensure destructed first.
906 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
907
908 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
909};
910
toyoshimf4d61522017-02-10 02:03:32 -0800911MidiManagerWinrt::MidiManagerWinrt(MidiService* service)
912 : MidiManager(service), com_thread_("Windows MIDI COM Thread") {}
shaochuane58f9c72016-08-30 22:27:08 -0700913
914MidiManagerWinrt::~MidiManagerWinrt() {
915 base::AutoLock auto_lock(lazy_init_member_lock_);
916
917 CHECK(!com_thread_checker_);
918 CHECK(!port_manager_in_);
919 CHECK(!port_manager_out_);
920 CHECK(!scheduler_);
921}
922
923void MidiManagerWinrt::StartInitialization() {
shaochuane58f9c72016-08-30 22:27:08 -0700924 com_thread_.init_com_with_mta(true);
925 com_thread_.Start();
926
927 com_thread_.task_runner()->PostTask(
928 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
929 base::Unretained(this)));
930}
931
932void MidiManagerWinrt::Finalize() {
933 com_thread_.task_runner()->PostTask(
934 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
935 base::Unretained(this)));
936
937 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
938 // will be ignored.
939 com_thread_.Stop();
940}
941
942void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
943 uint32_t port_index,
944 const std::vector<uint8_t>& data,
945 double timestamp) {
946 CHECK(scheduler_);
947
948 scheduler_->PostSendDataTask(
949 client, data.size(), timestamp,
950 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
951 port_index, data));
952}
953
954void MidiManagerWinrt::InitializeOnComThread() {
955 base::AutoLock auto_lock(lazy_init_member_lock_);
956
957 com_thread_checker_.reset(new base::ThreadChecker);
958
dalecurtis3f5ce942017-02-10 18:08:18 -0800959 if (!GetCombaseFunctions()->LoadFunctions()) {
shaochuan9ff63b82016-09-01 01:58:44 -0700960 VLOG(1) << "Failed loading functions from combase.dll: "
961 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
962 CompleteInitialization(Result::INITIALIZATION_ERROR);
963 return;
964 }
965
shaochuane58f9c72016-08-30 22:27:08 -0700966 port_manager_in_.reset(new MidiInPortManager(this));
967 port_manager_out_.reset(new MidiOutPortManager(this));
968
969 scheduler_.reset(new MidiScheduler(this));
970
971 if (!(port_manager_in_->StartWatcher() &&
972 port_manager_out_->StartWatcher())) {
973 port_manager_in_->StopWatcher();
974 port_manager_out_->StopWatcher();
975 CompleteInitialization(Result::INITIALIZATION_ERROR);
976 }
977}
978
979void MidiManagerWinrt::FinalizeOnComThread() {
980 base::AutoLock auto_lock(lazy_init_member_lock_);
981
982 DCHECK(com_thread_checker_->CalledOnValidThread());
983
984 scheduler_.reset();
985
shaochuan9ff63b82016-09-01 01:58:44 -0700986 if (port_manager_in_) {
987 port_manager_in_->StopWatcher();
988 port_manager_in_.reset();
989 }
990
991 if (port_manager_out_) {
992 port_manager_out_->StopWatcher();
993 port_manager_out_.reset();
994 }
shaochuane58f9c72016-08-30 22:27:08 -0700995
996 com_thread_checker_.reset();
997}
998
999void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
1000 const std::vector<uint8_t>& data) {
1001 DCHECK(com_thread_checker_->CalledOnValidThread());
1002
1003 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
1004 if (!(port && port->handle)) {
1005 VLOG(1) << "Port not available: " << port_index;
1006 return;
1007 }
1008
1009 auto buffer_factory =
1010 WrlStaticsFactory<IBufferFactory,
1011 RuntimeClass_Windows_Storage_Streams_Buffer>();
1012 if (!buffer_factory)
1013 return;
1014
1015 ScopedComPtr<IBuffer> buffer;
1016 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
1017 buffer.Receive());
1018 if (FAILED(hr)) {
1019 VLOG(1) << "Create failed: " << PrintHr(hr);
1020 return;
1021 }
1022
1023 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
1024 if (FAILED(hr)) {
1025 VLOG(1) << "put_Length failed: " << PrintHr(hr);
1026 return;
1027 }
1028
1029 uint8_t* p_buffer_data = nullptr;
robliao3566d1a2017-04-18 17:28:09 -07001030 hr = GetPointerToBufferData(buffer.Get(), &p_buffer_data);
shaochuane58f9c72016-08-30 22:27:08 -07001031 if (FAILED(hr))
1032 return;
1033
1034 std::copy(data.begin(), data.end(), p_buffer_data);
1035
robliao3566d1a2017-04-18 17:28:09 -07001036 hr = port->handle->SendBuffer(buffer.Get());
shaochuane58f9c72016-08-30 22:27:08 -07001037 if (FAILED(hr)) {
1038 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
1039 return;
1040 }
1041}
1042
1043void MidiManagerWinrt::OnPortManagerReady() {
1044 DCHECK(com_thread_checker_->CalledOnValidThread());
1045 DCHECK(port_manager_ready_count_ < 2);
1046
1047 if (++port_manager_ready_count_ == 2)
1048 CompleteInitialization(Result::OK);
1049}
1050
shaochuane58f9c72016-08-30 22:27:08 -07001051} // namespace midi