-
Notifications
You must be signed in to change notification settings - Fork 12.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Func] Delete DecomposeCallGraphTypes.cpp
#117424
base: users/matthias-springer/1n_pattern
Are you sure you want to change the base?
[mlir][Func] Delete DecomposeCallGraphTypes.cpp
#117424
Conversation
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments [WIP] 1:N conversion pattern update test cases Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> address comments rollback unresolved materializations properly
@llvm/pr-subscribers-mlir-func @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) Changes
Note for LLVM integration: If you are using Full diff: https://github.com/llvm/llvm-project/pull/117424.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
deleted file mode 100644
index 1be406bf3adf92..00000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Conversion patterns for decomposing types along call graph edges. That is,
-// decomposing types for calls, returns, and function args.
-//
-// TODO: Make this handle dialect-defined functions, calls, and returns.
-// Currently, the generic interfaces aren't sophisticated enough for the
-// types of mutations that we are doing here.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
-
-namespace mlir {
-
-/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `TypeConverter`.
-void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
- const TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index f8fb1f436a95b1..6384d25ee70273 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRFuncTransforms
- DecomposeCallGraphTypes.cpp
DuplicateFunctionElimination.cpp
FuncConversions.cpp
OneToNFuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
deleted file mode 100644
index 03be00328bda33..00000000000000
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForFuncArgs
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand function arguments according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForFuncArgs
- : public OpConversionPattern<func::FuncOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- auto functionType = op.getFunctionType();
-
- // Convert function arguments using the provided TypeConverter.
- TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
- for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
- SmallVector<Type, 2> decomposedTypes;
- if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
- return failure();
- if (!decomposedTypes.empty())
- conversion.addInputs(argType.index(), decomposedTypes);
- }
-
- // If the SignatureConversion doesn't apply, bail out.
- if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
- &conversion)))
- return failure();
-
- // Update the signature of the function.
- SmallVector<Type, 2> newResultTypes;
- if (failed(typeConverter->convertTypes(functionType.getResults(),
- newResultTypes)))
- return failure();
- rewriter.modifyOpInPlace(op, [&] {
- op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
- newResultTypes));
- });
- return success();
- }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForReturnOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand return operands according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForReturnOp
- : public OpConversionPattern<ReturnOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- SmallVector<Value, 2> newOperands;
- for (ValueRange operand : adaptor.getOperands())
- llvm::append_range(newOperands, operand);
- rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
- return success();
- }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand call op operands and results according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
-
- // Create the operands list of the new `CallOp`.
- SmallVector<Value, 2> newOperands;
- for (ValueRange operand : adaptor.getOperands())
- llvm::append_range(newOperands, operand);
-
- // Create the new result types for the new `CallOp` and track the number of
- // replacement types for each original op result.
- SmallVector<Type, 2> newResultTypes;
- SmallVector<unsigned> expandedResultSizes;
- for (Type resultType : op.getResultTypes()) {
- unsigned oldSize = newResultTypes.size();
- if (failed(typeConverter->convertType(resultType, newResultTypes)))
- return failure();
- expandedResultSizes.push_back(newResultTypes.size() - oldSize);
- }
-
- CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
- newResultTypes, newOperands);
-
- // Build a replacement value for each result to replace its uses.
- SmallVector<ValueRange> replacedValues;
- replacedValues.reserve(op.getNumResults());
- unsigned startIdx = 0;
- for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- ValueRange repl =
- newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
- replacedValues.push_back(repl);
- startIdx += expandedResultSizes[i];
- }
- rewriter.replaceOpWithMultiple(op, replacedValues);
- return success();
- }
-};
-} // namespace
-
-void mlir::populateDecomposeCallGraphTypesPatterns(
- MLIRContext *context, const TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns
- .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
-}
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 9e7759bef6d8fd..d531960aa285d1 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -124,12 +124,9 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- // For a return, all operands go to the results of the parent, so
- // rewrite them all.
- rewriter.modifyOpInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.replaceOpWithNewOp<ReturnOp>(op, flattenValues(adaptor.getOperands()));
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index de511c58ae6ee0..15c8bac61e38b0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -9,7 +9,7 @@
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -142,7 +142,9 @@ struct TestDecomposeCallGraphTypes
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);
- populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f8514c8
to
5d6e8e4
Compare
5d6e8e4
to
4e4a5c8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM!
47e321a
to
76a8541
Compare
76a8541
to
712c819
Compare
DecomposeCallGraphTypes.cpp
was a workaround around missing 1:N support in the dialect conversion. Now that 1:N support was added, the workaround can be deleted. The test remains in place, as an example for how to write such a transformation with the dialect conversion framework.Note for LLVM integration: If you are using
DecomposeCallGraphTypes.cpp
, switch to the patterns that are used inTestDecomposeCallGraphTypes.cpp
.