@@ -50,12 +50,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5050 return UR_RESULT_SUCCESS;
5151}
5252
53- bool isInstrumentedKernel (ur_kernel_handle_t hKernel) {
54- auto hProgram = GetProgram (hKernel);
55- auto PI = getMsanInterceptor ()->getProgramInfo (hProgram);
56- return PI->isKernelInstrumented (hKernel);
57- }
58-
5953} // namespace
6054
6155// /////////////////////////////////////////////////////////////////////////////
@@ -354,12 +348,6 @@ ur_result_t urEnqueueKernelLaunch(
354348
355349 getContext ()->logger .debug (" ==== urEnqueueKernelLaunch" );
356350
357- if (!isInstrumentedKernel (hKernel)) {
358- return pfnKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
359- pGlobalWorkSize, pLocalWorkSize,
360- numEventsInWaitList, phEventWaitList, phEvent);
361- }
362-
363351 USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
364352 pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
365353 workDim);
@@ -1155,26 +1143,6 @@ ur_result_t urEnqueueMemUnmap(
11551143 return UR_RESULT_SUCCESS;
11561144}
11571145
1158- // /////////////////////////////////////////////////////////////////////////////
1159- // / @brief Intercept function for urKernelCreate
1160- ur_result_t urKernelCreate (
1161- ur_program_handle_t hProgram, // /< [in] handle of the program instance
1162- const char *pKernelName, // /< [in] pointer to null-terminated string.
1163- ur_kernel_handle_t
1164- *phKernel // /< [out] pointer to handle of kernel object created.
1165- ) {
1166- auto pfnCreate = getContext ()->urDdiTable .Kernel .pfnCreate ;
1167-
1168- getContext ()->logger .debug (" ==== urKernelCreate" );
1169-
1170- UR_CALL (pfnCreate (hProgram, pKernelName, phKernel));
1171- if (isInstrumentedKernel (*phKernel)) {
1172- UR_CALL (getMsanInterceptor ()->insertKernel (*phKernel));
1173- }
1174-
1175- return UR_RESULT_SUCCESS;
1176- }
1177-
11781146// /////////////////////////////////////////////////////////////////////////////
11791147// / @brief Intercept function for urKernelRetain
11801148ur_result_t urKernelRetain (
@@ -1186,10 +1154,8 @@ ur_result_t urKernelRetain(
11861154
11871155 UR_CALL (pfnRetain (hKernel));
11881156
1189- auto KernelInfo = getMsanInterceptor ()->getKernelInfo (hKernel);
1190- if (KernelInfo) {
1191- KernelInfo->RefCount ++;
1192- }
1157+ auto &KernelInfo = getMsanInterceptor ()->getOrCreateKernelInfo (hKernel);
1158+ KernelInfo.RefCount ++;
11931159
11941160 return UR_RESULT_SUCCESS;
11951161}
@@ -1204,11 +1170,9 @@ ur_result_t urKernelRelease(
12041170 getContext ()->logger .debug (" ==== urKernelRelease" );
12051171 UR_CALL (pfnRelease (hKernel));
12061172
1207- auto KernelInfo = getMsanInterceptor ()->getKernelInfo (hKernel);
1208- if (KernelInfo) {
1209- if (--KernelInfo->RefCount == 0 ) {
1210- UR_CALL (getMsanInterceptor ()->eraseKernel (hKernel));
1211- }
1173+ auto &KernelInfo = getMsanInterceptor ()->getOrCreateKernelInfo (hKernel);
1174+ if (--KernelInfo.RefCount == 0 ) {
1175+ UR_CALL (getMsanInterceptor ()->eraseKernelInfo (hKernel));
12121176 }
12131177
12141178 return UR_RESULT_SUCCESS;
@@ -1230,13 +1194,12 @@ ur_result_t urKernelSetArgValue(
12301194 getContext ()->logger .debug (" ==== urKernelSetArgValue" );
12311195
12321196 std::shared_ptr<MemBuffer> MemBuffer;
1233- std::shared_ptr<KernelInfo> KernelInfo;
12341197 if (argSize == sizeof (ur_mem_handle_t ) &&
12351198 (MemBuffer = getMsanInterceptor ()->getMemBuffer (
1236- *ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
1237- ( KernelInfo = getMsanInterceptor ()->getKernelInfo (hKernel))) {
1238- std::scoped_lock<ur_shared_mutex> Guard (KernelInfo-> Mutex );
1239- KernelInfo-> BufferArgs [argIndex] = std::move (MemBuffer);
1199+ *ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
1200+ auto & KernelInfo = getMsanInterceptor ()->getOrCreateKernelInfo (hKernel);
1201+ std::scoped_lock<ur_shared_mutex> Guard (KernelInfo. Mutex );
1202+ KernelInfo. BufferArgs [argIndex] = std::move (MemBuffer);
12401203 } else {
12411204 UR_CALL (
12421205 pfnSetArgValue (hKernel, argIndex, argSize, pProperties, pArgValue));
@@ -1260,10 +1223,10 @@ ur_result_t urKernelSetArgMemObj(
12601223
12611224 std::shared_ptr<MemBuffer> MemBuffer;
12621225 std::shared_ptr<KernelInfo> KernelInfo;
1263- if ((MemBuffer = getMsanInterceptor ()->getMemBuffer (hArgValue)) &&
1264- ( KernelInfo = getMsanInterceptor ()->getKernelInfo (hKernel))) {
1265- std::scoped_lock<ur_shared_mutex> Guard (KernelInfo-> Mutex );
1266- KernelInfo-> BufferArgs [argIndex] = std::move (MemBuffer);
1226+ if ((MemBuffer = getMsanInterceptor ()->getMemBuffer (hArgValue))) {
1227+ auto & KernelInfo = getMsanInterceptor ()->getOrCreateKernelInfo (hKernel);
1228+ std::scoped_lock<ur_shared_mutex> Guard (KernelInfo. Mutex );
1229+ KernelInfo. BufferArgs [argIndex] = std::move (MemBuffer);
12671230 } else {
12681231 UR_CALL (pfnSetArgMemObj (hKernel, argIndex, pProperties, hArgValue));
12691232 }
@@ -1348,7 +1311,6 @@ ur_result_t urGetKernelProcAddrTable(
13481311) {
13491312 ur_result_t result = UR_RESULT_SUCCESS;
13501313
1351- pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate;
13521314 pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain;
13531315 pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease;
13541316 pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue;
0 commit comments