Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse all attention related dispatches. #19175

Open
MaheshRavishankar opened this issue Nov 17, 2024 · 4 comments
Open

Fuse all attention related dispatches. #19175

MaheshRavishankar opened this issue Nov 17, 2024 · 4 comments

Comments

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Nov 17, 2024

Recording some things that I found while looking at llama around attention ops. Here is the IR snippet I see just before forming dispatch regions

module {
  func.func @attention_fusion(%arg0: index, %arg1: tensor<?x64xcomplex<f32>>, %arg2: tensor<4x?x32x64x2xf16>, %arg3: tensor<4x?x8x64x2xf16>, %arg4: tensor<4x?x8x128xf16>, %arg5: tensor<?xi64>, %arg6: tensor<?xi64>, %arg7: tensor<4xi64>) -> tensor<4x?x32x128xf16> {                                                      %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %cst = arith.constant 0xFF800000 : f32
    %cst_0 = arith.constant 8.837890e-02 : f16
    %cst_1 = arith.constant 0.000000e+00 : f32
    %false = arith.constant false
    %0 = tensor.empty(%arg0) : tensor<4x32x?x64x2xf16>
    %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x64xcomplex<f32>>) outs(%0 : tensor<4x32x?x64x2xf16>) {
    ^bb0(%in: complex<f32>, %out: f16):
      %12 = linalg.index 0 : index
      %13 = linalg.index 1 : index
      %14 = linalg.index 2 : index                                                                                                                                                                                                                                                                                              %15 = linalg.index 3 : index
      %extracted = tensor.extract %arg2[%12, %13, %14, %15, %c0] : tensor<4x?x32x64x2xf16>
      %extracted_6 = tensor.extract %arg2[%12, %13, %14, %15, %c1] : tensor<4x?x32x64x2xf16>
      %16 = linalg.index 4 : index
      %17 = arith.extf %extracted : f16 to f32                                                                                                                                                                                                                                                                                  %18 = arith.extf %extracted_6 : f16 to f32
      %19 = complex.create %17, %18 : complex<f32>
      %20 = complex.mul %19, %in : complex<f32>
      %21 = complex.re %20 : complex<f32>
      %22 = complex.im %20 : complex<f32>
      %23 = arith.cmpi eq, %16, %c0 : index
      %24 = arith.select %23, %21, %22 : f32
      %25 = arith.truncf %24 : f32 to f16
      linalg.yield %25 : f16
    } -> tensor<4x32x?x64x2xf16>
    %2 = tensor.empty(%arg0) : tensor<4x8x4x?x128xf16>
    %3 = tensor.empty(%arg0) : tensor<4x8x4x?x64x2xf16>
    %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d1, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg3 : tensor<4x?x8x64x2xf16>) outs(%3 : tenso\r<4x8x4x?x64x2xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x8x4x?x64x2xf16>
    %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d1, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<4x?x8x128xf16>) outs(%2 : tensor<4x8x4x?x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x8x4x?x128xf16>
    %6 = tensor.empty(%arg0, %arg0) : tensor<4x32x?x?xf16>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d2)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg5, %arg6, %arg7 : tensor<?xi64>\, tensor<?xi64>, tensor<4xi64>) outs(%6 : tensor<4x32x?x?xf16>) {
    ^bb0(%in: i64, %in_6: i64, %in_7: i64, %out: f16):
      %12 = arith.cmpi sge, %in, %in_6 : i64
      %13 = linalg.index 3 : index
      %14 = arith.index_cast %13 : index to i64
      %15 = arith.cmpi sge, %14, %in_7 : i64                                                                                                                                                                                                                                                                                    %16 = arith.cmpi ne, %12, %false : i1
      %17 = arith.cmpi ne, %15, %false : i1
      %18 = arith.ori %16, %17 : i1
      %19 = arith.select %18, %cst, %cst_1 : f32
      %20 = arith.truncf %19 : f32 to f16
      linalg.yield %20 : f16
    } -> tensor<4x32x?x?xf16>
    %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3, 4]] : tensor<4x32x?x64x2xf16> into tensor<128x?x128xf16>
    %collapsed_2 = tensor.collapse_shape %4 [[0, 1, 2], [3], [4, 5]] : tensor<4x8x4x?x64x2xf16> into tensor<128x?x128xf16>
    %collapsed_3 = tensor.collapse_shape %5 [[0, 1, 2], [3], [4]] : tensor<4x8x4x?x128xf16> into tensor<128x?x128xf16>
    %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
    %expanded_4 = tensor.expand_shape %collapsed_2 [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
    %expanded_5 = tensor.expand_shape %collapsed_3 [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
    %8 = tensor.empty(%arg0) : tensor<4x32x?x128xf16>
    %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0\, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} ins(%expanded, %expanded_4, %expanded_5, %cst_0, %7 : tensor<4x32x?x128xf16>, tensor<4x32x?x128xf16>, tensor<4x32x?x128xf16>, f16, tensor<4x32x?x?xf16>) outs(%8 : tensor<4x32x?x128xf16>) {
    ^bb0(%arg8: f32):
      iree_linalg_ext.yield %arg8 : f32
    } -> tensor<4x32x?x128xf16>
    %10 = tensor.empty(%arg0) : tensor<4x?x32x128xf16>
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<4x32x?x128xf16>) outs(%10 : tensor<4x?x32x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x?x32x128xf16>
    return %11 : tensor<4x?x32x128xf16>
  }
}

