Skip to content

Commit 7527704

Browse files
committed
feat(virtq): add support for multi-descriptor payloads
Signed-off-by: Tomasz Andrzejak <andreiltd@gmail.com>
1 parent eeb9423 commit 7527704

9 files changed

Lines changed: 369 additions & 104 deletions

File tree

src/hyperlight_common/src/virtq/consumer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ pub struct VirtqConsumer<M, N> {
255255
inner: RingConsumer<M>,
256256
notifier: N,
257257
inflight: FixedBitSet,
258+
next_token: u32,
258259
}
259260

260261
impl<M: MemOps + Clone, N: Notifier> VirtqConsumer<M, N> {
@@ -273,6 +274,7 @@ impl<M: MemOps + Clone, N: Notifier> VirtqConsumer<M, N> {
273274
inner,
274275
notifier,
275276
inflight,
277+
next_token: 0,
276278
}
277279
}
278280

@@ -323,7 +325,8 @@ impl<M: MemOps + Clone, N: Notifier> VirtqConsumer<M, N> {
323325
}
324326

325327
self.inflight.insert(id_idx);
326-
let token = Token(id);
328+
let token = Token(self.next_token, id);
329+
self.next_token = self.next_token.wrapping_add(1);
327330

328331
// Copy entry data from shared memory
329332
let data = entry_elem

