@@ -170,10 +170,12 @@ std::vector<T> flatbufferDimsToVector(
170170/* *
171171Gets the constant data pointer associated with the given tensor value.
172172Obtaining the constant data pointer can either be from within the flatbuffer
173- payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174- data associated with the tensor value, then returns nullptr.
173+ payload (deprecated) or via offsets to the constant_data_ptr.
174+
175+ Failures are returned as an Error, and the successful value may be nullptr
176+ when the tensor has no associated constant data.
175177*/
176- const uint8_t * getConstantDataPtr (
178+ Result< const uint8_t *> getConstantDataPtr (
177179 uint32_t buffer_idx,
178180 GraphPtr flatbuffer_graph,
179181 const uint8_t * constant_data_ptr,
@@ -184,26 +186,56 @@ const uint8_t* getConstantDataPtr(
184186 if (!constant_data_ptr) {
185187 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
186188 // window
187- const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
188- return constant_buffer[buffer_idx]->storage ()->data ();
189+ auto * cb = flatbuffer_graph->constant_buffer ();
190+ ET_CHECK_OR_RETURN_ERROR (
191+ cb != nullptr , InvalidProgram, " constant_buffer is null" );
192+ ET_CHECK_OR_RETURN_ERROR (
193+ buffer_idx < cb->size (),
194+ InvalidProgram,
195+ " buffer_idx %u out of bounds for constant_buffer of size %zu" ,
196+ buffer_idx,
197+ cb->size ());
198+ auto * buffer_entry = (*cb)[buffer_idx];
199+ ET_CHECK_OR_RETURN_ERROR (
200+ buffer_entry != nullptr && buffer_entry->storage () != nullptr ,
201+ InvalidProgram,
202+ " Null constant_buffer entry at buffer_idx %u" ,
203+ buffer_idx);
204+ return buffer_entry->storage ()->data ();
189205 } else {
190- ConstantDataOffsetPtr constant_data_offset =
191- flatbuffer_graph->constant_data ()->Get (buffer_idx);
206+ auto * cd = flatbuffer_graph->constant_data ();
207+ ET_CHECK_OR_RETURN_ERROR (
208+ cd != nullptr , InvalidProgram, " constant_data is null" );
209+ ET_CHECK_OR_RETURN_ERROR (
210+ buffer_idx < cd->size (),
211+ InvalidProgram,
212+ " buffer_idx %u out of bounds for constant_data of size %zu" ,
213+ buffer_idx,
214+ cd->size ());
215+ ConstantDataOffsetPtr constant_data_offset = cd->Get (buffer_idx);
216+ ET_CHECK_OR_RETURN_ERROR (
217+ constant_data_offset != nullptr ,
218+ InvalidProgram,
219+ " Null constant_data entry at buffer_idx %u" ,
220+ buffer_idx);
192221 uint64_t offset = constant_data_offset->offset ();
193-
194222 bool has_named_key = flatbuffers::IsFieldPresent (
195223 constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
196224 // If there is no tensor name
197225 if (!has_named_key) {
198226 return constant_data_ptr + offset;
199227 } else {
228+ ET_CHECK_OR_RETURN_ERROR (
229+ constant_data_offset->named_key () != nullptr ,
230+ InvalidProgram,
231+ " Named key is null" );
200232 const std::string& data_name = constant_data_offset->named_key ()->str ();
201233#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
202234 Result<const uint8_t *> data_ptr =
203235 weights_cache->load_unpacked_data (data_name);
204236 if (!data_ptr.ok ()) {
205237 ET_LOG (Error, " Failed to load weights from cache" );
206- return nullptr ;
238+ return data_ptr. error () ;
207239 }
208240 return data_ptr.get ();
209241#else
@@ -215,7 +247,7 @@ const uint8_t* getConstantDataPtr(
215247 " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
216248 data_name.c_str (),
217249 static_cast <uint32_t >(buffer.error ()));
218- return nullptr ;
250+ return buffer. error () ;
219251 }
220252 const uint8_t * data_ptr =
221253 static_cast <const uint8_t *>(buffer.get ().data ());
@@ -229,7 +261,7 @@ const uint8_t* getConstantDataPtr(
229261 return nullptr ;
230262}
231263
232- const uint8_t * getConstantDataPtr (
264+ Result< const uint8_t *> getConstantDataPtr (
233265 const fb_xnnpack::XNNTensorValue* tensor_value,
234266 GraphPtr flatbuffer_graph,
235267 const uint8_t * constant_data_ptr,
@@ -298,13 +330,17 @@ Error defineTensor(
298330
299331 // Get Pointer to constant data from flatbuffer, if its non-constant
300332 // it is a nullptr
301- const uint8_t * buffer_ptr = getConstantDataPtr (
333+ auto buffer_result = getConstantDataPtr (
302334 tensor_value,
303335 flatbuffer_graph,
304336 constant_data_ptr,
305337 named_data_map,
306338 freeable_buffers,
307339 weights_cache);
340+ if (!buffer_result.ok ()) {
341+ return buffer_result.error ();
342+ }
343+ const uint8_t * buffer_ptr = buffer_result.get ();
308344
309345 xnn_status status;
310346 // The type we might have to convert to
@@ -449,13 +485,17 @@ Error defineTensor(
449485 const float * scale = qparams->scale ()->data ();
450486
451487 if (qparams->scale_buffer_idx () != 0 ) {
452- scale = reinterpret_cast < const float *>( getConstantDataPtr (
488+ auto scale_result = getConstantDataPtr (
453489 qparams->scale_buffer_idx (),
454490 flatbuffer_graph,
455491 constant_data_ptr,
456492 named_data_map,
457493 freeable_buffers,
458- weights_cache));
494+ weights_cache);
495+ if (!scale_result.ok ()) {
496+ return scale_result.error ();
497+ }
498+ scale = reinterpret_cast <const float *>(scale_result.get ());
459499 ET_CHECK_OR_RETURN_ERROR (
460500 scale != nullptr , Internal, " Failed to load scale data." );
461501 }
@@ -491,13 +531,18 @@ Error defineTensor(
491531 // Block scales are preferably serialized as bf16 but can also be
492532 // serialized as fp32 for backwards compatability.
493533 if (qparams->scale_buffer_idx () != 0 ) {
494- scale_data = reinterpret_cast < const uint16_t *>( getConstantDataPtr (
534+ auto scale_data_result = getConstantDataPtr (
495535 qparams->scale_buffer_idx (),
496536 flatbuffer_graph,
497537 constant_data_ptr,
498538 named_data_map,
499539 freeable_buffers,
500- weights_cache));
540+ weights_cache);
541+ if (!scale_data_result.ok ()) {
542+ return scale_data_result.error ();
543+ }
544+ scale_data =
545+ reinterpret_cast <const uint16_t *>(scale_data_result.get ());
501546 ET_CHECK_OR_RETURN_ERROR (
502547 scale_data != nullptr , Internal, " Failed to load scale data." );
503548 scale_numel = qparams->num_scales ();
0 commit comments