Ideally this entire sequence needs to be fused into a single dispatch.

Looking further there are a few issues that are problematic here

  1. There needs to be better folders for collapse_shape/expand_shape (see Fuse all attention related dispatches. #19175 (comment) ) . I think @IanWood1 already has a change for this. Would be good to verify that this case is handled
  2. We dont fuse attention with producers (same reason we dont fuse matmul with producers), but we should be able to fuse with producers that are broadcasts/transposes. We should be able to add these patterns. We do fuse broadcasts/transposes into matmuls.
  3. The main issue I see though is that the conversion from tensor<....xcomplex<f32>> into tensor<...x2xf16> seems to introduce some artifacts into the generated IR which is making the fusion pretty involved.
@MaheshRavishankar
Copy link
Contributor Author

First issue is this sequence

    %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3, 4]] : tensor<4x32x?x64x2xf16> into tensor<128x?x128xf16>
    %collapsed_2 = tensor.collapse_shape %4 [[0, 1, 2], [3], [4, 5]] : tensor<4x8x4x?x64x2xf16> into tensor<128x?x128xf16>
    %collapsed_3 = tensor.collapse_shape %5 [[0, 1, 2], [3], [4]] : tensor<4x8x4x?x128xf16> into tensor<128x?x128xf16>
    %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
    %expanded_4 = tensor.expand_shape %collapsed_2 [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>
    %expanded_5 = tensor.expand_shape %collapsed_3 [[0, 1], [2], [3]] output_shape [4, 32, %arg0, 128] : tensor<128x?x128xf16> into tensor<4x32x?x128xf16>

These could be folded into

%collapsed = tensor.collapse_shape %1[[0], [1], [2], [3, 4]] : tensor<4x32x?x64x2xf16> into tensor<4x32x?x128xf16>
%collapsed_2 = tensor.collapse_shape %4[[0], [1, 2], [3], [4, 5]] : tensor<4x8x4x?x64x2xf16> into tensor<4x32x?x128xf16>
%collapsed_3 = tensor.collapse_shape %5[[0], [1, 2], [3], [4]] : tensor<4x8x4x?x128xf16> into tensor<4x32x?x128xf16>

If this was done during BubbleUpExpandShapes pass then the collapse shape would propagate across the attention to allow for fusion of attention with its producers.

module {
  func.func @attention_fusion(%arg0: index, %arg1: tensor<?x64xcomplex<f32>>, %arg2: tensor<4x?x32x64x2xf16>, %arg3: tensor<4x?x8x64x2xf16>, %arg4: tensor<4x?x8x128xf16>, %arg5: tensor<?xi64>, %arg6: tensor<?xi64>, %arg7: tensor<4xi64>) -> tensor<4x?x32x128xf16> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %cst = arith.constant 0xFF800000 : f32
    %cst_0 = arith.constant 8.837890e-02 : f16
    %cst_1 = arith.constant 0.000000e+00 : f32
    %false = arith.constant false
    %0 = tensor.empty(%arg0) : tensor<4x32x?x64x2xf16>
    %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x64xcomplex<f32>>) outs(%0 : tensor<4x32x?x64x2xf16>) {
    ^bb0(%in: complex<f32>, %out: f16):
      %12 = linalg.index 0 : index
      %13 = linalg.index 1 : index
      %14 = linalg.index 2 : index
      %15 = linalg.index 3 : index
      %extracted = tensor.extract %arg2[%12, %13, %14, %15, %c0] : tensor<4x?x32x64x2xf16>
      %extracted_6 = tensor.extract %arg2[%12, %13, %14, %15, %c1] : tensor<4x?x32x64x2xf16>
      %16 = linalg.index 4 : index
      %17 = arith.extf %extracted : f16 to f32
      %18 = arith.extf %extracted_6 : f16 to f32
      %19 = complex.create %17, %18 : complex<f32>
      %20 = complex.mul %19, %in : complex<f32>
      %21 = complex.re %20 : complex<f32>
      %22 = complex.im %20 : complex<f32>
      %23 = arith.cmpi eq, %16, %c0 : index
      %24 = arith.select %23, %21, %22 : f32
      %25 = arith.truncf %24 : f32 to f16
      linalg.yield %25 : f16
    } -> tensor<4x32x?x64x2xf16>
    %2 = tensor.empty(%arg0) : tensor<4x8x4x?x128xf16>
    %3 = tensor.empty(%arg0) : tensor<4x8x4x?x64x2xf16>
    %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d1, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg3 : tensor<4x?x8x64x2xf16>) outs(%3 : tenso\r<4x8x4x?x64x2xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x8x4x?x64x2xf16>                                                                                                                                                                                                                                                                                             %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d1, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<4x?x8x128xf16>) outs(%2 : tensor<4x8x4x?x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x8x4x?x128xf16>
    %6 = tensor.empty(%arg0, %arg0) : tensor<4x32x?x?xf16>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d2)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg5, %arg6, %arg7 : tensor<?xi64>\, tensor<?xi64>, tensor<4xi64>) outs(%6 : tensor<4x32x?x?xf16>) {
    ^bb0(%in: i64, %in_6: i64, %in_7: i64, %out: f16):
      %12 = arith.cmpi sge, %in, %in_6 : i64
      %13 = linalg.index 3 : index
      %14 = arith.index_cast %13 : index to i64
      %15 = arith.cmpi sge, %14, %in_7 : i64
      %16 = arith.cmpi ne, %12, %false : i1
      %17 = arith.cmpi ne, %15, %false : i1
      %18 = arith.ori %16, %17 : i1
      %19 = arith.select %18, %cst, %cst_1 : f32
      %20 = arith.truncf %19 : f32 to f16
      linalg.yield %20 : f16
    } -> tensor<4x32x?x?xf16>
    %collapsed = tensor.collapse_shape %1[[0], [1], [2], [3, 4]] : tensor<4x32x?x64x2xf16> into tensor<4x32x?x128xf16>
    %collapsed_2 = tensor.collapse_shape %4[[0], [1, 2], [3], [4, 5]] : tensor<4x8x4x?x64x2xf16> into tensor<4x32x?x128xf16>
    %collapsed_3 = tensor.collapse_shape %5[[0], [1, 2], [3], [4]] : tensor<4x8x4x?x128xf16> into tensor<4x32x?x128xf16>
    %8 = tensor.empty(%arg0) : tensor<4x32x?x128xf16>
    %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0\, d1, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} ins(%collapsed, %collapsed_2, %collapsed_3, %cst_0, %7 : tensor<4x32x?x128xf16>, tensor<4x32x?x128xf16>, tensor<4x32x?x128xf16>, f16, tensor<4x32x?x?xf16>) outs(%8 : tensor<4x32x?x128xf16>) {
    ^bb0(%arg8: f32):
      iree_linalg_ext.yield %arg8 : f32
    } -> tensor<4x32x?x128xf16>
    %10 = tensor.empty(%arg0) : tensor<4x?x32x128xf16>
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<4x32x?x128xf16>) outs(%10 : tensor<4x?x32x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x?x32x128xf16>
    return %11 : tensor<4x?x32x128xf16>
  }
}

