Skip to content

Commit 43dabda

Browse files
committed
Refine some code
1 parent 0e5144d commit 43dabda

2 files changed

Lines changed: 84 additions & 69 deletions

File tree

source/loader/layers/sanitizer/asan_interceptor.cpp

Lines changed: 53 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,6 @@ ur_result_t SanitizerInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
248248
auto Device = GetDevice(Queue);
249249
auto ContextInfo = getContextInfo(Context);
250250
auto DeviceInfo = getDeviceInfo(Device);
251-
auto KernelInfo = getKernelInfo(Kernel);
252-
253-
UR_CALL(LaunchInfo.updateKernelInfo(*KernelInfo.get()));
254251

255252
ManagedQueue InternalQueue(Context, Device);
256253
if (!InternalQueue) {
@@ -663,34 +660,19 @@ ur_result_t SanitizerInterceptor::prepareLaunch(
663660
LocalWorkSize[Dim];
664661
}
665662

666-
// Set launch info argument
667663
auto ArgNums = GetKernelNumArgs(Kernel);
668664
if (ArgNums == 0) {
669665
return UR_RESULT_SUCCESS;
670666
}
671667

668+
// Prepare asan runtime data
672669
LaunchInfo.Data.Host.GlobalShadowOffset =
673670
DeviceInfo->Shadow->ShadowBegin;
674671
LaunchInfo.Data.Host.GlobalShadowOffsetEnd =
675672
DeviceInfo->Shadow->ShadowEnd;
676673
LaunchInfo.Data.Host.DeviceTy = DeviceInfo->Type;
677674
LaunchInfo.Data.Host.Debug = getOptions().Debug ? 1 : 0;
678675

679-
UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc(
680-
ContextInfo->Handle, DeviceInfo->Handle, nullptr, nullptr,
681-
sizeof(LaunchInfo), (void **)&LaunchInfo.Data.DevicePtr));
682-
getContext()->logger.debug(
683-
"launch_info {} (numLocalArgs={}, localArgs={})",
684-
(void *)LaunchInfo.Data.DevicePtr,
685-
LaunchInfo.Data.Host.NumLocalArgs,
686-
(void *)LaunchInfo.Data.Host.LocalArgs);
687-
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
688-
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.DevicePtr);
689-
if (URes != UR_RESULT_SUCCESS) {
690-
getContext()->logger.error("Failed to set launch info: {}", URes);
691-
return URes;
692-
}
693-
694676
auto EnqueueAllocateShadowMemory = [Context = ContextInfo->Handle,
695677
Device = DeviceInfo->Handle,
696678
Queue](size_t Size, uptr &Ptr) {
@@ -807,8 +789,34 @@ ur_result_t SanitizerInterceptor::prepareLaunch(
807789
}
808790
}
809791

810-
// Prepare launch info for device side
792+
// Write local arguments info
793+
if (!KernelInfo->LocalArgs.empty()) {
794+
std::vector<LocalArgsInfo> LocalArgsInfo;
795+
for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) {
796+
LocalArgsInfo.push_back(ArgInfo);
797+
getContext()->logger.debug(
798+
"local_args (argIndex={}, size={}, sizeWithRZ={})",
799+
ArgIndex, ArgInfo.Size, ArgInfo.SizeWithRedZone);
800+
}
801+
UR_CALL(LaunchInfo.Data.importLocalArgsInfo(Queue, LocalArgsInfo));
802+
}
803+
804+
// sync asan runtime data to device side
811805
UR_CALL(LaunchInfo.Data.syncToDevice(Queue));
806+
807+
// set kernel argument
808+
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
809+
Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr());
810+
if (URes != UR_RESULT_SUCCESS) {
811+
getContext()->logger.error("Failed to set launch info: {}", URes);
812+
return URes;
813+
}
814+
815+
getContext()->logger.debug(
816+
"launch_info {} (numLocalArgs={}, localArgs={})",
817+
(void *)LaunchInfo.Data.getDevicePtr(),
818+
LaunchInfo.Data.Host.NumLocalArgs,
819+
(void *)LaunchInfo.Data.Host.LocalArgs);
812820
} while (false);
813821

