Skip to content

Commit 8b31cc7

Browse files
author
Deepak Raj H R
authored
[SYCLomatic] Add support for two-way type-cast from dpct::kernel_library and dpct::kernel_function to uint64_t conversion (#2606)
1 parent 821800f commit 8b31cc7

2 files changed

Lines changed: 27 additions & 11 deletions

File tree

clang/runtime/dpct-rt/include/dpct/kernel.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,10 @@ class kernel_library {
318318
public:
319319
constexpr kernel_library() : ptr{nullptr} {}
320320
constexpr kernel_library(void *ptr) : ptr{ptr} {}
321+
kernel_library(uint64_t addr) : ptr(reinterpret_cast<void *>(addr)) {}
321322

322323
operator void *() const { return ptr; }
324+
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(ptr); }
323325

324326
private:
325327
void *ptr;
@@ -393,15 +395,17 @@ class kernel_function {
393395
public:
394396
constexpr kernel_function() : ptr{nullptr} {}
395397
constexpr kernel_function(dpct::kernel_functor ptr) : ptr{ptr} {}
398+
kernel_function(uint64_t addr)
399+
: ptr(reinterpret_cast<dpct::kernel_functor>(addr)) {}
396400

397401
operator void *() const { return ((void *)ptr); }
398402

399403
void operator()(sycl::queue &q, const sycl::nd_range<3> &range,
400-
unsigned int a, void **args, void **extra) {
404+
unsigned int a, void **args, void **extra) const {
401405
ptr(q, range, a, args, extra);
402406
}
403407

404-
explicit operator uint64_t() const { return (uint64_t)this; }
408+
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(ptr); }
405409

406410
private:
407411
dpct::kernel_functor ptr;
@@ -411,7 +415,7 @@ class kernel_function {
411415
/// \param [in] library Handle to the kernel library.
412416
/// \param [in] name Name of the kernel function.
413417
static inline dpct::kernel_function
414-
get_kernel_function(kernel_library &library, const std::string &name) {
418+
get_kernel_function(const kernel_library &library, const std::string &name) {
415419
#ifdef _WIN32
416420
dpct::kernel_functor fn = reinterpret_cast<dpct::kernel_functor>(
417421
GetProcAddress(static_cast<HMODULE>(static_cast<void *>(library)),
@@ -434,7 +438,7 @@ get_kernel_function(kernel_library &library, const std::string &name) {
434438
/// function.
435439
/// \param [in] kernelParams Array of pointers to kernel arguments.
436440
/// \param [in] extra Extra arguments.
437-
static inline void invoke_kernel_function(dpct::kernel_function &function,
441+
static inline void invoke_kernel_function(const dpct::kernel_function &function,
438442
sycl::queue &queue,
439443
sycl::range<3> groupRange,
440444
sycl::range<3> localRange,

clang/test/dpct/kernel-function-typecast.cu

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,24 @@
77

88
typedef uint64_t u64;
99

10-
// CHECK: u64 foo(dpct::kernel_function cuFunc, dpct::kernel_library cuMod) {
11-
u64 foo(CUfunction cuFunc, CUmodule cuMod) {
12-
// CHECK: cuFunc = dpct::get_kernel_function(cuMod, "kfoo");
13-
cuModuleGetFunction(&cuFunc, cuMod, "kfoo");
14-
u64 function = (u64)cuFunc;
10+
// CHECK: void exec_kernel(dpct::kernel_function cuFunc, dpct::kernel_library cuMod, dpct::queue_ptr stream) {
11+
void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
12+
u64 mod;
13+
u64 function;
1514

16-
return function;
17-
}
15+
// verify the conversion from dpct::kernel_library to uint64_t
16+
mod = (u64)cuMod;
17+
18+
// verify the conversion from uint64_t to dpct::kernel_library
19+
// CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, "kfoo");
20+
cuModuleGetFunction(&cuFunc, (CUmodule)mod, "kfoo");
21+
22+
// verify the conversion from dpct::kernel_function to uint64_t
23+
function = (u64)cuFunc;
1824

25+
void *config[] = {0};
26+
27+
// verify the conversion from uint64_t to dpct::kernel_function
28+
// CHECK: dpct::invoke_kernel_function((dpct::kernel_function)function, *stream, sycl::range<3>(100, 100, 100), sycl::range<3>(100, 100, 100), 1024, NULL, config);
29+
cuLaunchKernel((CUfunction)function, 100, 100, 100, 100, 100, 100, 1024, stream, NULL, config);
30+
}

0 commit comments

Comments
 (0)