-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] FP8 with row-wise scaling on Ada-Lovelace #1937
Comments
Based on the comments for PyTorh's |
Is this really rowwise scaling in PyTorch? Please check here. For rowwise scaling, the cc: @drisspg |
So we do have the RowwiseScaled cutlass template here: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172 We would accept a PR to also create a sm89 specialization |
I was looking at @vgoklani , can you please clarify whether you are interested in row-wise scaling or the calculation computed by |
Thanks @jackkosaian we are looking for row-wise scaling |
I modified CUTLASS Python's example 4 (EVT) to generate a GEMM with row-wise scaling. Here's a Python script that does so: import torch
import cutlass
from cutlass.epilogue import relu
from cutlass import Tensor as FakeTensor
from cutlass.utils.profiler import CUDAEventProfiler
# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to
# omit this information.
print_module = True
# The Epilogue Visitor feature currently only works for SM80 and 90
from cutlass.backend.utils.device import device_cc
if device_cc() not in [80, 89, 90]:
import sys
sys.exit()
m = 32
n = m
k = 32
type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16
torch.manual_seed(2023)
scope_min = -4
scope_max = 4
tensor_A = torch.ceil(torch.empty(size=(m, k), dtype=type_A, device="cuda").uniform_(scope_min, scope_max))
tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device="cuda").uniform_(scope_min, scope_max))
tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
tensor_D = torch.zeros_like(tensor_C)
tensor_A = torch.ones((m,k),device="cuda",dtype=type_A)
tensor_B = torch.eye(m,device="cuda",dtype=type_B)
tensor_C = torch.zeros_like(tensor_D)
plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)
# ## Define the epilogue visitor functor
# The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.
# * Each named variable must be assigned exactly once and defined before it used.
# * Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.
# * Return values must be a named variable.
#
# The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.
#
# The epilogue can be generated simply through `cutlass.evt.trace(<epilogue function>, <example_tensors>)`.
# In[ ]:
# Define epilogue visitor
def example_epilogue(accum, alpha, C, beta, bias):
D = (alpha * accum + (beta * C)) * bias
return D
# Construct inputs and outputs
alpha = 1.0
beta = 0.0
aux = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
tensor_F = torch.zeros_like(tensor_D)
examples_tensors = {
"accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
"alpha": alpha,
"C": tensor_C,
"beta": beta,
#"aux": aux,
"bias": bias,
"D": tensor_D,
#"F": tensor_F
}
# Trace the epilogue visitor
epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)
# ## Run a GEMM with the epilogue visitor functor
# The `epilogue_visitor` can be used by setting the plan's `epilogue_visitor` field. The arguments for the epilogue visitor are provided as a `dict` through the `visitor_args` keyword argument.
# In[ ]:
visitor_args = {
"alpha": alpha, "C": tensor_C, "beta": beta,
#"aux": aux,
"bias": bias,
"D": tensor_D,
#"F": tensor_F
}
plan.epilogue_visitor = epilogue_visitor
plan.run(
tensor_A, tensor_B, tensor_C, tensor_D,
visitor_args=visitor_args, print_module=print_module)
# The epilogue function `example_epilogue` can be used as a reference function. We can now verify the results simply with
# In[ ]:
class TorchReference(torch.nn.Module):
def forward(self, A, B, alpha, C, beta, bias):
accum = torch.matmul(A, B)
return example_epilogue(accum, alpha, C, beta, bias)
torch_reference = TorchReference()
tensor_D_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, bias)
print(bias)
print(tensor_D)
assert torch.equal(tensor_D, tensor_D_ref)
#assert torch.equal(tensor_F, tensor_F_ref)
# The performance of CUTLASS fused kernel can be profiled with
# In[ ]:
warmup_iterations = 10
profile_iterations = 50
# Profile CUTLASS fused kernel
duration = CUDAEventProfiler(
plan, warmup_iterations, profile_iterations,
tensor_A, tensor_B, tensor_C, tensor_D,
visitor_args=visitor_args)()
print(f"CUTLASS duration: {duration:.2f} ms") This prints out the corresponding C++ for performing row-wise scaling for this GEMM: using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::half_t,
8,
1 /* epilogue stages */
>;
using C = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, cutlass::half_t, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Alpha = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
>;
using Beta = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
>;
using Bias = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>
>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Alpha,
Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
Compute1,
Beta,
C>;
using Compute2 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT<
Compute2,
EVTCompute0,
EVTCompute1>;
using Compute3 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, cutlass::half_t, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute3 = cutlass::epilogue::threadblock::Sm80EVT<
Compute3,
EVTCompute2,
Bias>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute3>;
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8
using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
EVTD,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd,
1 /* epilogue stages */
>::GemmKernel;
// Define named type
struct cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type :
public cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base { }; One could consider trying this out in a PyTorch module to ensure that it does everything that's desired. |
I would like some clarity on this:
pytorch/pytorch#130359
it appears that Cutlass does not support row-wise scaling on Ada Lovelace cards....
Is there a time-table to get this resolved?
We purchased a bunch of these cards ($$$$$$$$) and this is very disappointing.
The text was updated successfully, but these errors were encountered: