Skip to content

Commit e649f64

Browse files
committed
fix bounds logic for multi column indices
1 parent 08037bc commit e649f64

File tree

2 files changed

+227
-68
lines changed

2 files changed

+227
-68
lines changed

crates/table/src/table_index/bytes_key.rs

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,36 @@ pub(super) fn required_bytes_key_size(ty: &AlgebraicType, is_ranged_idx: bool) -
163163
}
164164
}
165165

166+
/// Validates BSATN `byte` to conform to `seed`.
167+
///
168+
/// The BSATN can originate from untrusted sources, e.g., from module code.
169+
/// This also means that e.g., a `BytesKey` can be trusted to hold valid BSATN
170+
/// for the key type, which we can rely on in e.g., `decode_algebraic_value`,
171+
/// which isn't used in a context where it would be appropriate to fail.
172+
///
173+
/// Another reason to validate is that we wish for `BytesKey` to be strictly
174+
/// an optimization and not allow things that would be rejected by the non-optimized code.
175+
///
176+
/// After validating, we also don't need to validate that `bytes`
177+
/// will fit into e.g., a `BytesKey<N>`
178+
/// since if all parts that are encoded into it are valid according to a key type,
179+
/// then `bytes` cannot be longer than `N`.
180+
fn validate<'a, 'de, S: 'a + ?Sized>(seed: &'a S, mut bytes: &'de [u8]) -> DecodeResult<()>
181+
where
182+
WithTypespace<'a, S>: DeserializeSeed<'de>,
183+
{
184+
WithTypespace::empty(seed).validate(Deserializer::new(&mut bytes))?;
185+
186+
if !bytes.is_empty() {
187+
return Err(DecodeError::custom(format_args!(
188+
"after decoding, there are {} extra bytes",
189+
bytes.len()
190+
)));
191+
}
192+
193+
Ok(())
194+
}
195+
166196
impl<const N: usize> BytesKey<N> {
167197
fn new(length: usize, bytes: [u8; N]) -> Self {
168198
let length = length as _;
@@ -181,56 +211,13 @@ impl<const N: usize> BytesKey<N> {
181211
.expect("A `BytesKey` should by construction always deserialize to the right `key_type`")
182212
}
183213

184-
/// Ensure bytes of length `got` fit in `N` or return an error.
185-
fn ensure_key_fits(got: usize) -> DecodeResult<()> {
186-
if got > N {
187-
return Err(DecodeError::custom(format_args!(
188-
"key provided is too long, expected at most {N}, but got {got}"
189-
)));
190-
}
191-
Ok(())
192-
}
193-
194-
/// Decodes `prefix` and `endpoint` in BSATN to a [`BytesKey<N>`]
195-
/// by copying over both if they fit into the key.
196-
pub(super) fn from_bsatn_prefix_and_endpoint(
197-
prefix: &[u8],
198-
prefix_types: &[ProductTypeElement],
199-
endpoint: &[u8],
200-
range_type: &AlgebraicType,
201-
) -> DecodeResult<Self> {
202-
// Validate the BSATN.
203-
//
204-
// The BSATN can originate from untrusted sources, e.g., from module code.
205-
// This also means that a `BytesKey` can be trusted to hold valid BSATN
206-
// for the key type, which we can rely on in e.g., `decode_algebraic_value`,
207-
// which isn't used in a context where it would be appropriate to fail.
208-
//
209-
// Another reason to validate is that we wish for `BytesKey` to be strictly
210-
// an optimization and not allow things that would be rejected by the non-optimized code.
211-
WithTypespace::empty(prefix_types).validate(Deserializer::new(&mut { prefix }))?;
212-
WithTypespace::empty(range_type).validate(Deserializer::new(&mut { endpoint }))?;
213-
// Check that the `prefix` and the `endpoint` together fit into the key.
214-
let prefix_len = prefix.len();
215-
let endpoint_len = endpoint.len();
216-
let total_len = prefix_len + endpoint_len;
217-
Self::ensure_key_fits(total_len)?;
218-
// Copy the `prefix` and the `endpoint` over.
219-
let mut bytes = [0; N];
220-
bytes[..prefix_len].copy_from_slice(prefix);
221-
bytes[prefix_len..total_len].copy_from_slice(endpoint);
222-
Ok(Self::new(total_len, bytes))
223-
}
224-
225214
/// Decodes `bytes` in BSATN to a [`BytesKey<N>`]
226215
/// by copying over the bytes if they fit into the key.
227216
pub(super) fn from_bsatn(ty: &AlgebraicType, bytes: &[u8]) -> DecodeResult<Self> {
228-
// Validate the BSATN. See `Self::from_bsatn_prefix_and_endpoint` for more details.
229-
WithTypespace::empty(ty).validate(Deserializer::new(&mut { bytes }))?;
230-
// Check that the `bytes` fit into the key.
231-
let got = bytes.len();
232-
Self::ensure_key_fits(got)?;
217+
// Validate the BSATN.
218+
validate(ty, bytes)?;
233219
// Copy the bytes over.
220+
let got = bytes.len();
234221
let mut arr = [0; N];
235222
arr[..got].copy_from_slice(bytes);
236223
Ok(Self::new(got, arr))
@@ -343,6 +330,11 @@ fn split_map_write_back<const N: usize>(slice: &mut [u8], map_bytes: impl FnOnce
343330
}
344331

345332
impl<const N: usize> RangeCompatBytesKey<N> {
333+
fn new(length: usize, bytes: [u8; N]) -> Self {
334+
let length = length as _;
335+
Self { length, bytes }
336+
}
337+
346338
/// Decodes `self` as an [`AlgebraicValue`] at `key_type`.
347339
///
348340
/// An incorrect `key_type`,
@@ -354,6 +346,26 @@ impl<const N: usize> RangeCompatBytesKey<N> {
354346
Self::to_bytes_key(*self, key_type).decode_algebraic_value(key_type)
355347
}
356348

349+
/// Decodes `prefix` in BSATN to a [`RangeCompatBytesKey<N>`]
350+
/// by copying over `prefix` and massaging if they fit into the key.
351+
pub(super) fn from_bsatn_prefix(prefix: &[u8], prefix_types: &[ProductTypeElement]) -> DecodeResult<Self> {
352+
// Validate the BSATN.
353+
validate(prefix_types, prefix)?;
354+
355+
// Copy the `prefix` over.
356+
let mut bytes = [0; N];
357+
let got = prefix.len();
358+
bytes[..got].copy_from_slice(prefix);
359+
360+
// Massage the `bytes`.
361+
let mut slice = bytes.as_mut_slice();
362+
for ty in prefix_types {
363+
slice = Self::process_from_bytes_key(slice, &ty.algebraic_type);
364+
}
365+
366+
Ok(Self::new(got, bytes))
367+
}
368+
357369
/// Decodes `prefix` and `endpoint` in BSATN to a [`RangeCompatBytesKey<N>`]
358370
/// by copying over both and massaging if they fit into the key.
359371
pub(super) fn from_bsatn_prefix_and_endpoint(
@@ -362,17 +374,28 @@ impl<const N: usize> RangeCompatBytesKey<N> {
362374
endpoint: &[u8],
363375
range_type: &AlgebraicType,
364376
) -> DecodeResult<Self> {
365-
let BytesKey { length, mut bytes } =
366-
BytesKey::from_bsatn_prefix_and_endpoint(prefix, prefix_types, endpoint, range_type)?;
377+
// Validate the BSATN.
378+
validate(prefix_types, prefix)?;
379+
validate(range_type, endpoint)?;
380+
381+
// Sum up the lengths.
382+
let prefix_len = prefix.len();
383+
let endpoint_len = endpoint.len();
384+
let total_len = prefix_len + endpoint_len;
367385

368-
// Masage the bytes in `key`.
386+
// Copy the `prefix` and the `endpoint` over.
387+
let mut bytes = [0; N];
388+
bytes[..prefix_len].copy_from_slice(prefix);
389+
bytes[prefix_len..total_len].copy_from_slice(endpoint);
390+
391+
// Massage the bytes.
369392
let mut slice = bytes.as_mut_slice();
370393
for ty in prefix_types {
371394
slice = Self::process_from_bytes_key(slice, &ty.algebraic_type);
372395
}
373396
Self::process_from_bytes_key(slice, range_type);
374397

375-
Ok(Self { length, bytes })
398+
Ok(Self::new(total_len, bytes))
376399
}
377400

378401
/// Decodes `bytes` in BSATN to a [`RangeCompatBytesKey<N>`]
@@ -492,6 +515,14 @@ impl<const N: usize> RangeCompatBytesKey<N> {
492515
process(bytes.as_mut_slice(), ty);
493516
BytesKey { length, bytes }
494517
}
518+
519+
/// Extend the length to `N` by filling with `u8::MAX`.
520+
pub(super) fn add_max_suffix(mut self) -> Self {
521+
let len = self.len();
522+
self.bytes[len..].fill(u8::MAX);
523+
self.length = N as u8;
524+
self
525+
}
495526
}
496527

497528
#[cfg(test)]

0 commit comments

Comments
 (0)