src/hyperlight_common/src/virtq/mod.rs

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,11 @@ pub enum SuppressionKind {
335335

336336
/// A token representing a sent entry in the virtqueue.
337337
///
338-
/// Tokens uniquely identify in-flight requests and are used to correlate
339-
/// requests with their responses. The token value corresponds to the
340-
/// descriptor ID in the underlying ring.
338+
/// Tokens uniquely identify in-flight requests and are used to correlate requests with their responses.
339+
/// The first element is a monotonically increasing generation counter. The second element is the
340+
/// underlying descriptor ID
341341
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
342-
pub struct Token(pub u16);
342+
pub struct Token(pub u32, pub u16);
343343

344344
impl From<BufferElement> for Allocation {
345345
fn from(value: BufferElement) -> Self {
@@ -972,6 +972,54 @@ mod tests {
972972
assert_eq!(cqe2.token, tok_rw);
973973
assert_eq!(&cqe2.data[..], b"reply");
974974
}
975+
976+
/// Regression test: reclaim + submit must not cause token collisions.
977+
///
978+
/// Before the monotonic generation counter, Token wrapped the descriptor
979+
/// ID which gets recycled. This caused stale pending completions to
980+
/// match newly submitted entries with the same recycled descriptor ID.
981+
#[test]
982+
fn test_reclaim_submit_no_token_collision() {
983+
let ring = make_ring(8);
984+
let (mut producer, mut consumer, _) = make_test_producer(&ring);
985+
986+
// Submit and complete a ReadOnly entry
987+
let tok_old = send_readonly(&mut producer, b"log");
988+
989+
let (_, c) = consumer.poll(1024).unwrap().unwrap();
990+
consumer.complete(c).unwrap();
991+
992+
// Reclaim pushes the completion to pending (token = tok_old)
993+
let count = producer.reclaim().unwrap();
994+
assert_eq!(count, 1);
995+
996+
// Submit a new ReadWrite entry - may reuse the same descriptor ID
997+
let tok_new = send_readwrite(&mut producer, b"call", 64);
998+
999+
// Tokens must differ even if the descriptor ID was recycled
1000+
assert_ne!(
1001+
tok_old, tok_new,
1002+
"tokens must be unique across reclaim/submit cycles"
1003+
);
1004+
1005+
// Complete the ReadWrite entry
1006+
let (_, c) = consumer.poll(1024).unwrap().unwrap();
1007+
let SendCompletion::Writable(mut wc) = c else {
1008+
panic!("expected writable");
1009+
};
1010+
wc.write_all(b"result").unwrap();
1011+
consumer.complete(wc.into()).unwrap();
1012+
1013+
// Poll should return the stale ReadOnly completion first (wrong token)
1014+
let cqe1 = producer.poll().unwrap().unwrap();
1015+
assert_eq!(cqe1.token, tok_old);
1016+
assert!(cqe1.data.is_empty());
1017+
1018+
// Then the new ReadWrite completion (matching token)
1019+
let cqe2 = producer.poll().unwrap().unwrap();
1020+
assert_eq!(cqe2.token, tok_new);
1021+
assert_eq!(&cqe2.data[..], b"result");
1022+
}
9751023
}
9761024
#[cfg(all(test, loom))]
9771025
mod fuzz {

src/hyperlight_common/src/virtq/msg.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
//! fixed 8-byte header, enabling message type discrimination and
2121
//! request/response correlation.
2222
23+
use bitflags::bitflags;
24+
2325
/// Message types for the virtqueue wire protocol.
2426
#[repr(u8)]
2527
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -54,24 +56,33 @@ impl TryFrom<u8> for MsgKind {
5456
}
5557
}
5658

59+
bitflags! {
60+
#[repr(transparent)]
61+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
62+
pub struct MsgFlags: u8 {
63+
/// More descriptors follow for this message.
64+
const MORE = 1 << 0;
65+
}
66+
}
67+
5768
/// Wire header for all virtqueue messages
5869
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
5970
#[repr(C)]
6071
pub struct VirtqMsgHeader {
6172
/// Discriminates the message type.
6273
pub kind: u8,
63-
/// Per-type flags TODO(ring): add flags type.
74+
/// Per-message flags (see [`MsgFlags`]).
6475
pub flags: u8,
6576
/// Caller-assigned correlation ID. Responses echo the request's ID.
6677
pub req_id: u16,
67-
/// Byte length of the payload following this header.
78+
/// Byte length of the payload following this header in this descriptor.
6879
pub payload_len: u32,
6980
}
7081

7182
impl VirtqMsgHeader {
7283
pub const SIZE: usize = core::mem::size_of::<Self>();
7384

74-
/// Create a new message header.
85+
/// Create a new message header with no flags set.
7586
pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self {
7687
Self {
7788
kind: kind as u8,
@@ -82,10 +93,10 @@ impl VirtqMsgHeader {
8293
}
8394

8495
/// Create a new header with flags.
85-
pub const fn with_flags(kind: MsgKind, flags: u8, req_id: u16, payload_len: u32) -> Self {
96+
pub const fn with_flags(kind: MsgKind, flags: MsgFlags, req_id: u16, payload_len: u32) -> Self {
8697
Self {
8798
kind: kind as u8,
88-
flags,
99+
flags: flags.bits(),
89100
req_id,
90101
payload_len,
91102
}
@@ -95,4 +106,15 @@ impl VirtqMsgHeader {
95106
pub fn msg_kind(&self) -> Result<MsgKind, u8> {
96107
MsgKind::try_from(self.kind)
97108
}
109+
110+
/// Interpret the raw flags field as [`MsgFlags`].
111+
pub fn msg_flags(&self) -> MsgFlags {
112+
MsgFlags::from_bits_truncate(self.flags)
113+
}
114+
115+
/// Returns true if [`MsgFlags::MORE`] is set, indicating more
116+
/// descriptors follow for this message.
117+
pub const fn has_more(&self) -> bool {
118+
self.flags & MsgFlags::MORE.bits() != 0
119+
}
98120
}

src/hyperlight_common/src/virtq/pool.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ impl<const N: usize> Slab<N> {
150150
}
151151

152152
// Fallback to full search
153+
let total = self.used_slots.len();
153154
self.used_slots.zeroes().find(|&next_free| {
154-
self.used_slots
155-
.count_zeroes(next_free..next_free + slots_num)
156-
== slots_num
155+
let end = next_free + slots_num;
156+
end <= total && self.used_slots.count_zeroes(next_free..end) == slots_num
157157
})
158158
}
159159

@@ -416,6 +416,11 @@ impl<const L: usize, const U: usize> BufferPool<L, U> {
416416
inner: SyncWrap(Rc::new(RefCell::new(inner))),
417417
})
418418
}
419+
420+
/// Upper slab slot size in bytes.
421+
pub const fn upper_slot_size() -> usize {
422+
U
423+
}
419424
}
420425

421426
#[cfg(all(test, loom))]
@@ -821,6 +826,40 @@ mod tests {
821826
assert!(matches!(result, Err(AllocError::InvalidFree(_, _))));
822827
}
823828

829+
#[test]
830+
fn test_slab_multi_slot_alloc_near_end() {
831+
let mut slab = make_slab::<256>(1792); // 7 slots
832+
let a0 = slab.alloc(256).unwrap();
833+
let a1 = slab.alloc(256).unwrap();
834+
let _a2 = slab.alloc(256).unwrap();
835+
let _a3 = slab.alloc(256).unwrap();
836+
let _a4 = slab.alloc(256).unwrap();
837+
let _a5 = slab.alloc(256).unwrap();
838+
let _a6 = slab.alloc(256).unwrap();
839+
840+
slab.dealloc(a0).unwrap();
841+
slab.dealloc(a1).unwrap();
842+
843+
// 2-slot run fits at indices 0..2 but the search visits index 6
844+
// (a free zero) first if slots 0-1 are not found before it.
845+
// Actually slots 0-1 are free, so it should find them.
846+
let run = slab.alloc(300).unwrap(); // needs 2 slots
847+
assert_eq!(run.len, 512);
848+
}
849+
850+
#[test]
851+
fn test_slab_multi_slot_alloc_no_room_at_end() {
852+
// Only the last slot is free but a 2-slot run is requested.
853+
// find_slots must not panic when checking beyond the bitset.
854+
let mut slab = make_slab::<256>(1792); // 7 slots
855+
let allocs: Vec<_> = (0..7).map(|_| slab.alloc(256).unwrap()).collect();
856+
// Free only the last slot (index 6)
857+
slab.dealloc(allocs[6]).unwrap();
858+
859+
let result = slab.alloc(300); // needs 2 slots, only 1 free
860+
assert!(matches!(result, Err(AllocError::NoSpace)));
861+
}
862+
824863
#[test]
825864
fn test_slab_free_invalid_address() {
826865
let mut slab = make_slab::<256>(1024);

src/hyperlight_common/src/virtq/producer.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ pub struct VirtqProducer<M, N, P> {
125125
inner: RingProducer<M>,
126126
notifier: N,
127127
pool: P,
128-
inflight: Vec<Option<Inflight>>,
128+
next_token: u32,
129+
inflight: Vec<Option<(Token, Inflight)>>,
129130
pending: VecDeque<RecvCompletion>,
130131
}
131132

@@ -152,6 +153,7 @@ where
152153
pool,
153154
notifier,
154155
inflight,
156+
next_token: 0,
155157
pending: VecDeque::new(),
156158
}
157159
}
@@ -218,13 +220,20 @@ where
218220
};
219221

220222
let id = used.id as usize;
221-
let inf = self
223+
let (token, inf) = self
222224
.inflight
223225
.get_mut(id)
224226
.ok_or(VirtqError::InvalidState)?
225227
.take()
226228
.ok_or(VirtqError::InvalidState)?;
227229

230+
// the token's descriptor ID must match the ring's
231+
debug_assert_eq!(
232+
token.1, used.id,
233+
"ring returned desc_id={} but inflight slot {} has token with desc_id={}",
234+
used.id, id, token.1,
235+
);
236+
228237
let written = used.len as usize;
229238

230239
// Free entry buffers (request data no longer needed)
@@ -250,10 +259,7 @@ where
250259
None => Bytes::new(),
251260
};
252261

