Hand‑Writing a Triton Softmax Kernel: Program Instances, Block Size, Masking & Pointer Arithmetic
This article walks through implementing a row‑wise softmax kernel in Triton, explaining program‑instance mapping, block‑size selection, mask handling, pointer arithmetic, resource‑usage analysis, and a RTX 5090 benchmark that reveals performance cliffs compared to PyTorch.
Softmax: mathematical definition and memory considerations
Softmax is applied row‑wise: each row of a matrix is transformed independently into a probability distribution. In a naïve PyTorch implementation the operation is split into separate tensor ops (max, subtraction, exponent, sum, division), each reading from and writing to global memory.
Triton fuses these steps into a single kernel that loads a row, performs all arithmetic while the data resides in registers/shared memory, and writes the final result once.
Simple Triton model
Example vector of length 3072, each element reduced by 1. On CPU a loop processes the whole vector. In Triton a program instance processes BLOCK_SIZE = 1024 elements, so three instances cover the vector. Each instance obtains its program_id to locate its slice.
for i in range(3072):
X[i] = X[i] - 1Row‑wise softmax kernel
@triton.jit
def softmax_kernel(
output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_rows, n_cols,
BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(output_ptr + row_idx * output_row_stride + col_offsets,
softmax_output, mask=mask)The kernel uses tl.program_id(0) to obtain the current instance ID and tl.num_programs(0) for the total number of instances. The loop iterates over rows with a stride equal to the number of programs, allowing fewer programs than rows to process multiple rows. The mask prevents out‑of‑bounds loads; padded columns receive -inf so that exp(-inf)=0 does not affect the denominator.
Python wrapper and launch configuration
def softmax(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 8
num_stages = 4 if SIZE_SMEM > 200000 else 2
y = torch.empty_like(x)
kernel = softmax_kernel.warmup(
y, x, x.stride(0), y.stride(0),
n_rows, n_cols,
BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages,
num_warps=num_warps, grid=(1,))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
return y BLOCK_SIZEis chosen as the next power of two of n_cols to match Triton’s block model; for a row with 3000 columns the kernel uses 4096 and masks the excess. num_warps = 8 assigns eight warps per program instance. num_stages controls loop‑iteration overlap; more stages increase on‑chip resource usage but are not universally beneficial.
Occupancy and resource budgeting
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)Each SM has fixed registers and shared‑memory budgets. The kernel’s register and shared‑memory consumption determines how many program instances can run concurrently. The persistent‑style kernel fills the GPU with enough instances; each instance loops over multiple rows.
Benchmark on RTX 5090
RTX 5090 row‑wise softmax benchmark, varying M = 4096 and N .
For small to medium row sizes PyTorch is faster. Around N ≈ 8700 both implementations encounter a performance cliff. When N exceeds 8192, BLOCK_SIZE jumps from 8192 to 16384, increasing per‑program resource pressure and causing a sudden drop in performance. Beyond that point Triton overtakes PyTorch. The y‑axis shows effective bandwidth computed from input/output tensor sizes.
Conclusion
Triton enables writing GPU kernels at a Python‑like level, but speed‑ups are not guaranteed. Highly optimized PyTorch kernels can be faster for many shapes. The source code for this example is available at:
https://github.com/lounishamroun/optimization_sandbox/blob/main/triton_kernels/basics/softmax_kernel.py
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
DeepHub IMBA
A must‑follow public account sharing practical AI insights. Follow now. internet + machine learning + big data + architecture = IMBA
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
