Skip to content

Commit d5b0816

Browse files
lucylqGithub Executorch
andauthored
Add bounds checking on buffer_idx for constant_buffer and constant_data in XNNPACK (#18820)
Check that buffer_idx is within bounds of `flatbuffer_graph->constant_buffer()` and `flatbuffer_graph->constant_data()`, whichever branch is chosen Authored-with: Claude Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent 26e2ab8 commit d5b0816

1 file changed

Lines changed: 61 additions & 16 deletions

File tree

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ std::vector<T> flatbufferDimsToVector(
170170
/**
171171
Gets the constant data pointer associated with the given tensor value.
172172
Obtaining the constant data pointer can either be from within the flatbuffer
173-
payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174-
data associated with the tensor value, then returns nullptr.
173+
payload (deprecated) or via offsets to the constant_data_ptr.
174+
175+
Failures are returned as an Error, and the successful value may be nullptr
176+
when the tensor has no associated constant data.
175177
*/
176-
const uint8_t* getConstantDataPtr(
178+
Result<const uint8_t*> getConstantDataPtr(
177179
uint32_t buffer_idx,
178180
GraphPtr flatbuffer_graph,
179181
const uint8_t* constant_data_ptr,
@@ -184,26 +186,56 @@ const uint8_t* getConstantDataPtr(
184186
if (!constant_data_ptr) {
185187
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186188
// window
187-
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
188-
return constant_buffer[buffer_idx]->storage()->data();
189+
auto* cb = flatbuffer_graph->constant_buffer();
190+
ET_CHECK_OR_RETURN_ERROR(
191+
cb != nullptr, InvalidProgram, "constant_buffer is null");
192+
ET_CHECK_OR_RETURN_ERROR(
193+
buffer_idx < cb->size(),
194+
InvalidProgram,
195+
"buffer_idx %u out of bounds for constant_buffer of size %zu",
196+
buffer_idx,
197+
cb->size());
198+
auto* buffer_entry = (*cb)[buffer_idx];
199+
ET_CHECK_OR_RETURN_ERROR(
200+
buffer_entry != nullptr && buffer_entry->storage() != nullptr,
201+
InvalidProgram,
202+
"Null constant_buffer entry at buffer_idx %u",
203+
buffer_idx);
204+
return buffer_entry->storage()->data();
189205
} else {
190-
ConstantDataOffsetPtr constant_data_offset =
191-
flatbuffer_graph->constant_data()->Get(buffer_idx);
206+
auto* cd = flatbuffer_graph->constant_data();
207+
ET_CHECK_OR_RETURN_ERROR(
208+
cd != nullptr, InvalidProgram, "constant_data is null");
209+
ET_CHECK_OR_RETURN_ERROR(
210+
buffer_idx < cd->size(),
211+
InvalidProgram,
212+
"buffer_idx %u out of bounds for constant_data of size %zu",
213+
buffer_idx,
214+
cd->size());
215+
ConstantDataOffsetPtr constant_data_offset = cd->Get(buffer_idx);
216+
ET_CHECK_OR_RETURN_ERROR(
217+
constant_data_offset != nullptr,
218+
InvalidProgram,
219+
"Null constant_data entry at buffer_idx %u",
220+
buffer_idx);
192221
uint64_t offset = constant_data_offset->offset();
193-
194222
bool has_named_key = flatbuffers::IsFieldPresent(
195223
constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196224
// If there is no tensor name
197225
if (!has_named_key) {
198226
return constant_data_ptr + offset;
199227
} else {
228+
ET_CHECK_OR_RETURN_ERROR(
229+
constant_data_offset->named_key() != nullptr,
230+
InvalidProgram,
231+
"Named key is null");
200232
const std::string& data_name = constant_data_offset->named_key()->str();
201233
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
202234
Result<const uint8_t*> data_ptr =
203235
weights_cache->load_unpacked_data(data_name);
204236
if (!data_ptr.ok()) {
205237
ET_LOG(Error, "Failed to load weights from cache");
206-
return nullptr;
238+
return data_ptr.error();
207239
}
208240
return data_ptr.get();
209241
#else
@@ -215,7 +247,7 @@ const uint8_t* getConstantDataPtr(
215247
"Failed to get constant data for key %s from named_data_map. Error code: %u",
216248
data_name.c_str(),
217249
static_cast<uint32_t>(buffer.error()));
218-
return nullptr;
250+
return buffer.error();
219251
}
220252
const uint8_t* data_ptr =
221253
static_cast<const uint8_t*>(buffer.get().data());
@@ -229,7 +261,7 @@ const uint8_t* getConstantDataPtr(
229261
return nullptr;
230262
}
231263

232-
const uint8_t* getConstantDataPtr(
264+
Result<const uint8_t*> getConstantDataPtr(
233265
const fb_xnnpack::XNNTensorValue* tensor_value,
234266
GraphPtr flatbuffer_graph,
235267
const uint8_t* constant_data_ptr,
@@ -298,13 +330,17 @@ Error defineTensor(
298330

299331
// Get Pointer to constant data from flatbuffer, if its non-constant
300332
// it is a nullptr
301-
const uint8_t* buffer_ptr = getConstantDataPtr(
333+
auto buffer_result = getConstantDataPtr(
302334
tensor_value,
303335
flatbuffer_graph,
304336
constant_data_ptr,
305337
named_data_map,
306338
freeable_buffers,
307339
weights_cache);
340+
if (!buffer_result.ok()) {
341+
return buffer_result.error();
342+
}
343+
const uint8_t* buffer_ptr = buffer_result.get();
308344

309345
xnn_status status;
310346
// The type we might have to convert to
@@ -449,13 +485,17 @@ Error defineTensor(
449485
const float* scale = qparams->scale()->data();
450486

451487
if (qparams->scale_buffer_idx() != 0) {
452-
scale = reinterpret_cast<const float*>(getConstantDataPtr(
488+
auto scale_result = getConstantDataPtr(
453489
qparams->scale_buffer_idx(),
454490
flatbuffer_graph,
455491
constant_data_ptr,
456492
named_data_map,
457493
freeable_buffers,
458-
weights_cache));
494+
weights_cache);
495+
if (!scale_result.ok()) {
496+
return scale_result.error();
497+
}
498+
scale = reinterpret_cast<const float*>(scale_result.get());
459499
ET_CHECK_OR_RETURN_ERROR(
460500
scale != nullptr, Internal, "Failed to load scale data.");
461501
}
@@ -491,13 +531,18 @@ Error defineTensor(
491531
// Block scales are preferably serialized as bf16 but can also be
492532
// serialized as fp32 for backwards compatability.
493533
if (qparams->scale_buffer_idx() != 0) {
494-
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
534+
auto scale_data_result = getConstantDataPtr(
495535
qparams->scale_buffer_idx(),
496536
flatbuffer_graph,
497537
constant_data_ptr,
498538
named_data_map,
499539
freeable_buffers,
500-
weights_cache));
540+
weights_cache);
541+
if (!scale_data_result.ok()) {
542+
return scale_data_result.error();
543+
}
544+
scale_data =
545+
reinterpret_cast<const uint16_t*>(scale_data_result.get());
501546
ET_CHECK_OR_RETURN_ERROR(
502547
scale_data != nullptr, Internal, "Failed to load scale data.");
503548
scale_numel = qparams->num_scales();

0 commit comments

Comments
 (0)