-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse all attention related dispatches. #19175
Comments
First issue is this sequence
These could be folded into
If this was done during
|
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x64xcomplex<f32>>) outs(%0 : tensor<4x32x?x64x2xf16>) {
^bb0(%in: complex<f32>, %out: f16):
%12 = linalg.index 0 : index
%13 = linalg.index 1 : index
%14 = linalg.index 2 : index %15 = linalg.index 3 : index
%extracted = tensor.extract %arg2[%12, %13, %14, %15, %c0] : tensor<4x?x32x64x2xf16>
%extracted_6 = tensor.extract %arg2[%12, %13, %14, %15, %c1] : tensor<4x?x32x64x2xf16>
%16 = linalg.index 4 : index
%17 = arith.extf %extracted : f16 to f32 %18 = arith.extf %extracted_6 : f16 to f32
%19 = complex.create %17, %18 : complex<f32>
%20 = complex.mul %19, %in : complex<f32>
%21 = complex.re %20 : complex<f32>
%22 = complex.im %20 : complex<f32>
%23 = arith.cmpi eq, %16, %c0 : index
%24 = arith.select %23, %21, %22 : f32
%25 = arith.truncf %24 : f32 to f16
linalg.yield %25 : f16
} -> tensor<4x32x?x64x2xf16> This linalg.generic should not have these tensor.extracts, instead they should be passed to the linalg.generic as a tensor<4x?x32x64x2xcomplex> |
@Groverkss good point. Do we know where this op originates from? It looks like its 3 linalg.generic ops fused together (float to complex) + (mul) + (complex to float) |
Recording some things that I found while looking at llama around attention ops. Here is the IR snippet I see just before forming dispatch regions
Ideally this entire sequence needs to be fused into a single dispatch.
Looking further there are a few issues that are problematic here
collapse_shape
/expand_shape
(see Fuse all attention related dispatches. #19175 (comment) ) . I think @IanWood1 already has a change for this. Would be good to verify that this case is handledtensor<....xcomplex<f32>>
intotensor<...x2xf16>
seems to introduce some artifacts into the generated IR which is making the fusion pretty involved.The text was updated successfully, but these errors were encountered: