|
| 1 | +#include <cstdio> |
| 2 | +#include <algorithm> |
| 3 | +#include <cstring> |
| 4 | + |
| 5 | +#include <ATen/ATen.h> |
| 6 | +#include <ATen/cuda/CUDAContext.h> |
| 7 | + |
| 8 | +// #include <THC/THC.h> |
| 9 | +#include <THC/THCAtomics.cuh> |
| 10 | +// #include <THC/THCDeviceUtils.cuh> |
| 11 | + |
| 12 | +#define CUDA_KERNEL_LOOP(i, n) \ |
| 13 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ |
| 14 | + i < (n); \ |
| 15 | + i += blockDim.x * gridDim.x) |
| 16 | + |
| 17 | +const int CUDA_NUM_THREADS = 1024; |
| 18 | +inline int GET_BLOCKS(const int N) |
| 19 | +{ |
| 20 | + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; |
| 21 | +} |
| 22 | + |
| 23 | +template <typename scalar_t> |
| 24 | +__global__ void hist_cuda_kernel(const int n, |
| 25 | + const scalar_t* X, |
| 26 | + const scalar_t* Y, |
| 27 | + const int batch, const int dim, |
| 28 | + const int num_X, const int num_Y, |
| 29 | + const float min_x, const float min_y, const float min_z, |
| 30 | + const float max_x, const float max_y, const float max_z, |
| 31 | + const int len_x, const int len_y, const int len_z, |
| 32 | + scalar_t* bins |
| 33 | + ) |
| 34 | +{ |
| 35 | + CUDA_KERNEL_LOOP(index, n) |
| 36 | + { |
| 37 | + // index index of output matrix |
| 38 | + // launch in parallel: batch * numX * numY; |
| 39 | + // printf("hist cuda bin size: %d, %d, %d, %d. \n", batch, len_x, len_y, len_z); |
| 40 | + const int b = index / num_X / num_Y % batch; |
| 41 | + const int i = index / num_Y % num_X; |
| 42 | + const int j = index % num_Y; |
| 43 | + |
| 44 | + scalar_t flag_x = X[b*num_X*dim+i*dim+3]; |
| 45 | + scalar_t flag_y = Y[b*num_Y*dim+j*dim+3]; |
| 46 | + if (flag_x>0.0 && flag_y>0.0) |
| 47 | + { |
| 48 | + scalar_t val_x = X[b*num_X*dim+i*dim+0] - Y[b*num_Y*dim+j*dim+0]; |
| 49 | + scalar_t val_y = X[b*num_X*dim+i*dim+1] - Y[b*num_Y*dim+j*dim+1]; |
| 50 | + scalar_t val_z = X[b*num_X*dim+i*dim+2] - Y[b*num_Y*dim+j*dim+2]; |
| 51 | + if (val_x >= min_x && val_x < max_x && val_y >= min_y && val_y < max_y && val_z >= min_z && val_z < max_z) |
| 52 | + { |
| 53 | + // [): left included; right excluded. |
| 54 | + int p_x = __float2int_rd( (val_x-min_x) / (max_x-min_x) * __int2float_rd(len_x)); |
| 55 | + int p_y = __float2int_rd( (val_y-min_y) / (max_y-min_y) * __int2float_rd(len_y)); |
| 56 | + int p_z = __float2int_rd( (val_z-min_z) / (max_z-min_z) * __int2float_rd(len_z)); |
| 57 | + |
| 58 | + // printf("hist cuda coord: %d, %d, %d, %d; %d, %d, %d, %d. \n", batch, len_x, len_y, len_z, b, p_x, p_y, p_z); |
| 59 | + int bin_id = b*len_x*len_y*len_z + p_x*len_y*len_z + p_y*len_z + p_z; |
| 60 | + atomicAdd(bins + bin_id, 1); |
| 61 | + } |
| 62 | + } |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +template <typename scalar_t> |
| 67 | +void hist_cuda_core(cudaStream_t stream, |
| 68 | + const scalar_t* X, const scalar_t* Y, |
| 69 | + const int batch, const int dim, |
| 70 | + const int num_X, const int num_Y, |
| 71 | + const float min_x, const float min_y, const float min_z, |
| 72 | + const float max_x, const float max_y, const float max_z, |
| 73 | + const int len_x, const int len_y, const int len_z, |
| 74 | + scalar_t* bins |
| 75 | + ) |
| 76 | +{ |
| 77 | + const int num_kernels = batch * num_X * num_Y; |
| 78 | + // printf("num kernels: %d\n", num_kernels); |
| 79 | + |
| 80 | + // printf("hist cuda core: %f, %f, %f; %f, %f, %f; %f, %f, %f. \n", min_x, min_y, min_z, max_x, max_y, max_z, len_x, len_y, len_z); |
| 81 | + // printf("hist cuda core: ", min_x, min_y, min_z, max_x, max_y, max_z, len_x, len_y, len_z, " \n"); |
| 82 | + hist_cuda_kernel<scalar_t> |
| 83 | + <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>( |
| 84 | + num_kernels, |
| 85 | + X, Y, |
| 86 | + batch, dim, |
| 87 | + num_X, num_Y, |
| 88 | + min_x, min_y, min_z, |
| 89 | + max_x, max_y, max_z, |
| 90 | + len_x, len_y, len_z, |
| 91 | + bins |
| 92 | + ); |
| 93 | + |
| 94 | + cudaError_t err = cudaGetLastError(); |
| 95 | + if (err != cudaSuccess) |
| 96 | + { |
| 97 | + printf("error in hist_cuda_core: %s\n", cudaGetErrorString(err)); |
| 98 | + } |
| 99 | +} |
| 100 | + |
0 commit comments