blob: 11841cbde431cb09c8c6a70dee481c8fe16a081d [file] [log] [blame]
shaochuane58f9c72016-08-30 22:27:08 -07001// Copyright 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/midi/midi_manager_winrt.h"
6
7#include <comdef.h>
8#include <robuffer.h>
9#include <windows.devices.enumeration.h>
10#include <windows.devices.midi.h>
11#include <wrl/event.h>
12
13#include <iomanip>
14#include <unordered_map>
15#include <unordered_set>
16
17#include "base/bind.h"
18#include "base/strings/utf_string_conversions.h"
19#include "base/threading/thread_checker.h"
20#include "base/threading/thread_task_runner_handle.h"
21#include "base/timer/timer.h"
22#include "base/win/scoped_comptr.h"
23#include "base/win/windows_version.h"
24#include "media/midi/midi_scheduler.h"
25
26namespace media {
27namespace midi {
28namespace {
29
30namespace WRL = Microsoft::WRL;
31
32using namespace ABI::Windows::Devices::Enumeration;
33using namespace ABI::Windows::Devices::Midi;
34using namespace ABI::Windows::Foundation;
35using namespace ABI::Windows::Storage::Streams;
36
37using base::win::ScopedComPtr;
38
39// Helpers for printing HRESULTs.
40struct PrintHr {
41 PrintHr(HRESULT hr) : hr(hr) {}
42 HRESULT hr;
43};
44
45std::ostream& operator<<(std::ostream& os, const PrintHr& phr) {
46 std::ios_base::fmtflags ff = os.flags();
47 os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex
48 << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")";
49 os.flags(ff);
50 return os;
51}
52
53// Factory functions that activate and create WinRT components. The caller takes
54// ownership of the returning ComPtr.
55template <typename InterfaceType, base::char16 const* runtime_class_id>
56ScopedComPtr<InterfaceType> WrlStaticsFactory() {
57 ScopedComPtr<InterfaceType> com_ptr;
58
59 HRESULT hr = GetActivationFactory(
60 WRL::Wrappers::HStringReference(runtime_class_id).Get(),
61 com_ptr.Receive());
62 if (FAILED(hr)) {
63 VLOG(1) << "GetActivationFactory failed: " << PrintHr(hr);
64 com_ptr = nullptr;
65 }
66
67 return com_ptr;
68}
69
70template <typename T, HRESULT (T::*method)(HSTRING*)>
71std::string GetStringFromObjectMethod(T* obj) {
72 HSTRING result;
73 HRESULT hr = (obj->*method)(&result);
74 if (FAILED(hr)) {
75 VLOG(1) << "GetStringFromObjectMethod failed: " << PrintHr(hr);
76 return std::string();
77 }
78
79 // Note: empty HSTRINGs are represent as nullptr, and instantiating
80 // std::string with nullptr (in base::WideToUTF8) is undefined behavior.
81 const base::char16* buffer = WindowsGetStringRawBuffer(result, nullptr);
82 if (buffer)
83 return base::WideToUTF8(buffer);
84 return std::string();
85}
86
87template <typename T>
88std::string GetIdString(T* obj) {
89 return GetStringFromObjectMethod<T, &T::get_Id>(obj);
90}
91
92template <typename T>
93std::string GetDeviceIdString(T* obj) {
94 return GetStringFromObjectMethod<T, &T::get_DeviceId>(obj);
95}
96
97std::string GetNameString(IDeviceInformation* info) {
98 return GetStringFromObjectMethod<IDeviceInformation,
99 &IDeviceInformation::get_Name>(info);
100}
101
102HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) {
103 ScopedComPtr<Windows::Storage::Streams::IBufferByteAccess> buffer_byte_access;
104
105 HRESULT hr = buffer_byte_access.QueryFrom(buffer);
106 if (FAILED(hr)) {
107 VLOG(1) << "QueryInterface failed: " << PrintHr(hr);
108 return hr;
109 }
110
111 // Lifetime of the pointing buffer is controlled by the buffer object.
112 hr = buffer_byte_access->Buffer(out);
113 if (FAILED(hr)) {
114 VLOG(1) << "Buffer failed: " << PrintHr(hr);
115 return hr;
116 }
117
118 return S_OK;
119}
120
121// Tokens with value = 0 are considered invalid (as in <wrl/event.h>).
122const int64_t kInvalidTokenValue = 0;
123
124template <typename InterfaceType>
125struct MidiPort {
126 MidiPort() = default;
127
128 uint32_t index;
129 ScopedComPtr<InterfaceType> handle;
130 EventRegistrationToken token_MessageReceived;
131
132 private:
133 DISALLOW_COPY_AND_ASSIGN(MidiPort);
134};
135
136} // namespace
137
138template <typename InterfaceType,
139 typename RuntimeType,
140 typename StaticsInterfaceType,
141 base::char16 const* runtime_class_id>
142class MidiManagerWinrt::MidiPortManager {
143 public:
144 // MidiPortManager instances should be constructed on the COM thread.
145 MidiPortManager(MidiManagerWinrt* midi_manager)
146 : midi_manager_(midi_manager),
147 task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
148
149 virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); }
150
151 bool StartWatcher() {
152 DCHECK(thread_checker_.CalledOnValidThread());
153
154 HRESULT hr;
155
156 midi_port_statics_ =
157 WrlStaticsFactory<StaticsInterfaceType, runtime_class_id>();
158 if (!midi_port_statics_)
159 return false;
160
161 HSTRING device_selector = nullptr;
162 hr = midi_port_statics_->GetDeviceSelector(&device_selector);
163 if (FAILED(hr)) {
164 VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr);
165 return false;
166 }
167
168 auto dev_info_statics = WrlStaticsFactory<
169 IDeviceInformationStatics,
170 RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>();
171 if (!dev_info_statics)
172 return false;
173
174 hr = dev_info_statics->CreateWatcherAqsFilter(device_selector,
175 watcher_.Receive());
176 if (FAILED(hr)) {
177 VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr);
178 return false;
179 }
180
181 // Register callbacks to WinRT that post state-modifying jobs back to COM
182 // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for
183 // posting jobs. Note that WinRT callback arguments should not be passed
184 // outside the callback since the pointers may be unavailable afterwards.
185 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
186 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
187
188 hr = watcher_->add_Added(
189 WRL::Callback<ITypedEventHandler<DeviceWatcher*, DeviceInformation*>>(
190 [weak_ptr, task_runner](IDeviceWatcher* watcher,
191 IDeviceInformation* info) {
192 std::string dev_id = GetIdString(info),
193 dev_name = GetNameString(info);
194
195 task_runner->PostTask(
196 FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr,
197 dev_id, dev_name));
198
199 return S_OK;
200 })
201 .Get(),
202 &token_Added_);
203 if (FAILED(hr)) {
204 VLOG(1) << "add_Added failed: " << PrintHr(hr);
205 return false;
206 }
207
208 hr = watcher_->add_EnumerationCompleted(
209 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
210 [weak_ptr, task_runner](IDeviceWatcher* watcher,
211 IInspectable* insp) {
212 task_runner->PostTask(
213 FROM_HERE,
214 base::Bind(&MidiPortManager::OnEnumerationCompleted,
215 weak_ptr));
216
217 return S_OK;
218 })
219 .Get(),
220 &token_EnumerationCompleted_);
221 if (FAILED(hr)) {
222 VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr);
223 return false;
224 }
225
226 hr = watcher_->add_Removed(
227 WRL::Callback<
228 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
229 [weak_ptr, task_runner](IDeviceWatcher* watcher,
230 IDeviceInformationUpdate* update) {
231 std::string dev_id = GetIdString(update);
232
233 task_runner->PostTask(
234 FROM_HERE,
235 base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id));
236
237 return S_OK;
238 })
239 .Get(),
240 &token_Removed_);
241 if (FAILED(hr)) {
242 VLOG(1) << "add_Removed failed: " << PrintHr(hr);
243 return false;
244 }
245
246 hr = watcher_->add_Stopped(
247 WRL::Callback<ITypedEventHandler<DeviceWatcher*, IInspectable*>>(
248 [](IDeviceWatcher* watcher, IInspectable* insp) {
249 // Placeholder, does nothing for now.
250 return S_OK;
251 })
252 .Get(),
253 &token_Stopped_);
254 if (FAILED(hr)) {
255 VLOG(1) << "add_Stopped failed: " << PrintHr(hr);
256 return false;
257 }
258
259 hr = watcher_->add_Updated(
260 WRL::Callback<
261 ITypedEventHandler<DeviceWatcher*, DeviceInformationUpdate*>>(
262 [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) {
263 // TODO(shaochuan): Check for fields to be updated here.
264 return S_OK;
265 })
266 .Get(),
267 &token_Updated_);
268 if (FAILED(hr)) {
269 VLOG(1) << "add_Updated failed: " << PrintHr(hr);
270 return false;
271 }
272
273 hr = watcher_->Start();
274 if (FAILED(hr)) {
275 VLOG(1) << "Start failed: " << PrintHr(hr);
276 return false;
277 }
278
279 is_initialized_ = true;
280 return true;
281 }
282
283 void StopWatcher() {
284 DCHECK(thread_checker_.CalledOnValidThread());
285
286 HRESULT hr;
287
288 for (const auto& entry : ports_)
289 RemovePortEventHandlers(entry.second.get());
290
291 if (token_Added_.value != kInvalidTokenValue) {
292 hr = watcher_->remove_Added(token_Added_);
293 VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr);
294 token_Added_.value = kInvalidTokenValue;
295 }
296 if (token_EnumerationCompleted_.value != kInvalidTokenValue) {
297 hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_);
298 VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: "
299 << PrintHr(hr);
300 token_EnumerationCompleted_.value = kInvalidTokenValue;
301 }
302 if (token_Removed_.value != kInvalidTokenValue) {
303 hr = watcher_->remove_Removed(token_Removed_);
304 VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr);
305 token_Removed_.value = kInvalidTokenValue;
306 }
307 if (token_Stopped_.value != kInvalidTokenValue) {
308 hr = watcher_->remove_Stopped(token_Stopped_);
309 VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr);
310 token_Stopped_.value = kInvalidTokenValue;
311 }
312 if (token_Updated_.value != kInvalidTokenValue) {
313 hr = watcher_->remove_Updated(token_Updated_);
314 VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr);
315 token_Updated_.value = kInvalidTokenValue;
316 }
317
318 if (is_initialized_) {
319 hr = watcher_->Stop();
320 VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr);
321 is_initialized_ = false;
322 }
323 }
324
325 MidiPort<InterfaceType>* GetPortByDeviceId(std::string dev_id) {
326 DCHECK(thread_checker_.CalledOnValidThread());
327 CHECK(is_initialized_);
328
329 auto it = ports_.find(dev_id);
330 if (it == ports_.end())
331 return nullptr;
332 return it->second.get();
333 }
334
335 MidiPort<InterfaceType>* GetPortByIndex(uint32_t port_index) {
336 DCHECK(thread_checker_.CalledOnValidThread());
337 CHECK(is_initialized_);
338
339 return GetPortByDeviceId(port_ids_[port_index]);
340 }
341
342 protected:
343 // Points to the MidiManagerWinrt instance, which is expected to outlive the
344 // MidiPortManager instance.
345 MidiManagerWinrt* midi_manager_;
346
347 // Task runner of the COM thread.
348 scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
349
350 // Ensures all methods are called on the COM thread.
351 base::ThreadChecker thread_checker_;
352
353 private:
354 // DeviceWatcher callbacks:
355 void OnAdded(std::string dev_id, std::string dev_name) {
356 DCHECK(thread_checker_.CalledOnValidThread());
357 CHECK(is_initialized_);
358
359 // TODO(shaochuan): Disable Microsoft GS Wavetable Synth due to security
360 // reasons. http://crbug.com/499279
361
362 port_names_[dev_id] = dev_name;
363
364 WRL::Wrappers::HString dev_id_hstring;
365 HRESULT hr = dev_id_hstring.Set(base::UTF8ToWide(dev_id).c_str());
366 if (FAILED(hr)) {
367 VLOG(1) << "Set failed: " << PrintHr(hr);
368 return;
369 }
370
371 IAsyncOperation<RuntimeType*>* async_op;
372
373 hr = midi_port_statics_->FromIdAsync(dev_id_hstring.Get(), &async_op);
374 if (FAILED(hr)) {
375 VLOG(1) << "FromIdAsync failed: " << PrintHr(hr);
376 return;
377 }
378
379 base::WeakPtr<MidiPortManager> weak_ptr = GetWeakPtrFromFactory();
380 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
381
382 hr = async_op->put_Completed(
383 WRL::Callback<IAsyncOperationCompletedHandler<RuntimeType*>>(
384 [weak_ptr, task_runner](IAsyncOperation<RuntimeType*>* async_op,
385 AsyncStatus status) {
386 InterfaceType* handle;
387 HRESULT hr = async_op->GetResults(&handle);
388 if (FAILED(hr)) {
389 VLOG(1) << "GetResults failed: " << PrintHr(hr);
390 return hr;
391 }
392
393 // A reference to |async_op| is kept in |async_ops_|, safe to pass
394 // outside.
395 task_runner->PostTask(
396 FROM_HERE,
397 base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync,
398 weak_ptr, handle, async_op));
399
400 return S_OK;
401 })
402 .Get());
403 if (FAILED(hr)) {
404 VLOG(1) << "put_Completed failed: " << PrintHr(hr);
405 return;
406 }
407
408 // Keep a reference to incompleted |async_op| for releasing later.
409 async_ops_.insert(async_op);
410 }
411
412 void OnEnumerationCompleted() {
413 DCHECK(thread_checker_.CalledOnValidThread());
414 CHECK(is_initialized_);
415
416 if (async_ops_.empty())
417 midi_manager_->OnPortManagerReady();
418 else
419 enumeration_completed_not_ready_ = true;
420 }
421
422 void OnRemoved(std::string dev_id) {
423 DCHECK(thread_checker_.CalledOnValidThread());
424 CHECK(is_initialized_);
425
426 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
427 if (!port) {
428 VLOG(1) << "Removing non-existent port " << dev_id;
429 return;
430 }
431
432 SetPortState(port->index, MIDI_PORT_DISCONNECTED);
433
434 RemovePortEventHandlers(port);
435 port->handle = nullptr;
436 }
437
438 void OnCompletedGetPortFromIdAsync(InterfaceType* handle,
439 IAsyncOperation<RuntimeType*>* async_op) {
440 DCHECK(thread_checker_.CalledOnValidThread());
441 CHECK(is_initialized_);
442
443 EventRegistrationToken token = {kInvalidTokenValue};
444 if (!RegisterOnMessageReceived(handle, &token))
445 return;
446
447 std::string dev_id = GetDeviceIdString(handle);
448
449 MidiPort<InterfaceType>* port = GetPortByDeviceId(dev_id);
450
451 if (port == nullptr) {
452 // TODO(crbug.com/642604): Fill in manufacturer and driver version.
453 AddPort(MidiPortInfo(dev_id, std::string("Manufacturer"),
454 port_names_[dev_id], std::string("DriverVersion"),
455 MIDI_PORT_OPENED));
456
457 port = new MidiPort<InterfaceType>;
458 port->index = static_cast<uint32_t>(port_ids_.size());
459
460 ports_[dev_id].reset(port);
461 port_ids_.push_back(dev_id);
462 } else {
463 SetPortState(port->index, MIDI_PORT_CONNECTED);
464 }
465
466 port->handle = handle;
467 port->token_MessageReceived = token;
468
469 // Manually release COM interface to completed |async_op|.
470 auto it = async_ops_.find(async_op);
471 CHECK(it != async_ops_.end());
472 (*it)->Release();
473 async_ops_.erase(it);
474
475 if (enumeration_completed_not_ready_ && async_ops_.empty()) {
476 midi_manager_->OnPortManagerReady();
477 enumeration_completed_not_ready_ = false;
478 }
479 }
480
481 // Overrided by MidiInPortManager to listen to input ports.
482 virtual bool RegisterOnMessageReceived(InterfaceType* handle,
483 EventRegistrationToken* p_token) {
484 return true;
485 }
486
487 // Overrided by MidiInPortManager to remove MessageReceived event handler.
488 virtual void RemovePortEventHandlers(MidiPort<InterfaceType>* port) {}
489
490 // Calls midi_manager_->Add{Input,Output}Port.
491 virtual void AddPort(MidiPortInfo info) = 0;
492
493 // Calls midi_manager_->Set{Input,Output}PortState.
494 virtual void SetPortState(uint32_t port_index, MidiPortState state) = 0;
495
496 // WeakPtrFactory has to be declared in derived class, use this method to
497 // retrieve upcasted WeakPtr for posting tasks.
498 virtual base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() = 0;
499
500 // Midi{In,Out}PortStatics instance.
501 ScopedComPtr<StaticsInterfaceType> midi_port_statics_;
502
503 // DeviceWatcher instance and event registration tokens for unsubscribing
504 // events in destructor.
505 ScopedComPtr<IDeviceWatcher> watcher_;
506 EventRegistrationToken token_Added_ = {kInvalidTokenValue},
507 token_EnumerationCompleted_ = {kInvalidTokenValue},
508 token_Removed_ = {kInvalidTokenValue},
509 token_Stopped_ = {kInvalidTokenValue},
510 token_Updated_ = {kInvalidTokenValue};
511
512 // All manipulations to these fields should be done on COM thread.
513 std::unordered_map<std::string, std::unique_ptr<MidiPort<InterfaceType>>>
514 ports_;
515 std::vector<std::string> port_ids_;
516 std::unordered_map<std::string, std::string> port_names_;
517
518 // Keeps AsyncOperation references before the operation completes. Note that
519 // raw pointers are used here and the COM interfaces should be released
520 // manually.
521 std::unordered_set<IAsyncOperation<RuntimeType*>*> async_ops_;
522
523 // Set when device enumeration is completed but OnPortManagerReady() is not
524 // called since some ports are not yet ready (i.e. |async_ops_| is not empty).
525 // In such cases, OnPortManagerReady() will be called in
526 // OnCompletedGetPortFromIdAsync() when the last pending port is ready.
527 bool enumeration_completed_not_ready_ = false;
528
529 // Set if the instance is initialized without error. Should be checked in all
530 // methods on COM thread except StartWatcher().
531 bool is_initialized_ = false;
532};
533
534class MidiManagerWinrt::MidiInPortManager final
535 : public MidiPortManager<IMidiInPort,
536 MidiInPort,
537 IMidiInPortStatics,
538 RuntimeClass_Windows_Devices_Midi_MidiInPort> {
539 public:
540 MidiInPortManager(MidiManagerWinrt* midi_manager)
541 : MidiPortManager(midi_manager), weak_factory_(this) {}
542
543 private:
544 // MidiPortManager overrides:
545 bool RegisterOnMessageReceived(IMidiInPort* handle,
546 EventRegistrationToken* p_token) override {
547 DCHECK(thread_checker_.CalledOnValidThread());
548
549 base::WeakPtr<MidiInPortManager> weak_ptr = weak_factory_.GetWeakPtr();
550 scoped_refptr<base::SingleThreadTaskRunner> task_runner = task_runner_;
551
552 HRESULT hr = handle->add_MessageReceived(
553 WRL::Callback<
554 ITypedEventHandler<MidiInPort*, MidiMessageReceivedEventArgs*>>(
555 [weak_ptr, task_runner](IMidiInPort* handle,
556 IMidiMessageReceivedEventArgs* args) {
557 const base::TimeTicks now = base::TimeTicks::Now();
558
559 std::string dev_id = GetDeviceIdString(handle);
560
561 ScopedComPtr<IMidiMessage> message;
562 HRESULT hr = args->get_Message(message.Receive());
563 if (FAILED(hr)) {
564 VLOG(1) << "get_Message failed: " << PrintHr(hr);
565 return hr;
566 }
567
568 ScopedComPtr<IBuffer> buffer;
569 hr = message->get_RawData(buffer.Receive());
570 if (FAILED(hr)) {
571 VLOG(1) << "get_RawData failed: " << PrintHr(hr);
572 return hr;
573 }
574
575 uint8_t* p_buffer_data = nullptr;
576 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
577 if (FAILED(hr))
578 return hr;
579
580 uint32_t data_length = 0;
581 hr = buffer->get_Length(&data_length);
582 if (FAILED(hr)) {
583 VLOG(1) << "get_Length failed: " << PrintHr(hr);
584 return hr;
585 }
586
587 std::vector<uint8_t> data(p_buffer_data,
588 p_buffer_data + data_length);
589
590 task_runner->PostTask(
591 FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived,
592 weak_ptr, dev_id, data, now));
593
594 return S_OK;
595 })
596 .Get(),
597 p_token);
598 if (FAILED(hr)) {
599 VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr);
600 return false;
601 }
602
603 return true;
604 }
605
606 void RemovePortEventHandlers(MidiPort<IMidiInPort>* port) override {
607 if (!(port->handle &&
608 port->token_MessageReceived.value != kInvalidTokenValue))
609 return;
610
611 HRESULT hr =
612 port->handle->remove_MessageReceived(port->token_MessageReceived);
613 VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr);
614 port->token_MessageReceived.value = kInvalidTokenValue;
615 }
616
617 void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); }
618
619 void SetPortState(uint32_t port_index, MidiPortState state) final {
620 midi_manager_->SetInputPortState(port_index, state);
621 }
622
623 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
624 DCHECK(thread_checker_.CalledOnValidThread());
625
626 return weak_factory_.GetWeakPtr();
627 }
628
629 // Callback on receiving MIDI input message.
630 void OnMessageReceived(std::string dev_id,
631 std::vector<uint8_t> data,
632 base::TimeTicks time) {
633 DCHECK(thread_checker_.CalledOnValidThread());
634
635 MidiPort<IMidiInPort>* port = GetPortByDeviceId(dev_id);
636 CHECK(port);
637
638 midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time);
639 }
640
641 // Last member to ensure destructed first.
642 base::WeakPtrFactory<MidiInPortManager> weak_factory_;
643
644 DISALLOW_COPY_AND_ASSIGN(MidiInPortManager);
645};
646
647class MidiManagerWinrt::MidiOutPortManager final
648 : public MidiPortManager<IMidiOutPort,
649 IMidiOutPort,
650 IMidiOutPortStatics,
651 RuntimeClass_Windows_Devices_Midi_MidiOutPort> {
652 public:
653 MidiOutPortManager(MidiManagerWinrt* midi_manager)
654 : MidiPortManager(midi_manager), weak_factory_(this) {}
655
656 private:
657 // MidiPortManager overrides:
658 void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); }
659
660 void SetPortState(uint32_t port_index, MidiPortState state) final {
661 midi_manager_->SetOutputPortState(port_index, state);
662 }
663
664 base::WeakPtr<MidiPortManager> GetWeakPtrFromFactory() final {
665 DCHECK(thread_checker_.CalledOnValidThread());
666
667 return weak_factory_.GetWeakPtr();
668 }
669
670 // Last member to ensure destructed first.
671 base::WeakPtrFactory<MidiOutPortManager> weak_factory_;
672
673 DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager);
674};
675
676MidiManagerWinrt::MidiManagerWinrt() : com_thread_("Windows MIDI COM Thread") {}
677
678MidiManagerWinrt::~MidiManagerWinrt() {
679 base::AutoLock auto_lock(lazy_init_member_lock_);
680
681 CHECK(!com_thread_checker_);
682 CHECK(!port_manager_in_);
683 CHECK(!port_manager_out_);
684 CHECK(!scheduler_);
685}
686
687void MidiManagerWinrt::StartInitialization() {
688 if (base::win::GetVersion() < base::win::VERSION_WIN10) {
689 VLOG(1) << "WinRT MIDI backend is only supported on Windows 10 or later.";
690 CompleteInitialization(Result::INITIALIZATION_ERROR);
691 return;
692 }
693
694 com_thread_.init_com_with_mta(true);
695 com_thread_.Start();
696
697 com_thread_.task_runner()->PostTask(
698 FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread,
699 base::Unretained(this)));
700}
701
702void MidiManagerWinrt::Finalize() {
703 com_thread_.task_runner()->PostTask(
704 FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread,
705 base::Unretained(this)));
706
707 // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks
708 // will be ignored.
709 com_thread_.Stop();
710}
711
712void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client,
713 uint32_t port_index,
714 const std::vector<uint8_t>& data,
715 double timestamp) {
716 CHECK(scheduler_);
717
718 scheduler_->PostSendDataTask(
719 client, data.size(), timestamp,
720 base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this),
721 port_index, data));
722}
723
724void MidiManagerWinrt::InitializeOnComThread() {
725 base::AutoLock auto_lock(lazy_init_member_lock_);
726
727 com_thread_checker_.reset(new base::ThreadChecker);
728
729 port_manager_in_.reset(new MidiInPortManager(this));
730 port_manager_out_.reset(new MidiOutPortManager(this));
731
732 scheduler_.reset(new MidiScheduler(this));
733
734 if (!(port_manager_in_->StartWatcher() &&
735 port_manager_out_->StartWatcher())) {
736 port_manager_in_->StopWatcher();
737 port_manager_out_->StopWatcher();
738 CompleteInitialization(Result::INITIALIZATION_ERROR);
739 }
740}
741
742void MidiManagerWinrt::FinalizeOnComThread() {
743 base::AutoLock auto_lock(lazy_init_member_lock_);
744
745 DCHECK(com_thread_checker_->CalledOnValidThread());
746
747 scheduler_.reset();
748
749 port_manager_in_->StopWatcher();
750 port_manager_in_.reset();
751 port_manager_out_->StopWatcher();
752 port_manager_out_.reset();
753
754 com_thread_checker_.reset();
755}
756
757void MidiManagerWinrt::SendOnComThread(uint32_t port_index,
758 const std::vector<uint8_t>& data) {
759 DCHECK(com_thread_checker_->CalledOnValidThread());
760
761 MidiPort<IMidiOutPort>* port = port_manager_out_->GetPortByIndex(port_index);
762 if (!(port && port->handle)) {
763 VLOG(1) << "Port not available: " << port_index;
764 return;
765 }
766
767 auto buffer_factory =
768 WrlStaticsFactory<IBufferFactory,
769 RuntimeClass_Windows_Storage_Streams_Buffer>();
770 if (!buffer_factory)
771 return;
772
773 ScopedComPtr<IBuffer> buffer;
774 HRESULT hr = buffer_factory->Create(static_cast<UINT32>(data.size()),
775 buffer.Receive());
776 if (FAILED(hr)) {
777 VLOG(1) << "Create failed: " << PrintHr(hr);
778 return;
779 }
780
781 hr = buffer->put_Length(static_cast<UINT32>(data.size()));
782 if (FAILED(hr)) {
783 VLOG(1) << "put_Length failed: " << PrintHr(hr);
784 return;
785 }
786
787 uint8_t* p_buffer_data = nullptr;
788 hr = GetPointerToBufferData(buffer.get(), &p_buffer_data);
789 if (FAILED(hr))
790 return;
791
792 std::copy(data.begin(), data.end(), p_buffer_data);
793
794 hr = port->handle->SendBuffer(buffer.get());
795 if (FAILED(hr)) {
796 VLOG(1) << "SendBuffer failed: " << PrintHr(hr);
797 return;
798 }
799}
800
801void MidiManagerWinrt::OnPortManagerReady() {
802 DCHECK(com_thread_checker_->CalledOnValidThread());
803 DCHECK(port_manager_ready_count_ < 2);
804
805 if (++port_manager_ready_count_ == 2)
806 CompleteInitialization(Result::OK);
807}
808
809MidiManager* MidiManager::Create() {
810 return new MidiManagerWinrt();
811}
812
813} // namespace midi
814} // namespace media