Skip to content

Commit e7e21b4

Browse files
committed
Adding support for int8 inputs
1 parent 77c2fc3 commit e7e21b4

8 files changed

Lines changed: 242 additions & 78 deletions

File tree

Common/ML/src/OrtInterface.cxx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,11 @@ template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t
354354
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t*, int64_t, float*);
355355
template void OrtModel::inference<float, OrtDataType::Float16_t>(float*, int64_t, OrtDataType::Float16_t*);
356356
template void OrtModel::inference<float, float>(float*, int64_t, float*);
357+
template void OrtModel::inference<int8_t, int8_t>(int8_t*, int64_t, int8_t*);
358+
template void OrtModel::inference<int8_t, float>(int8_t*, int64_t, float*);
359+
template void OrtModel::inference<float, int8_t>(float*, int64_t, int8_t*);
360+
template void OrtModel::inference<int8_t, OrtDataType::Float16_t>(int8_t*, int64_t, OrtDataType::Float16_t*);
361+
template void OrtModel::inference<OrtDataType::Float16_t, int8_t>(OrtDataType::Float16_t*, int64_t, int8_t*);
357362

358363
template <class I, class O>
359364
void OrtModel::inference(I** input, int64_t input_size, O* output)
@@ -414,6 +419,11 @@ template void OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t
414419
template void OrtModel::inference<OrtDataType::Float16_t, float>(OrtDataType::Float16_t**, int64_t, float*);
415420
template void OrtModel::inference<float, OrtDataType::Float16_t>(float**, int64_t, OrtDataType::Float16_t*);
416421
template void OrtModel::inference<float, float>(float**, int64_t, float*);
422+
template void OrtModel::inference<int8_t, int8_t>(int8_t**, int64_t, int8_t*);
423+
template void OrtModel::inference<int8_t, float>(int8_t**, int64_t, float*);
424+
template void OrtModel::inference<float, int8_t>(float**, int64_t, int8_t*);
425+
template void OrtModel::inference<int8_t, OrtDataType::Float16_t>(int8_t**, int64_t, OrtDataType::Float16_t*);
426+
template void OrtModel::inference<OrtDataType::Float16_t, int8_t>(OrtDataType::Float16_t**, int64_t, int8_t*);
417427

418428
template <class I, class O>
419429
std::vector<O> OrtModel::inference(std::vector<std::vector<I>>& inputs)

GPU/GPUTracking/Definitions/GPUSettingsList.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ AddOption(nnCCDBClassificationLayerType, std::string, "FC", "", 0, "Distinguishe
300300
AddOption(nnCCDBRegressionLayerType, std::string, "FC", "", 0, "Distinguishes between network with different layer types. Options: FC, CNN")
301301
AddOption(nnCCDBBeamType, std::string, "pp", "", 0, "Distinguishes between networks trained for different beam types. Options: pp, pPb, PbPb")
302302
AddOption(nnCCDBInteractionRate, std::string, "500", "", 0, "Distinguishes between networks for different interaction rates [kHz].")
303+
AddOption(nnCCDBExtraMetadata, std::string, "", "", 0, "Extra metadata to distinguish between networks, e.g. for different internal datatypes, etc.")
303304
AddHelp("help", 'h')
304305
EndConfig()
305306

GPU/GPUTracking/Global/GPUChainTrackingClusterizer.cxx

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,15 +1269,27 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
12691269
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane]->Start(); }
12701270
if (clustererNNShadow.mNnInferenceInputDType == 0) {
12711271
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1272-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_16);
1272+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_32);
12731273
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1274-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_32);
1274+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_16);
1275+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1276+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_8);
12751277
}
12761278
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
12771279
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1278-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_16);
1280+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_32);
12791281
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1280-
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mModelProbabilities_32);
1282+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_16);
1283+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1284+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mModelProbabilities_8);
1285+
}
1286+
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1287+
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1288+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_32);
1289+
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1290+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_16);
1291+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1292+
(nnApplication.mModelClass).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mModelProbabilities_8);
12811293
}
12821294
}
12831295
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane]->Stop(); } // doGPU || lane<4 -> only for GPU or first 4 CPU lanes (to limit number of concurrent timers). At least gives some statistics for CPU time...
@@ -1289,31 +1301,55 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
12891301
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 1]->Start(); }
12901302
if (clustererNNShadow.mNnInferenceInputDType == 0) {
12911303
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1292-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_16);
1304+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_32);
12931305
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1294-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_32);
1306+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_16);
1307+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1308+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_8);
12951309
}
12961310
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
12971311
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1298-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_16);
1312+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_32);
12991313
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1300-
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg1_32);
1314+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_16);
1315+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1316+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg1_8);
1317+
}
1318+
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1319+
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1320+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_32);
1321+
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1322+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_16);
1323+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1324+
(nnApplication.mModelReg1).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg1_8);
13011325
}
13021326
}
13031327
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 1]->Stop(); }
13041328
if (nnApplication.mModelClass.getNumOutputNodes()[0][1] > 1 && nnApplication.mModelReg2.isInitialized()) {
13051329
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 2]->Start(); }
13061330
if (clustererNNShadow.mNnInferenceInputDType == 0) {
13071331
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1308-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_16);
1332+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_32);
13091333
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1310-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_32);
1334+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_16);
1335+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1336+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_8);
13111337
}
13121338
} else if (clustererNNShadow.mNnInferenceInputDType == 1) {
13131339
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1314-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_16);
1340+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_32);
13151341
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1316-
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_32, iSize, clustererNNShadow.mOutputDataReg2_32);
1342+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_16);
1343+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1344+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_16, iSize, clustererNNShadow.mOutputDataReg2_8);
1345+
}
1346+
} else if (clustererNNShadow.mNnInferenceInputDType == 2) {
1347+
if (clustererNNShadow.mNnInferenceOutputDType == 0) {
1348+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_32);
1349+
} else if (clustererNNShadow.mNnInferenceOutputDType == 1) {
1350+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_16);
1351+
} else if (clustererNNShadow.mNnInferenceOutputDType == 2) {
1352+
(nnApplication.mModelReg2).inference(clustererNNShadow.mInputData_8, iSize, clustererNNShadow.mOutputDataReg2_8);
13171353
}
13181354
}
13191355
if(GetProcessingSettings().debugLevel >= 1 && (doGPU || lane < 4)) { nnTimers[3*lane + 2]->Stop(); }

0 commit comments

Comments
 (0)