Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] FP8 with row-wise scaling on Ada-Lovelace #1937

Open
vgoklani opened this issue Nov 11, 2024 · 6 comments
Open

[QST] FP8 with row-wise scaling on Ada-Lovelace #1937

vgoklani opened this issue Nov 11, 2024 · 6 comments

Comments

@vgoklani
Copy link

I would like some clarity on this:

pytorch/pytorch#130359

it appears that Cutlass does not support row-wise scaling on Ada Lovelace cards....

Is there a time-table to get this resolved?

We purchased a bunch of these cards ($$$$$$$$) and this is very disappointing.

@jackkosaian
Copy link
Contributor

jackkosaian commented Nov 12, 2024

Based on the comments for PyTorh's scaled_mm method here, we do have this functionality for Ada. Please see example 58.

@manishucsd
Copy link
Contributor

manishucsd commented Nov 13, 2024

Based on the comments for PyTorh's scaled_mm method here

Is this really rowwise scaling in PyTorch? Please check here. For rowwise scaling, the scale_a should of Mx1 and scale_b should be 1xN. Further, I don't think rowwise scaling needs any special feature from CUTLASS-side. You should be able to use this EVT construction to obtain rowwise scaling GEMM on Ada Lovelace. Let us know if it works for you on Ada Lovelace.

cc: @drisspg

@drisspg
Copy link
Contributor

drisspg commented Nov 13, 2024

So we do have the RowwiseScaled cutlass template here: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172
it is based off of the one from fbgemm but w/ some tweaks for increased performance. I think the real issue is that this template is only stampped out for sm90: https://github.com/pytorch/pytorch/blob/a8a1e58e24ab1b9a64c6c3be4adc5919a267b56b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L172

We would accept a PR to also create a sm89 specialization

@jackkosaian
Copy link
Contributor

Is this really rowwise scaling in PyTorch?

I was looking at scaled_mm based on the API used in the linked PyTorch issue (pytorch/pytorch#130359). I agree that it does not appear to be row-wise scaling.

@vgoklani , can you please clarify whether you are interested in row-wise scaling or the calculation computed by torch._scaled_mm?

@vgoklani
Copy link
Author

Thanks @jackkosaian we are looking for row-wise scaling

@jackkosaian
Copy link
Contributor

I modified CUTLASS Python's example 4 (EVT) to generate a GEMM with row-wise scaling. Here's a Python script that does so:

import torch
import cutlass
from cutlass.epilogue import relu
from cutlass import Tensor as FakeTensor
from cutlass.utils.profiler import CUDAEventProfiler

# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to
# omit this information.
print_module = True

# The Epilogue Visitor feature currently only works for SM80 and 90
from cutlass.backend.utils.device import device_cc
if device_cc() not in [80, 89, 90]:
    import sys
    sys.exit()

m = 32
n = m
k = 32

type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16

torch.manual_seed(2023)
scope_min = -4
scope_max = 4
tensor_A = torch.ceil(torch.empty(size=(m, k), dtype=type_A, device="cuda").uniform_(scope_min, scope_max))
tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device="cuda").uniform_(scope_min, scope_max))
tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
tensor_D = torch.zeros_like(tensor_C)
tensor_A = torch.ones((m,k),device="cuda",dtype=type_A)
tensor_B = torch.eye(m,device="cuda",dtype=type_B)
tensor_C = torch.zeros_like(tensor_D)

plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)

# ## Define the epilogue visitor functor
# The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.
# * Each named variable must be assigned exactly once and defined before it used.
# * Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.
# * Return values must be a named variable.
#
# The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.
#
# The epilogue can be generated simply through `cutlass.evt.trace(<epilogue function>, <example_tensors>)`.

# In[ ]:


# Define epilogue visitor
def example_epilogue(accum, alpha, C, beta, bias):
    D = (alpha * accum + (beta * C)) * bias
    return D

# Construct inputs and outputs
alpha = 1.0
beta = 0.0
aux = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device="cuda").uniform_(scope_min, scope_max))
tensor_F = torch.zeros_like(tensor_D)
examples_tensors = {
    "accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
    "alpha": alpha,
    "C": tensor_C,
    "beta": beta,
    #"aux": aux,
    "bias": bias,
    "D": tensor_D,
    #"F": tensor_F
}

# Trace the epilogue visitor
epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)

# ## Run a GEMM with the epilogue visitor functor
# The `epilogue_visitor` can be used by setting the plan's `epilogue_visitor` field. The arguments for the epilogue visitor are provided as a `dict` through the `visitor_args` keyword argument.

# In[ ]:


visitor_args = {
    "alpha": alpha, "C": tensor_C, "beta": beta,
    #"aux": aux,
    "bias": bias,
    "D": tensor_D,
    #"F": tensor_F
}

plan.epilogue_visitor = epilogue_visitor
plan.run(
    tensor_A, tensor_B, tensor_C, tensor_D,
    visitor_args=visitor_args, print_module=print_module)

# The epilogue function `example_epilogue` can be used as a reference function. We can now verify the results simply with

# In[ ]:


class TorchReference(torch.nn.Module):
    def forward(self, A, B, alpha, C, beta, bias):
        accum = torch.matmul(A, B)
        return example_epilogue(accum, alpha, C, beta, bias)

torch_reference = TorchReference()
tensor_D_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, bias)
print(bias)
print(tensor_D)
assert torch.equal(tensor_D, tensor_D_ref)
#assert torch.equal(tensor_F, tensor_F_ref)

# The performance of CUTLASS fused kernel can be profiled with

# In[ ]:


warmup_iterations = 10
profile_iterations = 50
# Profile CUTLASS fused kernel
duration = CUDAEventProfiler(
    plan, warmup_iterations, profile_iterations,
    tensor_A, tensor_B, tensor_C, tensor_D,
    visitor_args=visitor_args)()

print(f"CUTLASS duration: {duration:.2f} ms")

This prints out the corresponding C++ for performing row-wise scaling for this GEMM:

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    cutlass::gemm::GemmShape<256, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::half_t,
    8,
    1 /* epilogue stages */
>;


using C = cutlass::epilogue::threadblock::VisitorAuxLoad<
    OutputTileThreadMap, cutlass::half_t, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>
>;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using Alpha = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
    float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
>;

using Beta = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
    float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
>;

using Bias = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>
>;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, float, float,
    cutlass::FloatRoundStyle::round_to_nearest
>;

using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute0,
    Alpha,
    Accum>;

using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, float, float,
    cutlass::FloatRoundStyle::round_to_nearest
>;

using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute1,
    Beta,
    C>;

using Compute2 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::plus, float, float,
    cutlass::FloatRoundStyle::round_to_nearest
>;

using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute2,
    EVTCompute0,
    EVTCompute1>;

using Compute3 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, cutlass::half_t, float,
    cutlass::FloatRoundStyle::round_to_nearest
>;

using EVTCompute3 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute3,
    EVTCompute2,
    Bias>;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>
>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTCompute3>;


// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8
using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::half_t, cutlass::layout::RowMajor, 8,
    float,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    EVTD,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::arch::OpMultiplyAdd,
    1 /* epilogue stages */
>::GemmKernel;

// Define named type
struct cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type :
  public cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base { };

One could consider trying this out in a PyTorch module to ensure that it does everything that's desired.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants