Skip to content

Commit f09b710

Browse files
committed
Reverting int8 changes as int8 computations are taken care of in the ONNX graph due to requirement of additional scaling
1 parent e7e21b4 commit f09b710

6 files changed

Lines changed: 50 additions & 182 deletions

File tree

Common/ML/src/OrtInterface.cxx

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ void OrtModel::initEnvironment()
140140

141141
void OrtModel::initSessionFromBuffer(const char* buffer, size_t bufferSize)
142142
{
143+
if (mAllocateDeviceMemory) {
144+
memoryOnDevice(mDeviceId);
145+
}
143146
mPImplOrt->sessionOptions.AddConfigEntry("session.load_model_format", "ONNX");
144147
mPImplOrt->sessionOptions.AddConfigEntry("session.use_ort_model_bytes_directly", "1");
145148

@@ -354,11 +357,6 @@ template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t
354357
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355358
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356359
template void OrtModel::inference<float, float>(float*, int64_t, float*);
357-
template void OrtModel::inference<int8_t, int8_t>(int8_t*, int64_t, int8_t*);
358-
template void OrtModel::inference<int8_t, float>(int8_t*, int64_t, float*);
359-
template void OrtModel::inference<float, int8_t>(float*, int64_t, int8_t*);
360-
template void OrtModel::inference<int8_t, OrtDataType::Float16_t>(int8_t*, int64_t, OrtDataType::Float16_t*);
361-
template void OrtModel::inference<OrtDataType::Float16_t, int8_t>(OrtDataType::Float16_t*, int64_t, int8_t*);
362360

363361
template <class I, class O>
364362
void OrtModel::inference(I** input, int64_t input_size, O* output)
@@ -419,11 +417,6 @@ template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t
419417
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
420418
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
421419
template void OrtModel::inference<float, float>(float**, int64_t, float*);
422-
template void OrtModel::inference<int8_t, int8_t>(int8_t**, int64_t, int8_t*);
423-
template void OrtModel::inference<int8_t, float>(int8_t**, int64_t, float*);
424-
template void OrtModel::inference<float, int8_t>(float**, int64_t, int8_t*);
425-
template void OrtModel::inference<int8_t, OrtDataType::Float16_t>(int8_t**, int64_t, OrtDataType::Float16_t*);
426-
template void OrtModel::inference<OrtDataType::Float16_t, int8_t>(OrtDataType::Float16_t**, int64_t, int8_t*);
427420

428421
template <class I, class O>
429422
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)

GPU/GPUTracking/Base/cuda/GPUReconstructionCUDA.cu

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -631,34 +631,40 @@ void GPUReconstructionCUDA::loadKernelModules(bool perKernel)
631631
} \
632632
}
633633

