@@ -951,18 +951,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command,
951951}
952952
953953/* *
954- * Updates the arguments of CommandDesc->hNewKernel
955- * @param[in] Device The device associated with the kernel being updated.
956- * @param[in] UpdateCommandDesc The update command description that contains
957- * the new kernel and its arguments.
954+ * Updates the arguments of a kernel command.
955+ * @param[in] Command The command associated with the kernel node being updated.
956+ * @param[in] UpdateCommandDesc The update command description that contains the
957+ * new arguments.
958958 * @return UR_RESULT_SUCCESS or an error code on failure
959959 */
960960ur_result_t
961- updateKernelArguments (ur_device_handle_t Device ,
961+ updateKernelArguments (ur_exp_command_buffer_command_handle_t Command ,
962962 const ur_exp_command_buffer_update_kernel_launch_desc_t
963963 *UpdateCommandDesc) {
964964
965- ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel ;
965+ ur_kernel_handle_t Kernel = Command->Kernel ;
966+ ur_device_handle_t Device = Command->CommandBuffer ->Device ;
966967
967968 // Update pointer arguments to the kernel
968969 uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs ;
@@ -974,7 +975,7 @@ updateKernelArguments(ur_device_handle_t Device,
974975 const void *ArgValue = PointerArgDesc.pNewPointerArg ;
975976
976977 try {
977- NewKernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
978+ Kernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
978979 } catch (ur_result_t Err) {
979980 return Err;
980981 }
@@ -991,10 +992,10 @@ updateKernelArguments(ur_device_handle_t Device,
991992
992993 try {
993994 if (ArgValue == nullptr ) {
994- NewKernel ->setKernelArg (ArgIndex, 0 , nullptr );
995+ Kernel ->setKernelArg (ArgIndex, 0 , nullptr );
995996 } else {
996997 void *HIPPtr = std::get<BufferMem>(ArgValue->Mem ).getVoid (Device);
997- NewKernel ->setKernelArg (ArgIndex, sizeof (void *), (void *)&HIPPtr);
998+ Kernel ->setKernelArg (ArgIndex, sizeof (void *), (void *)&HIPPtr);
998999 }
9991000 } catch (ur_result_t Err) {
10001001 return Err;
@@ -1012,7 +1013,7 @@ updateKernelArguments(ur_device_handle_t Device,
10121013 const void *ArgValue = ValueArgDesc.pNewValueArg ;
10131014
10141015 try {
1015- NewKernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
1016+ Kernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
10161017 } catch (ur_result_t Err) {
10171018 return Err;
10181019 }
@@ -1067,9 +1068,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
10671068 ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer ;
10681069
10691070 UR_CHECK_ERROR (validateCommandDesc (hCommand, pUpdateKernelLaunch));
1070- UR_CHECK_ERROR (
1071- updateKernelArguments (CommandBuffer->Device , pUpdateKernelLaunch));
10721071 UR_CHECK_ERROR (updateCommand (hCommand, pUpdateKernelLaunch));
1072+ UR_CHECK_ERROR (updateKernelArguments (hCommand, pUpdateKernelLaunch));
10731073
10741074 // If no worksize is provided make sure we pass nullptr to setKernelParams
10751075 // so it can guess the local work size.
0 commit comments