814822
return UR_RESULT_SUCCESS;
@@ -860,54 +868,39 @@ ContextInfo::~ContextInfo() {
860868
}
861869
}
862870

863-
ur_result_t LaunchInfo::updateKernelInfo(const KernelInfo &KI) {
864-
if (!KI.LocalArgs.empty()) {
865-
std::vector<LocalArgsInfo> LocalArgsInfo;
866-
for (auto [ArgIndex, ArgInfo] : KI.LocalArgs) {
867-
LocalArgsInfo.push_back(ArgInfo);
868-
getContext()->logger.debug(
869-
"local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex,
870-
ArgInfo.Size, ArgInfo.SizeWithRedZone);
871-
}
872-
ManagedQueue Queue(Context, Device);
873-
UR_CALL(
874-
Data.importLocalArgsInfo(Context, Device, Queue, LocalArgsInfo));
875-
}
876-
return UR_RESULT_SUCCESS;
877-
}
878-
879-
LaunchInfo::~LaunchInfo() {
871+
AsanRuntimeDataWrapper::~AsanRuntimeDataWrapper() {
880872
[[maybe_unused]] ur_result_t Result;
881-
if (Data.DevicePtr) {
882-
auto Type = GetDeviceType(Context, Device);
883-
auto ContextInfo = getContext()->interceptor->getContextInfo(Context);
884-
if (Type == DeviceType::GPU_PVC || Type == DeviceType::GPU_DG2) {
885-
if (Data.Host.PrivateShadowOffset) {
886-
ContextInfo->Stats.UpdateShadowFreed(
887-
Data.Host.PrivateShadowOffsetEnd -
888-
Data.Host.PrivateShadowOffset + 1);
889-
Result = getContext()->urDdiTable.USM.pfnFree(
890-
Context, (void *)Data.Host.PrivateShadowOffset);
891-
assert(Result == UR_RESULT_SUCCESS);
892-
}
893-
if (Data.Host.LocalShadowOffset) {
894-
ContextInfo->Stats.UpdateShadowFreed(
895-
Data.Host.LocalShadowOffsetEnd -
896-
Data.Host.LocalShadowOffset + 1);
897-
Result = getContext()->urDdiTable.USM.pfnFree(
898-
Context, (void *)Data.Host.LocalShadowOffset);
899-
assert(Result == UR_RESULT_SUCCESS);
900-
}
873+
auto Type = GetDeviceType(Context, Device);
874+
auto ContextInfo = getContext()->interceptor->getContextInfo(Context);
875+
if (Type == DeviceType::GPU_PVC || Type == DeviceType::GPU_DG2) {
876+
if (Host.PrivateShadowOffset) {
877+
ContextInfo->Stats.UpdateShadowFreed(Host.PrivateShadowOffsetEnd -
878+
Host.PrivateShadowOffset + 1);
879+
Result = getContext()->urDdiTable.USM.pfnFree(
880+
Context, (void *)Host.PrivateShadowOffset);
881+
assert(Result == UR_RESULT_SUCCESS);
901882
}
902-
if (Data.Host.LocalArgs) {
883+
if (Host.LocalShadowOffset) {
884+
ContextInfo->Stats.UpdateShadowFreed(Host.LocalShadowOffsetEnd -
885+
Host.LocalShadowOffset + 1);
903886
Result = getContext()->urDdiTable.USM.pfnFree(
904-
Context, (void *)Data.Host.LocalArgs);
887+
Context, (void *)Host.LocalShadowOffset);
905888
assert(Result == UR_RESULT_SUCCESS);
906889
}
890+
}
891+
if (Host.LocalArgs) {
907892
Result = getContext()->urDdiTable.USM.pfnFree(Context,
908-
(void *)Data.DevicePtr);
893+
(void *)Host.LocalArgs);
909894
assert(Result == UR_RESULT_SUCCESS);
910895
}
896+
if (DevicePtr) {
897+
Result = getContext()->urDdiTable.USM.pfnFree(Context, DevicePtr);
898+
assert(Result == UR_RESULT_SUCCESS);
899+
}
900+
}
901+
902+
LaunchInfo::~LaunchInfo() {
903+
[[maybe_unused]] ur_result_t Result;
911904
Result = getContext()->urDdiTable.Context.pfnRelease(Context);
912905
assert(Result == UR_RESULT_SUCCESS);
913906
Result = getContext()->urDdiTable.Device.pfnRelease(Device);

source/loader/layers/sanitizer/asan_interceptor.hpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,48 @@ struct AsanRuntimeDataWrapper {
159159

160160
AsanRuntimeData *DevicePtr = nullptr;
161161

162+
ur_context_handle_t Context{};
163+
164+
ur_device_handle_t Device{};
165+
166+
AsanRuntimeDataWrapper(ur_context_handle_t Context,
167+
ur_device_handle_t Device)
168+
: Context(Context), Device(Device) {}
169+
170+
~AsanRuntimeDataWrapper();
171+
172+
AsanRuntimeData *getDevicePtr() {
173+
if (DevicePtr == nullptr) {
174+
ur_result_t Result = getContext()->urDdiTable.USM.pfnDeviceAlloc(
175+
Context, Device, nullptr, nullptr, sizeof(AsanRuntimeData),
176+
(void **)&DevicePtr);
177+
if (Result != UR_RESULT_SUCCESS) {
178+
getContext()->logger.error(
179+
"Failed to alloc device usm for asan runtime data: {}",
180+
Result);
181+
}
182+
}
183+
return DevicePtr;
184+
}
185+
162186
ur_result_t syncFromDevice(ur_queue_handle_t Queue) {
163187
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
164-
Queue, true, ur_cast<void *>(&Host), DevicePtr,
188+
Queue, true, ur_cast<void *>(&Host), getDevicePtr(),
165189
sizeof(AsanRuntimeData), 0, nullptr, nullptr));
166190

167191
return UR_RESULT_SUCCESS;
168192
}
169193

170194
ur_result_t syncToDevice(ur_queue_handle_t Queue) {
171195
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
172-
Queue, true, DevicePtr, ur_cast<void *>(&Host),
196+
Queue, true, getDevicePtr(), ur_cast<void *>(&Host),
173197
sizeof(AsanRuntimeData), 0, nullptr, nullptr));
174198

175199
return UR_RESULT_SUCCESS;
176200
}
177201

178202
ur_result_t
179-
importLocalArgsInfo(ur_context_handle_t Context, ur_device_handle_t Device,
180-
ur_queue_handle_t Queue,
203+
importLocalArgsInfo(ur_queue_handle_t Queue,
181204
const std::vector<LocalArgsInfo> &LocalArgs) {
182205
assert(!LocalArgs.empty());
183206

@@ -197,20 +220,21 @@ struct AsanRuntimeDataWrapper {
197220
};
198221

199222
struct LaunchInfo {
200-
AsanRuntimeDataWrapper Data{};
201-
202223
ur_context_handle_t Context = nullptr;
203224
ur_device_handle_t Device = nullptr;
204225
const size_t *GlobalWorkSize = nullptr;
205226
const size_t *GlobalWorkOffset = nullptr;
206227
std::vector<size_t> LocalWorkSize;
207228
uint32_t WorkDim = 0;
208229

230+
AsanRuntimeDataWrapper Data;
231+
209232
LaunchInfo(ur_context_handle_t Context, ur_device_handle_t Device,
210233
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
211234
const size_t *GlobalWorkOffset, uint32_t WorkDim)
212235
: Context(Context), Device(Device), GlobalWorkSize(GlobalWorkSize),
213-
GlobalWorkOffset(GlobalWorkOffset), WorkDim(WorkDim) {
236+
GlobalWorkOffset(GlobalWorkOffset), WorkDim(WorkDim),
237+
Data(Context, Device) {
214238
if (LocalWorkSize) {
215239
this->LocalWorkSize =
216240
std::vector<size_t>(LocalWorkSize, LocalWorkSize + WorkDim);
@@ -222,8 +246,6 @@ struct LaunchInfo {
222246
assert(Result == UR_RESULT_SUCCESS);
223247
}
224248
~LaunchInfo();
225-
226-
ur_result_t updateKernelInfo(const KernelInfo &KI);
227249
};
228250

229251
struct DeviceGlobalInfo {

0 commit comments

Comments
 (0)