blob: 9a510c7b60eb9e4b6685f170929154610c5d0ba7 [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
qiankun.miao53f2d662016-09-02 17:44:08 -07007#pragma warning(disable : 4467)
shaochuan80f1fba2016-09-01 20:44:51 -07008
shaochuan4eff30e2016-09-09 01:24:14 -07009#include <initguid.h> // Required by <devpkey.h>
10
11#include <cfgmgr32.h>
shaochuane58f9c72016-08-30 22:27:08 -070012#include <comdef.h>
shaochuan4eff30e2016-09-09 01:24:14 -070013#include <devpkey.h>
shaochuane58f9c72016-08-30 22:27:08 -070014#include <robuffer.h>
15#include <windows.devices.enumeration.h>
16#include <windows.devices.midi.h>
17#include <wrl/event.h>
18
19#include <iomanip>
20#include <unordered_map>
21#include <unordered_set>
22
23#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070024#include "base/lazy_instance.h"
25#include "base/scoped_generic.h"
shaochuan17bc4a02016-09-06 01:42:12 -070026#include "base/strings/string_util.h"
shaochuane58f9c72016-08-30 22:27:08 -070027#include "base/strings/utf_string_conversions.h"
28#include "base/threading/thread_checker.h"
29#include "base/threading/thread_task_runner_handle.h"
30#include "base/timer/timer.h"
31#include "base/win/scoped_comptr.h"
shaochuane58f9c72016-08-30 22:27:08 -070032#include "media/midi/midi_scheduler.h"
33
34namespace media {
35namespace midi {
36namespace {
37
38namespace WRL = Microsoft::WRL;
39
40using namespace ABI::Windows::Devices::Enumeration;
41using namespace ABI::Windows::Devices::Midi;
42using namespace ABI::Windows::Foundation;
43using namespace ABI::Windows::Storage::Streams;
44
45using base::win::ScopedComPtr;
46
47// Helpers for printing HRESULTs.
48struct PrintHr {
49 PrintHr(HRESULT hr) : hr(hr) {}
50 HRESULT hr;
51};
52
53std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
54 std::ios_base::fmtflags ff = os.flags();
55 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
56 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
57 os.flags(ff);
58 return os;
59}
60
shaochuan9ff63b82016-09-01 01:58:44 -070061// Provides access to functions in combase.dll which may not be available on
62// Windows 7. Loads functions dynamically at runtime to prevent library
63// dependencies. Use this class through the global LazyInstance
64// |g_combase_functions|.
65class CombaseFunctions {
66 public:
67 CombaseFunctions() = default;
68
69 ~CombaseFunctions() {
70 if (combase_dll_)
71 ::FreeLibrary(combase_dll_);
72 }
73
74 bool LoadFunctions() {
75 combase_dll_ = ::LoadLibrary(L"combase.dll");
76 if (!combase_dll_)
77 return false;
78
79 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
80 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
81 if (!get_factory_func_)
82 return false;
83
84 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
85 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
86 if (!create_string_func_)
87 return false;
88
89 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
90 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
91 if (!delete_string_func_)
92 return false;
93
94 get_string_raw_buffer_func_ =
95 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
96 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
97 if (!get_string_raw_buffer_func_)
98 return false;
99
100 return true;
101 }
102
103 HRESULT RoGetActivationFactory(HSTRING class_id,
104 const IID& iid,
105 void** out_factory) {
106 DCHECK(get_factory_func_);
107 return get_factory_func_(class_id, iid, out_factory);
108 }
109
110 HRESULT WindowsCreateString(const base::char16* src,
111 uint32_t len,
112 HSTRING* out_hstr) {
113 DCHECK(create_string_func_);
114 return create_string_func_(src, len, out_hstr);
115 }
116
117 HRESULT WindowsDeleteString(HSTRING hstr) {
118 DCHECK(delete_string_func_);
119 return delete_string_func_(hstr);
120 }
121
122 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
123 uint32_t* out_len) {
124 DCHECK(get_string_raw_buffer_func_);
125 return get_string_raw_buffer_func_(hstr, out_len);
126 }
127
128 private:
129 HMODULE combase_dll_ = nullptr;
130
131 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
132 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
133 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
134 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
135};
136
137base::LazyInstance<CombaseFunctions> g_combase_functions =
138 LAZY_INSTANCE_INITIALIZER;
139
140// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
141// WindowsCreateString().
142class ScopedHStringTraits {
143 public:
144 static HSTRING InvalidValue() { return nullptr; }
145
146 static void Free(HSTRING hstr) {
147 g_combase_functions.Get().WindowsDeleteString(hstr);
148 }
149};
150
151class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
152 public:
153 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
154 HSTRING hstr;
155 HRESULT hr = g_combase_functions.Get().WindowsCreateString(
156 str, static_cast<uint32_t>(wcslen(str)), &hstr);
157 if (FAILED(hr))
158 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
159 else
160 reset(hstr);
161 }
162};
163
shaochuane58f9c72016-08-30 22:27:08 -0700164// Factory functions that activate and create WinRT components. The caller takes
165// ownership of the returning ComPtr.
166template <typename InterfaceType, base::char16 const* runtime_class_id>
167ScopedComPtr<InterfaceType> WrlStaticsFactory() {
168 ScopedComPtr<InterfaceType> com_ptr;
169
shaochuan9ff63b82016-09-01 01:58:44 -0700170 ScopedHString class_id_hstring(runtime_class_id);
171 if (!class_id_hstring.is_valid()) {
172 com_ptr = nullptr;
173 return com_ptr;
174 }
175
176 HRESULT hr = g_combase_functions.Get().RoGetActivationFactory(
177 class_id_hstring.get(), __uuidof(InterfaceType), com_ptr.ReceiveVoid());
shaochuane58f9c72016-08-30 22:27:08 -0700178 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700179 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700180 com_ptr = nullptr;
181 }
182
183 return com_ptr;
184}
185
shaochuan80f1fba2016-09-01 20:44:51 -0700186std::string HStringToString(HSTRING hstr) {
shaochuane58f9c72016-08-30 22:27:08 -0700187 // Note: empty HSTRINGs are represent as nullptr, and instantiating
188 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700189 const base::char16* buffer =
shaochuan80f1fba2016-09-01 20:44:51 -0700190 g_combase_functions.Get().WindowsGetStringRawBuffer(hstr, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700191 if (buffer)
192 return base::WideToUTF8(buffer);
193 return std::string();
194}
195
196template <typename T>
197std::string GetIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700198 HSTRING result;
199 HRESULT hr = obj->get_Id(&result);
200 if (FAILED(hr)) {
201 VLOG(1) << "get_Id failed: " << PrintHr(hr);
202 return std::string();
203 }
204 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700205}
206
207template <typename T>
208std::string GetDeviceIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700209 HSTRING result;
210 HRESULT hr = obj->get_DeviceId(&result);
211 if (FAILED(hr)) {
212 VLOG(1) << "get_DeviceId failed: " << PrintHr(hr);
213 return std::string();
214 }
215 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700216}
217
218std::string GetNameString(IDeviceInformation* info) {
shaochuan80f1fba2016-09-01 20:44:51 -0700219 HSTRING result;
220 HRESULT hr = info->get_Name(&result);
221 if (FAILED(hr)) {
222 VLOG(1) << "get_Name failed: " << PrintHr(hr);
223 return std::string();
224 }
225 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700226}
227
228HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
229 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
230
231 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
232 if (FAILED(hr)) {
233 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
234 return hr;
235 }
236
237 // Lifetime of the pointing buffer is controlled by the buffer object.
238 hr = buffer_byte_access->Buffer(out);
239 if (FAILED(hr)) {
240 VLOG(1) << "Buffer failed: " << PrintHr(hr);
241 return hr;
242 }
243
244 return S_OK;
245}
246
shaochuan110262b2016-08-31 02:15:16 -0700247// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
248// instance.
249bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
250 auto midi_synthesizer_statics =
251 WrlStaticsFactory<IMidiSynthesizerStatics,
252 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
253 boolean result = FALSE;
254 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
255 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
256 return result != FALSE;
257}
258
shaochuan4eff30e2016-09-09 01:24:14 -0700259void GetDevPropString(DEVINST handle,
260 const DEVPROPKEY* devprop_key,
261 std::string* out) {
262 DEVPROPTYPE devprop_type;
263 unsigned long buffer_size = 0;
shaochuan17bc4a02016-09-06 01:42:12 -0700264
shaochuan4eff30e2016-09-09 01:24:14 -0700265 // Retrieve |buffer_size| and allocate buffer later for receiving data.
266 CONFIGRET cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type,
267 nullptr, &buffer_size, 0);
268 if (cr != CR_BUFFER_SMALL) {
269 // Here we print error codes in hex instead of using PrintHr() with
270 // HRESULT_FROM_WIN32() and CM_MapCrToWin32Err(), since only a minor set of
271 // CONFIGRET values are mapped to Win32 errors. Same for following VLOG()s.
272 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
273 return;
shaochuan17bc4a02016-09-06 01:42:12 -0700274 }
shaochuan4eff30e2016-09-09 01:24:14 -0700275 if (devprop_type != DEVPROP_TYPE_STRING) {
276 VLOG(1) << "CM_Get_DevNode_Property returns wrong data type, "
277 << "expected DEVPROP_TYPE_STRING";
278 return;
279 }
shaochuan17bc4a02016-09-06 01:42:12 -0700280
shaochuan4eff30e2016-09-09 01:24:14 -0700281 std::unique_ptr<uint8_t[]> buffer(new uint8_t[buffer_size]);
282
283 // Receive property data.
284 cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type, buffer.get(),
285 &buffer_size, 0);
286 if (cr != CR_SUCCESS)
287 VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr;
288 else
289 *out = base::WideToUTF8(reinterpret_cast<base::char16*>(buffer.get()));
290}
shaochuan17bc4a02016-09-06 01:42:12 -0700291
292// Retrieves manufacturer (provider) and version information of underlying
shaochuan4eff30e2016-09-09 01:24:14 -0700293// device driver through PnP Configuration Manager, given device (interface) ID
294// provided by WinRT. |out_manufacturer| and |out_driver_version| won't be
295// modified if retrieval fails.
shaochuan17bc4a02016-09-06 01:42:12 -0700296//
297// Device instance ID is extracted from device (interface) ID provided by WinRT
298// APIs, for example from the following interface ID:
299// \\?\SWD#MMDEVAPI#MIDII_60F39FCA.P_0002#{504be32c-ccf6-4d2c-b73f-6f8b3747e22b}
300// we extract the device instance ID: SWD\MMDEVAPI\MIDII_60F39FCA.P_0002
shaochuan4eff30e2016-09-09 01:24:14 -0700301//
302// However the extracted device instance ID represent a "software device"
303// provided by Microsoft, which is an interface on top of the hardware for each
304// input/output port. Therefore we further locate its parent device, which is
305// the actual hardware device, for driver information.
shaochuan17bc4a02016-09-06 01:42:12 -0700306void GetDriverInfoFromDeviceId(const std::string& dev_id,
307 std::string* out_manufacturer,
308 std::string* out_driver_version) {
309 base::string16 dev_instance_id =
310 base::UTF8ToWide(dev_id.substr(4, dev_id.size() - 43));
311 base::ReplaceChars(dev_instance_id, L"#", L"\\", &dev_instance_id);
312
shaochuan4eff30e2016-09-09 01:24:14 -0700313 DEVINST dev_instance_handle;
314 CONFIGRET cr = CM_Locate_DevNode(&dev_instance_handle, &dev_instance_id[0],
315 CM_LOCATE_DEVNODE_NORMAL);
316 if (cr != CR_SUCCESS) {
317 VLOG(1) << "CM_Locate_DevNode failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700318 return;
319 }
320
shaochuan4eff30e2016-09-09 01:24:14 -0700321 DEVINST parent_handle;
322 cr = CM_Get_Parent(&parent_handle, dev_instance_handle, 0);
323 if (cr != CR_SUCCESS) {
324 VLOG(1) << "CM_Get_Parent failed: CONFIGRET 0x" << std::hex << cr;
shaochuan17bc4a02016-09-06 01:42:12 -0700325 return;
326 }
327
shaochuan4eff30e2016-09-09 01:24:14 -0700328 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverProvider,
329 out_manufacturer);
330 GetDevPropString(parent_handle, &DEVPKEY_Device_DriverVersion,
331 out_driver_version);
shaochuan17bc4a02016-09-06 01:42:12 -0700332}
333
shaochuane58f9c72016-08-30 22:27:08 -0700334// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
335const int64_t kInvalidTokenValue = 0;
336
337template <typename InterfaceType>
338struct MidiPort {
339 MidiPort() = default;
340
341 uint32_t index;
342 ScopedComPtr<InterfaceType> handle;
343 EventRegistrationToken token_MessageReceived;
344
345 private:
346 DISALLOW_COPY_AND_ASSIGN(MidiPort);
347};
348
349} // namespace
350
351template <typename InterfaceType,
352 typename RuntimeType,
353 typename StaticsInterfaceType,
354 base::char16 const* runtime_class_id>
355class MidiManagerWinrt::MidiPortManager {
356 public:
357 // MidiPortManager instances should be constructed on the COM thread.
358 MidiPortManager(MidiManagerWinrt* midi_manager)
359 : midi_manager_(midi_manager),
360 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
361
362 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
363
364 bool StartWatcher() {
365 DCHECK(thread_checker_.CalledOnValidThread());
366
367 HRESULT hr;
368
369 midi_port_statics_ =
370 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
371 if (!midi_port_statics_)
372 return false;
373
374 HSTRING device_selector = nullptr;
375 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
376 if (FAILED(hr)) {
377 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
378 return false;
379 }
380
381 auto dev_info_statics = WrlStaticsFactory<
382 IDeviceInformationStatics,
383 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
384 if (!dev_info_statics)
385 return false;
386
387 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
388 watcher_.Receive());
389 if (FAILED(hr)) {
390 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
391 return false;
392 }
393
394 // Register callbacks to WinRT that post state-modifying jobs back to COM
395 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
396 // posting jobs. Note that WinRT callback arguments should not be passed
397 // outside the callback since the pointers may be unavailable afterwards.
398 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
399 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
400
401 hr = watcher_->add_Added(
402 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
403 [weak_ptr, task_runner](IDeviceWatcher* watcher,
404 IDeviceInformation* info) {
shaochuanc2894522016-09-20 01:10:50 -0700405 if (!info) {
406 VLOG(1) << "DeviceWatcher.Added callback provides null "
407 "pointer, ignoring";
408 return S_OK;
409 }
410
shaochuan110262b2016-08-31 02:15:16 -0700411 // Disable Microsoft GS Wavetable Synth due to security reasons.
412 // http://crbug.com/499279
413 if (IsMicrosoftSynthesizer(info))
414 return S_OK;
415
shaochuane58f9c72016-08-30 22:27:08 -0700416 std::string dev_id = GetIdString(info),
417 dev_name = GetNameString(info);
418
419 task_runner->PostTask(
420 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
421 dev_id, dev_name));
422
423 return S_OK;
424 })
425 .Get(),
426 &token_Added_);
427 if (FAILED(hr)) {
428 VLOG(1) << "add_Added failed: " << PrintHr(hr);
429 return false;
430 }
431
432 hr = watcher_->add_EnumerationCompleted(
433 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
434 [weak_ptr, task_runner](IDeviceWatcher* watcher,
435 IInspectable* insp) {
436 task_runner->PostTask(
437 FROM_HERE,
438 base::Bind(&MidiPortManager::OnEnumerationCompleted,
439 weak_ptr));
440
441 return S_OK;
442 })
443 .Get(),
444 &token_EnumerationCompleted_);
445 if (FAILED(hr)) {
446 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
447 return false;
448 }
449
450 hr = watcher_->add_Removed(
451 WRL::Callback<
452 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
453 [weak_ptr, task_runner](IDeviceWatcher* watcher,
454 IDeviceInformationUpdate* update) {
shaochuanc2894522016-09-20 01:10:50 -0700455 if (!update) {
456 VLOG(1) << "DeviceWatcher.Removed callback provides null "
457 "pointer, ignoring";
458 return S_OK;
459 }
460
shaochuane58f9c72016-08-30 22:27:08 -0700461 std::string dev_id = GetIdString(update);
462
463 task_runner->PostTask(
464 FROM_HERE,
465 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
466
467 return S_OK;
468 })
469 .Get(),
470 &token_Removed_);
471 if (FAILED(hr)) {
472 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
473 return false;
474 }
475
476 hr = watcher_->add_Stopped(
477 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
478 [](IDeviceWatcher* watcher, IInspectable* insp) {
479 // Placeholder, does nothing for now.
480 return S_OK;
481 })
482 .Get(),
483 &token_Stopped_);
484 if (FAILED(hr)) {
485 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
486 return false;
487 }
488
489 hr = watcher_->add_Updated(
490 WRL::Callback<
491 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
492 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
493 // TODO(shaochuan): Check for fields to be updated here.
494 return S_OK;
495 })
496 .Get(),
497 &token_Updated_);
498 if (FAILED(hr)) {
499 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
500 return false;
501 }
502
503 hr = watcher_->Start();
504 if (FAILED(hr)) {
505 VLOG(1) << "Start failed: " << PrintHr(hr);
506 return false;
507 }
508
509 is_initialized_ = true;
510 return true;
511 }
512
513 void StopWatcher() {
514 DCHECK(thread_checker_.CalledOnValidThread());
515
516 HRESULT hr;
517
518 for (const auto& entry : ports_)
519 RemovePortEventHandlers(entry.second.get());
520
521 if (token_Added_.value != kInvalidTokenValue) {
522 hr = watcher_->remove_Added(token_Added_);
523 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
524 token_Added_.value = kInvalidTokenValue;
525 }
526 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
527 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
528 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
529 << PrintHr(hr);
530 token_EnumerationCompleted_.value = kInvalidTokenValue;
531 }
532 if (token_Removed_.value != kInvalidTokenValue) {
533 hr = watcher_->remove_Removed(token_Removed_);
534 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
535 token_Removed_.value = kInvalidTokenValue;
536 }
537 if (token_Stopped_.value != kInvalidTokenValue) {
538 hr = watcher_->remove_Stopped(token_Stopped_);
539 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
540 token_Stopped_.value = kInvalidTokenValue;
541 }
542 if (token_Updated_.value != kInvalidTokenValue) {
543 hr = watcher_->remove_Updated(token_Updated_);
544 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
545 token_Updated_.value = kInvalidTokenValue;
546 }
547
548 if (is_initialized_) {
549 hr = watcher_->Stop();
550 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
551 is_initialized_ = false;
552 }
553 }
554
555 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
556 DCHECK(thread_checker_.CalledOnValidThread());
557 CHECK(is_initialized_);
558
559 auto it = ports_.find(dev_id);
560 if (it == ports_.end())
561 return nullptr;
562 return it->second.get();
563 }
564
565 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
566 DCHECK(thread_checker_.CalledOnValidThread());
567 CHECK(is_initialized_);
568
569 return GetPortByDeviceId(port_ids_[port_index]);
570 }
571
572 protected:
573 // Points to the MidiManagerWinrt instance, which is expected to outlive the
574 // MidiPortManager instance.
575 MidiManagerWinrt* midi_manager_;
576
577 // Task runner of the COM thread.
578 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
579
580 // Ensures all methods are called on the COM thread.
581 base::ThreadChecker thread_checker_;
582
583 private:
584 // DeviceWatcher callbacks:
585 void OnAdded(std::string dev_id, std::string dev_name) {
586 DCHECK(thread_checker_.CalledOnValidThread());
587 CHECK(is_initialized_);
588
shaochuane58f9c72016-08-30 22:27:08 -0700589 port_names_[dev_id] = dev_name;
590
shaochuan9ff63b82016-09-01 01:58:44 -0700591 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
592 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700593 return;
shaochuane58f9c72016-08-30 22:27:08 -0700594
595 IAsyncOperation<RuntimeType*>* async_op;
596
shaochuan9ff63b82016-09-01 01:58:44 -0700597 HRESULT hr =
598 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700599 if (FAILED(hr)) {
600 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
601 return;
602 }
603
604 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
605 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
606
607 hr = async_op->put_Completed(
608 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
609 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
610 AsyncStatus status) {
shaochuane58f9c72016-08-30 22:27:08 -0700611 // A reference to |async_op| is kept in |async_ops_|, safe to pass
612 // outside.
613 task_runner->PostTask(
614 FROM_HERE,
615 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
shaochuanc2894522016-09-20 01:10:50 -0700616 weak_ptr, async_op));
shaochuane58f9c72016-08-30 22:27:08 -0700617
618 return S_OK;
619 })
620 .Get());
621 if (FAILED(hr)) {
622 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
623 return;
624 }
625
626 // Keep a reference to incompleted |async_op| for releasing later.
627 async_ops_.insert(async_op);
628 }
629
630 void OnEnumerationCompleted() {
631 DCHECK(thread_checker_.CalledOnValidThread());
632 CHECK(is_initialized_);
633
634 if (async_ops_.empty())
635 midi_manager_->OnPortManagerReady();
636 else
637 enumeration_completed_not_ready_ = true;
638 }
639
640 void OnRemoved(std::string dev_id) {
641 DCHECK(thread_checker_.CalledOnValidThread());
642 CHECK(is_initialized_);
643
shaochuan110262b2016-08-31 02:15:16 -0700644 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
645 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700646 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
647 if (!port) {
648 VLOG(1) << "Removing non-existent port " << dev_id;
649 return;
650 }
651
652 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
653
654 RemovePortEventHandlers(port);
655 port->handle = nullptr;
656 }
657
shaochuanc2894522016-09-20 01:10:50 -0700658 void OnCompletedGetPortFromIdAsync(IAsyncOperation<RuntimeType*>* async_op) {
shaochuane58f9c72016-08-30 22:27:08 -0700659 DCHECK(thread_checker_.CalledOnValidThread());
660 CHECK(is_initialized_);
661
shaochuanc2894522016-09-20 01:10:50 -0700662 InterfaceType* handle = nullptr;
663 HRESULT hr = async_op->GetResults(&handle);
664 if (FAILED(hr)) {
665 VLOG(1) << "GetResults failed: " << PrintHr(hr);
666 return;
667 }
668
669 // Manually release COM interface to completed |async_op|.
670 auto it = async_ops_.find(async_op);
671 CHECK(it != async_ops_.end());
672 (*it)->Release();
673 async_ops_.erase(it);
674
675 if (!handle) {
676 VLOG(1) << "Midi{In,Out}Port.FromIdAsync callback provides null pointer, "
677 "ignoring";
678 return;
679 }
680
shaochuane58f9c72016-08-30 22:27:08 -0700681 EventRegistrationToken token = {kInvalidTokenValue};
682 if (!RegisterOnMessageReceived(handle, &token))
683 return;
684
685 std::string dev_id = GetDeviceIdString(handle);
686
687 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
688
689 if (port == nullptr) {
shaochuan17bc4a02016-09-06 01:42:12 -0700690 std::string manufacturer = "Unknown", driver_version = "Unknown";
691 GetDriverInfoFromDeviceId(dev_id, &manufacturer, &driver_version);
692
693 AddPort(MidiPortInfo(dev_id, manufacturer, port_names_[dev_id],
694 driver_version, MIDI_PORT_OPENED));
shaochuane58f9c72016-08-30 22:27:08 -0700695
696 port = new MidiPort<InterfaceType>;
697 port->index = static_cast<uint32_t>(port_ids_.size());
698
699 ports_[dev_id].reset(port);
700 port_ids_.push_back(dev_id);
701 } else {
702 SetPortState(port->index, MIDI_PORT_CONNECTED);
703 }
704
705 port->handle = handle;
706 port->token_MessageReceived = token;
707
shaochuane58f9c72016-08-30 22:27:08 -0700708 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
709 midi_manager_->OnPortManagerReady();
710 enumeration_completed_not_ready_ = false;
711 }
712 }
713
714 // Overrided by MidiInPortManager to listen to input ports.
715 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
716 EventRegistrationToken* p_token) {
717 return true;
718 }
719
720 // Overrided by MidiInPortManager to remove MessageReceived event handler.
721 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
722
723 // Calls midi_manager_->Add{Input,Output}Port.
724 virtual void AddPort(MidiPortInfo info) = 0;
725
726 // Calls midi_manager_->Set{Input,Output}PortState.
727 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
728
729 // WeakPtrFactory has to be declared in derived class, use this method to
730 // retrieve upcasted WeakPtr for posting tasks.
731 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
732
733 // Midi{In,Out}PortStatics instance.
734 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
735
736 // DeviceWatcher instance and event registration tokens for unsubscribing
737 // events in destructor.
738 ScopedComPtr<IDeviceWatcher> watcher_;
739 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
740 token_EnumerationCompleted_ = {kInvalidTokenValue},
741 token_Removed_ = {kInvalidTokenValue},
742 token_Stopped_ = {kInvalidTokenValue},
743 token_Updated_ = {kInvalidTokenValue};
744
745 // All manipulations to these fields should be done on COM thread.
746 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
747 ports_;
748 std::vector<std::string> port_ids_;
749 std::unordered_map<std::string, std::string> port_names_;
750
751 // Keeps AsyncOperation references before the operation completes. Note that
752 // raw pointers are used here and the COM interfaces should be released
753 // manually.
754 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
755
756 // Set when device enumeration is completed but OnPortManagerReady() is not
757 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
758 // In such cases, OnPortManagerReady() will be called in
759 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
760 bool enumeration_completed_not_ready_ = false;
761
762 // Set if the instance is initialized without error. Should be checked in all
763 // methods on COM thread except StartWatcher().
764 bool is_initialized_ = false;
765};
766
767class MidiManagerWinrt::MidiInPortManager final
768 : public MidiPortManager<IMidiInPort,
769 MidiInPort,
770 IMidiInPortStatics,
771 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
772 public:
773 MidiInPortManager(MidiManagerWinrt* midi_manager)
774 : MidiPortManager(midi_manager), weak_factory_(this) {}
775
776 private:
777 // MidiPortManager overrides:
778 bool RegisterOnMessageReceived(IMidiInPort* handle,
779 EventRegistrationToken* p_token) override {
780 DCHECK(thread_checker_.CalledOnValidThread());
781
782 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
783 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
784
785 HRESULT hr = handle->add_MessageReceived(
786 WRL::Callback<
787 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
788 [weak_ptr, task_runner](IMidiInPort* handle,
789 IMidiMessageReceivedEventArgs* args) {
790 const base::TimeTicks now = base::TimeTicks::Now();
791
792 std::string dev_id = GetDeviceIdString(handle);
793
794 ScopedComPtr<IMidiMessage> message;
795 HRESULT hr = args->get_Message(message.Receive());
796 if (FAILED(hr)) {
797 VLOG(1) << "get_Message failed: " << PrintHr(hr);
798 return hr;
799 }
800
801 ScopedComPtr<IBuffer> buffer;
802 hr = message->get_RawData(buffer.Receive());
803 if (FAILED(hr)) {
804 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
805 return hr;
806 }
807
808 uint8_t* p_buffer_data = nullptr;
809 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
810 if (FAILED(hr))
811 return hr;
812
813 uint32_t data_length = 0;
814 hr = buffer->get_Length(&data_length);
815 if (FAILED(hr)) {
816 VLOG(1) << "get_Length failed: " << PrintHr(hr);
817 return hr;
818 }
819
820 std::vector<uint8_t> data(p_buffer_data,
821 p_buffer_data + data_length);
822
823 task_runner->PostTask(
824 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
825 weak_ptr, dev_id, data, now));
826
827 return S_OK;
828 })
829 .Get(),
830 p_token);
831 if (FAILED(hr)) {
832 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
833 return false;
834 }
835
836 return true;
837 }
838
839 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
840 if (!(port->handle &&
841 port->token_MessageReceived.value != kInvalidTokenValue))
842 return;
843
844 HRESULT hr =
845 port->handle->remove_MessageReceived(port->token_MessageReceived);
846 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
847 port->token_MessageReceived.value = kInvalidTokenValue;
848 }
849
850 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
851
852 void SetPortState(uint32_t port_index, MidiPortState state) final {
853 midi_manager_->SetInputPortState(port_index, state);
854 }
855
856 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
857 DCHECK(thread_checker_.CalledOnValidThread());
858
859 return weak_factory_.GetWeakPtr();
860 }
861
862 // Callback on receiving MIDI input message.
863 void OnMessageReceived(std::string dev_id,
864 std::vector<uint8_t> data,
865 base::TimeTicks time) {
866 DCHECK(thread_checker_.CalledOnValidThread());
867
868 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
869 CHECK(port);
870
871 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
872 }
873
874 // Last member to ensure destructed first.
875 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
876
877 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
878};
879
880class MidiManagerWinrt::MidiOutPortManager final
881 : public MidiPortManager<IMidiOutPort,
882 IMidiOutPort,
883 IMidiOutPortStatics,
884 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
885 public:
886 MidiOutPortManager(MidiManagerWinrt* midi_manager)
887 : MidiPortManager(midi_manager), weak_factory_(this) {}
888
889 private:
890 // MidiPortManager overrides:
891 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
892
893 void SetPortState(uint32_t port_index, MidiPortState state) final {
894 midi_manager_->SetOutputPortState(port_index, state);
895 }
896
897 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
898 DCHECK(thread_checker_.CalledOnValidThread());
899
900 return weak_factory_.GetWeakPtr();
901 }
902
903 // Last member to ensure destructed first.
904 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
905
906 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
907};
908
909MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
910
911MidiManagerWinrt::~MidiManagerWinrt() {
912 base::AutoLock auto_lock(lazy_init_member_lock_);
913
914 CHECK(!com_thread_checker_);
915 CHECK(!port_manager_in_);
916 CHECK(!port_manager_out_);
917 CHECK(!scheduler_);
918}
919
920void MidiManagerWinrt::StartInitialization() {
shaochuane58f9c72016-08-30 22:27:08 -0700921 com_thread_.init_com_with_mta(true);
922 com_thread_.Start();
923
924 com_thread_.task_runner()->PostTask(
925 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
926 base::Unretained(this)));
927}
928
929void MidiManagerWinrt::Finalize() {
930 com_thread_.task_runner()->PostTask(
931 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
932 base::Unretained(this)));
933
934 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
935 // will be ignored.
936 com_thread_.Stop();
937}
938
939void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
940 uint32_t port_index,
941 const std::vector<uint8_t>& data,
942 double timestamp) {
943 CHECK(scheduler_);
944
945 scheduler_->PostSendDataTask(
946 client, data.size(), timestamp,
947 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
948 port_index, data));
949}
950
951void MidiManagerWinrt::InitializeOnComThread() {
952 base::AutoLock auto_lock(lazy_init_member_lock_);
953
954 com_thread_checker_.reset(new base::ThreadChecker);
955
shaochuan9ff63b82016-09-01 01:58:44 -0700956 if (!g_combase_functions.Get().LoadFunctions()) {
957 VLOG(1) << "Failed loading functions from combase.dll: "
958 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
959 CompleteInitialization(Result::INITIALIZATION_ERROR);
960 return;
961 }
962
shaochuane58f9c72016-08-30 22:27:08 -0700963 port_manager_in_.reset(new MidiInPortManager(this));
964 port_manager_out_.reset(new MidiOutPortManager(this));
965
966 scheduler_.reset(new MidiScheduler(this));
967
968 if (!(port_manager_in_->StartWatcher() &&
969 port_manager_out_->StartWatcher())) {
970 port_manager_in_->StopWatcher();
971 port_manager_out_->StopWatcher();
972 CompleteInitialization(Result::INITIALIZATION_ERROR);
973 }
974}
975
976void MidiManagerWinrt::FinalizeOnComThread() {
977 base::AutoLock auto_lock(lazy_init_member_lock_);
978
979 DCHECK(com_thread_checker_->CalledOnValidThread());
980
981 scheduler_.reset();
982
shaochuan9ff63b82016-09-01 01:58:44 -0700983 if (port_manager_in_) {
984 port_manager_in_->StopWatcher();
985 port_manager_in_.reset();
986 }
987
988 if (port_manager_out_) {
989 port_manager_out_->StopWatcher();
990 port_manager_out_.reset();
991 }
shaochuane58f9c72016-08-30 22:27:08 -0700992
993 com_thread_checker_.reset();
994}
995
996void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
997 const std::vector<uint8_t>& data) {
998 DCHECK(com_thread_checker_->CalledOnValidThread());
999
1000 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
1001 if (!(port && port->handle)) {
1002 VLOG(1) << "Port not available: " << port_index;
1003 return;
1004 }
1005
1006 auto buffer_factory =
1007 WrlStaticsFactory<IBufferFactory,
1008 RuntimeClass_Windows_Storage_Streams_Buffer>();
1009 if (!buffer_factory)
1010 return;
1011
1012 ScopedComPtr<IBuffer> buffer;
1013 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
1014 buffer.Receive());
1015 if (FAILED(hr)) {
1016 VLOG(1) << "Create failed: " << PrintHr(hr);
1017 return;
1018 }
1019
1020 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
1021 if (FAILED(hr)) {
1022 VLOG(1) << "put_Length failed: " << PrintHr(hr);
1023 return;
1024 }
1025
1026 uint8_t* p_buffer_data = nullptr;
1027 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
1028 if (FAILED(hr))
1029 return;
1030
1031 std::copy(data.begin(), data.end(), p_buffer_data);
1032
1033 hr = port->handle->SendBuffer(buffer.get());
1034 if (FAILED(hr)) {
1035 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
1036 return;
1037 }
1038}
1039
1040void MidiManagerWinrt::OnPortManagerReady() {
1041 DCHECK(com_thread_checker_->CalledOnValidThread());
1042 DCHECK(port_manager_ready_count_ < 2);
1043
1044 if (++port_manager_ready_count_ == 2)
1045 CompleteInitialization(Result::OK);
1046}
1047
shaochuane58f9c72016-08-30 22:27:08 -07001048} // namespace midi
1049} // namespace media