Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[thunder] INTERNAL_ASSERT_FAILED #3461

Open
kshitij12345 opened this issue Nov 22, 2024 · 3 comments
Open

[thunder] INTERNAL_ASSERT_FAILED #3461

kshitij12345 opened this issue Nov 22, 2024 · 3 comments

Comments

@kshitij12345
Copy link

Repro Script

# CUDA devices:
#  0: NVIDIA GeForce RTX 3090
#  1: NVIDIA GeForce RTX 3090
# torch version: 2.5.1+cu121
# cuda version: 12.1
# nvfuser version: 0.2.22+gitc14d418
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id22(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 32, 1024, 128], contiguity=[None, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 1, 2, 0])
    T1 = fd.define_tensor(shape=[1, 1024, 128], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 0, 1])
    T2 = fd.define_tensor(shape=[1, 8, 1024, 128], contiguity=[None, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T3 = fd.define_tensor(shape=[1, 32, 1024, 128], contiguity=[None, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 1, 2, 0])
    T4 = fd.define_tensor(shape=[1, 32, 1024, 128], contiguity=[None, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 1, 2, 0])
    T5 = fd.define_tensor(shape=[1, 1024, 128], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 0, 1])
    T6 = fd.define_tensor(shape=[1, 8, 1024, 128], contiguity=[None, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T13 = fd.ops.reshape(T0, new_shape=[1, 8, 4, 1024, 128])
    T14 = fd.ops.sum(T13, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T21 = fd.ops.broadcast_in_dim(T14, shape=[1, 8, 1, 1024, 128], broadcast_dims=[1, 3, 4])
    T22 = fd.ops.sum(T21, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T28 = fd.ops.broadcast_in_dim(T1, shape=[1, 1, 1024, 128], broadcast_dims=[0, 2, 3])
    T34 = fd.ops.broadcast_in_dim(T22, shape=[1, 8, 1024, 128], broadcast_dims=[1, 2, 3])
    T40 = fd.ops.broadcast_in_dim(T28, shape=[1, 32, 1024, 128], broadcast_dims=[0, 1, 2, 3])
    T41 = fd.ops.add(T2, T34)
    T47 = fd.ops.broadcast_in_dim(T28, shape=[1, 8, 1024, 128], broadcast_dims=[0, 1, 2, 3])
    T48 = fd.ops.mul(T40, T3)
    T49 = fd.ops.mul(T47, T41)
    T56 = fd.ops.reshape(T4, new_shape=[1, 8, 4, 1024, 128])
    T72 = fd.ops.slice(T48, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 1024, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T88 = fd.ops.slice(T49, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 1024, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T89 = fd.ops.sum(T56, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T95 = fd.ops.broadcast_in_dim(T5, shape=[1, 1, 1024, 128], broadcast_dims=[0, 2, 3])
    T111 = fd.ops.slice(T48, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 1024, 128], strides=[1, 1, 1, 1], manual_normalization=0)
    T112 = fd.ops.neg(T72)
    T128 = fd.ops.slice(T49, start_indices=[0, 0, 0, 64], end_indices=[1, 8, 1024, 128], strides=[1, 1, 1, 1], manual_normalization=0)
    T129 = fd.ops.neg(T88)
    T136 = fd.ops.broadcast_in_dim(T89, shape=[1, 8, 1, 1024, 128], broadcast_dims=[1, 3, 4])
    T142 = fd.ops.broadcast_in_dim(T95, shape=[1, 32, 1024, 128], broadcast_dims=[0, 1, 2, 3])
    S143 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T153 = fd.ops.pad(T111, [0, 64, 0, 0, 0, 0, 0, 0], S143)
    S154 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T164 = fd.ops.pad(T112, [64, 0, 0, 0, 0, 0, 0, 0], S154)
    T170 = fd.ops.broadcast_in_dim(T95, shape=[1, 8, 1024, 128], broadcast_dims=[0, 1, 2, 3])
    S171 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T181 = fd.ops.pad(T128, [0, 64, 0, 0, 0, 0, 0, 0], S171)
    S182 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T192 = fd.ops.pad(T129, [64, 0, 0, 0, 0, 0, 0, 0], S182)
    T193 = fd.ops.sum(T136, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T194 = fd.ops.mul(T142, T3)
    T195 = fd.ops.add(T164, T153)
    T196 = fd.ops.mul(T170, T41)
    T197 = fd.ops.add(T192, T181)
    T203 = fd.ops.broadcast_in_dim(T193, shape=[1, 8, 1024, 128], broadcast_dims=[1, 2, 3])
    T204 = fd.ops.add(T195, T194)
    T205 = fd.ops.add(T197, T196)
    T206 = fd.ops.add(T6, T203)
    T207 = fd.ops.permute(T204, dims=[0, 2, 1, 3])
    T208 = fd.ops.permute(T205, dims=[0, 2, 1, 3])
    T209 = fd.ops.permute(T206, dims=[0, 2, 1, 3])
    T214 = fd.ops.reshape(T207, new_shape=[1, 1024, 4096])
    T219 = fd.ops.reshape(T208, new_shape=[1, 1024, 1024])
    T224 = fd.ops.reshape(T209, new_shape=[1, 1024, 1024])
    T228 = fd.ops.reshape(T214, new_shape=[1024, 4096])
    T232 = fd.ops.reshape(T219, new_shape=[1024, 1024])
    T236 = fd.ops.reshape(T224, new_shape=[1024, 1024])
    T237 = fd.ops.permute(T228, dims=[1, 0])
    T238 = fd.ops.permute(T232, dims=[1, 0])
    T239 = fd.ops.permute(T236, dims=[1, 0])
    fd.add_output(T236)
    fd.add_output(T239)
    fd.add_output(T232)
    fd.add_output(T238)
    fd.add_output(T228)
    fd.add_output(T237)

with FusionDefinition() as fd:
    nvfuser_fusion_id22(fd)

inputs = [
    torch.randn(4194304, dtype=torch.float32, device='cuda:0').as_strided((1, 32, 1024, 128), (4194304, 128, 4096, 1)),
    torch.randn(131072, dtype=torch.float32, device='cuda:0').as_strided((1, 1024, 128), (131072, 1, 1024)),
    torch.testing.make_tensor((1, 8, 1024, 128), dtype=torch.float32, device='cuda:0'),
    torch.randn(4194304, dtype=torch.float32, device='cuda:0').as_strided((1, 32, 1024, 128), (4194304, 128, 4096, 1)),
    torch.randn(4194304, dtype=torch.float32, device='cuda:0').as_strided((1, 32, 1024, 128), (4194304, 128, 4096, 1)),
    torch.randn(131072, dtype=torch.float32, device='cuda:0').as_strided((1, 1024, 128), (131072, 1, 1024)),
    torch.testing.make_tensor((1, 8, 1024, 128), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

Failing CI - https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=220387&view=logs&j=3f274fac-2e11-54ca-487e-194c91f3ae9f&t=244491d3-5bd5-5b27-6d81-66bb4c7264ae&l=375

CI Log - ci_log.txt

@naoyam
Copy link
Collaborator

naoyam commented Nov 22, 2024

There's some recent bug fixes that may be related. Could you try the latest version?

@kshitij12345
Copy link
Author

Updating to latest worked locally. I will wait for the thunder CI to pick-up the update and then close the issue once CI is green. Thank you!!

@kshitij12345
Copy link
Author

On latest version, I am seeing an internal assert (this is using an internal image)

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git45ed7c1
# nvfuser version: 0.2.23+git8546b62
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[128, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[128, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[5, 5, 288], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[5, 5, 1024], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T13 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[5, 4], strides=[1, 1], manual_normalization=0)
    T23 = fd.ops.slice(T1, start_indices=[0, 0], end_indices=[5, 4], strides=[1, 1], manual_normalization=0)
    T30 = fd.ops.reshape(T2, new_shape=[5, 5, 4, 18, 4])
    T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4])
    T50 = fd.ops.slice(T31, start_indices=[0, 0, 0, 0, 0], end_indices=[5, 4, 16, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T69 = fd.ops.slice(T31, start_indices=[0, 0, 16, 0, 0], end_indices=[5, 4, 17, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T88 = fd.ops.slice(T31, start_indices=[0, 0, 17, 0, 0], end_indices=[5, 4, 18, 5, 4], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T95 = fd.ops.broadcast_in_dim(T69, shape=[5, 4, 16, 5, 4], broadcast_dims=[0, 1, 2, 3, 4])
    T102 = fd.ops.broadcast_in_dim(T88, shape=[5, 4, 16, 5, 4], broadcast_dims=[0, 1, 2, 3, 4])
    T108 = fd.ops.reshape(T50, new_shape=[5, 64, 5, 4])
    T114 = fd.ops.reshape(T95, new_shape=[5, 64, 5, 4])
    T120 = fd.ops.reshape(T102, new_shape=[5, 64, 5, 4])
    T136 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 2], strides=[1, 1, 1, 1], manual_normalization=0)
    T152 = fd.ops.slice(T108, start_indices=[0, 0, 0, 2], end_indices=[5, 64, 5, 4], strides=[1, 1, 1, 1], manual_normalization=0)
    T153 = fd.ops.neg(T152)
    T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0)
    T160 = fd.ops.broadcast_in_dim(T13, shape=[5, 64, 5, 4], broadcast_dims=[2, 3])
    T161 = fd.ops.mul(T108, T160)
    T167 = fd.ops.broadcast_in_dim(T23, shape=[5, 64, 5, 4], broadcast_dims=[2, 3])
    T168 = fd.ops.mul(T154, T167)
    T169 = fd.ops.add(T161, T168)
    T185 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 2], strides=[1, 1, 1, 1], manual_normalization=0)
    T201 = fd.ops.slice(T114, start_indices=[0, 0, 0, 2], end_indices=[5, 64, 5, 4], strides=[1, 1, 1, 1], manual_normalization=0)
    T202 = fd.ops.neg(T201)
    T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0)
    T204 = fd.ops.mul(T114, T160)
    T205 = fd.ops.mul(T203, T167)
    T206 = fd.ops.add(T204, T205)
    T222 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
    T239 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 64, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T240 = fd.ops.cat([T206, T239], dim=-1, manual_padding=0)
    S241 = fd.define_scalar(0.707107, dtype=DataType.Double)
    T242 = fd.ops.mul(T223, S241)
    T243 = fd.ops.permute(T240, dims=[0, 1, 3, 2])
    S244 = fd.define_scalar(0.707107, dtype=DataType.Double)
    T245 = fd.ops.mul(T243, S244)
    S246 = fd.define_scalar(1.41421, dtype=DataType.Double)
    S247 = fd.ops.reciprocal(S246)
    T248 = fd.ops.mul(T3, S247)
    T249 = fd.ops.erf(T248)
    S250 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T251 = fd.ops.mul(S250, T249)
    S252 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T253 = fd.ops.add(S252, T251)
    T254 = fd.ops.mul(T3, T253)
    fd.add_output(T120)
    fd.add_output(T160)
    fd.add_output(T167)
    fd.add_output(T242)
    fd.add_output(T245)
    fd.add_output(T254)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((128, 4), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((128, 4), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 288), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 1024), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

faillure_log.log

CI Failure Log - https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=220909&view=logs&j=2840892e-91ab-5245-da62-77ec9923516a&t=444f4171-6797-5730-4229-41427ed3bdc9&l=15340

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants