Skip to content

Commit 039f813

Browse files
committed
Bug-fix
1 parent f09b710 commit 039f813

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,21 +126,13 @@ void GPUTPCNNClusterizerHost::initClusterizer(const GPUSettingsProcessingNNclust
126126
// Define the datatype for input and output
127127
if (settings.nnInferenceInputDType.find("32") != std::string::npos) {
128128
clustererNN.mNnInferenceInputDType = 0;
129-
} else if (settings.nnInferenceInputDType.find("16") != std::string::npos) {
130-
clustererNN.mNnInferenceInputDType = 1;
131-
} else if (settings.nnInferenceInputDType.find("8") != std::string::npos) {
132-
clustererNN.mNnInferenceInputDType = 2;
133129
} else {
134-
clustererNN.mNnInferenceInputDType = 1; // Default to float32
130+
clustererNN.mNnInferenceInputDType = 1; // Default to float16
135131
}
136132
if (settings.nnInferenceOutputDType.find("32") != std::string::npos) {
137133
clustererNN.mNnInferenceOutputDType = 0;
138-
} else if (settings.nnInferenceOutputDType.find("16") != std::string::npos) {
139-
clustererNN.mNnInferenceOutputDType = 1;
140-
} else if (settings.nnInferenceOutputDType.find("8") != std::string::npos) {
141-
clustererNN.mNnInferenceOutputDType = 2;
142134
} else {
143-
clustererNN.mNnInferenceOutputDType = 1; // Default to float32
135+
clustererNN.mNnInferenceOutputDType = 1; // Default to float16
144136
}
145137
clustererNN.mNnClusterizerModelClassNumOutputNodes = mModelClass.getNumOutputNodes()[0][1];
146138
if (!settings.nnClusterizerUseCfRegression) {

0 commit comments

Comments
 (0)