Matrix Transpose in CUDA
I recently found a LeetCode-like site called LeetGPU. It lets you solve GPU problems using CUDA, PyTorch, and JAX. I had studied CUDA before but forgot most of it, and I had never seriously learned optimization. So I decided to solve problems one by one and document what I learn along the way.
Problem Description
For the first problem, I implemented matrix transpose. The task is to fill in matrix_transpose_kernel below:
#include <cuda_runtime.h>
__global__ void matrix_transpose_kernel(const float* input, float* output, int rows, int cols) {
}
// input, output are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* input, float* output, int rows, int cols) {
dim3 threadsPerBlock(16, 16);
dim3 blocksPerGrid((cols + threadsPerBlock.x - 1) / threadsPerBlock.x,
(rows + threadsPerBlock.y - 1) / threadsPerBlock.y);
matrix_transpose_kernel<<<blocksPerGrid, threadsPerBlock>>>(input, output, rows, cols);
cudaDeviceSynchronize();
}
The input matrix has shape (rows, cols) stored in row-major order, and the output should be (cols, rows).
Naive Solution
The simplest approach is: each thread reads one element from the input and writes it to the transposed position in the output. Because the grid is defined like this:
dim3 blocksPerGrid((cols + threadsPerBlock.x - 1) / threadsPerBlock.x,
(rows + threadsPerBlock.y - 1) / threadsPerBlock.y);
we naturally interpret:
xas the column indexyas the row index
A straightforward kernel is:
__global__ void matrix_transpose_kernel(const float* input, float* output, int rows, int cols) {
int x = blockIdx.x * blockDim.x + threadIdx.x; // col index in input
int y = blockIdx.y * blockDim.y + threadIdx.y; // row index in input
if (x < cols && y < rows) {
// input(y, x) -> output(x, y)
output[x * rows + y] = input[y * cols + x];
}
}
The boundary check is required because blocksPerGrid is computed using ceiling division,
so some threads may map to out-of-range indices.
This kernel works, but it is not fast.
Why it’s slow: one side is inevitably “non-coalesced”
Matrix transpose is almost always memory-bandwidth bound, so performance depends heavily on how efficiently warps access global memory.
Coalescing (for global memory) means the GPU can merge the memory requests of consecutive threads in a warp into a small number of memory transactions. This happens when consecutive threads access consecutive (or near-consecutive) addresses, improving bandwidth utilization.1
In the naive transpose:
- The load
input[y * cols + x]is typically reasonably coalesced because warps tend to traversexquickly, so threads often read nearby addresses in the same row. - The store
output[x * rows + y]is the real problem: when threads in a warp varyx, the output address jumps byrowseach time (a large stride). That pattern usually forces the hardware to split the warp store into many transactions, wasting bandwidth.
That’s why the naive kernel tends to land around the middle of performance rankings on benchmarking/leaderboard-style sites.
Optimized Solution: Shared-Memory Tiling
To optimize transpose, we want both the global load and the global store to be coalesced. The standard trick is:
- Load a tile from global memory into shared memory using coalesced loads.
- Read the tile from shared memory in transposed order.
- Store it back to global memory using coalesced stores.
This is the classic CUDA transpose optimization described by NVIDIA, and with padding it can reach ~95% of a pure copy kernel’s throughput.2 Here is the optimized kernel:
#include <cuda_runtime.h>
#define TILE 16
__global__ void matrix_transpose_tiled(const float* __restrict__ input,
float* __restrict__ output,
int rows, int cols)
{
__shared__ float tile[TILE][TILE + 1]; // +1: avoid shared-memory bank conflicts
int x = blockIdx.x * TILE + threadIdx.x; // input col
int y = blockIdx.y * TILE + threadIdx.y; // input row
// 1) global -> shared (coalesced load)
if (x < cols && y < rows) {
tile[threadIdx.y][threadIdx.x] = input[y * cols + x];
}
__syncthreads();
// 2) shared -> global (coalesced store)
int tx = blockIdx.y * TILE + threadIdx.x; // output col (= input row)
int ty = blockIdx.x * TILE + threadIdx.y; // output row (= input col)
if (tx < rows && ty < cols) {
output[ty * rows + tx] = tile[threadIdx.x][threadIdx.y];
}
}
extern "C" void solve(const float* input, float* output, int rows, int cols) {
dim3 threadsPerBlock(TILE, TILE);
dim3 blocksPerGrid((cols + TILE - 1) / TILE,
(rows + TILE - 1) / TILE);
matrix_transpose_tiled<<<blocksPerGrid, threadsPerBlock>>>(input, output, rows, cols);
cudaDeviceSynchronize();
}
What changed conceptually?
- We still read
input[y * cols + x]in row-major order (good for coalescing). - But instead of writing the transposed result directly to global memory (bad stride), we first write into
tile[ty][tx]in shared memory. - Then we swap the indices when reading from shared memory:
output[ty * rows + tx] = tile[threadIdx.x][threadIdx.y];
This makes the global store pattern contiguous across threads, so it becomes much more coalesced.
Advanced Points
1) Why TILE + 1? (Shared memory bank conflicts)
Shared memory is fast on-chip memory, but it has its own performance hazard: bank conflicts.
On modern NVIDIA GPUs, shared memory is organized into (typically) 32 banks, and successive 32-bit words are assigned to successive banks. Bank conflicts can occur when multiple threads in the same warp access different addresses that map to the same bank, forcing serialization.3
A common simplified model is:
- Treat a
floatas a 32-bit word. - Bank mapping behaves like “word index mod 32.” (Exact details vary by architecture, but this mental model matches the common optimization pattern.)3
Now look at the shared-memory read during the transpose step:
tile[threadIdx.x][threadIdx.y]
That is effectively a column-wise access pattern into a row-major 2D array. If the leading dimension is exactly 32 (e.g., tile[32][32]), column-wise accesses can align so that many threads in a warp repeatedly hit the same bank, producing severe conflicts. This is exactly why the standard optimization pads the shared tile by one column: tile[TILE][TILE+1]. The extra column shifts the start of each row by one word, breaking the worst-case bank-alignment pattern and dramatically reducing conflicts.2
That’s why the +1 exists: it is not for global coalescing; it is specifically to prevent shared-memory bank conflicts during the transposed shared access.
2) Why __syncthreads()?
All threads in the block collaboratively fill tile[][]. If some threads start reading from tile (transpose phase) before other threads finish writing to it (load phase), the kernel will read uninitialized or stale values—i.e., a race condition.
CUDA provides a block-wide barrier, __syncthreads(), which ensures all threads in a block reach the same point before any proceed. This is the standard way to make shared-memory producer/consumer patterns correct.4
Summary
In this post, I implemented matrix transpose in CUDA and then optimized it.
- The naive version is correct but slow because the global store is typically strided and fails to coalesce well.
- The optimized version uses shared-memory tiling to make both the global load and global store more coalesced.
- Padding the shared tile as
TILE+1avoids the common bank conflict pattern in the transposed shared-memory read. __syncthreads()is required to prevent races between the tile write phase and the tile read phase.
With these changes, my solution’s performance jumped significantly (into roughly the top few percentiles on the site), consistent with NVIDIA’s observation that the padded tiled transpose can approach copy-kernel throughput.2
References
-
Unlock GPU Performance: Global Memory Access in CUDA - NVIDIA Developer Blog ↩
-
An Efficient Matrix Transpose in CUDA C/C++ - NVIDIA Developer Blog ↩ ↩2 ↩3
-
CUDA C++ Best Practices Guide - NVIDIA Corporation ↩ ↩2
-
Using Shared Memory in CUDA C/C++ - NVIDIA Developer Blog ↩