253-
Ok(Some(RecvCompletion {
254-
token: Token(used.id),
255-
data,
256-
}))
262+
Ok(Some(RecvCompletion { token, data }))
257263
}
258264

259265
/// Drain all available completions, calling the provided closure for each.
@@ -310,6 +316,9 @@ where
310316
let chain = inflight.try_into_chain(written)?;
311317
let id = self.inner.submit_available(&chain)?;
312318

319+
let token = Token(self.next_token, id);
320+
self.next_token = self.next_token.wrapping_add(1);
321+
313322
let slot = self
314323
.inflight
315324
.get_mut(id as usize)
@@ -319,7 +328,7 @@ where
319328
return Err(VirtqError::InvalidState);
320329
}
321330

322-
*slot = Some(inflight);
331+
*slot = Some((token, inflight));
323332

324333
let should_notify = self.inner.should_notify_since(cursor_before)?;
325334

@@ -336,7 +345,7 @@ where
336345
});
337346
}
338347

339-
Ok(Token(id))
348+
Ok(token)
340349
}
341350

342351
/// Signal backpressure to the consumer.
@@ -474,12 +483,18 @@ where
474483
.slot_addr(pos as usize)
475484
.ok_or(VirtqError::InvalidState)?;
476485

477-
self.inflight[id as usize] = Some(Inflight::WriteOnly {
478-
completion: Allocation {
479-
addr,
480-
len: slot_size,
486+
let token = Token(self.next_token, id);
487+
self.next_token = self.next_token.wrapping_add(1);
488+
489+
self.inflight[id as usize] = Some((
490+
token,
491+
Inflight::WriteOnly {
492+
completion: Allocation {
493+
addr,
494+
len: slot_size,
495+
},
481496
},
482-
});
497+
));
483498

484499
ids.push(id);
485500
}
@@ -869,7 +884,7 @@ mod tests {
869884
// Ring should still be fully usable
870885
let se = producer.chain().entry(64).completion(128).build().unwrap();
871886
let tok = producer.submit(se).unwrap();
872-
assert!(tok.0 < 16);
887+
assert!(tok.1 < 16);
873888
}
874889

875890
#[test]
@@ -885,7 +900,7 @@ mod tests {
885900
// Ring should still be fully usable
886901
let se = producer.chain().entry(64).completion(128).build().unwrap();
887902
let tok = producer.submit(se).unwrap();
888-
assert!(tok.0 < 16);
903+
assert!(tok.1 < 16);
889904
}
890905

891906
#[test]

0 commit comments

Comments
 (0)