@MaheshRavishankar
Copy link
Contributor Author

cc @Groverkss @IanWood1 @manupak

@Groverkss
Copy link
Contributor

   %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x64xcomplex<f32>>) outs(%0 : tensor<4x32x?x64x2xf16>) {
    ^bb0(%in: complex<f32>, %out: f16):
      %12 = linalg.index 0 : index
      %13 = linalg.index 1 : index
      %14 = linalg.index 2 : index                                                                                                                                                                                                                                                                                              %15 = linalg.index 3 : index
      %extracted = tensor.extract %arg2[%12, %13, %14, %15, %c0] : tensor<4x?x32x64x2xf16>
      %extracted_6 = tensor.extract %arg2[%12, %13, %14, %15, %c1] : tensor<4x?x32x64x2xf16>
      %16 = linalg.index 4 : index
      %17 = arith.extf %extracted : f16 to f32                                                                                                                                                                                                                                                                                  %18 = arith.extf %extracted_6 : f16 to f32
      %19 = complex.create %17, %18 : complex<f32>
      %20 = complex.mul %19, %in : complex<f32>
      %21 = complex.re %20 : complex<f32>
      %22 = complex.im %20 : complex<f32>
      %23 = arith.cmpi eq, %16, %c0 : index
      %24 = arith.select %23, %21, %22 : f32
      %25 = arith.truncf %24 : f32 to f16
      linalg.yield %25 : f16
    } -> tensor<4x32x?x64x2xf16>

This linalg.generic should not have these tensor.extracts, instead they should be passed to the linalg.generic as a tensor<4x?x32x64x2xcomplex>

@IanWood1
Copy link
Contributor

@Groverkss good point. Do we know where this op originates from? It looks like its 3 linalg.generic ops fused together (float to complex) + (mul) + (complex to float)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants