summaryrefslogtreecommitdiff
path: root/rust/kernel/kasync/net.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/kernel/kasync/net.rs')
-rw-r--r--rust/kernel/kasync/net.rs322
1 files changed, 322 insertions, 0 deletions
diff --git a/rust/kernel/kasync/net.rs b/rust/kernel/kasync/net.rs
new file mode 100644
index 000000000000..b4bbeffad94a
--- /dev/null
+++ b/rust/kernel/kasync/net.rs
@@ -0,0 +1,322 @@
+// SPDX-License-Identifier: GPL-2.0
+
+//! Async networking.
+
+use crate::{bindings, error::code::*, net, sync::NoWaitLock, types::Opaque, Result};
+use core::{
+ future::Future,
+ marker::{PhantomData, PhantomPinned},
+ ops::Deref,
+ pin::Pin,
+ task::{Context, Poll, Waker},
+};
+
+/// A socket listening on a TCP port.
+///
+/// The [`TcpListener::accept`] method is meant to be used in async contexts.
+pub struct TcpListener {
+ listener: net::TcpListener,
+}
+
+impl TcpListener {
+ /// Creates a new TCP listener.
+ ///
+ /// It is configured to listen on the given socket address for the given namespace.
+ pub fn try_new(ns: &net::Namespace, addr: &net::SocketAddr) -> Result<Self> {
+ Ok(Self {
+ listener: net::TcpListener::try_new(ns, addr)?,
+ })
+ }
+
+ /// Accepts a new connection.
+ ///
+ /// Returns a future that when ready indicates the result of the accept operation; on success,
+ /// it contains the newly-accepted tcp stream.
+ pub fn accept(&self) -> impl Future<Output = Result<TcpStream>> + '_ {
+ SocketFuture::from_listener(
+ self,
+ bindings::BINDINGS_EPOLLIN | bindings::BINDINGS_EPOLLERR,
+ || {
+ Ok(TcpStream {
+ stream: self.listener.accept(false)?,
+ })
+ },
+ )
+ }
+}
+
+impl Deref for TcpListener {
+ type Target = net::TcpListener;
+
+ fn deref(&self) -> &Self::Target {
+ &self.listener
+ }
+}
+
+/// A connected TCP socket.
+///
+/// The potentially blocking methods (e.g., [`TcpStream::read`], [`TcpStream::write`]) are meant
+/// to be used in async contexts.
+///
+/// # Examples
+///
+/// ```
+/// # use kernel::prelude::*;
+/// # use kernel::kasync::net::TcpStream;
+/// async fn echo_server(stream: TcpStream) -> Result {
+/// let mut buf = [0u8; 1024];
+/// loop {
+/// let n = stream.read(&mut buf).await?;
+/// if n == 0 {
+/// return Ok(());
+/// }
+/// stream.write_all(&buf[..n]).await?;
+/// }
+/// }
+/// ```
+pub struct TcpStream {
+ stream: net::TcpStream,
+}
+
+impl TcpStream {
+ /// Reads data from a connected socket.
+ ///
+ /// Returns a future that when ready indicates the result of the read operation; on success, it
+ /// contains the number of bytes read, which will be zero if the connection is closed.
+ pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize>> + 'a {
+ SocketFuture::from_stream(
+ self,
+ bindings::BINDINGS_EPOLLIN | bindings::BINDINGS_EPOLLHUP | bindings::BINDINGS_EPOLLERR,
+ || self.stream.read(buf, false),
+ )
+ }
+
+ /// Writes data to the connected socket.
+ ///
+ /// Returns a future that when ready indicates the result of the write operation; on success, it
+ /// contains the number of bytes written.
+ pub fn write<'a>(&'a self, buf: &'a [u8]) -> impl Future<Output = Result<usize>> + 'a {
+ SocketFuture::from_stream(
+ self,
+ bindings::BINDINGS_EPOLLOUT | bindings::BINDINGS_EPOLLHUP | bindings::BINDINGS_EPOLLERR,
+ || self.stream.write(buf, false),
+ )
+ }
+
+ /// Writes all the data to the connected socket.
+ ///
+ /// Returns a future that when ready indicates the result of the write operation; on success, it
+ /// has written all the data.
+ pub async fn write_all<'a>(&'a self, buf: &'a [u8]) -> Result {
+ let mut rem = buf;
+
+ while !rem.is_empty() {
+ let n = self.write(rem).await?;
+ rem = &rem[n..];
+ }
+
+ Ok(())
+ }
+}
+
+impl Deref for TcpStream {
+ type Target = net::TcpStream;
+
+ fn deref(&self) -> &Self::Target {
+ &self.stream
+ }
+}
+
+/// A future for a socket operation.
+///
+/// # Invariants
+///
+/// `sock` is always non-null and valid for the duration of the lifetime of the instance.
+struct SocketFuture<'a, Out, F: FnMut() -> Result<Out> + Send + 'a> {
+ sock: *mut bindings::socket,
+ mask: u32,
+ is_queued: bool,
+ wq_entry: Opaque<bindings::wait_queue_entry>,
+ waker: NoWaitLock<Option<Waker>>,
+ _p: PhantomData<&'a ()>,
+ _pin: PhantomPinned,
+ operation: F,
+}
+
+// SAFETY: A kernel socket can be used from any thread, `wq_entry` is only used on drop and when
+// `is_queued` is initially `false`.
+unsafe impl<Out, F: FnMut() -> Result<Out> + Send> Send for SocketFuture<'_, Out, F> {}
+
+impl<'a, Out, F: FnMut() -> Result<Out> + Send + 'a> SocketFuture<'a, Out, F> {
+ /// Creates a new socket future.
+ ///
+ /// # Safety
+ ///
+ /// Callers must ensure that `sock` is non-null, valid, and remains valid for the lifetime
+ /// (`'a`) of the returned instance.
+ unsafe fn new(sock: *mut bindings::socket, mask: u32, operation: F) -> Self {
+ Self {
+ sock,
+ mask,
+ is_queued: false,
+ wq_entry: Opaque::uninit(),
+ waker: NoWaitLock::new(None),
+ operation,
+ _p: PhantomData,
+ _pin: PhantomPinned,
+ }
+ }
+
+ /// Creates a new socket future for a tcp listener.
+ fn from_listener(listener: &'a TcpListener, mask: u32, operation: F) -> Self {
+ // SAFETY: The socket is guaranteed to remain valid because it is bound to the reference to
+ // the listener (whose existence guarantees the socket remains valid).
+ unsafe { Self::new(listener.listener.sock, mask, operation) }
+ }
+
+ /// Creates a new socket future for a tcp stream.
+ fn from_stream(stream: &'a TcpStream, mask: u32, operation: F) -> Self {
+ // SAFETY: The socket is guaranteed to remain valid because it is bound to the reference to
+ // the stream (whose existence guarantees the socket remains valid).
+ unsafe { Self::new(stream.stream.sock, mask, operation) }
+ }
+
+ /// Callback called when the socket changes state.
+ ///
+ /// If the state matches the one we're waiting on, we wake up the task so that the future can be
+ /// polled again.
+ unsafe extern "C" fn wake_callback(
+ wq_entry: *mut bindings::wait_queue_entry,
+ _mode: core::ffi::c_uint,
+ _flags: core::ffi::c_int,
+ key: *mut core::ffi::c_void,
+ ) -> core::ffi::c_int {
+ let mask = key as u32;
+
+ // SAFETY: The future is valid while this callback is called because we remove from the
+ // queue on drop.
+ //
+ // There is a potential soundness issue here because we're generating a shared reference to
+ // `Self` while `Self::poll` has a mutable (unique) reference. However, for `!Unpin` types
+ // (like `Self`), `&mut T` is treated as `*mut T` per
+ // <https://github.com/rust-lang/rust/issues/63818> -- so we avoid the unsoundness. Once a
+ // more definitive solution is available, we can change this to use it.
+ let s = unsafe { &*crate::container_of!(wq_entry, Self, wq_entry) };
+ if mask & s.mask == 0 {
+ // Nothing to do as this notification doesn't interest us.
+ return 0;
+ }
+
+ // If we can't acquire the waker lock, the waker is in the process of being modified. Our
+ // attempt to acquire the lock will be reported to the lock owner, so it will trigger the
+ // wake up.
+ if let Some(guard) = s.waker.try_lock() {
+ if let Some(ref w) = *guard {
+ let cloned = w.clone();
+ drop(guard);
+ cloned.wake();
+ return 1;
+ }
+ }
+ 0
+ }
+
+ /// Poll the future once.
+ ///
+ /// It calls the operation and converts `EAGAIN` errors into a pending state.
+ fn poll_once(self: Pin<&mut Self>) -> Poll<Result<Out>> {
+ // SAFETY: We never move out of `this`.
+ let this = unsafe { self.get_unchecked_mut() };
+ match (this.operation)() {
+ Ok(s) => Poll::Ready(Ok(s)),
+ Err(e) => {
+ if e == EAGAIN {
+ Poll::Pending
+ } else {
+ Poll::Ready(Err(e))
+ }
+ }
+ }
+ }
+
+ /// Updates the waker stored in the future.
+ ///
+ /// It automatically triggers a wake up on races with the reactor.
+ fn set_waker(&self, waker: &Waker) {
+ if let Some(mut guard) = self.waker.try_lock() {
+ let old = core::mem::replace(&mut *guard, Some(waker.clone()));
+ let contention = guard.unlock();
+ drop(old);
+ if !contention {
+ return;
+ }
+ }
+
+ // We either couldn't store the waker because the existing one is being awakened, or the
+ // reactor tried to acquire the lock while we held it (contention). In either case, we just
+ // wake it up to ensure we don't miss any notification.
+ waker.wake_by_ref();
+ }
+}
+
+impl<Out, F: FnMut() -> Result<Out> + Send> Future for SocketFuture<'_, Out, F> {
+ type Output = Result<Out>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ match self.as_mut().poll_once() {
+ Poll::Ready(r) => Poll::Ready(r),
+ Poll::Pending => {
+ // Store away the latest waker every time we may `Pending`.
+ self.set_waker(cx.waker());
+ if self.is_queued {
+ // Nothing else to do was the waiter is already queued.
+ return Poll::Pending;
+ }
+
+ // SAFETY: We never move out of `this`.
+ let this = unsafe { self.as_mut().get_unchecked_mut() };
+
+ this.is_queued = true;
+
+ // SAFETY: `wq_entry` is valid for write.
+ unsafe {
+ bindings::init_waitqueue_func_entry(
+ this.wq_entry.get(),
+ Some(Self::wake_callback),
+ )
+ };
+
+ // SAFETY: `wq_entry` was just initialised above and is valid for read/write.
+ // By the type invariants, the socket is always valid.
+ unsafe {
+ bindings::add_wait_queue(
+ core::ptr::addr_of_mut!((*this.sock).wq.wait),
+ this.wq_entry.get(),
+ )
+ };
+
+ // If the future wasn't queued yet, we need to poll again in case it reached
+ // the desired state between the last poll and being queued (in which case we
+ // would have missed the notification).
+ self.poll_once()
+ }
+ }
+ }
+}
+
+impl<Out, F: FnMut() -> Result<Out> + Send> Drop for SocketFuture<'_, Out, F> {
+ fn drop(&mut self) {
+ if !self.is_queued {
+ return;
+ }
+
+ // SAFETY: `wq_entry` is initialised because `is_queued` is set to `true`, so it is valid
+ // for read/write. By the type invariants, the socket is always valid.
+ unsafe {
+ bindings::remove_wait_queue(
+ core::ptr::addr_of_mut!((*self.sock).wq.wait),
+ self.wq_entry.get(),
+ )
+ };
+ }
+}