blob: b8400e944a6404235e3a223bdd528adcbb756bc1 [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
7#include <comdef.h>
8#include <robuffer.h>
9#include <windows.devices.enumeration.h>
10#include <windows.devices.midi.h>
11#include <wrl/event.h>
12
13#include <iomanip>
14#include <unordered_map>
15#include <unordered_set>
16
17#include "base/bind.h"
shaochuan9ff63b82016-09-01 01:58:44 -070018#include "base/lazy_instance.h"
19#include "base/scoped_generic.h"
shaochuane58f9c72016-08-30 22:27:08 -070020#include "base/strings/utf_string_conversions.h"
21#include "base/threading/thread_checker.h"
22#include "base/threading/thread_task_runner_handle.h"
23#include "base/timer/timer.h"
24#include "base/win/scoped_comptr.h"
25#include "base/win/windows_version.h"
26#include "media/midi/midi_scheduler.h"
27
28namespace media {
29namespace midi {
30namespace {
31
32namespace WRL = Microsoft::WRL;
33
34using namespace ABI::Windows::Devices::Enumeration;
35using namespace ABI::Windows::Devices::Midi;
36using namespace ABI::Windows::Foundation;
37using namespace ABI::Windows::Storage::Streams;
38
39using base::win::ScopedComPtr;
40
41// Helpers for printing HRESULTs.
42struct PrintHr {
43 PrintHr(HRESULT hr) : hr(hr) {}
44 HRESULT hr;
45};
46
47std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
48 std::ios_base::fmtflags ff = os.flags();
49 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
50 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
51 os.flags(ff);
52 return os;
53}
54
shaochuan9ff63b82016-09-01 01:58:44 -070055// Provides access to functions in combase.dll which may not be available on
56// Windows 7. Loads functions dynamically at runtime to prevent library
57// dependencies. Use this class through the global LazyInstance
58// |g_combase_functions|.
59class CombaseFunctions {
60 public:
61 CombaseFunctions() = default;
62
63 ~CombaseFunctions() {
64 if (combase_dll_)
65 ::FreeLibrary(combase_dll_);
66 }
67
68 bool LoadFunctions() {
69 combase_dll_ = ::LoadLibrary(L"combase.dll");
70 if (!combase_dll_)
71 return false;
72
73 get_factory_func_ = reinterpret_cast<decltype(&::RoGetActivationFactory)>(
74 ::GetProcAddress(combase_dll_, "RoGetActivationFactory"));
75 if (!get_factory_func_)
76 return false;
77
78 create_string_func_ = reinterpret_cast<decltype(&::WindowsCreateString)>(
79 ::GetProcAddress(combase_dll_, "WindowsCreateString"));
80 if (!create_string_func_)
81 return false;
82
83 delete_string_func_ = reinterpret_cast<decltype(&::WindowsDeleteString)>(
84 ::GetProcAddress(combase_dll_, "WindowsDeleteString"));
85 if (!delete_string_func_)
86 return false;
87
88 get_string_raw_buffer_func_ =
89 reinterpret_cast<decltype(&::WindowsGetStringRawBuffer)>(
90 ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer"));
91 if (!get_string_raw_buffer_func_)
92 return false;
93
94 return true;
95 }
96
97 HRESULT RoGetActivationFactory(HSTRING class_id,
98 const IID& iid,
99 void** out_factory) {
100 DCHECK(get_factory_func_);
101 return get_factory_func_(class_id, iid, out_factory);
102 }
103
104 HRESULT WindowsCreateString(const base::char16* src,
105 uint32_t len,
106 HSTRING* out_hstr) {
107 DCHECK(create_string_func_);
108 return create_string_func_(src, len, out_hstr);
109 }
110
111 HRESULT WindowsDeleteString(HSTRING hstr) {
112 DCHECK(delete_string_func_);
113 return delete_string_func_(hstr);
114 }
115
116 const base::char16* WindowsGetStringRawBuffer(HSTRING hstr,
117 uint32_t* out_len) {
118 DCHECK(get_string_raw_buffer_func_);
119 return get_string_raw_buffer_func_(hstr, out_len);
120 }
121
122 private:
123 HMODULE combase_dll_ = nullptr;
124
125 decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr;
126 decltype(&::WindowsCreateString) create_string_func_ = nullptr;
127 decltype(&::WindowsDeleteString) delete_string_func_ = nullptr;
128 decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr;
129};
130
131base::LazyInstance<CombaseFunctions> g_combase_functions =
132 LAZY_INSTANCE_INITIALIZER;
133
134// Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with
135// WindowsCreateString().
136class ScopedHStringTraits {
137 public:
138 static HSTRING InvalidValue() { return nullptr; }
139
140 static void Free(HSTRING hstr) {
141 g_combase_functions.Get().WindowsDeleteString(hstr);
142 }
143};
144
145class ScopedHString : public base::ScopedGeneric<HSTRING, ScopedHStringTraits> {
146 public:
147 explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) {
148 HSTRING hstr;
149 HRESULT hr = g_combase_functions.Get().WindowsCreateString(
150 str, static_cast<uint32_t>(wcslen(str)), &hstr);
151 if (FAILED(hr))
152 VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr);
153 else
154 reset(hstr);
155 }
156};
157
shaochuane58f9c72016-08-30 22:27:08 -0700158// Factory functions that activate and create WinRT components. The caller takes
159// ownership of the returning ComPtr.
160template <typename InterfaceType, base::char16 const* runtime_class_id>
161ScopedComPtr<InterfaceType> WrlStaticsFactory() {
162 ScopedComPtr<InterfaceType> com_ptr;
163
shaochuan9ff63b82016-09-01 01:58:44 -0700164 ScopedHString class_id_hstring(runtime_class_id);
165 if (!class_id_hstring.is_valid()) {
166 com_ptr = nullptr;
167 return com_ptr;
168 }
169
170 HRESULT hr = g_combase_functions.Get().RoGetActivationFactory(
171 class_id_hstring.get(), __uuidof(InterfaceType), com_ptr.ReceiveVoid());
shaochuane58f9c72016-08-30 22:27:08 -0700172 if (FAILED(hr)) {
shaochuan9ff63b82016-09-01 01:58:44 -0700173 VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr);
shaochuane58f9c72016-08-30 22:27:08 -0700174 com_ptr = nullptr;
175 }
176
177 return com_ptr;
178}
179
180template <typename T, HRESULT (T::*method)(HSTRING*)>
181std::string GetStringFromObjectMethod(T* obj) {
182 HSTRING result;
183 HRESULT hr = (obj->*method)(&result);
184 if (FAILED(hr)) {
185 VLOG(1) << "GetStringFromObjectMethod failed: " << PrintHr(hr);
186 return std::string();
187 }
188
189 // Note: empty HSTRINGs are represent as nullptr, and instantiating
190 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
shaochuan9ff63b82016-09-01 01:58:44 -0700191 const base::char16* buffer =
192 g_combase_functions.Get().WindowsGetStringRawBuffer(result, nullptr);
shaochuane58f9c72016-08-30 22:27:08 -0700193 if (buffer)
194 return base::WideToUTF8(buffer);
195 return std::string();
196}
197
198template <typename T>
199std::string GetIdString(T* obj) {
200 return GetStringFromObjectMethod<T, &T::get_Id>(obj);
201}
202
203template <typename T>
204std::string GetDeviceIdString(T* obj) {
205 return GetStringFromObjectMethod<T, &T::get_DeviceId>(obj);
206}
207
208std::string GetNameString(IDeviceInformation* info) {
209 return GetStringFromObjectMethod<IDeviceInformation,
210 &IDeviceInformation::get_Name>(info);
211}
212
213HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
214 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
215
216 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
217 if (FAILED(hr)) {
218 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
219 return hr;
220 }
221
222 // Lifetime of the pointing buffer is controlled by the buffer object.
223 hr = buffer_byte_access->Buffer(out);
224 if (FAILED(hr)) {
225 VLOG(1) << "Buffer failed: " << PrintHr(hr);
226 return hr;
227 }
228
229 return S_OK;
230}
231
shaochuan110262b2016-08-31 02:15:16 -0700232// Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth
233// instance.
234bool IsMicrosoftSynthesizer(IDeviceInformation* info) {
235 auto midi_synthesizer_statics =
236 WrlStaticsFactory<IMidiSynthesizerStatics,
237 RuntimeClass_Windows_Devices_Midi_MidiSynthesizer>();
238 boolean result = FALSE;
239 HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result);
240 VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr);
241 return result != FALSE;
242}
243
shaochuane58f9c72016-08-30 22:27:08 -0700244// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
245const int64_t kInvalidTokenValue = 0;
246
247template <typename InterfaceType>
248struct MidiPort {
249 MidiPort() = default;
250
251 uint32_t index;
252 ScopedComPtr<InterfaceType> handle;
253 EventRegistrationToken token_MessageReceived;
254
255 private:
256 DISALLOW_COPY_AND_ASSIGN(MidiPort);
257};
258
259} // namespace
260
261template <typename InterfaceType,
262 typename RuntimeType,
263 typename StaticsInterfaceType,
264 base::char16 const* runtime_class_id>
265class MidiManagerWinrt::MidiPortManager {
266 public:
267 // MidiPortManager instances should be constructed on the COM thread.
268 MidiPortManager(MidiManagerWinrt* midi_manager)
269 : midi_manager_(midi_manager),
270 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
271
272 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
273
274 bool StartWatcher() {
275 DCHECK(thread_checker_.CalledOnValidThread());
276
277 HRESULT hr;
278
279 midi_port_statics_ =
280 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
281 if (!midi_port_statics_)
282 return false;
283
284 HSTRING device_selector = nullptr;
285 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
286 if (FAILED(hr)) {
287 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
288 return false;
289 }
290
291 auto dev_info_statics = WrlStaticsFactory<
292 IDeviceInformationStatics,
293 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
294 if (!dev_info_statics)
295 return false;
296
297 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
298 watcher_.Receive());
299 if (FAILED(hr)) {
300 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
301 return false;
302 }
303
304 // Register callbacks to WinRT that post state-modifying jobs back to COM
305 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
306 // posting jobs. Note that WinRT callback arguments should not be passed
307 // outside the callback since the pointers may be unavailable afterwards.
308 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
309 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
310
311 hr = watcher_->add_Added(
312 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
313 [weak_ptr, task_runner](IDeviceWatcher* watcher,
314 IDeviceInformation* info) {
shaochuan110262b2016-08-31 02:15:16 -0700315 // Disable Microsoft GS Wavetable Synth due to security reasons.
316 // http://crbug.com/499279
317 if (IsMicrosoftSynthesizer(info))
318 return S_OK;
319
shaochuane58f9c72016-08-30 22:27:08 -0700320 std::string dev_id = GetIdString(info),
321 dev_name = GetNameString(info);
322
323 task_runner->PostTask(
324 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
325 dev_id, dev_name));
326
327 return S_OK;
328 })
329 .Get(),
330 &token_Added_);
331 if (FAILED(hr)) {
332 VLOG(1) << "add_Added failed: " << PrintHr(hr);
333 return false;
334 }
335
336 hr = watcher_->add_EnumerationCompleted(
337 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
338 [weak_ptr, task_runner](IDeviceWatcher* watcher,
339 IInspectable* insp) {
340 task_runner->PostTask(
341 FROM_HERE,
342 base::Bind(&MidiPortManager::OnEnumerationCompleted,
343 weak_ptr));
344
345 return S_OK;
346 })
347 .Get(),
348 &token_EnumerationCompleted_);
349 if (FAILED(hr)) {
350 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
351 return false;
352 }
353
354 hr = watcher_->add_Removed(
355 WRL::Callback<
356 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
357 [weak_ptr, task_runner](IDeviceWatcher* watcher,
358 IDeviceInformationUpdate* update) {
359 std::string dev_id = GetIdString(update);
360
361 task_runner->PostTask(
362 FROM_HERE,
363 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
364
365 return S_OK;
366 })
367 .Get(),
368 &token_Removed_);
369 if (FAILED(hr)) {
370 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
371 return false;
372 }
373
374 hr = watcher_->add_Stopped(
375 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
376 [](IDeviceWatcher* watcher, IInspectable* insp) {
377 // Placeholder, does nothing for now.
378 return S_OK;
379 })
380 .Get(),
381 &token_Stopped_);
382 if (FAILED(hr)) {
383 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
384 return false;
385 }
386
387 hr = watcher_->add_Updated(
388 WRL::Callback<
389 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
390 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
391 // TODO(shaochuan): Check for fields to be updated here.
392 return S_OK;
393 })
394 .Get(),
395 &token_Updated_);
396 if (FAILED(hr)) {
397 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
398 return false;
399 }
400
401 hr = watcher_->Start();
402 if (FAILED(hr)) {
403 VLOG(1) << "Start failed: " << PrintHr(hr);
404 return false;
405 }
406
407 is_initialized_ = true;
408 return true;
409 }
410
411 void StopWatcher() {
412 DCHECK(thread_checker_.CalledOnValidThread());
413
414 HRESULT hr;
415
416 for (const auto& entry : ports_)
417 RemovePortEventHandlers(entry.second.get());
418
419 if (token_Added_.value != kInvalidTokenValue) {
420 hr = watcher_->remove_Added(token_Added_);
421 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
422 token_Added_.value = kInvalidTokenValue;
423 }
424 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
425 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
426 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
427 << PrintHr(hr);
428 token_EnumerationCompleted_.value = kInvalidTokenValue;
429 }
430 if (token_Removed_.value != kInvalidTokenValue) {
431 hr = watcher_->remove_Removed(token_Removed_);
432 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
433 token_Removed_.value = kInvalidTokenValue;
434 }
435 if (token_Stopped_.value != kInvalidTokenValue) {
436 hr = watcher_->remove_Stopped(token_Stopped_);
437 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
438 token_Stopped_.value = kInvalidTokenValue;
439 }
440 if (token_Updated_.value != kInvalidTokenValue) {
441 hr = watcher_->remove_Updated(token_Updated_);
442 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
443 token_Updated_.value = kInvalidTokenValue;
444 }
445
446 if (is_initialized_) {
447 hr = watcher_->Stop();
448 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
449 is_initialized_ = false;
450 }
451 }
452
453 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
454 DCHECK(thread_checker_.CalledOnValidThread());
455 CHECK(is_initialized_);
456
457 auto it = ports_.find(dev_id);
458 if (it == ports_.end())
459 return nullptr;
460 return it->second.get();
461 }
462
463 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
464 DCHECK(thread_checker_.CalledOnValidThread());
465 CHECK(is_initialized_);
466
467 return GetPortByDeviceId(port_ids_[port_index]);
468 }
469
470 protected:
471 // Points to the MidiManagerWinrt instance, which is expected to outlive the
472 // MidiPortManager instance.
473 MidiManagerWinrt* midi_manager_;
474
475 // Task runner of the COM thread.
476 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
477
478 // Ensures all methods are called on the COM thread.
479 base::ThreadChecker thread_checker_;
480
481 private:
482 // DeviceWatcher callbacks:
483 void OnAdded(std::string dev_id, std::string dev_name) {
484 DCHECK(thread_checker_.CalledOnValidThread());
485 CHECK(is_initialized_);
486
shaochuane58f9c72016-08-30 22:27:08 -0700487 port_names_[dev_id] = dev_name;
488
shaochuan9ff63b82016-09-01 01:58:44 -0700489 ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str());
490 if (!dev_id_hstring.is_valid())
shaochuane58f9c72016-08-30 22:27:08 -0700491 return;
shaochuane58f9c72016-08-30 22:27:08 -0700492
493 IAsyncOperation<RuntimeType*>* async_op;
494
shaochuan9ff63b82016-09-01 01:58:44 -0700495 HRESULT hr =
496 midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op);
shaochuane58f9c72016-08-30 22:27:08 -0700497 if (FAILED(hr)) {
498 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
499 return;
500 }
501
502 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
503 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
504
505 hr = async_op->put_Completed(
506 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
507 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
508 AsyncStatus status) {
509 InterfaceType* handle;
510 HRESULT hr = async_op->GetResults(&handle);
511 if (FAILED(hr)) {
512 VLOG(1) << "GetResults failed: " << PrintHr(hr);
513 return hr;
514 }
515
516 // A reference to |async_op| is kept in |async_ops_|, safe to pass
517 // outside.
518 task_runner->PostTask(
519 FROM_HERE,
520 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
521 weak_ptr, handle, async_op));
522
523 return S_OK;
524 })
525 .Get());
526 if (FAILED(hr)) {
527 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
528 return;
529 }
530
531 // Keep a reference to incompleted |async_op| for releasing later.
532 async_ops_.insert(async_op);
533 }
534
535 void OnEnumerationCompleted() {
536 DCHECK(thread_checker_.CalledOnValidThread());
537 CHECK(is_initialized_);
538
539 if (async_ops_.empty())
540 midi_manager_->OnPortManagerReady();
541 else
542 enumeration_completed_not_ready_ = true;
543 }
544
545 void OnRemoved(std::string dev_id) {
546 DCHECK(thread_checker_.CalledOnValidThread());
547 CHECK(is_initialized_);
548
shaochuan110262b2016-08-31 02:15:16 -0700549 // Note: in case Microsoft GS Wavetable Synth triggers this event for some
550 // reason, it will be ignored here with log emitted.
shaochuane58f9c72016-08-30 22:27:08 -0700551 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
552 if (!port) {
553 VLOG(1) << "Removing non-existent port " << dev_id;
554 return;
555 }
556
557 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
558
559 RemovePortEventHandlers(port);
560 port->handle = nullptr;
561 }
562
563 void OnCompletedGetPortFromIdAsync(InterfaceType* handle,
564 IAsyncOperation<RuntimeType*>* async_op) {
565 DCHECK(thread_checker_.CalledOnValidThread());
566 CHECK(is_initialized_);
567
568 EventRegistrationToken token = {kInvalidTokenValue};
569 if (!RegisterOnMessageReceived(handle, &token))
570 return;
571
572 std::string dev_id = GetDeviceIdString(handle);
573
574 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
575
576 if (port == nullptr) {
577 // TODO(crbug.com/642604): Fill in manufacturer and driver version.
578 AddPort(MidiPortInfo(dev_id, std::string("Manufacturer"),
579 port_names_[dev_id], std::string("DriverVersion"),
580 MIDI_PORT_OPENED));
581
582 port = new MidiPort<InterfaceType>;
583 port->index = static_cast<uint32_t>(port_ids_.size());
584
585 ports_[dev_id].reset(port);
586 port_ids_.push_back(dev_id);
587 } else {
588 SetPortState(port->index, MIDI_PORT_CONNECTED);
589 }
590
591 port->handle = handle;
592 port->token_MessageReceived = token;
593
594 // Manually release COM interface to completed |async_op|.
595 auto it = async_ops_.find(async_op);
596 CHECK(it != async_ops_.end());
597 (*it)->Release();
598 async_ops_.erase(it);
599
600 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
601 midi_manager_->OnPortManagerReady();
602 enumeration_completed_not_ready_ = false;
603 }
604 }
605
606 // Overrided by MidiInPortManager to listen to input ports.
607 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
608 EventRegistrationToken* p_token) {
609 return true;
610 }
611
612 // Overrided by MidiInPortManager to remove MessageReceived event handler.
613 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
614
615 // Calls midi_manager_->Add{Input,Output}Port.
616 virtual void AddPort(MidiPortInfo info) = 0;
617
618 // Calls midi_manager_->Set{Input,Output}PortState.
619 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
620
621 // WeakPtrFactory has to be declared in derived class, use this method to
622 // retrieve upcasted WeakPtr for posting tasks.
623 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
624
625 // Midi{In,Out}PortStatics instance.
626 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
627
628 // DeviceWatcher instance and event registration tokens for unsubscribing
629 // events in destructor.
630 ScopedComPtr<IDeviceWatcher> watcher_;
631 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
632 token_EnumerationCompleted_ = {kInvalidTokenValue},
633 token_Removed_ = {kInvalidTokenValue},
634 token_Stopped_ = {kInvalidTokenValue},
635 token_Updated_ = {kInvalidTokenValue};
636
637 // All manipulations to these fields should be done on COM thread.
638 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
639 ports_;
640 std::vector<std::string> port_ids_;
641 std::unordered_map<std::string, std::string> port_names_;
642
643 // Keeps AsyncOperation references before the operation completes. Note that
644 // raw pointers are used here and the COM interfaces should be released
645 // manually.
646 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
647
648 // Set when device enumeration is completed but OnPortManagerReady() is not
649 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
650 // In such cases, OnPortManagerReady() will be called in
651 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
652 bool enumeration_completed_not_ready_ = false;
653
654 // Set if the instance is initialized without error. Should be checked in all
655 // methods on COM thread except StartWatcher().
656 bool is_initialized_ = false;
657};
658
659class MidiManagerWinrt::MidiInPortManager final
660 : public MidiPortManager<IMidiInPort,
661 MidiInPort,
662 IMidiInPortStatics,
663 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
664 public:
665 MidiInPortManager(MidiManagerWinrt* midi_manager)
666 : MidiPortManager(midi_manager), weak_factory_(this) {}
667
668 private:
669 // MidiPortManager overrides:
670 bool RegisterOnMessageReceived(IMidiInPort* handle,
671 EventRegistrationToken* p_token) override {
672 DCHECK(thread_checker_.CalledOnValidThread());
673
674 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
675 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
676
677 HRESULT hr = handle->add_MessageReceived(
678 WRL::Callback<
679 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
680 [weak_ptr, task_runner](IMidiInPort* handle,
681 IMidiMessageReceivedEventArgs* args) {
682 const base::TimeTicks now = base::TimeTicks::Now();
683
684 std::string dev_id = GetDeviceIdString(handle);
685
686 ScopedComPtr<IMidiMessage> message;
687 HRESULT hr = args->get_Message(message.Receive());
688 if (FAILED(hr)) {
689 VLOG(1) << "get_Message failed: " << PrintHr(hr);
690 return hr;
691 }
692
693 ScopedComPtr<IBuffer> buffer;
694 hr = message->get_RawData(buffer.Receive());
695 if (FAILED(hr)) {
696 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
697 return hr;
698 }
699
700 uint8_t* p_buffer_data = nullptr;
701 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
702 if (FAILED(hr))
703 return hr;
704
705 uint32_t data_length = 0;
706 hr = buffer->get_Length(&data_length);
707 if (FAILED(hr)) {
708 VLOG(1) << "get_Length failed: " << PrintHr(hr);
709 return hr;
710 }
711
712 std::vector<uint8_t> data(p_buffer_data,
713 p_buffer_data + data_length);
714
715 task_runner->PostTask(
716 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
717 weak_ptr, dev_id, data, now));
718
719 return S_OK;
720 })
721 .Get(),
722 p_token);
723 if (FAILED(hr)) {
724 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
725 return false;
726 }
727
728 return true;
729 }
730
731 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
732 if (!(port->handle &&
733 port->token_MessageReceived.value != kInvalidTokenValue))
734 return;
735
736 HRESULT hr =
737 port->handle->remove_MessageReceived(port->token_MessageReceived);
738 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
739 port->token_MessageReceived.value = kInvalidTokenValue;
740 }
741
742 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
743
744 void SetPortState(uint32_t port_index, MidiPortState state) final {
745 midi_manager_->SetInputPortState(port_index, state);
746 }
747
748 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
749 DCHECK(thread_checker_.CalledOnValidThread());
750
751 return weak_factory_.GetWeakPtr();
752 }
753
754 // Callback on receiving MIDI input message.
755 void OnMessageReceived(std::string dev_id,
756 std::vector<uint8_t> data,
757 base::TimeTicks time) {
758 DCHECK(thread_checker_.CalledOnValidThread());
759
760 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
761 CHECK(port);
762
763 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
764 }
765
766 // Last member to ensure destructed first.
767 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
768
769 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
770};
771
772class MidiManagerWinrt::MidiOutPortManager final
773 : public MidiPortManager<IMidiOutPort,
774 IMidiOutPort,
775 IMidiOutPortStatics,
776 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
777 public:
778 MidiOutPortManager(MidiManagerWinrt* midi_manager)
779 : MidiPortManager(midi_manager), weak_factory_(this) {}
780
781 private:
782 // MidiPortManager overrides:
783 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
784
785 void SetPortState(uint32_t port_index, MidiPortState state) final {
786 midi_manager_->SetOutputPortState(port_index, state);
787 }
788
789 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
790 DCHECK(thread_checker_.CalledOnValidThread());
791
792 return weak_factory_.GetWeakPtr();
793 }
794
795 // Last member to ensure destructed first.
796 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
797
798 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
799};
800
801MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
802
803MidiManagerWinrt::~MidiManagerWinrt() {
804 base::AutoLock auto_lock(lazy_init_member_lock_);
805
806 CHECK(!com_thread_checker_);
807 CHECK(!port_manager_in_);
808 CHECK(!port_manager_out_);
809 CHECK(!scheduler_);
810}
811
812void MidiManagerWinrt::StartInitialization() {
813 if (base::win::GetVersion() < base::win::VERSION_WIN10) {
814 VLOG(1) << "WinRT MIDI backend is only supported on Windows 10 or later.";
815 CompleteInitialization(Result::INITIALIZATION_ERROR);
816 return;
817 }
818
819 com_thread_.init_com_with_mta(true);
820 com_thread_.Start();
821
822 com_thread_.task_runner()->PostTask(
823 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
824 base::Unretained(this)));
825}
826
827void MidiManagerWinrt::Finalize() {
828 com_thread_.task_runner()->PostTask(
829 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
830 base::Unretained(this)));
831
832 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
833 // will be ignored.
834 com_thread_.Stop();
835}
836
837void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
838 uint32_t port_index,
839 const std::vector<uint8_t>& data,
840 double timestamp) {
841 CHECK(scheduler_);
842
843 scheduler_->PostSendDataTask(
844 client, data.size(), timestamp,
845 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
846 port_index, data));
847}
848
849void MidiManagerWinrt::InitializeOnComThread() {
850 base::AutoLock auto_lock(lazy_init_member_lock_);
851
852 com_thread_checker_.reset(new base::ThreadChecker);
853
shaochuan9ff63b82016-09-01 01:58:44 -0700854 if (!g_combase_functions.Get().LoadFunctions()) {
855 VLOG(1) << "Failed loading functions from combase.dll: "
856 << PrintHr(HRESULT_FROM_WIN32(GetLastError()));
857 CompleteInitialization(Result::INITIALIZATION_ERROR);
858 return;
859 }
860
shaochuane58f9c72016-08-30 22:27:08 -0700861 port_manager_in_.reset(new MidiInPortManager(this));
862 port_manager_out_.reset(new MidiOutPortManager(this));
863
864 scheduler_.reset(new MidiScheduler(this));
865
866 if (!(port_manager_in_->StartWatcher() &&
867 port_manager_out_->StartWatcher())) {
868 port_manager_in_->StopWatcher();
869 port_manager_out_->StopWatcher();
870 CompleteInitialization(Result::INITIALIZATION_ERROR);
871 }
872}
873
874void MidiManagerWinrt::FinalizeOnComThread() {
875 base::AutoLock auto_lock(lazy_init_member_lock_);
876
877 DCHECK(com_thread_checker_->CalledOnValidThread());
878
879 scheduler_.reset();
880
shaochuan9ff63b82016-09-01 01:58:44 -0700881 if (port_manager_in_) {
882 port_manager_in_->StopWatcher();
883 port_manager_in_.reset();
884 }
885
886 if (port_manager_out_) {
887 port_manager_out_->StopWatcher();
888 port_manager_out_.reset();
889 }
shaochuane58f9c72016-08-30 22:27:08 -0700890
891 com_thread_checker_.reset();
892}
893
894void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
895 const std::vector<uint8_t>& data) {
896 DCHECK(com_thread_checker_->CalledOnValidThread());
897
898 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
899 if (!(port && port->handle)) {
900 VLOG(1) << "Port not available: " << port_index;
901 return;
902 }
903
904 auto buffer_factory =
905 WrlStaticsFactory<IBufferFactory,
906 RuntimeClass_Windows_Storage_Streams_Buffer>();
907 if (!buffer_factory)
908 return;
909
910 ScopedComPtr<IBuffer> buffer;
911 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
912 buffer.Receive());
913 if (FAILED(hr)) {
914 VLOG(1) << "Create failed: " << PrintHr(hr);
915 return;
916 }
917
918 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
919 if (FAILED(hr)) {
920 VLOG(1) << "put_Length failed: " << PrintHr(hr);
921 return;
922 }
923
924 uint8_t* p_buffer_data = nullptr;
925 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
926 if (FAILED(hr))
927 return;
928
929 std::copy(data.begin(), data.end(), p_buffer_data);
930
931 hr = port->handle->SendBuffer(buffer.get());
932 if (FAILED(hr)) {
933 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
934 return;
935 }
936}
937
938void MidiManagerWinrt::OnPortManagerReady() {
939 DCHECK(com_thread_checker_->CalledOnValidThread());
940 DCHECK(port_manager_ready_count_ < 2);
941
942 if (++port_manager_ready_count_ == 2)
943 CompleteInitialization(Result::OK);
944}
945
946MidiManager* MidiManager::Create() {
947 return new MidiManagerWinrt();
948}
949
950} // namespace midi
951} // namespace media