Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A100 vs MI250X conv performance #3310

Open
etiennemlb opened this issue Oct 11, 2024 · 8 comments
Open

A100 vs MI250X conv performance #3310

etiennemlb opened this issue Oct 11, 2024 · 8 comments

Comments

@etiennemlb
Copy link

I would like to inquire about the performance of two kernels:
naive_conv_nonpacked_bwd_nchw_half_double_half
naive_conv_nonpacked_fwd_nchw_half_double_half

When are these used when we call miopen_convolution_forward ? I have a pytorch model that is x6.4 times slower on MI250X compared to A100.

@averinevg
Copy link
Collaborator

Hi @etiennemlb, naive kernels are the last resort when none of the other kernels are applicable. Could you provide more information about the tensor sizes? It would also be useful to have a minimal reproducer.

@formiel
Copy link

formiel commented Oct 16, 2024

Hello @averinevg. Thanks a lot for your reply! I'm @etiennemlb's colleague. We have prepared a minimal code example to reproduce the speed differences we observed, along with profiling results for various configurations. These details are available via the link below. I’ve also included the code here for your convenience.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.profiler import profile, ProfilerActivity, record_function


class TransposeLast(nn.Module):
    def __init__(self, transpose_dim=-2):
        super().__init__()
        self.transpose_dim = transpose_dim
    def forward(self, x):
        return x.transpose(self.transpose_dim, -1)
    

class MyAudioModel(nn.Module):
    def __init__(self, num_conv=2):
        super().__init__()

        conv_layers = []
        # first conv layer
        conv_layers.append(nn.Conv1d(1, 512, 10, 5))
        conv_layers.append(TransposeLast())
        conv_layers.append(nn.LayerNorm(512, elementwise_affine=True))
        conv_layers.append(TransposeLast())
        conv_layers.append(nn.GELU())

        for _ in range(num_conv - 1):
            conv_layers.append(nn.Conv1d(512, 512, 3, 2))
            conv_layers.append(TransposeLast())
            conv_layers.append(nn.LayerNorm(512, elementwise_affine=True))
            conv_layers.append(TransposeLast())
            conv_layers.append(nn.GELU())
        self.conv_layers = nn.Sequential(*conv_layers)

        self.proj = nn.Sequential(
            TransposeLast(),
            nn.LayerNorm(512),
            nn.Linear(512, 64),
            TransposeLast(),
        )

    def forward(self, x):
        # BxT -> BxCxT
        x = x.unsqueeze(1)
        x = self.conv_layers(x)
        x = self.proj(x) # BxCxT -> BxTxD
        return torch.mean(x, dim=-1)

def main():

    # Main params
    fp16_training = True
    input_size = "8_320000"
    num_conv = 6
    device = "cuda"
    suffix = "_mi250x"
    epochs = 3

    B, T = [int(i) for i in input_size.split("_")]
  
    # Create dummy dataset and data loader
    x = torch.randn((B, T))
    y = torch.randn(B, 64)
    dataset = TensorDataset(x, y)
    data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Initialize model and optimizer
    model = MyAudioModel(num_conv=int(num_conv))
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    model.to(device=device)
    if fp16_training:
        model = model.half()
  
  # Profiling
    folder_name = f"conv{num_conv}L_input{input_size}{suffix}"
    profile_dir = f"{os.environ.get('WORK')}/profile-conv-ops/{folder_name}"
    os.makedirs(profile_dir, exist_ok=True)
  
    prof = profile(activities=[
                        ProfilerActivity.CPU, 
                        ProfilerActivity.CUDA
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
                    record_shapes=True, profile_memory=True, with_flops=True,
                    with_stack=True,
                    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
                )
    prof.start()
  
    for _ in range(epochs):
        for batch_x, batch_y in data_loader:
            batch_x, batch_y = batch_x.to(device=device), batch_y.to(device=device)
            if fp16_training:
                batch_x, batch_y = batch_x.half(), batch_y.half()
            with record_function("model_forward"):
                output = model(batch_x)

            # Compute loss
            loss = criterion(output, batch_y)

            # Backward pass
            with record_function("backward_pass"):
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

    prof.stop()

We found that the MI250x is around 7 times slower than the A100 when using 6 CNN layers for input tensors of shape (8, 320000). Disabling the direct convolution algorithm with export MIOPEN_DEBUG_CONV_DIRECT=0 prevents the naive_conv_packed kernels from being invoked, but we did not observe improvements in speed for our configuration by setting this global variable. Could you please help us take a look if there is anything that we can do to improve the training speed?

Many thanks in advance for your help!

@etiennemlb
Copy link
Author

@averinevg Hey, have you had the time to take a look? The slowdown incurred in the example above makes the usage of mi250x cards impractical.

@huanrwan-amd
Copy link

Hi @formiel and @etiennemlb can you provide more information on your hardware, software (OS version, ROCm version) and rocminfo output?
Another thing is to follow https://rocm.docs.amd.com/projects/install-on-linux/en/develop/install/3rd-party/pytorch-install.html#using-docker-with-pytorch-pre-installed for the installing pytorch for ROCm. Thanks.

@etiennemlb
Copy link
Author

etiennemlb commented Oct 30, 2024

We run on a RHEL8 using rocm 5.7.1 or 6.0.0. We use HPE-Cray's bardpeak nodes equipped with 4 MI250X (8 gcd). If you are familiar with Frontier's nodes, well thats it.

For me, using containers is a no go for this machine.

@huanrwan-amd
Copy link

huanrwan-amd commented Oct 30, 2024

Hi @etiennemlb, have you tried to upgrade ROCm to the latest release? https://rocm.docs.amd.com/en/latest/about/release-notes.html and enable logs https://rocm.docs.amd.com/projects/MIOpen/en/latest/how-to/debug-log.html Thanks.

@RobQuistNL
Copy link

@huanrwan-amd experiencing this on 6.2.4 as well.

@huanrwan-amd
Copy link

Hi @etiennemlb and @formiel, any updates after updating ROCm? Please be advised that using Pytorch ROCm stack: https://github.com/ROCm/pytorch.

In general, MIOpen benchmarks various kernels to select the most efficient one for a given operation. If the naive kernel is being selected, it might indicate that other kernels are not applicable or not performing well for your specific configuration. Naive convolution kernels are typically used as a fallback when more optimized kernels are not applicable. They are generally less efficient and can lead to slower performance.

You can find more info on: https://rocm.docs.amd.com/projects/MIOpen/en/latest/conceptual/finddb.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants