vhost_user_devices: Use async style
This should hopefully simplify some of the control flow and remove some
leaky abstractions.
BUG=none
TEST=`curl www.example.com` inside a vm with vhost-user-net
Cq-Depend: chromium:2893896
Change-Id: Ie22af368a2c0d92297e8a078c695d4015eae92d3
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2891123
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Chirantan Ekbote <chirantan@chromium.org>
Reviewed-by: Keiichi Watanabe <keiichiw@chromium.org>
diff --git a/Cargo.lock b/Cargo.lock
index 87c7e0f..0f436c9 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -709,6 +709,12 @@
]
[[package]]
+name = "once_cell"
+version = "1.7.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af8b08b04175473088b46763e51ee54da5f9a164bc162f615b91bc179dbf15a3"
+
+[[package]]
name = "p9"
version = "0.1.0"
dependencies = [
@@ -1152,11 +1158,14 @@
version = "0.1.0"
dependencies = [
"base",
+ "cros_async",
"data_model",
"devices",
+ "futures",
"getopts",
"libc",
"net_util",
+ "once_cell",
"remain",
"sync",
"tempfile",
diff --git a/vhost_user_devices/Cargo.toml b/vhost_user_devices/Cargo.toml
index 4f76fbf..ee81e7e 100644
--- a/vhost_user_devices/Cargo.toml
+++ b/vhost_user_devices/Cargo.toml
@@ -17,11 +17,13 @@
[dependencies]
base = { path = "../base" }
+cros_async = { path = "../cros_async" }
data_model = { path = "../data_model", optional = true }
devices = { path = "../devices" }
getopts = { version = "0.2", optional = true }
libc = "*"
net_util = { path = "../net_util", optional = true }
+once_cell = "1.7.2"
remain = "*"
sync = { path = "../sync" }
thiserror = "*"
@@ -29,6 +31,10 @@
vm_memory = { path = "../vm_memory" }
vmm_vhost = { version = "*", features = ["vhost-user-slave"] }
+[dependencies.futures]
+version = "*"
+default-features = false
+
[dev-dependencies]
data_model = { path = "../data_model" }
tempfile = { path = "../tempfile" }
diff --git a/vhost_user_devices/src/lib.rs b/vhost_user_devices/src/lib.rs
index 1ee873f..e7d8d7c 100644
--- a/vhost_user_devices/src/lib.rs
+++ b/vhost_user_devices/src/lib.rs
@@ -9,9 +9,10 @@
//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
//!
//! They are expected to be used as follows:
-//! 1. Define a struct which `VhostUserBackend` is implemented for.
-//! 2. Create an instance of `DeviceRequestHandler` with the backend and call its `start()` method
-//! to start an event loop.
+//!
+//! 1. Define a struct and implement `VhostUserBackend` for it.
+//! 2. Create a `DeviceRequestHandler` with the backend struct.
+//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
//!
//! ```ignore
//! struct MyBackend {
@@ -22,30 +23,36 @@
//! /* implement methods */
//! }
//!
-//! fn main() {
+//! fn main() -> Result<(), Box<dyn Error>> {
//! let backend = MyBackend { /* initialize fields */ };
-//! let handler = DeviceRequestHandler::new(backend).unwrap();
+//! let handler = DeviceRequestHandler::new(backend);
//! let socket = std::path::Path("/path/to/socket");
+//! let ex = cros_async::Executor::new()?;
//!
-//! if let Err(e) = handler.start(socket) {
+//! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
//! eprintln!("error happened: {}", e);
//! }
+//! Ok(())
//! }
//! ```
//!
+// Implementation note:
+// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
+// protocol. DeviceRequestHandler implements the VhostUserSlaveReqHandlerMut trait from vmm_vhost,
+// and includes some common code for setting up guest memory and managing partially configured
+// vrings. DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request()
+// when it becomes readable. handle_request() reads and parses the message and then calls one of the
+// VhostUserSlaveReqHandlerMut trait methods. These dispatch back to the supplied VhostUserBackend
+// implementation (this is what our devices implement).
-use std::cell::RefCell;
use std::convert::TryFrom;
use std::num::Wrapping;
use std::os::unix::io::{AsRawFd, RawFd};
use std::path::Path;
-use std::rc::Rc;
use std::sync::Arc;
-use base::{
- error, AsRawDescriptor, Event, EventType, FromRawDescriptor, PollToken, SafeDescriptor,
- SharedMemory, SharedMemoryUnix, WaitContext,
-};
+use base::{error, Event, FromRawDescriptor, SafeDescriptor, SharedMemory, SharedMemoryUnix};
+use cros_async::{AsyncError, AsyncWrapper, Executor};
use devices::virtio::{Queue, SignalableInterrupt};
use remain::sorted;
use thiserror::Error as ThisError;
@@ -82,7 +89,7 @@
fn do_interrupt_resample(&self) {}
}
-/// Keeps a mpaaing from the vmm's virtual addresses to guest addresses.
+/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
/// used to translate messages from the vmm to guest offsets.
#[derive(Default)]
struct MappingInfo {
@@ -104,21 +111,14 @@
pub trait VhostUserBackend
where
Self: Sized,
- Self::EventToken: PollToken + std::fmt::Debug,
Self::Error: std::error::Error + std::fmt::Debug,
{
const MAX_QUEUE_NUM: usize;
- const MAX_VRING_NUM: usize;
-
- /// Types of tokens that can be associated with polling events.
- type EventToken;
+ const MAX_VRING_LEN: u16;
/// Error type specific to this backend.
type Error;
- /// Translates a queue's index into `EventToken`.
- fn index_to_event_type(queue_index: usize) -> Option<Self::EventToken>;
-
/// The set of feature bits that this backend supports.
fn features(&self) -> u64;
@@ -140,30 +140,30 @@
/// Reads this device configuration space at `offset`.
fn read_config(&self, offset: u64, dst: &mut [u8]);
- /// Sets guest memory regions.
- fn set_guest_mem(&mut self, mem: GuestMemory);
-
- /// Returns a backend event to be waited for.
- fn backend_event(&self) -> Option<(&dyn AsRawDescriptor, EventType, Self::EventToken)>;
-
- /// Processes a given event.
- fn handle_event(
+ /// Indicates that the backend should start processing requests for virtio queue number `idx`.
+ /// This method must not block the current thread so device backends should either spawn an
+ /// async task or another thread to handle messages from the Queue.
+ fn start_queue(
&mut self,
- wait_ctx: &Rc<WaitContext<HandlerPollToken<Self>>>,
- event: &Self::EventToken,
- vrings: &[Rc<RefCell<Vring>>],
+ idx: usize,
+ queue: Queue,
+ mem: GuestMemory,
+ call_evt: CallEvent,
+ kick_evt: Event,
) -> std::result::Result<(), Self::Error>;
+ /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
+ fn stop_queue(&mut self, idx: usize);
+
/// Resets the vhost-user backend.
fn reset(&mut self);
}
/// A virtio ring entry.
-pub struct Vring {
- pub queue: Queue,
- pub call_evt: Option<Arc<CallEvent>>,
- pub kick_evt: Option<Event>,
- pub enabled: bool,
+struct Vring {
+ queue: Queue,
+ call_evt: Option<CallEvent>,
+ enabled: bool,
}
impl Vring {
@@ -171,7 +171,6 @@
Self {
queue: Queue::new(max_size),
call_evt: None,
- kick_evt: None,
enabled: false,
}
}
@@ -179,80 +178,51 @@
fn reset(&mut self) {
self.queue.reset();
self.call_evt = None;
- self.kick_evt = None;
self.enabled = false;
}
}
#[sorted]
#[derive(ThisError, Debug)]
-pub enum HandlerError<BackendError: std::error::Error> {
+pub enum HandlerError {
/// Failed to accept an incoming connection.
#[error("failed to accept an incoming connection: {0}")]
AcceptConnection(VhostError),
+ /// Failed to create an async source.
+ #[error("failed to create an async source: {0}")]
+ CreateAsyncSource(AsyncError),
/// Failed to create a connection listener.
#[error("failed to create a connection listener: {0}")]
CreateConnectionListener(VhostError),
/// Failed to create a UNIX domain socket listener.
#[error("failed to create a UNIX domain socket listener: {0}")]
CreateSocketListener(VhostError),
- /// Failed to handle a backend event.
- #[error("failed to handle a backend event: {0}")]
- HandleBackendEvent(BackendError),
/// Failed to handle a vhost-user request.
#[error("failed to handle a vhost-user request: {0}")]
HandleVhostUserRequest(VhostError),
/// Invalid queue index is given.
#[error("invalid queue index is given: {index}")]
InvalidQueueIndex { index: usize },
- /// Failed to add new FD(s) to wait context.
- #[error("failed to add new FD(s) to wait context: {0}")]
- WaitContextAdd(base::Error),
- /// Failed to create a wait context.
- #[error("failed to create a wait context: {0}")]
- WaitContextCreate(base::Error),
- /// Failed to delete a FD from wait context.
- #[error("failed to delete a FD from wait context: {0}")]
- WaitContextDel(base::Error),
- /// Failed to wait for event.
- #[error("failed to wait for an event triggered: {0}")]
- WaitContextWait(base::Error),
+ /// Failed to wait for the handler socket to become readable.
+ #[error("failed to wait for the handler socket to become readable: {0}")]
+ WaitForHandler(AsyncError),
+ /// Failed to wait for the listener socket to become readable.
+ #[error("failed to wait for the listener socket to become readable: {0}")]
+ WaitForListener(AsyncError),
}
-type HandlerResult<B, T> = std::result::Result<T, HandlerError<<B as VhostUserBackend>::Error>>;
-
-#[derive(Debug)]
-pub enum HandlerPollToken<B: VhostUserBackend> {
- BackendToken(B::EventToken),
- VhostUserRequest,
-}
-
-impl<B: VhostUserBackend> PollToken for HandlerPollToken<B> {
- fn as_raw_token(&self) -> u64 {
- match self {
- Self::BackendToken(t) => t.as_raw_token(),
- Self::VhostUserRequest => u64::MAX,
- }
- }
-
- fn from_raw_token(data: u64) -> Self {
- match data {
- u64::MAX => Self::VhostUserRequest,
- _ => Self::BackendToken(B::EventToken::from_raw_token(data)),
- }
- }
-}
+type HandlerResult<T> = std::result::Result<T, HandlerError>;
/// Structure to have an event loop for interaction between a VMM and `VhostUserBackend`.
pub struct DeviceRequestHandler<B>
where
B: 'static + VhostUserBackend,
{
+ vrings: Vec<Vring>,
owned: bool,
- vrings: Vec<Rc<RefCell<Vring>>>,
vmm_maps: Option<Vec<MappingInfo>>,
- backend: Rc<RefCell<B>>,
- wait_ctx: Rc<WaitContext<HandlerPollToken<B>>>,
+ mem: Option<GuestMemory>,
+ backend: B,
}
impl<B> DeviceRequestHandler<B>
@@ -260,88 +230,50 @@
B: 'static + VhostUserBackend,
{
/// Creates the handler instance for `backend`.
- pub fn new(backend: B) -> HandlerResult<B, Self> {
- let mut vrings = Vec::with_capacity(B::MAX_QUEUE_NUM as usize);
+ pub fn new(backend: B) -> Self {
+ let mut vrings = Vec::with_capacity(B::MAX_QUEUE_NUM);
for _ in 0..B::MAX_QUEUE_NUM {
- vrings.push(Rc::new(RefCell::new(Vring::new(B::MAX_VRING_NUM as u16))));
+ vrings.push(Vring::new(B::MAX_VRING_LEN as u16));
}
- let wait_ctx: WaitContext<HandlerPollToken<B>> =
- WaitContext::new().map_err(HandlerError::WaitContextCreate)?;
-
- if let Some((evt, typ, token)) = backend.backend_event() {
- wait_ctx
- .add_for_event(evt, typ, HandlerPollToken::BackendToken(token))
- .map_err(HandlerError::WaitContextAdd)?;
- }
-
- Ok(DeviceRequestHandler {
+ DeviceRequestHandler {
+ vrings,
owned: false,
vmm_maps: None,
- vrings,
- backend: Rc::new(RefCell::new(backend)),
- wait_ctx: Rc::new(wait_ctx),
- })
+ mem: None,
+ backend,
+ }
}
- /// Connects to `socket` and starts an event loop which handles incoming vhost-user requests from
- /// the VMM and events from the backend.
- // TODO(keiichiw): Remove the clippy annotation once we uprev clippy to 1.52.0 or later.
- // cf. https://github.com/rust-lang/rust-clippy/issues/6546
- #[allow(clippy::clippy::result_unit_err)]
- pub fn start<P: AsRef<Path>>(self, socket: P) -> HandlerResult<B, ()> {
- let vrings = self.vrings.clone();
- let backend = self.backend.clone();
- let wait_ctx = self.wait_ctx.clone();
-
- let listener = Listener::new(socket, true).map_err(HandlerError::CreateSocketListener)?;
- let mut s_listener = SlaveListener::new(listener, Arc::new(std::sync::Mutex::new(self)))
- .map_err(HandlerError::CreateConnectionListener)?;
-
- let mut req_handler = s_listener
+ /// Creates a listening socket at `socket` and handles incoming messages from the VMM, which are
+ /// dispatched to the device backend via the `VhostUserBackend` trait methods.
+ pub async fn run<P: AsRef<Path>>(self, socket: P, ex: &Executor) -> HandlerResult<()> {
+ let mut listener = Listener::new(socket, true)
+ .map_err(HandlerError::CreateSocketListener)
+ .and_then(|l| {
+ SlaveListener::new(l, Arc::new(std::sync::Mutex::new(self)))
+ .map_err(HandlerError::CreateConnectionListener)
+ })?;
+ let mut req_handler = listener
.accept()
.map_err(HandlerError::AcceptConnection)?
.expect("no incoming connection was detected");
- let sd = SafeDescriptor::try_from(&req_handler as &dyn AsRawFd)
+ let h = SafeDescriptor::try_from(&req_handler as &dyn AsRawFd)
+ .map(AsyncWrapper::new)
.expect("failed to get safe descriptor for handler");
- wait_ctx
- .add(&sd, HandlerPollToken::VhostUserRequest)
- .map_err(HandlerError::WaitContextAdd)?;
+ let handler_source = ex.async_from(h).map_err(HandlerError::CreateAsyncSource)?;
loop {
- let events = wait_ctx.wait().map_err(HandlerError::WaitContextWait)?;
- for event in events.iter() {
- match &event.token {
- HandlerPollToken::BackendToken(token) => {
- backend
- .borrow_mut()
- .handle_event(&wait_ctx, &token, &vrings)
- .map_err(HandlerError::HandleBackendEvent)?;
- }
- HandlerPollToken::VhostUserRequest => {
- req_handler
- .handle_request()
- .map_err(HandlerError::HandleVhostUserRequest)?;
- }
- }
- }
+ handler_source
+ .wait_readable()
+ .await
+ .map_err(HandlerError::WaitForHandler)?;
+ req_handler
+ .handle_request()
+ .map_err(HandlerError::HandleVhostUserRequest)?;
}
}
-
- fn register_kickfd(&self, index: usize, event: &Event) -> HandlerResult<B, ()> {
- let token =
- B::index_to_event_type(index).ok_or(HandlerError::InvalidQueueIndex { index })?;
- self.wait_ctx
- .add(&event.0, HandlerPollToken::BackendToken(token))
- .map_err(HandlerError::WaitContextAdd)
- }
-
- fn unregister_kickfd(&self, event: &Event) -> HandlerResult<B, ()> {
- self.wait_ctx
- .delete(&event.0)
- .map_err(HandlerError::WaitContextDel)
- }
}
impl<B: VhostUserBackend> VhostUserSlaveReqHandlerMut for DeviceRequestHandler<B> {
@@ -355,12 +287,12 @@
fn reset_owner(&mut self) -> VhostResult<()> {
self.owned = false;
- self.backend.borrow_mut().reset();
+ self.backend.reset();
Ok(())
}
fn get_features(&mut self) -> VhostResult<u64> {
- let features = self.backend.borrow().features();
+ let features = self.backend.features();
Ok(features)
}
@@ -369,11 +301,11 @@
return Err(VhostError::InvalidOperation);
}
- if (features & !(self.backend.borrow().features())) != 0 {
+ if (features & !(self.backend.features())) != 0 {
return Err(VhostError::InvalidParam);
}
- if let Err(e) = self.backend.borrow_mut().ack_features(features) {
+ if let Err(e) = self.backend.ack_features(features) {
error!("failed to acknowledge features 0x{:x}: {}", features, e);
return Err(VhostError::InvalidOperation);
}
@@ -385,22 +317,21 @@
// Client must not pass data to/from the backend until ring is enabled by
// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
// VHOST_USER_SET_VRING_ENABLE with parameter 0.
- let acked_features = self.backend.borrow().acked_features();
+ let acked_features = self.backend.acked_features();
let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0;
for v in &mut self.vrings {
- let mut vring = v.borrow_mut();
- vring.enabled = vring_enabled;
+ v.enabled = vring_enabled;
}
Ok(())
}
fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
- Ok(self.backend.borrow().protocol_features())
+ Ok(self.backend.protocol_features())
}
fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
- if let Err(e) = self.backend.borrow_mut().ack_protocol_features(features) {
+ if let Err(e) = self.backend.ack_protocol_features(features) {
error!("failed to set protocol features 0x{:x}: {}", features, e);
return Err(VhostError::InvalidOperation);
}
@@ -451,8 +382,7 @@
})
.collect();
- self.backend.borrow_mut().set_guest_mem(guest_mem);
-
+ self.mem = Some(guest_mem);
self.vmm_maps = Some(vmm_maps);
Ok(())
}
@@ -462,11 +392,10 @@
}
fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
- if index as usize >= self.vrings.len() || num == 0 || num as usize > B::MAX_VRING_NUM {
+ if index as usize >= self.vrings.len() || num == 0 || num > B::MAX_VRING_LEN.into() {
return Err(VhostError::InvalidParam);
}
- let mut vring = self.vrings[index as usize].borrow_mut();
- vring.queue.size = num as u16;
+ self.vrings[index as usize].queue.size = num as u16;
Ok(())
}
@@ -485,7 +414,7 @@
}
let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
- let mut vring = self.vrings[index as usize].borrow_mut();
+ let vring = &mut self.vrings[index as usize];
vring.queue.desc_table = vmm_va_to_gpa(&vmm_maps, descriptor)?;
vring.queue.avail_ring = vmm_va_to_gpa(&vmm_maps, available)?;
vring.queue.used_ring = vmm_va_to_gpa(&vmm_maps, used)?;
@@ -494,11 +423,11 @@
}
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
- if index as usize >= self.vrings.len() || base as usize >= B::MAX_VRING_NUM {
+ if index as usize >= self.vrings.len() || base >= B::MAX_VRING_LEN.into() {
return Err(VhostError::InvalidParam);
}
- let mut vring = self.vrings[index as usize].borrow_mut();
+ let vring = &mut self.vrings[index as usize];
vring.queue.next_avail = Wrapping(base as u16);
vring.queue.next_used = Wrapping(base as u16);
@@ -515,11 +444,10 @@
// that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE.
- let mut vring = self.vrings[index as usize].borrow_mut();
+ self.backend.stop_queue(index as usize);
+
+ let vring = &mut self.vrings[index as usize];
vring.reset();
- if let Some(kick) = &vring.kick_evt {
- self.unregister_kickfd(kick).expect("unregister_kickfd");
- }
Ok(VhostUserVringState::new(
index,
@@ -541,14 +469,26 @@
VhostError::InvalidParam
})?;
// Safe because the FD is now owned.
- let kick = unsafe { Event::from_raw_descriptor(rd) };
+ let kick_evt = unsafe { Event::from_raw_descriptor(rd) };
- self.register_kickfd(index as usize, &kick)
- .expect("register_kickfd");
-
- let mut vring = self.vrings[index as usize].borrow_mut();
- vring.kick_evt = Some(kick);
+ let vring = &mut self.vrings[index as usize];
vring.queue.ready = true;
+
+ let queue = vring.queue.clone();
+ let call_evt = vring.call_evt.take().ok_or(VhostError::InvalidOperation)?;
+ let mem = self
+ .mem
+ .as_ref()
+ .cloned()
+ .ok_or(VhostError::InvalidOperation)?;
+
+ if let Err(e) = self
+ .backend
+ .start_queue(index as usize, queue, mem, call_evt, kick_evt)
+ {
+ error!("Failed to start queue {}: {}", index, e);
+ return Err(VhostError::SlaveInternalError);
+ }
}
Ok(())
}
@@ -565,7 +505,7 @@
})?;
// Safe because the FD is now owned.
let call = unsafe { Event::from_raw_descriptor(rd) };
- self.vrings[index as usize].borrow_mut().call_evt = Some(Arc::new(CallEvent(call)));
+ self.vrings[index as usize].call_evt = Some(CallEvent(call));
}
Ok(())
@@ -583,10 +523,7 @@
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
// has been negotiated.
- if self.backend.borrow().acked_features()
- & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
- == 0
- {
+ if self.backend.acked_features() & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return Err(VhostError::InvalidOperation);
}
@@ -594,8 +531,7 @@
// enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
// or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
// with parameter 0.
- let mut vring = self.vrings[index as usize].borrow_mut();
- vring.enabled = enable;
+ self.vrings[index as usize].enabled = enable;
Ok(())
}
@@ -611,9 +547,7 @@
}
let mut data = vec![0; size as usize];
- self.backend
- .borrow()
- .read_config(u64::from(offset), &mut data);
+ self.backend.read_config(u64::from(offset), &mut data);
Ok(data)
}
@@ -663,11 +597,6 @@
use tempfile::{Builder, TempDir};
use vmm_vhost::vhost_user::Master;
- #[derive(PollToken, Debug)]
- enum FakeToken {
- Queue0,
- }
-
#[derive(ThisError, Debug)]
enum FakeError {
#[error("invalid features are given: 0x{features:x}")]
@@ -688,7 +617,6 @@
const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
struct FakeBackend {
- mem: Option<GuestMemory>,
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
@@ -697,7 +625,6 @@
impl FakeBackend {
fn new() -> Self {
Self {
- mem: None,
avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
acked_features: 0,
acked_protocol_features: VhostUserProtocolFeatures::empty(),
@@ -707,18 +634,10 @@
impl VhostUserBackend for FakeBackend {
const MAX_QUEUE_NUM: usize = 16;
- const MAX_VRING_NUM: usize = 256;
+ const MAX_VRING_LEN: u16 = 256;
- type EventToken = FakeToken;
type Error = FakeError;
- fn index_to_event_type(queue_index: usize) -> Option<Self::EventToken> {
- match queue_index {
- 0 => Some(FakeToken::Queue0),
- _ => None,
- }
- }
-
fn features(&self) -> u64 {
self.avail_features
}
@@ -738,10 +657,6 @@
self.acked_features
}
- fn set_guest_mem(&mut self, mem: GuestMemory) {
- self.mem = Some(mem);
- }
-
fn protocol_features(&self) -> VhostUserProtocolFeatures {
VhostUserProtocolFeatures::CONFIG
}
@@ -758,24 +673,24 @@
self.acked_protocol_features.bits()
}
- fn backend_event(&self) -> Option<(&dyn AsRawDescriptor, EventType, Self::EventToken)> {
- None
- }
-
- fn handle_event(
- &mut self,
- _wait_ctx: &Rc<WaitContext<HandlerPollToken<Self>>>,
- _event: &Self::EventToken,
- _vrings: &[Rc<RefCell<Vring>>],
- ) -> std::result::Result<(), Self::Error> {
- Ok(())
- }
-
fn read_config(&self, offset: u64, dst: &mut [u8]) {
dst.copy_from_slice(&FAKE_CONFIG_DATA.as_slice()[offset as usize..]);
}
fn reset(&mut self) {}
+
+ fn start_queue(
+ &mut self,
+ _idx: usize,
+ _queue: Queue,
+ _mem: GuestMemory,
+ _call_evt: CallEvent,
+ _kick_evt: Event,
+ ) -> std::result::Result<(), Self::Error> {
+ Ok(())
+ }
+
+ fn stop_queue(&mut self, _idx: usize) {}
}
fn temp_dir() -> TempDir {
@@ -836,9 +751,9 @@
});
// Device side
- let handler = Arc::new(std::sync::Mutex::new(
- DeviceRequestHandler::new(FakeBackend::new()).unwrap(),
- ));
+ let handler = Arc::new(std::sync::Mutex::new(DeviceRequestHandler::new(
+ FakeBackend::new(),
+ )));
let mut listener = SlaveListener::new(listener, handler).unwrap();
// Notify listener is ready.
diff --git a/vhost_user_devices/src/net.rs b/vhost_user_devices/src/net.rs
index 6e0a39f..fc9012e 100644
--- a/vhost_user_devices/src/net.rs
+++ b/vhost_user_devices/src/net.rs
@@ -2,60 +2,104 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use std::cell::RefCell;
use std::net::Ipv4Addr;
-use std::rc::Rc;
use std::str::FromStr;
-use base::{error, AsRawDescriptor, EventType, WaitContext};
+use base::{error, warn, Event};
+use cros_async::{AsyncError, EventAsync, Executor, IoSourceExt};
use data_model::DataInit;
use devices::virtio;
use devices::virtio::net::{
build_config, process_rx, process_tx, validate_and_configure_tap,
- virtio_features_to_tap_offload, NetError, Token,
+ virtio_features_to_tap_offload, NetError,
};
use devices::ProtectionType;
+use futures::future::{AbortHandle, Abortable};
use getopts::Options;
-use net_util::{MacAddress, Tap, TapT};
+use net_util::{Error as NetUtilError, MacAddress, Tap, TapT};
+use once_cell::sync::OnceCell;
use remain::sorted;
use thiserror::Error as ThisError;
-use vhost_user_devices::{DeviceRequestHandler, HandlerPollToken, VhostUserBackend, Vring};
+use vhost_user_devices::{CallEvent, DeviceRequestHandler, VhostUserBackend};
use virtio_sys::virtio_net;
use vm_memory::GuestMemory;
use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
+static NET_EXECUTOR: OnceCell<Executor> = OnceCell::new();
+
#[sorted]
#[derive(ThisError, Debug)]
enum Error {
+ #[error("failed to clone tap device: {0}")]
+ CloneTap(NetUtilError),
+ #[error("failed to create async tap device: {0}")]
+ CreateAsyncTap(AsyncError),
+ #[error("failed to create EventAsync: {0}")]
+ CreateEventAsync(AsyncError),
#[error("invalid features are given: 0x{features:x}")]
InvalidFeatures { features: u64 },
#[error("invalid protocol features are given: 0x{features:x}")]
InvalidProtocolFeatures { features: u64 },
- #[error("call event is not set for vring {index}")]
- NoCallEvent { index: usize },
- #[error("guest memory is not set for vring {index}")]
- NoGuestMemory { index: usize },
- #[error("kill event is not set for vring {index}")]
- NoKillEvent { index: usize },
- #[error("failed to process rx queue: {0}")]
- ProcessRx(NetError),
- #[error("failed to read kick event for vring {index}: {err}")]
- ReadKickEvent { index: usize, err: base::Error },
+ #[error("guest memory is not set for queue {idx}")]
+ NoGuestMemory { idx: usize },
#[error("failed to set tap offload to match acked features: {0}")]
TapOffload(net_util::Error),
- #[error("unexpected token is given: {0:?}")]
- UnexpectedToken(Token),
- #[error("failed to modify wait context: {0}")]
- WaitCtxModify(base::Error),
+ #[error("attempted to start unknown queue: {0}")]
+ UnknownQueue(usize),
+}
+
+async fn run_tx_queue(
+ mut queue: virtio::Queue,
+ mem: GuestMemory,
+ mut tap: Tap,
+ call_evt: CallEvent,
+ kick_evt: EventAsync,
+) {
+ loop {
+ if let Err(e) = kick_evt.next_val().await {
+ error!("Failed to read kick event for tx queue: {}", e);
+ break;
+ }
+
+ process_tx(&call_evt, &mut queue, &mem, &mut tap);
+ }
+}
+
+async fn run_rx_queue(
+ mut queue: virtio::Queue,
+ mem: GuestMemory,
+ mut tap: Box<dyn IoSourceExt<Tap>>,
+ call_evt: CallEvent,
+ kick_evt: EventAsync,
+) {
+ loop {
+ if let Err(e) = tap.wait_readable().await {
+ error!("Failed to wait for tap device to become readable: {}", e);
+ break;
+ }
+
+ match process_rx(&call_evt, &mut queue, &mem, tap.as_source_mut()) {
+ Ok(()) => {}
+ Err(NetError::RxDescriptorsExhausted) => {
+ if let Err(e) = kick_evt.next_val().await {
+ error!("Failed to read kick event for rx queue: {}", e);
+ break;
+ }
+ }
+ Err(e) => {
+ error!("Failed to process rx queue: {}", e);
+ break;
+ }
+ }
+ }
}
struct NetBackend {
tap: Tap,
- mem: Option<GuestMemory>,
- tap_polling_enabled: bool,
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
+ workers: [Option<AbortHandle>; Self::MAX_QUEUE_NUM],
}
impl NetBackend {
@@ -84,11 +128,10 @@
Self {
tap,
- mem: None,
avail_features,
acked_features: 0,
acked_protocol_features: VhostUserProtocolFeatures::empty(),
- tap_polling_enabled: false,
+ workers: Default::default(),
}
}
@@ -100,19 +143,10 @@
impl VhostUserBackend for NetBackend {
// TODO(keiichiw): Support multiple queue pairs.
const MAX_QUEUE_NUM: usize = 2; /* 1 rx and 1 tx */
- const MAX_VRING_NUM: usize = 256;
+ const MAX_VRING_LEN: u16 = 256;
- type EventToken = Token;
type Error = Error;
- fn index_to_event_type(queue_index: usize) -> Option<Self::EventToken> {
- match queue_index {
- 0 => Some(Token::RxQueue),
- 1 => Some(Token::TxQueue),
- _ => None,
- }
- }
-
fn features(&self) -> u64 {
self.avail_features
}
@@ -138,10 +172,6 @@
self.acked_features
}
- fn set_guest_mem(&mut self, mem: GuestMemory) {
- self.mem = Some(mem);
- }
-
fn protocol_features(&self) -> VhostUserProtocolFeatures {
// TODO(keiichiw): Support MQ.
VhostUserProtocolFeatures::CONFIG
@@ -159,126 +189,61 @@
self.acked_protocol_features.bits()
}
- fn backend_event(&self) -> Option<(&dyn AsRawDescriptor, EventType, Self::EventToken)> {
- Some((
- &self.tap as &dyn AsRawDescriptor,
- EventType::None,
- Token::RxTap,
- ))
- }
-
- fn handle_event(
- &mut self,
- wait_ctx: &Rc<WaitContext<HandlerPollToken<Self>>>,
- token: &Self::EventToken,
- vrings: &[Rc<RefCell<Vring>>],
- ) -> std::result::Result<(), Error> {
- match token {
- Token::RxTap => {
- let index = 0;
- let mut vring = vrings[index].borrow_mut();
-
- if !vring.enabled {
- return Ok(());
- }
-
- let Vring {
- ref mut queue,
- ref call_evt,
- ..
- } = *vring;
-
- let call_evt = call_evt
- .as_ref()
- .ok_or(Error::NoCallEvent { index })?
- .as_ref();
-
- let guest_mem = self.mem.as_ref().ok_or(Error::NoGuestMemory { index })?;
-
- match process_rx(call_evt, queue, &guest_mem, &mut self.tap) {
- Ok(()) => Ok(()),
- Err(NetError::RxDescriptorsExhausted) => {
- wait_ctx
- .modify(
- &self.tap,
- EventType::None,
- HandlerPollToken::BackendToken(Token::RxTap),
- )
- .map_err(Error::WaitCtxModify)?;
- self.tap_polling_enabled = false;
-
- Ok(())
- }
- Err(e) => Err(Error::ProcessRx(e)),
- }
- }
- Token::RxQueue => {
- let index = 0;
- let vring = vrings[index].borrow();
- if !vring.enabled {
- return Ok(());
- }
-
- let kick_evt = vring
- .kick_evt
- .as_ref()
- .ok_or(Error::NoKillEvent { index })?;
- kick_evt
- .read()
- .map_err(|err| Error::ReadKickEvent { index, err })?;
-
- if !self.tap_polling_enabled {
- wait_ctx
- .modify(
- &self.tap,
- EventType::Read,
- HandlerPollToken::BackendToken(Token::RxTap),
- )
- .map_err(Error::WaitCtxModify)?;
- self.tap_polling_enabled = true;
- }
- Ok(())
- }
- Token::TxQueue => {
- let index = 1;
- let mut vring = vrings[index].borrow_mut();
-
- if !vring.enabled {
- return Ok(());
- }
-
- let Vring {
- ref mut queue,
- ref call_evt,
- ref kick_evt,
- ..
- } = *vring;
-
- let call_evt = call_evt
- .as_ref()
- .ok_or(Error::NoCallEvent { index })?
- .as_ref();
-
- let kick_evt = kick_evt.as_ref().ok_or(Error::NoKillEvent { index })?;
- if let Err(e) = kick_evt.read() {
- error!("error reading tx queue Event: {}", e);
- }
-
- let guest_mem = self.mem.as_ref().ok_or(Error::NoGuestMemory { index })?;
-
- process_tx(call_evt, queue, &guest_mem, &mut self.tap);
- Ok(())
- }
- token => Err(Error::UnexpectedToken(token.clone())),
- }
- }
-
fn read_config(&self, offset: u64, data: &mut [u8]) {
let config_space = build_config(Self::max_vq_pairs() as u16);
virtio::copy_config(data, 0, config_space.as_slice(), offset);
}
fn reset(&mut self) {}
+
+ fn start_queue(
+ &mut self,
+ idx: usize,
+ queue: virtio::Queue,
+ mem: GuestMemory,
+ call_evt: CallEvent,
+ kick_evt: Event,
+ ) -> std::result::Result<(), Self::Error> {
+ if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ warn!("Starting new queue handler without stopping old handler");
+ handle.abort();
+ }
+
+ // Safe because the executor is initialized in main() below.
+ let ex = NET_EXECUTOR.get().expect("Executor not initialized");
+
+ let kick_evt = EventAsync::new(kick_evt.0, ex).map_err(Error::CreateEventAsync)?;
+ let tap = self.tap.try_clone().map_err(Error::CloneTap)?;
+ let (handle, registration) = AbortHandle::new_pair();
+ match idx {
+ 0 => {
+ let tap = ex.async_from(tap).map_err(Error::CreateAsyncTap)?;
+
+ ex.spawn_local(Abortable::new(
+ run_rx_queue(queue, mem, tap, call_evt, kick_evt),
+ registration,
+ ))
+ .detach();
+ }
+ 1 => {
+ ex.spawn_local(Abortable::new(
+ run_tx_queue(queue, mem, tap, call_evt, kick_evt),
+ registration,
+ ))
+ .detach();
+ }
+ _ => return Err(Error::UnknownQueue(idx)),
+ }
+
+ self.workers[idx] = Some(handle);
+ Ok(())
+ }
+
+ fn stop_queue(&mut self, idx: usize) {
+ if let Some(handle) = self.workers.get_mut(idx).and_then(Option::take) {
+ handle.abort();
+ }
+ }
}
struct TapConfig {
@@ -359,9 +324,13 @@
}
};
+ let ex = Executor::new().expect("Failed to create executor");
+ let _ = NET_EXECUTOR.set(ex.clone());
+
let net = NetBackend::new(tap_cfg.host_ip, tap_cfg.netmask, tap_cfg.mac);
- let handler = DeviceRequestHandler::new(net).expect("new handler");
- if let Err(e) = handler.start(socket) {
+ let handler = DeviceRequestHandler::new(net);
+
+ if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
error!("error occurred: {}", e);
}
}