@@ -163,6 +163,36 @@ pub(super) fn required_bytes_key_size(ty: &AlgebraicType, is_ranged_idx: bool) -
163163 }
164164}
165165
166+ /// Validates BSATN `byte` to conform to `seed`.
167+ ///
168+ /// The BSATN can originate from untrusted sources, e.g., from module code.
169+ /// This also means that e.g., a `BytesKey` can be trusted to hold valid BSATN
170+ /// for the key type, which we can rely on in e.g., `decode_algebraic_value`,
171+ /// which isn't used in a context where it would be appropriate to fail.
172+ ///
173+ /// Another reason to validate is that we wish for `BytesKey` to be strictly
174+ /// an optimization and not allow things that would be rejected by the non-optimized code.
175+ ///
176+ /// After validating, we also don't need to validate that `bytes`
177+ /// will fit into e.g., a `BytesKey<N>`
178+ /// since if all parts that are encoded into it are valid according to a key type,
179+ /// then `bytes` cannot be longer than `N`.
180+ fn validate < ' a , ' de , S : ' a + ?Sized > ( seed : & ' a S , mut bytes : & ' de [ u8 ] ) -> DecodeResult < ( ) >
181+ where
182+ WithTypespace < ' a , S > : DeserializeSeed < ' de > ,
183+ {
184+ WithTypespace :: empty ( seed) . validate ( Deserializer :: new ( & mut bytes) ) ?;
185+
186+ if !bytes. is_empty ( ) {
187+ return Err ( DecodeError :: custom ( format_args ! (
188+ "after decoding, there are {} extra bytes" ,
189+ bytes. len( )
190+ ) ) ) ;
191+ }
192+
193+ Ok ( ( ) )
194+ }
195+
166196impl < const N : usize > BytesKey < N > {
167197 fn new ( length : usize , bytes : [ u8 ; N ] ) -> Self {
168198 let length = length as _ ;
@@ -181,56 +211,13 @@ impl<const N: usize> BytesKey<N> {
181211 . expect ( "A `BytesKey` should by construction always deserialize to the right `key_type`" )
182212 }
183213
184- /// Ensure bytes of length `got` fit in `N` or return an error.
185- fn ensure_key_fits ( got : usize ) -> DecodeResult < ( ) > {
186- if got > N {
187- return Err ( DecodeError :: custom ( format_args ! (
188- "key provided is too long, expected at most {N}, but got {got}"
189- ) ) ) ;
190- }
191- Ok ( ( ) )
192- }
193-
194- /// Decodes `prefix` and `endpoint` in BSATN to a [`BytesKey<N>`]
195- /// by copying over both if they fit into the key.
196- pub ( super ) fn from_bsatn_prefix_and_endpoint (
197- prefix : & [ u8 ] ,
198- prefix_types : & [ ProductTypeElement ] ,
199- endpoint : & [ u8 ] ,
200- range_type : & AlgebraicType ,
201- ) -> DecodeResult < Self > {
202- // Validate the BSATN.
203- //
204- // The BSATN can originate from untrusted sources, e.g., from module code.
205- // This also means that a `BytesKey` can be trusted to hold valid BSATN
206- // for the key type, which we can rely on in e.g., `decode_algebraic_value`,
207- // which isn't used in a context where it would be appropriate to fail.
208- //
209- // Another reason to validate is that we wish for `BytesKey` to be strictly
210- // an optimization and not allow things that would be rejected by the non-optimized code.
211- WithTypespace :: empty ( prefix_types) . validate ( Deserializer :: new ( & mut { prefix } ) ) ?;
212- WithTypespace :: empty ( range_type) . validate ( Deserializer :: new ( & mut { endpoint } ) ) ?;
213- // Check that the `prefix` and the `endpoint` together fit into the key.
214- let prefix_len = prefix. len ( ) ;
215- let endpoint_len = endpoint. len ( ) ;
216- let total_len = prefix_len + endpoint_len;
217- Self :: ensure_key_fits ( total_len) ?;
218- // Copy the `prefix` and the `endpoint` over.
219- let mut bytes = [ 0 ; N ] ;
220- bytes[ ..prefix_len] . copy_from_slice ( prefix) ;
221- bytes[ prefix_len..total_len] . copy_from_slice ( endpoint) ;
222- Ok ( Self :: new ( total_len, bytes) )
223- }
224-
225214 /// Decodes `bytes` in BSATN to a [`BytesKey<N>`]
226215 /// by copying over the bytes if they fit into the key.
227216 pub ( super ) fn from_bsatn ( ty : & AlgebraicType , bytes : & [ u8 ] ) -> DecodeResult < Self > {
228- // Validate the BSATN. See `Self::from_bsatn_prefix_and_endpoint` for more details.
229- WithTypespace :: empty ( ty) . validate ( Deserializer :: new ( & mut { bytes } ) ) ?;
230- // Check that the `bytes` fit into the key.
231- let got = bytes. len ( ) ;
232- Self :: ensure_key_fits ( got) ?;
217+ // Validate the BSATN.
218+ validate ( ty, bytes) ?;
233219 // Copy the bytes over.
220+ let got = bytes. len ( ) ;
234221 let mut arr = [ 0 ; N ] ;
235222 arr[ ..got] . copy_from_slice ( bytes) ;
236223 Ok ( Self :: new ( got, arr) )
@@ -343,6 +330,11 @@ fn split_map_write_back<const N: usize>(slice: &mut [u8], map_bytes: impl FnOnce
343330}
344331
345332impl < const N : usize > RangeCompatBytesKey < N > {
333+ fn new ( length : usize , bytes : [ u8 ; N ] ) -> Self {
334+ let length = length as _ ;
335+ Self { length, bytes }
336+ }
337+
346338 /// Decodes `self` as an [`AlgebraicValue`] at `key_type`.
347339 ///
348340 /// An incorrect `key_type`,
@@ -354,6 +346,26 @@ impl<const N: usize> RangeCompatBytesKey<N> {
354346 Self :: to_bytes_key ( * self , key_type) . decode_algebraic_value ( key_type)
355347 }
356348
349+ /// Decodes `prefix` in BSATN to a [`RangeCompatBytesKey<N>`]
350+ /// by copying over `prefix` and massaging if they fit into the key.
351+ pub ( super ) fn from_bsatn_prefix ( prefix : & [ u8 ] , prefix_types : & [ ProductTypeElement ] ) -> DecodeResult < Self > {
352+ // Validate the BSATN.
353+ validate ( prefix_types, prefix) ?;
354+
355+ // Copy the `prefix` over.
356+ let mut bytes = [ 0 ; N ] ;
357+ let got = prefix. len ( ) ;
358+ bytes[ ..got] . copy_from_slice ( prefix) ;
359+
360+ // Massage the `bytes`.
361+ let mut slice = bytes. as_mut_slice ( ) ;
362+ for ty in prefix_types {
363+ slice = Self :: process_from_bytes_key ( slice, & ty. algebraic_type ) ;
364+ }
365+
366+ Ok ( Self :: new ( got, bytes) )
367+ }
368+
357369 /// Decodes `prefix` and `endpoint` in BSATN to a [`RangeCompatBytesKey<N>`]
358370 /// by copying over both and massaging if they fit into the key.
359371 pub ( super ) fn from_bsatn_prefix_and_endpoint (
@@ -362,17 +374,28 @@ impl<const N: usize> RangeCompatBytesKey<N> {
362374 endpoint : & [ u8 ] ,
363375 range_type : & AlgebraicType ,
364376 ) -> DecodeResult < Self > {
365- let BytesKey { length, mut bytes } =
366- BytesKey :: from_bsatn_prefix_and_endpoint ( prefix, prefix_types, endpoint, range_type) ?;
377+ // Validate the BSATN.
378+ validate ( prefix_types, prefix) ?;
379+ validate ( range_type, endpoint) ?;
380+
381+ // Sum up the lengths.
382+ let prefix_len = prefix. len ( ) ;
383+ let endpoint_len = endpoint. len ( ) ;
384+ let total_len = prefix_len + endpoint_len;
367385
368- // Masage the bytes in `key`.
386+ // Copy the `prefix` and the `endpoint` over.
387+ let mut bytes = [ 0 ; N ] ;
388+ bytes[ ..prefix_len] . copy_from_slice ( prefix) ;
389+ bytes[ prefix_len..total_len] . copy_from_slice ( endpoint) ;
390+
391+ // Massage the bytes.
369392 let mut slice = bytes. as_mut_slice ( ) ;
370393 for ty in prefix_types {
371394 slice = Self :: process_from_bytes_key ( slice, & ty. algebraic_type ) ;
372395 }
373396 Self :: process_from_bytes_key ( slice, range_type) ;
374397
375- Ok ( Self { length , bytes } )
398+ Ok ( Self :: new ( total_len , bytes) )
376399 }
377400
378401 /// Decodes `bytes` in BSATN to a [`RangeCompatBytesKey<N>`]
@@ -492,6 +515,14 @@ impl<const N: usize> RangeCompatBytesKey<N> {
492515 process ( bytes. as_mut_slice ( ) , ty) ;
493516 BytesKey { length, bytes }
494517 }
518+
519+ /// Extend the length to `N` by filling with `u8::MAX`.
520+ pub ( super ) fn add_max_suffix ( mut self ) -> Self {
521+ let len = self . len ( ) ;
522+ self . bytes [ len..] . fill ( u8:: MAX ) ;
523+ self . length = N as u8 ;
524+ self
525+ }
495526}
496527
497528#[ cfg( test) ]
0 commit comments