Writing your own CUDA kernel (Part 1)

Posted on mar. 02 octobre 2018 in Tutorial by Laura Domine

At some point you might come upon some operation that you wish it existed in Tensorflow or PyTorch, but no GPU implementation is available. In addition it might even be something that is easily parallelizable on GPU. So why not write your own CUDA kernel and integrate it in your main program? Let us start with the CUDA kernel itself since it will be the same in both implementations.

Introduction

Some vocabulary first:

  • Kernel: name of a function run by CUDA on the GPU.
  • Thread: CUDA will run many threads in parallel on the GPU. Each thread executes the kernel.
  • Blocks: Threads are grouped into blocks, a programming abstraction. Currently a thread block can contain up to 1024 threads.
  • Grid: contains thread blocks.

Threads and blocks Threads and blocks illustration from CUDA documentation

When should we write a custom CUDA kernel?

  • Data size: you should make sure you will launch a lot of threads and blocks in order to beat the CUDA overhead. Otherwise, you might not see a great improvement between a CPU and GPU version.
  • Parallelizable: you should be able to pinpoint a single or double for loop where the iterations are independent of each other.

The only tricky part is to figure out how to balance the load: how many threads and blocks should be launched, what portion of the data is going to be processed by each of these.

Use case description

We want to write a CUDA kernel to crop a single big image into several smaller crops. A sequential implementation would loop over all the crops voxels and copy the corresponding voxel from the original image. Our input is thus:

  • A single big image of shape ($N$, $N$, $N$, $C$) where $N$ is the image size and $C$ the number of channels
  • A list of crop centers coordinates in the original image ($M$, 3) where $M$ is the total number of crops.
  • The size $D$ of a crop (we require for simplicity that all crops have the same size).

The output should be a list of the crops and have a shape ($M$, $D$, $D$, $D$, $C$).

Naive approach

In our case, a first naive approach would be to assign to each thread a single voxel to copy from the input data array to the output crop array. We launch as many blocks as we have crops (i.e. $M$ blocks), and the threads inside the block will go over all the voxels inside a single crop (i.e. $D^3 \times C$ threads per block). Remember that the number of threads per block is fixed, so a single thread might have to work on several voxels, not just one.

template <typename T>
__global__ void CropCudaKernel(
  const T* image_ptr,
  const int* crop_centers_ptr,
  const int image_size,
  const int channels,
  int crop_size,
  const int num_crops,
  T* crops_ptr
) {

The keyword __global__ signals that the function will be compiled by nvcc (NVIDIA compiler, a wrapper around gcc) and run on GPU. In our case we will need a pointer to the (flattened) big image, an array of (flattened) crop centers coordinates, as well as the image size, the number of channels, the crop size, and the total number of crops. The output result will be stored in crops_ptr array.

const int crop_id = blockIdx.x;
  const int center_x = crop_centers_ptr[crop_id*3];
  const int center_y = crop_centers_ptr[1 + 3*crop_id];
  const int center_z = crop_centers_ptr[2 + 3*crop_id];

Since the crop centers coordinates array was flattened we retrieve the current crop center coordinates with the block index blockIdx.x information. We specified 1 block per crop, hence the block index will correspond exactly to the crop index.

for (int id = threadIdx.x; id < crop_size*crop_size*crop_size*channels; id += blockDim.x) {

We have to process all the pixels of the output array. Each thread is going to loop over them with a step of size the block dimension (i.e. the total number of threads working on this crop).

// Coordinates inside the crop (0 <= coords < crop_size)
    int id_temp = id;
    const int c = id_temp % channels;
    id_temp /= channels;
    const int z = id_temp % crop_size;
    id_temp /= crop_size;
    const int y = id_temp % crop_size;
    const int x = id_temp / crop_size;

We reconstruct the coordinates of the current pixel inside the crop from the loop index. Note that this and all the following conversions between index/coordinates will depend on how the array was flattened.

// Corresponding coordinates in original image
    int image_x = x + (center_x - crop_size / 2);
    int image_y = y + (center_y - crop_size / 2);
    int image_z = z + (center_z - crop_size / 2);
    int img_idx = c + channels * (image_z + image_size * (image_y + image_size * image_x ));

We retrieve the equivalent coordinates in the original image.

if ((img_idx >= image_size * image_size * image_size * channels) || (img_idx < 0)) continue;

    int crop_idx = c + channels * (z + crop_size * (y + crop_size * (x + crop_size * crop_id)));
    crops_ptr[crop_idx] = image_ptr[img_idx];
  }
}

Finally we proceed to copy the pixel from the original image array to the final output array.

More balanced approach

After profiling the previous CUDA kernel we found out that it wasn't that much faster than a Numpy version running on CPU. The reason is that the number of crops (estimated around 100) was not high enough to harness the GPU power, which relies on high parallelization. A second more refined approach would be to set the number of blocks to $M \times D$: each block will process a 2D slice of a single crop, i.e. $D^2 \times C$ voxels.

template <typename T>
__global__ void CropCudaKernel2(
  const T* image_ptr,
  const int* crop_centers_ptr,
  const int image_size,
  const int channels,
  int crop_size,
  const int num_crops,
  T* crops_ptr
) {

The kernel declaration does not change.

const int crop_id = blockIdx.x/crop_size;
  const int center_x = crop_centers_ptr[crop_id*3];
  const int center_y = crop_centers_ptr[1 + crop_id*3];
  const int center_z = crop_centers_ptr[2 + crop_id*3];
  int offset = (blockIdx.x % crop_size) * crop_size*crop_size*channels;

The main difference is about bookkeeping: the crop index is different now, it is not just the block index anymore. We will have to add an offset when we compute coordinates inside the crop:

for (int id = threadIdx.x; id < crop_size*crop_size*channels; id += blockDim.x) {
    // Coordinates inside the crop (0 <= coords < crop_size)
    int id_temp = offset + id;
    const int c = id_temp % channels;
    id_temp /= channels;
    const int z = id_temp % crop_size;
    id_temp /= crop_size;
    const int y = id_temp % crop_size;
    const int x = id_temp / crop_size;

    // Corresponding coordinates in original image
    int image_x = x + (center_x - crop_size / 2);
    int image_y = y + (center_y - crop_size / 2);
    int image_z = z + (center_z - crop_size / 2);
    int img_idx = c + channels * (image_z + image_size * (image_y + image_size * image_x ));

    if ((img_idx >= image_size * image_size * image_size * channels) || (img_idx < 0)) continue;

    int crop_idx = c + channels * (z + crop_size * (y + crop_size * (x + crop_size * crop_id)));
    crops_ptr[crop_idx] = image_ptr[img_idx];
  }
}

This was a short introduction to CUDA kernels. This custom one is able to crop a big image into smaller pieces. Now you probably want to compile it in order to integrate it into your main Tensorflow or PyTorch program, so continue with the next part of the tutorial.

References