Support computation pipelining after SWP refactoring #5185
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
With the recent SWP refactoring, it is much easier to support arbitrary stage assignments where computations can be separated into different stages. Computation pipelining is basically splitting computations to different stages. Take flash attention as an example:
Currently the two loads are in stage 0 (S0), all other ops are in the last stage (stage 2). The loop body will look like
MMA0(i)
Softmax(i)
MUL(i)
MMA1(i)
LoadV(i+2)
LoadK(i+2)
This patch defines two different pipeline schedule for attention-like kernels:
1> putting first dot in S2, other computations in S3, loadK in stage 0, loadV in stage 1
MMA0(i+1)
Softmax(i)
MUL(i)
MMA1(i)
loadK(i+3)
loadV(i+2)
2> putting second dot in S3, other computations in S2, loadK in stage 0, loadV in stage 1
MMA0(i+1)
MMA1(i)
Softmax(i+1)
MUL(i+1)
loadK(i+3)
loadV(i+2)
Preliminary performance number on H100 for flash attention:
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_opt-tflops triton_tutorial_flash_v2_tma-tflops triton_tutorial_flash_v2-tflops
The implementation and the frontend is preliminary for discussion.