blob: a5141a6ce4d00681f98c1479ff5c25480f69b860 [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
shaochuan80f1fba2016-09-01 20:44:51 -07007// TODO(shaochuan): Remove this once clang supports uuid syntax in <robuffer.h>.
8// https://reviews.llvm.org/D23895
9namespace Windows {
10namespace Storage {
11namespace Streams {
qiankun.miao53f2d662016-09-02 17:44:08 -070012#pragma warning(disable : 4467)
shaochuan80f1fba2016-09-01 20:44:51 -070013struct __declspec(uuid("905a0fef-bc53-11df-8c49-001e4fc686da"))
14 IBufferByteAccess;
15}
16}
17}
18
shaochuane58f9c72016-08-30 22:27:08 -070019#include <comdef.h>
20#include <robuffer.h>
21#include <windows.devices.enumeration.h>
22#include <windows.devices.midi.h>
23#include <wrl/event.h>
24
25#include <iomanip>
26#include <unordered_map>
27#include <unordered_set>
28
29#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070030#include "base/lazy_instance.h"
31#include "base/scoped_generic.h"
shaochuane58f9c72016-08-30 22:27:08 -070032#include "base/strings/utf_string_conversions.h"
33#include "base/threading/thread_checker.h"
34#include "base/threading/thread_task_runner_handle.h"
35#include "base/timer/timer.h"
36#include "base/win/scoped_comptr.h"
shaochuane58f9c72016-08-30 22:27:08 -070037#include "media/midi/midi_scheduler.h"
38
39namespace media {
40namespace midi {
41namespace {
42
43namespace WRL = Microsoft::WRL;
44
45using namespace ABI::Windows::Devices::Enumeration;
46using namespace ABI::Windows::Devices::Midi;
47using namespace ABI::Windows::Foundation;
48using namespace ABI::Windows::Storage::Streams;
49
50using base::win::ScopedComPtr;
51
52// Helpers for printing HRESULTs.
53struct PrintHr {
54 PrintHr(HRESULT hr) : hr(hr) {}
55 HRESULT hr;
56};
57
58std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
59 std::ios_base::fmtflags ff = os.flags();
60 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
61 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
62 os.flags(ff);
63 return os;
64}
65
shaochuan9ff63b82016-09-01 01:58:44 -070066// Provides access to functions in combase.dll which may not be available on
67// Windows 7. Loads functions dynamically at runtime to prevent library
68// dependencies. Use this class through the global LazyInstance
69// |g_combase_functions|.
70class CombaseFunctions {
71 public:
72 CombaseFunctions() = default;
73
74 ~CombaseFunctions() {
75 if (combase_dll_)
76 ::FreeLibrary(combase_dll_);
77 }
78
79 bool LoadFunctions() {
80 combase_dll_ = ::LoadLibrary(L"combase.dll");
81 if (!combase_dll_)
82 return false;
83
84 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
85 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
86 if (!get_factory_func_)
87 return false;
88
89 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
90 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
91 if (!create_string_func_)
92 return false;
93
94 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
95 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
96 if (!delete_string_func_)
97 return false;
98
99 get_string_raw_buffer_func_ =
100 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
101 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
102 if (!get_string_raw_buffer_func_)
103 return false;
104
105 return true;
106 }
107
108 HRESULT RoGetActivationFactory(HSTRING class_id,
109 const IID& iid,
110 void** out_factory) {
111 DCHECK(get_factory_func_);
112 return get_factory_func_(class_id, iid, out_factory);
113 }
114
115 HRESULT WindowsCreateString(const base::char16* src,
116 uint32_t len,
117 HSTRING* out_hstr) {
118 DCHECK(create_string_func_);
119 return create_string_func_(src, len, out_hstr);
120 }
121
122 HRESULT WindowsDeleteString(HSTRING hstr) {
123 DCHECK(delete_string_func_);
124 return delete_string_func_(hstr);
125 }
126
127 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
128 uint32_t* out_len) {
129 DCHECK(get_string_raw_buffer_func_);
130 return get_string_raw_buffer_func_(hstr, out_len);
131 }
132
133 private:
134 HMODULE combase_dll_ = nullptr;
135
136 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
137 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
138 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
139 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
140};
141
142base::LazyInstance<CombaseFunctions> g_combase_functions =
143 LAZY_INSTANCE_INITIALIZER;
144
145// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
146// WindowsCreateString().
147class ScopedHStringTraits {
148 public:
149 static HSTRING InvalidValue() { return nullptr; }
150
151 static void Free(HSTRING hstr) {
152 g_combase_functions.Get().WindowsDeleteString(hstr);
153 }
154};
155
156class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
157 public:
158 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
159 HSTRING hstr;
160 HRESULT hr = g_combase_functions.Get().WindowsCreateString(
161 str, static_cast<uint32_t>(wcslen(str)), &hstr);
162 if (FAILED(hr))
163 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
164 else
165 reset(hstr);
166 }
167};
168
shaochuane58f9c72016-08-30 22:27:08 -0700169// Factory functions that activate and create WinRT components. The caller takes
170// ownership of the returning ComPtr.
171template <typename InterfaceType, base::char16 const* runtime_class_id>
172ScopedComPtr<InterfaceType> WrlStaticsFactory() {
173 ScopedComPtr<InterfaceType> com_ptr;
174
shaochuan9ff63b82016-09-01 01:58:44 -0700175 ScopedHString class_id_hstring(runtime_class_id);
176 if (!class_id_hstring.is_valid()) {
177 com_ptr = nullptr;
178 return com_ptr;
179 }
180
181 HRESULT hr = g_combase_functions.Get().RoGetActivationFactory(
182 class_id_hstring.get(), __uuidof(InterfaceType), com_ptr.ReceiveVoid());
shaochuane58f9c72016-08-30 22:27:08 -0700183 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700184 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700185 com_ptr = nullptr;
186 }
187
188 return com_ptr;
189}
190
shaochuan80f1fba2016-09-01 20:44:51 -0700191std::string HStringToString(HSTRING hstr) {
shaochuane58f9c72016-08-30 22:27:08 -0700192 // Note: empty HSTRINGs are represent as nullptr, and instantiating
193 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700194 const base::char16* buffer =
shaochuan80f1fba2016-09-01 20:44:51 -0700195 g_combase_functions.Get().WindowsGetStringRawBuffer(hstr, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700196 if (buffer)
197 return base::WideToUTF8(buffer);
198 return std::string();
199}
200
201template <typename T>
202std::string GetIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700203 HSTRING result;
204 HRESULT hr = obj->get_Id(&result);
205 if (FAILED(hr)) {
206 VLOG(1) << "get_Id failed: " << PrintHr(hr);
207 return std::string();
208 }
209 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700210}
211
212template <typename T>
213std::string GetDeviceIdString(T* obj) {
shaochuan80f1fba2016-09-01 20:44:51 -0700214 HSTRING result;
215 HRESULT hr = obj->get_DeviceId(&result);
216 if (FAILED(hr)) {
217 VLOG(1) << "get_DeviceId failed: " << PrintHr(hr);
218 return std::string();
219 }
220 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700221}
222
223std::string GetNameString(IDeviceInformation* info) {
shaochuan80f1fba2016-09-01 20:44:51 -0700224 HSTRING result;
225 HRESULT hr = info->get_Name(&result);
226 if (FAILED(hr)) {
227 VLOG(1) << "get_Name failed: " << PrintHr(hr);
228 return std::string();
229 }
230 return HStringToString(result);
shaochuane58f9c72016-08-30 22:27:08 -0700231}
232
233HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
234 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
235
236 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
237 if (FAILED(hr)) {
238 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
239 return hr;
240 }
241
242 // Lifetime of the pointing buffer is controlled by the buffer object.
243 hr = buffer_byte_access->Buffer(out);
244 if (FAILED(hr)) {
245 VLOG(1) << "Buffer failed: " << PrintHr(hr);
246 return hr;
247 }
248
249 return S_OK;
250}
251
shaochuan110262b2016-08-31 02:15:16 -0700252// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
253// instance.
254bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
255 auto midi_synthesizer_statics =
256 WrlStaticsFactory<IMidiSynthesizerStatics,
257 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
258 boolean result = FALSE;
259 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
260 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
261 return result != FALSE;
262}
263
shaochuane58f9c72016-08-30 22:27:08 -0700264// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
265const int64_t kInvalidTokenValue = 0;
266
267template <typename InterfaceType>
268struct MidiPort {
269 MidiPort() = default;
270
271 uint32_t index;
272 ScopedComPtr<InterfaceType> handle;
273 EventRegistrationToken token_MessageReceived;
274
275 private:
276 DISALLOW_COPY_AND_ASSIGN(MidiPort);
277};
278
279} // namespace
280
281template <typename InterfaceType,
282 typename RuntimeType,
283 typename StaticsInterfaceType,
284 base::char16 const* runtime_class_id>
285class MidiManagerWinrt::MidiPortManager {
286 public:
287 // MidiPortManager instances should be constructed on the COM thread.
288 MidiPortManager(MidiManagerWinrt* midi_manager)
289 : midi_manager_(midi_manager),
290 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
291
292 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
293
294 bool StartWatcher() {
295 DCHECK(thread_checker_.CalledOnValidThread());
296
297 HRESULT hr;
298
299 midi_port_statics_ =
300 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
301 if (!midi_port_statics_)
302 return false;
303
304 HSTRING device_selector = nullptr;
305 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
306 if (FAILED(hr)) {
307 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
308 return false;
309 }
310
311 auto dev_info_statics = WrlStaticsFactory<
312 IDeviceInformationStatics,
313 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
314 if (!dev_info_statics)
315 return false;
316
317 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
318 watcher_.Receive());
319 if (FAILED(hr)) {
320 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
321 return false;
322 }
323
324 // Register callbacks to WinRT that post state-modifying jobs back to COM
325 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
326 // posting jobs. Note that WinRT callback arguments should not be passed
327 // outside the callback since the pointers may be unavailable afterwards.
328 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
329 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
330
331 hr = watcher_->add_Added(
332 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
333 [weak_ptr, task_runner](IDeviceWatcher* watcher,
334 IDeviceInformation* info) {
shaochuan110262b2016-08-31 02:15:16 -0700335 // Disable Microsoft GS Wavetable Synth due to security reasons.
336 // http://crbug.com/499279
337 if (IsMicrosoftSynthesizer(info))
338 return S_OK;
339
shaochuane58f9c72016-08-30 22:27:08 -0700340 std::string dev_id = GetIdString(info),
341 dev_name = GetNameString(info);
342
343 task_runner->PostTask(
344 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
345 dev_id, dev_name));
346
347 return S_OK;
348 })
349 .Get(),
350 &token_Added_);
351 if (FAILED(hr)) {
352 VLOG(1) << "add_Added failed: " << PrintHr(hr);
353 return false;
354 }
355
356 hr = watcher_->add_EnumerationCompleted(
357 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
358 [weak_ptr, task_runner](IDeviceWatcher* watcher,
359 IInspectable* insp) {
360 task_runner->PostTask(
361 FROM_HERE,
362 base::Bind(&MidiPortManager::OnEnumerationCompleted,
363 weak_ptr));
364
365 return S_OK;
366 })
367 .Get(),
368 &token_EnumerationCompleted_);
369 if (FAILED(hr)) {
370 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
371 return false;
372 }
373
374 hr = watcher_->add_Removed(
375 WRL::Callback<
376 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
377 [weak_ptr, task_runner](IDeviceWatcher* watcher,
378 IDeviceInformationUpdate* update) {
379 std::string dev_id = GetIdString(update);
380
381 task_runner->PostTask(
382 FROM_HERE,
383 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
384
385 return S_OK;
386 })
387 .Get(),
388 &token_Removed_);
389 if (FAILED(hr)) {
390 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
391 return false;
392 }
393
394 hr = watcher_->add_Stopped(
395 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
396 [](IDeviceWatcher* watcher, IInspectable* insp) {
397 // Placeholder, does nothing for now.
398 return S_OK;
399 })
400 .Get(),
401 &token_Stopped_);
402 if (FAILED(hr)) {
403 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
404 return false;
405 }
406
407 hr = watcher_->add_Updated(
408 WRL::Callback<
409 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
410 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
411 // TODO(shaochuan): Check for fields to be updated here.
412 return S_OK;
413 })
414 .Get(),
415 &token_Updated_);
416 if (FAILED(hr)) {
417 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
418 return false;
419 }
420
421 hr = watcher_->Start();
422 if (FAILED(hr)) {
423 VLOG(1) << "Start failed: " << PrintHr(hr);
424 return false;
425 }
426
427 is_initialized_ = true;
428 return true;
429 }
430
431 void StopWatcher() {
432 DCHECK(thread_checker_.CalledOnValidThread());
433
434 HRESULT hr;
435
436 for (const auto& entry : ports_)
437 RemovePortEventHandlers(entry.second.get());
438
439 if (token_Added_.value != kInvalidTokenValue) {
440 hr = watcher_->remove_Added(token_Added_);
441 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
442 token_Added_.value = kInvalidTokenValue;
443 }
444 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
445 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
446 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
447 << PrintHr(hr);
448 token_EnumerationCompleted_.value = kInvalidTokenValue;
449 }
450 if (token_Removed_.value != kInvalidTokenValue) {
451 hr = watcher_->remove_Removed(token_Removed_);
452 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
453 token_Removed_.value = kInvalidTokenValue;
454 }
455 if (token_Stopped_.value != kInvalidTokenValue) {
456 hr = watcher_->remove_Stopped(token_Stopped_);
457 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
458 token_Stopped_.value = kInvalidTokenValue;
459 }
460 if (token_Updated_.value != kInvalidTokenValue) {
461 hr = watcher_->remove_Updated(token_Updated_);
462 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
463 token_Updated_.value = kInvalidTokenValue;
464 }
465
466 if (is_initialized_) {
467 hr = watcher_->Stop();
468 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
469 is_initialized_ = false;
470 }
471 }
472
473 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
474 DCHECK(thread_checker_.CalledOnValidThread());
475 CHECK(is_initialized_);
476
477 auto it = ports_.find(dev_id);
478 if (it == ports_.end())
479 return nullptr;
480 return it->second.get();
481 }
482
483 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
484 DCHECK(thread_checker_.CalledOnValidThread());
485 CHECK(is_initialized_);
486
487 return GetPortByDeviceId(port_ids_[port_index]);
488 }
489
490 protected:
491 // Points to the MidiManagerWinrt instance, which is expected to outlive the
492 // MidiPortManager instance.
493 MidiManagerWinrt* midi_manager_;
494
495 // Task runner of the COM thread.
496 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
497
498 // Ensures all methods are called on the COM thread.
499 base::ThreadChecker thread_checker_;
500
501 private:
502 // DeviceWatcher callbacks:
503 void OnAdded(std::string dev_id, std::string dev_name) {
504 DCHECK(thread_checker_.CalledOnValidThread());
505 CHECK(is_initialized_);
506
shaochuane58f9c72016-08-30 22:27:08 -0700507 port_names_[dev_id] = dev_name;
508
shaochuan9ff63b82016-09-01 01:58:44 -0700509 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
510 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700511 return;
shaochuane58f9c72016-08-30 22:27:08 -0700512
513 IAsyncOperation<RuntimeType*>* async_op;
514
shaochuan9ff63b82016-09-01 01:58:44 -0700515 HRESULT hr =
516 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700517 if (FAILED(hr)) {
518 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
519 return;
520 }
521
522 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
523 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
524
525 hr = async_op->put_Completed(
526 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
527 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
528 AsyncStatus status) {
529 InterfaceType* handle;
530 HRESULT hr = async_op->GetResults(&handle);
531 if (FAILED(hr)) {
532 VLOG(1) << "GetResults failed: " << PrintHr(hr);
533 return hr;
534 }
535
536 // A reference to |async_op| is kept in |async_ops_|, safe to pass
537 // outside.
538 task_runner->PostTask(
539 FROM_HERE,
540 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
541 weak_ptr, handle, async_op));
542
543 return S_OK;
544 })
545 .Get());
546 if (FAILED(hr)) {
547 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
548 return;
549 }
550
551 // Keep a reference to incompleted |async_op| for releasing later.
552 async_ops_.insert(async_op);
553 }
554
555 void OnEnumerationCompleted() {
556 DCHECK(thread_checker_.CalledOnValidThread());
557 CHECK(is_initialized_);
558
559 if (async_ops_.empty())
560 midi_manager_->OnPortManagerReady();
561 else
562 enumeration_completed_not_ready_ = true;
563 }
564
565 void OnRemoved(std::string dev_id) {
566 DCHECK(thread_checker_.CalledOnValidThread());
567 CHECK(is_initialized_);
568
shaochuan110262b2016-08-31 02:15:16 -0700569 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
570 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700571 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
572 if (!port) {
573 VLOG(1) << "Removing non-existent port " << dev_id;
574 return;
575 }
576
577 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
578
579 RemovePortEventHandlers(port);
580 port->handle = nullptr;
581 }
582
583 void OnCompletedGetPortFromIdAsync(InterfaceType* handle,
584 IAsyncOperation<RuntimeType*>* async_op) {
585 DCHECK(thread_checker_.CalledOnValidThread());
586 CHECK(is_initialized_);
587
588 EventRegistrationToken token = {kInvalidTokenValue};
589 if (!RegisterOnMessageReceived(handle, &token))
590 return;
591
592 std::string dev_id = GetDeviceIdString(handle);
593
594 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
595
596 if (port == nullptr) {
597 // TODO(crbug.com/642604): Fill in manufacturer and driver version.
598 AddPort(MidiPortInfo(dev_id, std::string("Manufacturer"),
599 port_names_[dev_id], std::string("DriverVersion"),
600 MIDI_PORT_OPENED));
601
602 port = new MidiPort<InterfaceType>;
603 port->index = static_cast<uint32_t>(port_ids_.size());
604
605 ports_[dev_id].reset(port);
606 port_ids_.push_back(dev_id);
607 } else {
608 SetPortState(port->index, MIDI_PORT_CONNECTED);
609 }
610
611 port->handle = handle;
612 port->token_MessageReceived = token;
613
614 // Manually release COM interface to completed |async_op|.
615 auto it = async_ops_.find(async_op);
616 CHECK(it != async_ops_.end());
617 (*it)->Release();
618 async_ops_.erase(it);
619
620 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
621 midi_manager_->OnPortManagerReady();
622 enumeration_completed_not_ready_ = false;
623 }
624 }
625
626 // Overrided by MidiInPortManager to listen to input ports.
627 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
628 EventRegistrationToken* p_token) {
629 return true;
630 }
631
632 // Overrided by MidiInPortManager to remove MessageReceived event handler.
633 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
634
635 // Calls midi_manager_->Add{Input,Output}Port.
636 virtual void AddPort(MidiPortInfo info) = 0;
637
638 // Calls midi_manager_->Set{Input,Output}PortState.
639 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
640
641 // WeakPtrFactory has to be declared in derived class, use this method to
642 // retrieve upcasted WeakPtr for posting tasks.
643 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
644
645 // Midi{In,Out}PortStatics instance.
646 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
647
648 // DeviceWatcher instance and event registration tokens for unsubscribing
649 // events in destructor.
650 ScopedComPtr<IDeviceWatcher> watcher_;
651 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
652 token_EnumerationCompleted_ = {kInvalidTokenValue},
653 token_Removed_ = {kInvalidTokenValue},
654 token_Stopped_ = {kInvalidTokenValue},
655 token_Updated_ = {kInvalidTokenValue};
656
657 // All manipulations to these fields should be done on COM thread.
658 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
659 ports_;
660 std::vector<std::string> port_ids_;
661 std::unordered_map<std::string, std::string> port_names_;
662
663 // Keeps AsyncOperation references before the operation completes. Note that
664 // raw pointers are used here and the COM interfaces should be released
665 // manually.
666 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
667
668 // Set when device enumeration is completed but OnPortManagerReady() is not
669 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
670 // In such cases, OnPortManagerReady() will be called in
671 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
672 bool enumeration_completed_not_ready_ = false;
673
674 // Set if the instance is initialized without error. Should be checked in all
675 // methods on COM thread except StartWatcher().
676 bool is_initialized_ = false;
677};
678
679class MidiManagerWinrt::MidiInPortManager final
680 : public MidiPortManager<IMidiInPort,
681 MidiInPort,
682 IMidiInPortStatics,
683 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
684 public:
685 MidiInPortManager(MidiManagerWinrt* midi_manager)
686 : MidiPortManager(midi_manager), weak_factory_(this) {}
687
688 private:
689 // MidiPortManager overrides:
690 bool RegisterOnMessageReceived(IMidiInPort* handle,
691 EventRegistrationToken* p_token) override {
692 DCHECK(thread_checker_.CalledOnValidThread());
693
694 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
695 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
696
697 HRESULT hr = handle->add_MessageReceived(
698 WRL::Callback<
699 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
700 [weak_ptr, task_runner](IMidiInPort* handle,
701 IMidiMessageReceivedEventArgs* args) {
702 const base::TimeTicks now = base::TimeTicks::Now();
703
704 std::string dev_id = GetDeviceIdString(handle);
705
706 ScopedComPtr<IMidiMessage> message;
707 HRESULT hr = args->get_Message(message.Receive());
708 if (FAILED(hr)) {
709 VLOG(1) << "get_Message failed: " << PrintHr(hr);
710 return hr;
711 }
712
713 ScopedComPtr<IBuffer> buffer;
714 hr = message->get_RawData(buffer.Receive());
715 if (FAILED(hr)) {
716 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
717 return hr;
718 }
719
720 uint8_t* p_buffer_data = nullptr;
721 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
722 if (FAILED(hr))
723 return hr;
724
725 uint32_t data_length = 0;
726 hr = buffer->get_Length(&data_length);
727 if (FAILED(hr)) {
728 VLOG(1) << "get_Length failed: " << PrintHr(hr);
729 return hr;
730 }
731
732 std::vector<uint8_t> data(p_buffer_data,
733 p_buffer_data + data_length);
734
735 task_runner->PostTask(
736 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
737 weak_ptr, dev_id, data, now));
738
739 return S_OK;
740 })
741 .Get(),
742 p_token);
743 if (FAILED(hr)) {
744 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
745 return false;
746 }
747
748 return true;
749 }
750
751 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
752 if (!(port->handle &&
753 port->token_MessageReceived.value != kInvalidTokenValue))
754 return;
755
756 HRESULT hr =
757 port->handle->remove_MessageReceived(port->token_MessageReceived);
758 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
759 port->token_MessageReceived.value = kInvalidTokenValue;
760 }
761
762 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
763
764 void SetPortState(uint32_t port_index, MidiPortState state) final {
765 midi_manager_->SetInputPortState(port_index, state);
766 }
767
768 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
769 DCHECK(thread_checker_.CalledOnValidThread());
770
771 return weak_factory_.GetWeakPtr();
772 }
773
774 // Callback on receiving MIDI input message.
775 void OnMessageReceived(std::string dev_id,
776 std::vector<uint8_t> data,
777 base::TimeTicks time) {
778 DCHECK(thread_checker_.CalledOnValidThread());
779
780 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
781 CHECK(port);
782
783 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
784 }
785
786 // Last member to ensure destructed first.
787 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
788
789 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
790};
791
792class MidiManagerWinrt::MidiOutPortManager final
793 : public MidiPortManager<IMidiOutPort,
794 IMidiOutPort,
795 IMidiOutPortStatics,
796 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
797 public:
798 MidiOutPortManager(MidiManagerWinrt* midi_manager)
799 : MidiPortManager(midi_manager), weak_factory_(this) {}
800
801 private:
802 // MidiPortManager overrides:
803 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
804
805 void SetPortState(uint32_t port_index, MidiPortState state) final {
806 midi_manager_->SetOutputPortState(port_index, state);
807 }
808
809 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
810 DCHECK(thread_checker_.CalledOnValidThread());
811
812 return weak_factory_.GetWeakPtr();
813 }
814
815 // Last member to ensure destructed first.
816 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
817
818 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
819};
820
821MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
822
823MidiManagerWinrt::~MidiManagerWinrt() {
824 base::AutoLock auto_lock(lazy_init_member_lock_);
825
826 CHECK(!com_thread_checker_);
827 CHECK(!port_manager_in_);
828 CHECK(!port_manager_out_);
829 CHECK(!scheduler_);
830}
831
832void MidiManagerWinrt::StartInitialization() {
shaochuane58f9c72016-08-30 22:27:08 -0700833 com_thread_.init_com_with_mta(true);
834 com_thread_.Start();
835
836 com_thread_.task_runner()->PostTask(
837 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
838 base::Unretained(this)));
839}
840
841void MidiManagerWinrt::Finalize() {
842 com_thread_.task_runner()->PostTask(
843 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
844 base::Unretained(this)));
845
846 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
847 // will be ignored.
848 com_thread_.Stop();
849}
850
851void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
852 uint32_t port_index,
853 const std::vector<uint8_t>& data,
854 double timestamp) {
855 CHECK(scheduler_);
856
857 scheduler_->PostSendDataTask(
858 client, data.size(), timestamp,
859 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
860 port_index, data));
861}
862
863void MidiManagerWinrt::InitializeOnComThread() {
864 base::AutoLock auto_lock(lazy_init_member_lock_);
865
866 com_thread_checker_.reset(new base::ThreadChecker);
867
shaochuan9ff63b82016-09-01 01:58:44 -0700868 if (!g_combase_functions.Get().LoadFunctions()) {
869 VLOG(1) << "Failed loading functions from combase.dll: "
870 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
871 CompleteInitialization(Result::INITIALIZATION_ERROR);
872 return;
873 }
874
shaochuane58f9c72016-08-30 22:27:08 -0700875 port_manager_in_.reset(new MidiInPortManager(this));
876 port_manager_out_.reset(new MidiOutPortManager(this));
877
878 scheduler_.reset(new MidiScheduler(this));
879
880 if (!(port_manager_in_->StartWatcher() &&
881 port_manager_out_->StartWatcher())) {
882 port_manager_in_->StopWatcher();
883 port_manager_out_->StopWatcher();
884 CompleteInitialization(Result::INITIALIZATION_ERROR);
885 }
886}
887
888void MidiManagerWinrt::FinalizeOnComThread() {
889 base::AutoLock auto_lock(lazy_init_member_lock_);
890
891 DCHECK(com_thread_checker_->CalledOnValidThread());
892
893 scheduler_.reset();
894
shaochuan9ff63b82016-09-01 01:58:44 -0700895 if (port_manager_in_) {
896 port_manager_in_->StopWatcher();
897 port_manager_in_.reset();
898 }
899
900 if (port_manager_out_) {
901 port_manager_out_->StopWatcher();
902 port_manager_out_.reset();
903 }
shaochuane58f9c72016-08-30 22:27:08 -0700904
905 com_thread_checker_.reset();
906}
907
908void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
909 const std::vector<uint8_t>& data) {
910 DCHECK(com_thread_checker_->CalledOnValidThread());
911
912 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
913 if (!(port && port->handle)) {
914 VLOG(1) << "Port not available: " << port_index;
915 return;
916 }
917
918 auto buffer_factory =
919 WrlStaticsFactory<IBufferFactory,
920 RuntimeClass_Windows_Storage_Streams_Buffer>();
921 if (!buffer_factory)
922 return;
923
924 ScopedComPtr<IBuffer> buffer;
925 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
926 buffer.Receive());
927 if (FAILED(hr)) {
928 VLOG(1) << "Create failed: " << PrintHr(hr);
929 return;
930 }
931
932 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
933 if (FAILED(hr)) {
934 VLOG(1) << "put_Length failed: " << PrintHr(hr);
935 return;
936 }
937
938 uint8_t* p_buffer_data = nullptr;
939 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
940 if (FAILED(hr))
941 return;
942
943 std::copy(data.begin(), data.end(), p_buffer_data);
944
945 hr = port->handle->SendBuffer(buffer.get());
946 if (FAILED(hr)) {
947 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
948 return;
949 }
950}
951
952void MidiManagerWinrt::OnPortManagerReady() {
953 DCHECK(com_thread_checker_->CalledOnValidThread());
954 DCHECK(port_manager_ready_count_ < 2);
955
956 if (++port_manager_ready_count_ == 2)
957 CompleteInitialization(Result::OK);
958}
959
shaochuane58f9c72016-08-30 22:27:08 -0700960} // namespace midi
961} // namespace media