634-
void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions& session_options, int32_t stream, int32_t* deviceId)
635-
{
634+
void GPUReconstructionCUDA::SetONNXGPUStream(Ort::SessionOptions& sessionOptions, int32_t stream, int32_t* deviceId) {
636635
GPUChkErr(cudaGetDevice(deviceId));
636+
637637
#if !defined(__HIPCC__) && defined(ORT_CUDA_BUILD)
638638
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
639-
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
640-
ORTCHK(api->CreateCUDAProviderOptions(&cuda_options));
641639

642-
// std::vector<const char*> keys{"device_id", "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"};
643-
// std::vector<const char*> values{"0", "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", "1"};
644-
// UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size());
640+
#ifdef ORT_TENSORRT_BUILD
641+
OrtTensorRTProviderOptionsV2* trtOptions = nullptr;
642+
ORTCHK(api->CreateTensorRTProviderOptions(&trtOptions));
643+
644+
const std::string device = std::to_string(*deviceId);
645+
const char* keys[] = {"device_id", "trt_int8_enable"};
646+
const char* values[] = {device.c_str(), "1"};
645647

646-
// this implicitly sets "has_user_compute_stream"
647-
ORTCHK(api->UpdateCUDAProviderOptionsWithValue(cuda_options, "user_compute_stream", mInternals->Streams[stream]));
648-
ORTCHK(api->SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options));
648+
ORTCHK(api->UpdateTensorRTProviderOptions(trtOptions,keys,values,sizeof(keys) / sizeof(keys[0])));
649+
ORTCHK(api->UpdateTensorRTProviderOptionsWithValue(trtOptions,"user_compute_stream",mInternals->Streams[stream]));
650+
ORTCHK(api->SessionOptionsAppendExecutionProvider_TensorRT_V2(sessionOptions,trtOptions)); // Register TensorRT first: it consequently has higher priority.
651+
api->ReleaseTensorRTProviderOptions(trtOptions);
652+
#endif
653+
654+
// CUDA is the fallback for nodes unsupported by TensorRT.
655+
OrtCUDAProviderOptionsV2* cudaOptions = nullptr;
656+
ORTCHK(api->CreateCUDAProviderOptions(&cudaOptions));
657+
ORTCHK(api->UpdateCUDAProviderOptionsWithValue(cudaOptions,"user_compute_stream",mInternals->Streams[stream]));
658+
ORTCHK(api->SessionOptionsAppendExecutionProvider_CUDA_V2(sessionOptions,cudaOptions));
659+
api->ReleaseCUDAProviderOptions(cudaOptions);
649660

650-
// Finally, don't forget to release the provider options
651-
api->ReleaseCUDAProviderOptions(cuda_options);
652661
#elif defined(ORT_ROCM_BUILD)
653-
// const auto& api = Ort::GetApi();
654-
// api.GetCurrentGpuDeviceId(deviceId);
655-
OrtROCMProviderOptions rocm_options;
656-
rocm_options.has_user_compute_stream = 1; // Indicate that we are passing a user stream
657-
rocm_options.arena_extend_strategy = 0; // kNextPowerOfTwo = 0, kSameAsRequested = 1 -> https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20kSameAsRequested&type=code
658-
// rocm_options.gpu_mem_limit = 1073741824; // 0 means no limit
659-
rocm_options.user_compute_stream = mInternals->Streams[stream];
660-
session_options.AppendExecutionProvider_ROCM(rocm_options);
661-
#endif // ORT_ROCM_BUILD
662+
OrtROCMProviderOptions rocmOptions;
663+
rocmOptions.has_user_compute_stream = 1;
664+
rocmOptions.arena_extend_strategy = 0;
665+
rocmOptions.user_compute_stream = mInternals->Streams[stream];
666+
sessionOptions.AppendExecutionProvider_ROCM(rocmOptions);
667+
#endif
662668
}
663669

664670
#ifndef __HIPCC__ // CUDA

GPU/GPUTracking/Global/GPUChainTrackingClusterizer.cxx

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,26 +1270,14 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
12701270
if (clustererNNShadow.mNnInferenceInputDType == 0) {
12711271
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
12721272
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_32);
1273-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1273+
} else {
12741274
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_16);
1275-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1276-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_8);
12771275
}
12781276
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
12791277
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
12801278
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_32);
1281-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1279+
} else {
12821280
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_16);
1283-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1284-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_8);
1285-
}
1286-
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1287-
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1288-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_32);
1289-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1290-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_16);
1291-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1292-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_8);
12931281
}
12941282
}
12951283
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane]->Stop(); } // doGPU || lane<4 -> only for GPU or first 4 CPU lanes (to limit number of concurrent timers). At least gives some statistics for CPU time...
@@ -1302,26 +1290,14 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
13021290
if (clustererNNShadow.mNnInferenceInputDType == 0) {
13031291
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
13041292
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_32);
1305-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1293+
} else {
13061294
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_16);
1307-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1308-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_8);
13091295
}
1310-
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
1296+
} else {
13111297
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
13121298
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_32);
1313-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1299+
} else {
13141300
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_16);
1315-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1316-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_8);
1317-
}
1318-
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1319-
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1320-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_32);
1321-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1322-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_16);
1323-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1324-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_8);
13251301
}
13261302
}
13271303
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 1]->Stop(); }
@@ -1330,26 +1306,14 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
13301306
if (clustererNNShadow.mNnInferenceInputDType == 0) {
13311307
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
13321308
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_32);
1333-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1309+
} else {
13341310
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_16);
1335-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1336-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_8);
13371311
}
13381312
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
13391313
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
13401314
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_32);
1341-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1315+
} else {
13421316
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_16);
1343-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1344-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_8);
1345-
}
1346-
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1347-
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1348-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_32);
1349-
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1350-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_16);
1351-
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1352-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_8);
13531317
}
13541318
}
13551319
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 2]->Stop(); }

0 commit comments

Comments
 (0)