Assignment 3
In this assignment, you'll extend the work from Assignment 1 and 2 to speed up the transformer model for more efficient training and inference. You will focus on optimizing the Softmax and LayerNorm batch reduction operations by writing custom CUDA code.
The CUDA optimizations are based on methods from the LightSeq and LightSeq2 papers. We strongly encourage you to refer to the papers and the relevant lecture slides while working on this assignment. Before starting to write the CUDA code, make sure you have read through the write-up and understood what each kernel is doing.
Setting up the Code
The starting codebase is provided in the following repository:
https://github.com/llmsystem/llmsys_s25_hw3
You will need to merge it with your implementation from the previous assignments. Here are our suggested steps:
Step 1: Install Requirements
Install the required dependencies and miniTorch by running the following commands:
Step 2: Copy and Compile CUDA Kernels
Copy the CUDA kernel file combine.cu
from Assignment 2 and compile it:
# From Assignment 2 to the current directory
cp <path_to_assignment_2>/combine.cu src/combine.cu
# Compile the CUDA kernels
bash compile_cuda.sh
Step 3: Copy Files from Assignment 1 and Assignment 2
Copy autodiff.py
from Assignment 1 to the specified location:
Hints for Copying and Pasting:
When copying the backpropagate()
function, make sure to double-check the implementation you wrote for Assignment 1. It's a good idea to set a default value of 0.0 before accumulating the gradients in backpropagate()
. We recommend initializing the derivatives map for each unique_id to 0.0 outside the for loop.
Step 4: Incrementally Add Functions
Keep copying several other functions from Assignment 2 as needed to complete this assignment.
Copy minitorch/nn.py
from Assignment 2 to the specified location:
Copy minitorch/modules_basic.py
from Assignment 2 to the specified location:
Copy minitorch/modules_transfomer.py
from Assignment 2 to the specified location:
Copy run_machine_translation.py
from Assignment 2 to the specified location:
Problem 1.1: Softmax Forward (20 points)
In this part, you will implement a fused kernel of softmax in the attention mechanism.
The softmax function for a vector \( \mathbf{x} \) is defined as:
where \( x_i \) is the \( i \)-th element of \( \mathbf{x} \).
The kernel also incorporates attention mechanisms, particularly in its handling of attention masks. Attention masks are used to control the model's focus on specific parts of the input.
Instructions
-
Implement the CUDA Kernel
Write the CUDA kernel for softmax insrc/softmax_kernel.cu
:
template <typename T, int block_dim, int ele_per_thread> __global__ void ker_attn_softmax(T *inp, ...) { ... } template <typename T, int block_dim, int ele_per_thread> __global__ void ker_attn_softmax_lt32(T *inp, ...) { ... }
- The
ker_attn_softmax_lt32
kernel is already implemented for sequences shorter than 32 and does not require block-level parallelism. - Review the provided implementation of
ker_attn_softmax_lt32
and use it as a reference to implementker_attn_softmax
.
- The
-
Compile the CUDA File
Compile the CUDA file using the following command: -
Bind the Kernel to miniTorch
Inminitorch/cuda_kernel_ops.py
, bind the kernel by passing the CUDA stream:In
minitorch/tensor_functions.py
, define the softmax forward function: -
Test the Implementation
Run the provided test script and ensure your implementation achieves an average speedup of approximately 6.5×:
Understanding Softmax Forward Kernels
The ker_attn_softmax_lt32
and ker_attn_softmax
kernels are optimized for different input sizes:
-
ker_attn_softmax_lt32
:- Utilizes warp-level primitives for reduction.
- Suitable for smaller input sizes (<32 sequence length).
- Efficient parallel reduction without block-wide synchronization.
-
ker_attn_softmax
:- Employs block-level reduction techniques.
- Suitable for larger input sizes.
- Includes two phases of reduction (max and sum) followed by a normalization step with synchronization.
Algorithmic Steps
The softmax computation in both kernels consists of three main steps:
-
Compute Max
- Identify the maximum value for normalization to avoid numerical overflow during exponentiation.
-
Compute Exponentials and Sum
- Calculate the exponentials of normalized values and their sum for final normalization.
-
Compute Final Result
- Normalize the exponentials by dividing by the sum to obtain softmax probabilities.
- Use CUB library's
BlockStore
to minimize memory transactions.
Computing the Maximum Value for Normalization
Both kernels implement the maximum value computation in the same way. Study the implementation in ker_attn_softmax_lt32
.
Steps to Compute the Maximum Value:
-
Thread Local Max:
Each thread computes a local maximum:- Intermediate values and attention mask adjustments are stored in
val[token_per_reduce][ele_per_thread]
. - The thread-local maximum is recorded in
l_max[token_per_reduce]
, initialized toREDUCE_FLOAT_INF_NEG
.
- Intermediate values and attention mask adjustments are stored in
-
Future Token Masking:
If future tokens are masked, their values are excluded from the maximum computation by setting them toREDUCE_FLOAT_INF_NEG
. -
Attention Mask Adjustment:
Adjust the input value by adding the corresponding attention mask value. -
Iterative Updates:
Update the thread-local maximum usingfmaxf
.
Block-Level Reduction for Global Max:
-
ker_attn_softmax_lt32
:- Uses a warp-level reduction with a custom
warpReduce
function.
- Uses a warp-level reduction with a custom
-
ker_attn_softmax
:- Uses block-wide reduction with the CUB library's
BlockLoad
and shared memory for synchronization.
- Uses block-wide reduction with the CUB library's
Problem 1.2: Softmax Backward (20)
The gradient of the softmax function for a vector \( \mathbf{x} \) is given by:
where \( \delta_{ij} \) is the Kronecker delta.
-
Implement
launch_attn_softmax_bw
insrc/softmax_kernel.cu
:void launch_attn_softmax_bw(float *out_grad, const float *soft_inp, int rows, int softmax_len, cudaStream_t stream)
In lectures, we described the use of templates for tuning kernel parameters. When implementing
launch_attn_softmax_bw
, you should compute theITERATIONS
parameter ofker_attn_softmax_bw
depending on different max sequence lengths in{32, 64, 128, 256, 384, 512, 768, 1024, 2048}
.Hint: Refer to the way templates are used in
launch_attn_softmax
. -
Compile the CUDA file:
-
Bind the kernel with miniTorch in
minitorch/cuda_kernel_ops.py
.Hint: You should pass the CUDA stream to the function, define it with:
class CudaKernelOps(TensorOps): @staticmethod def attn_softmax_bw(out_grad: Tensor, soft_inp: Tensor): ...
And in
minitorch/tensor_functions.py
: -
Pass the test and notice an average speedup around 0.5 with our given default max lengths
{32, 64, 128, 256, 384, 512, 768, 1024, 2048}
. You can try other setups of max length and achieve a higher speedup, but it will not be graded.
Understanding Softmax Backward Kernel
The ker_attn_softmax_bw
function is a CUDA kernel for computing the backward pass of the softmax function in self-attention mechanisms. Here are the steps:
Initialization
- The function calculates the backward pass for each element in the gradient and the output of the softmax forward pass.
- The grid and block dimensions are configured based on the batch size, number of heads, and sequence length.
Gradient Calculation
- The function iterates over the softmax length, with each thread handling a portion of the data.
- It loads the gradient and input (output of softmax forward) into registers.
- A local sum is computed for each thread, which is a key part of the gradient calculation for softmax.
Gradient Computation
- The sum is shared across the warp using warp shuffle operations.
- The final gradient for each element is computed by modifying the forward pass output with the computed sum.
Problem 2.1: LayerNorm Forward (20)
LayerNorm normalizes the input \( \boldsymbol{x} \) by:
where \( \mu{(\boldsymbol{x})} \) and \( \sigma{(\boldsymbol{x})} \) are the mean and the standard deviation of \( \boldsymbol{x} \) respectively, and \( \boldsymbol{\gamma} \) and \( \boldsymbol{\beta} \) are the learnable affine transform parameters in LayerNorm.
Noting that the equation above requires two reduction operations (mean and standard deviation), these cannot be computed in parallel. Speedup can be achieved by computing the standard deviation as:
where \( \epsilon = 1 \times 10^{-8} \) is a small value added to the variance for numerical stability. This approach allows concurrent computation of the means of \( \boldsymbol{x} \) and \( \boldsymbol{x}^{2} \).
Steps
-
Implement the CUDA kernel of LayerNorm forward in
src/layernorm_kernel.cu
: -
Compile the CUDA file:
-
Bind the kernel with miniTorch in
minitorch/cuda_kernel_ops.py
:Hint: You should pass the CUDA stream to the function, defining it as:
class CudaKernelOps(TensorOps): @staticmethod def layernorm_fw(inp: Tensor, gamma: Tensor, beta: Tensor): ...
And in
minitorch/tensor_functions.py
: -
Pass the test and notice an average speedup around \( 15.8\times \):
Understanding LayerNorm Forward Kernels
In this kernel, we use float4
to speed up computations. This approach enhances performance when handling large datasets by processing multiple data elements simultaneously, leveraging the SIMD (Single Instruction, Multiple Data) parallelism inherent in GPUs.
When using CUDA programming and float4
, reinterpret_cast
is required to convert between types. For example, in src/layernorm_kernel.cu
, the sum of \boldsymbol{x} in step 1 is computed as follows:
reinterpret_cast
is used to convert a float arrayinp
to afloat4
vectorinp_f4
.- Each thread within a block calculates
l_sum
for its assigned elements ininp_f4
.
Algorithmic Steps
-
Compute the sums of \boldsymbol{x} and \boldsymbol{x}^{2} :
- Use
reinterpret_cast
by casting tofloat4
for speedup.
- Use
-
Perform reduction:
- Compute the reduce sum with
blockReduce
and add epsilon (LN_EPSILON
).
- Compute the reduce sum with
-
Compute the LayerNorm result:
- Use
reinterpret_cast
to cast tofloat4
for speedup.
- Use
Problem 2.2: LayerNorm Backward (20)
Let \(\hat{\boldsymbol{x}}_{i} = \frac{\boldsymbol{x}_{i} - \mu{(\boldsymbol{x})}}{\sigma{(\boldsymbol{x})}}\), then the final gradient of \(\boldsymbol{x}_{i}\) can be re-written as:
where \(m\) is the dimension of \(\boldsymbol{x}\), and \(\nabla{\boldsymbol{x}}\) and \(\nabla{\boldsymbol{y}}\) are the input and output gradients.
The speedup can be achieved by concurrently executing two batch reduction operations in the parentheses above.
The gradients of \(\boldsymbol{\gamma}_{i}\) and \(\boldsymbol{\beta}_{i}\) are:
Steps to Implement
-
Implement the CUDA kernel of LayerNorm backward in
src/layernorm_kernel.cu
: -
Compile the CUDA file:
-
Bind the kernel with miniTorch in
Example implementation: Then integrate it inminitorch/cuda_kernel_ops.py
: Hint: Pass the CUDA stream to the function, defining it with:minitorch/tensor_functions.py
: -
Pass the test and notice an average speedup of approximately 3.7×:
Understanding LayerNorm Backward Kernels
Input Gradients
Initialization:
Each thread is responsible for a specific element in the inp_grad
array.
Algorithmic Steps:
- Compute \(\nabla{\boldsymbol{y}_{i}}\boldsymbol{\gamma}_{i}\) with
reinterpret_cast
by casting tofloat4
for speedup. - Compute \(\hat{\boldsymbol{x}}_{i}\) with
reinterpret_cast
by casting tofloat4
for speedup. - Compute reduce sum for \(\nabla{\boldsymbol{y}_{i}}\boldsymbol{\gamma}_{i}\) and \(\nabla{\boldsymbol{y}_{i}}\boldsymbol{\gamma}_{i}\hat{\boldsymbol{x}}_{i}\) with
blockReduce
. - Compute final gradient.
Gamma and Beta Gradients
Initialization:
Shared memory arrays betta_buffer
and gamma_buffer
are declared to store intermediate results within the thread block.
CUDA thread blocks cg::thread_block
and thread block tiles cg::thread_block_tile
are used to organize threads.
Loop Over Rows:
Threads in the y-dimension loop over rows, calculating partial gradients for each row based on the given inputs (out_grad
, inp
, means
, vars
).
Shared Memory Storage:
The computed partial gradient values are stored in shared memory arrays betta_buffer
and gamma_buffer
in a tiled manner.
Reduction within Thread Block:
Threads cooperate to perform a reduction operation on betta_buffer
and gamma_buffer
using g.shfl_down
(shuffle down) operations along threadIdx.y
.
This approach avoids bank conflicts and improves warp-level parallelism.
Final Result Assignment:
The final reduction result is assigned to the appropriate positions in the global output arrays (gamma_grad
and betta_grad
).
Algorithmic Steps:
- Compute the partial gradients by looping across
inp
rows. - Store the partial gradients in the shared memory arrays.
- Compute the reduce sum of the shared memory arrays with
g.shfl_down
. - Assign the final result to the correct position in the global output.
More hints about g.shfl_down
:
Read
https://developer.nvidia.com/blog/cooperative-groups/#:~:text=Using%20thread_block_tile%3A%3Ashfl_down()%20to%20simplify%20our%20warp%2Dlevel%20reduction%20does%20benefit%20our%20code%3A%20it%20simplifies%20it%20and%20eliminates%20the%20need%20for%20shared%20memory.
g.shfl_down
is doing. Usually, the threads inside a block need to load everything to shared memory and work together to reduce the result (like what you have implemented in the hw1 for reduce function).
Now g.shfl_down
helps you do so without consuming any shared memory. g.shfl_down
makes it more efficient.
Problem 3: Adopt Fused Kernels in Transformer (20)
The improved CUDA kernels are now bound with the miniTorch library.
Integrate the improved CUDA kernels into the transformer from Assignment 2.
- Replace the softmax and layernorm operations in
MultiHeadAttention
,TransformerLayer
, andDecoderLM
with your accelerated kernels inminitorch/modules_transfomer.py
. - Train the transformer for one epoch, with and without using the fused kernel, and record the running time:
- According to Amdahl's law, the improvement should not be significant since only softmax and layernorm are improved. However, you should still notice an average speedup of approximately 1.1×.
Submission
Please submit the entire llmsys_s24_hw3
folder as a zip file on Canvas.