blob: ef718d7113676d6a5889bc261a29df9a1d406ec5 [file] [log] [blame]
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
use std::{io::Write, num::Wrapping};
use vm_memory::{bitmap::BitmapSlice, VolatileSlice};
use crate::vhu_vsock::{Error, Result};
#[derive(Debug)]
pub(crate) struct LocalTxBuf {
/// Buffer holding data to be forwarded to a host-side application
buf: Vec<u8>,
/// Index into buffer from which data can be consumed from the buffer
head: Wrapping<u32>,
/// Index into buffer from which data can be added to the buffer
tail: Wrapping<u32>,
}
impl LocalTxBuf {
/// Create a new instance of LocalTxBuf.
pub fn new(buf_size: u32) -> Self {
Self {
buf: vec![0; buf_size as usize],
head: Wrapping(0),
tail: Wrapping(0),
}
}
/// Get the buffer size
pub fn get_buf_size(&self) -> u32 {
self.buf.len() as u32
}
/// Check if the buf is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Add new data to the tx buffer, push all or none.
/// Returns LocalTxBufFull error if space not sufficient.
pub fn push<B: BitmapSlice>(&mut self, data_buf: &VolatileSlice<B>) -> Result<()> {
if self.get_buf_size() as usize - self.len() < data_buf.len() {
// Tx buffer is full
return Err(Error::LocalTxBufFull);
}
// Get index into buffer at which data can be inserted
let tail_idx = self.tail.0 as usize % self.get_buf_size() as usize;
// Check if we can fit the data buffer between head and end of buffer
let len = std::cmp::min(self.get_buf_size() as usize - tail_idx, data_buf.len());
let txbuf = &mut self.buf[tail_idx..tail_idx + len];
data_buf.copy_to(txbuf);
// Check if there is more data to be wrapped around
if len < data_buf.len() {
let remain_txbuf = &mut self.buf[..(data_buf.len() - len)];
data_buf.copy_to(remain_txbuf);
}
// Increment tail by the amount of data that has been added to the buffer
self.tail += Wrapping(data_buf.len() as u32);
Ok(())
}
/// Flush buf data to stream.
pub fn flush_to<S: Write>(&mut self, stream: &mut S) -> Result<usize> {
if self.is_empty() {
// No data to be flushed
return Ok(0);
}
// Get index into buffer from which data can be read
let head_idx = self.head.0 as usize % self.get_buf_size() as usize;
// First write from head to end of buffer
let len = std::cmp::min(self.get_buf_size() as usize - head_idx, self.len());
let written = stream
.write(&self.buf[head_idx..(head_idx + len)])
.map_err(Error::LocalTxBufFlush)?;
// Increment head by amount of data that has been flushed to the stream
self.head += Wrapping(written as u32);
// If written length is less than the expected length we can try again in the future
if written < len {
return Ok(written);
}
// The head index has wrapped around the end of the buffer, we call self again
Ok(written + self.flush_to(stream).unwrap_or(0))
}
/// Return amount of data in the buffer.
fn len(&self) -> usize {
(self.tail - self.head).0 as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
#[test]
fn test_txbuf_len() {
let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);
// Zero length tx buf
assert_eq!(loc_tx_buf.len(), 0);
// finite length tx buf
loc_tx_buf.head = Wrapping(0);
loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
assert_eq!(loc_tx_buf.len(), CONN_TX_BUF_SIZE as usize);
loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE / 2);
assert_eq!(loc_tx_buf.len(), (CONN_TX_BUF_SIZE / 2) as usize);
loc_tx_buf.head = Wrapping(256);
assert_eq!(loc_tx_buf.len(), 32512);
}
#[test]
fn test_txbuf_is_empty() {
let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);
// empty tx buffer
assert!(loc_tx_buf.is_empty());
// non empty tx buffer
loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
assert!(!loc_tx_buf.is_empty());
}
#[test]
fn test_txbuf_push() {
let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);
let mut buf = [0; CONN_TX_BUF_SIZE as usize];
// SAFETY: Safe as the buffer is guaranteed to be valid here.
let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };
// push data into empty tx buffer
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_ok());
assert_eq!(loc_tx_buf.head, Wrapping(0));
assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE));
// push data into full tx buffer
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_err());
// head and tail wrap at full
loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE);
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_ok());
assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE * 2));
// only tail wraps at full
let mut buf = vec![1; 4];
// SAFETY: Safe as the buffer is guaranteed to be valid here.
let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };
let mut cmp_data = vec![1; 4];
cmp_data.append(&mut vec![0; (CONN_TX_BUF_SIZE - 4) as usize]);
loc_tx_buf.head = Wrapping(4);
loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_ok());
assert_eq!(loc_tx_buf.head, Wrapping(4));
assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE + 4));
assert_eq!(loc_tx_buf.buf, cmp_data);
}
#[test]
fn test_txbuf_flush_to() {
let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);
// data to be flushed
let mut buf = vec![1; CONN_TX_BUF_SIZE as usize];
// SAFETY: Safe as the buffer is guaranteed to be valid here.
let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };
// target to which data is flushed
let mut cmp_vec = Vec::with_capacity(data.len());
// flush no data
let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
assert!(res_flush.is_ok());
assert_eq!(res_flush.unwrap(), 0);
// flush data of CONN_TX_BUF_SIZE amount
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_ok());
let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
if let Ok(n) = res_flush {
assert_eq!(loc_tx_buf.head, Wrapping(n as u32));
assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE));
assert_eq!(n, cmp_vec.len());
assert_eq!(cmp_vec, buf[..n]);
}
// wrapping head flush
let mut buf = vec![0; (CONN_TX_BUF_SIZE / 2) as usize];
buf.append(&mut vec![1; (CONN_TX_BUF_SIZE / 2) as usize]);
// SAFETY: Safe as the buffer is guaranteed to be valid here.
let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };
loc_tx_buf.head = Wrapping(0);
loc_tx_buf.tail = Wrapping(0);
let res_push = loc_tx_buf.push(&data);
assert!(res_push.is_ok());
cmp_vec.clear();
loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE / 2);
loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2));
let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
if let Ok(n) = res_flush {
assert_eq!(
loc_tx_buf.head,
Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2))
);
assert_eq!(
loc_tx_buf.tail,
Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2))
);
assert_eq!(n, cmp_vec.len());
let mut data = vec![1; (CONN_TX_BUF_SIZE / 2) as usize];
data.append(&mut vec![0; (CONN_TX_BUF_SIZE / 2) as usize]);
assert_eq!(cmp_vec, data[..n]);
}
}
}