aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp336
1 files changed, 230 insertions, 106 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f1e1359255bd..17594b98c5bc 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -46,6 +46,8 @@
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"
+#include <cmath>
+
using namespace llvm;
using namespace PatternMatch;
@@ -80,6 +82,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
"Use row-major layout")));
+static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
+ cl::init(false));
+
/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -88,6 +93,39 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
}
+/// Erase \p V from \p BB and move \II forward to avoid invalidating
+/// iterators.
+static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
+ BasicBlock &BB) {
+ auto *Inst = cast<Instruction>(V);
+ // Still used, don't erase.
+ if (!Inst->use_empty())
+ return;
+ if (II != BB.rend() && Inst == &*II)
+ ++II;
+ Inst->eraseFromParent();
+}
+
+/// Return true if V is a splat of a value (which is used when multiplying a
+/// matrix with a scalar).
+static bool isSplat(Value *V) {
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
+ return SV->isZeroEltSplat();
+ return false;
+}
+
+/// Match any mul operation (fp or integer).
+template <typename LTy, typename RTy>
+auto m_AnyMul(const LTy &L, const RTy &R) {
+ return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
+}
+
+/// Match any add operation (fp or integer).
+template <typename LTy, typename RTy>
+auto m_AnyAdd(const LTy &L, const RTy &R) {
+ return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
+}
+
namespace {
// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
@@ -384,6 +422,9 @@ class LowerMatrixIntrinsics {
return NumColumns;
return NumRows;
}
+
+ /// Returns the transposed shape.
+ ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
};
/// Maps instructions to their shape information. The shape information
@@ -437,10 +478,10 @@ public:
/// Return the estimated number of vector ops required for an operation on
/// \p VT * N.
unsigned getNumOps(Type *ST, unsigned N) {
- return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
+ return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
double(TTI.getRegisterBitWidth(
TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedSize()));
+ .getFixedValue()));
}
/// Return the set of vectors that a matrix value is lowered to.
@@ -684,115 +725,198 @@ public:
return NewWorkList;
}
- /// Try moving transposes in order to fold them away or into multiplies.
- void optimizeTransposes() {
- auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
- // We need to remove Old from the ShapeMap otherwise RAUW will replace it
- // with New. We should only add New it it supportsShapeInfo so we insert
- // it conditionally instead.
- auto S = ShapeMap.find(&Old);
- if (S != ShapeMap.end()) {
- ShapeMap.erase(S);
- if (supportsShapeInfo(New))
- ShapeMap.insert({New, S->second});
- }
- Old.replaceAllUsesWith(New);
- };
+ /// (Op0 op Op1)^T -> Op0^T op Op1^T
+ /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
+ /// them on both sides of \p Operation.
+ Instruction *distributeTransposes(
+ Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
+ MatrixBuilder &Builder,
+ function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
+ Operation) {
+ Value *T0 = Builder.CreateMatrixTranspose(
+ Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
+ // We are being run after shape prop, add shape for newly created
+ // instructions so that we lower them later.
+ setShapeInfo(T0, Shape0.t());
+ Value *T1 = Builder.CreateMatrixTranspose(
+ Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
+ setShapeInfo(T1, Shape1.t());
+ return Operation(T0, Shape0.t(), T1, Shape1.t());
+ }
- // First sink all transposes inside matmuls, hoping that we end up with NN,
- // NT or TN variants.
- for (BasicBlock &BB : reverse(Func)) {
- for (auto II = BB.rbegin(); II != BB.rend();) {
- Instruction &I = *II;
- // We may remove II. By default continue on the next/prev instruction.
- ++II;
- // If we were to erase II, move again.
- auto EraseFromParent = [&II, &BB](Value *V) {
- auto *Inst = cast<Instruction>(V);
- if (Inst->use_empty()) {
- if (II != BB.rend() && Inst == &*II) {
- ++II;
- }
- Inst->eraseFromParent();
- }
- };
+ void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
+ // We need to remove Old from the ShapeMap otherwise RAUW will replace it
+ // with New. We should only add New it it supportsShapeInfo so we insert
+ // it conditionally instead.
+ auto S = ShapeMap.find(&Old);
+ if (S != ShapeMap.end()) {
+ ShapeMap.erase(S);
+ if (supportsShapeInfo(New))
+ ShapeMap.insert({New, S->second});
+ }
+ Old.replaceAllUsesWith(New);
+ }
- // If we're creating a new instruction, continue from there.
- Instruction *NewInst = nullptr;
+ /// Sink a top-level transpose inside matmuls and adds.
+ /// This creates and erases instructions as needed, and returns the newly
+ /// created instruction while updating the iterator to avoid invalidation. If
+ /// this returns nullptr, no new instruction was created.
+ Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
+ BasicBlock &BB = *I.getParent();
+ IRBuilder<> IB(&I);
+ MatrixBuilder Builder(IB);
- IRBuilder<> IB(&I);
- MatrixBuilder Builder(IB);
+ Value *TA, *TAMA, *TAMB;
+ ConstantInt *R, *K, *C;
+ if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
+ return nullptr;
- Value *TA, *TAMA, *TAMB;
- ConstantInt *R, *K, *C;
- if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
+ // Transpose of a transpose is a nop
+ Value *TATA;
+ if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
+ updateShapeAndReplaceAllUsesWith(I, TATA);
+ eraseFromParentAndMove(&I, II, BB);
+ eraseFromParentAndMove(TA, II, BB);
+ return nullptr;
+ }
- // Transpose of a transpose is a nop
- Value *TATA;
- if (match(TA,
- m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
- ReplaceAllUsesWith(I, TATA);
- EraseFromParent(&I);
- EraseFromParent(TA);
- }
+ // k^T -> k
+ if (isSplat(TA)) {
+ updateShapeAndReplaceAllUsesWith(I, TA);
+ eraseFromParentAndMove(&I, II, BB);
+ return nullptr;
+ }
- // (A * B)^t -> B^t * A^t
- // RxK KxC CxK KxR
- else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
- m_ConstantInt(K), m_ConstantInt(C)))) {
- Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
- C->getZExtValue(),
- TAMB->getName() + "_t");
- // We are being run after shape prop, add shape for newly created
- // instructions so that we lower them later.
- setShapeInfo(T0, {C, K});
- Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
- K->getZExtValue(),
- TAMA->getName() + "_t");
- setShapeInfo(T1, {K, R});
- NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
- K->getZExtValue(),
- R->getZExtValue(), "mmul");
- ReplaceAllUsesWith(I, NewInst);
- EraseFromParent(&I);
- EraseFromParent(TA);
- }
- }
+ // (A * B)^t -> B^t * A^t
+ // RxK KxC CxK KxR
+ if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
+ m_ConstantInt(K), m_ConstantInt(C)))) {
+ auto NewInst = distributeTransposes(
+ TAMB, {K, C}, TAMA, {R, K}, Builder,
+ [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+ return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
+ Shape0.NumColumns,
+ Shape1.NumColumns, "mmul");
+ });
+ updateShapeAndReplaceAllUsesWith(I, NewInst);
+ eraseFromParentAndMove(&I, II, BB);
+ eraseFromParentAndMove(TA, II, BB);
+ return NewInst;
+ }
+
+ // Same as above, but with a mul, which occurs when multiplied
+ // with a scalar.
+ // (A * k)^t -> A^t * k
+ // R x C RxC
+ if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
+ (isSplat(TAMA) || isSplat(TAMB))) {
+ IRBuilder<> LocalBuilder(&I);
+ // We know that the transposed operand is of shape RxC.
+ // An when multiplied with a scalar, the shape is preserved.
+ auto NewInst = distributeTransposes(
+ TAMA, {R, C}, TAMB, {R, C}, Builder,
+ [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+ bool IsFP = I.getType()->isFPOrFPVectorTy();
+ auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
+ : LocalBuilder.CreateMul(T0, T1, "mmul");
+ auto *Result = cast<Instruction>(Mul);
+ setShapeInfo(Result, Shape0);
+ return Result;
+ });
+ updateShapeAndReplaceAllUsesWith(I, NewInst);
+ eraseFromParentAndMove(&I, II, BB);
+ eraseFromParentAndMove(TA, II, BB);
+ return NewInst;
+ }
+
+ // (A + B)^t -> A^t + B^t
+ // RxC RxC CxR CxR
+ if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
+ IRBuilder<> LocalBuilder(&I);
+ auto NewInst = distributeTransposes(
+ TAMA, {R, C}, TAMB, {R, C}, Builder,
+ [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+ auto *FAdd =
+ cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
+ setShapeInfo(FAdd, Shape0);
+ return FAdd;
+ });
+ updateShapeAndReplaceAllUsesWith(I, NewInst);
+ eraseFromParentAndMove(&I, II, BB);
+ eraseFromParentAndMove(TA, II, BB);
+ return NewInst;
+ }
+
+ return nullptr;
+ }
+
+ void liftTranspose(Instruction &I) {
+ // Erase dead Instructions after lifting transposes from binops.
+ auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
+ if (T.use_empty())
+ T.eraseFromParent();
+ if (A->use_empty())
+ cast<Instruction>(A)->eraseFromParent();
+ if (A != B && B->use_empty())
+ cast<Instruction>(B)->eraseFromParent();
+ };
+
+ Value *A, *B, *AT, *BT;
+ ConstantInt *R, *K, *C;
+ // A^t * B ^t -> (B * A)^t
+ if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(A), m_Value(B), m_ConstantInt(R),
+ m_ConstantInt(K), m_ConstantInt(C))) &&
+ match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
+ match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
+ IRBuilder<> IB(&I);
+ MatrixBuilder Builder(IB);
+ Value *M = Builder.CreateMatrixMultiply(
+ BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
+ setShapeInfo(M, {C, R});
+ Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
+ R->getZExtValue());
+ updateShapeAndReplaceAllUsesWith(I, NewInst);
+ CleanupBinOp(I, A, B);
+ }
+ // A^t + B ^t -> (A + B)^t
+ else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
+ match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
+ match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+ IRBuilder<> Builder(&I);
+ Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+ setShapeInfo(Add, {C, R});
+ MatrixBuilder MBuilder(Builder);
+ Instruction *NewInst = MBuilder.CreateMatrixTranspose(
+ Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+ updateShapeAndReplaceAllUsesWith(I, NewInst);
+ CleanupBinOp(I, A, B);
+ }
+ }
- // If we replaced I with a new instruction, continue from there.
- if (NewInst)
+ /// Try moving transposes in order to fold them away or into multiplies.
+ void optimizeTransposes() {
+ // First sink all transposes inside matmuls and adds, hoping that we end up
+ // with NN, NT or TN variants.
+ for (BasicBlock &BB : reverse(Func)) {
+ for (auto II = BB.rbegin(); II != BB.rend();) {
+ Instruction &I = *II;
+ // We may remove II. By default continue on the next/prev instruction.
+ ++II;
+ if (Instruction *NewInst = sinkTranspose(I, II))
II = std::next(BasicBlock::reverse_iterator(NewInst));
}
}
- // If we have a TT matmul, lift the transpose. We may be able to fold into
- // consuming multiply.
+ // If we have a TT matmul or a TT add, lift the transpose. We may be able
+ // to fold into consuming multiply or add.
for (BasicBlock &BB : Func) {
for (Instruction &I : llvm::make_early_inc_range(BB)) {
- Value *A, *B, *AT, *BT;
- ConstantInt *R, *K, *C;
- // A^t * B ^t -> (B * A)^t
- if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(A), m_Value(B), m_ConstantInt(R),
- m_ConstantInt(K), m_ConstantInt(C))) &&
- match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
- match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
- IRBuilder<> IB(&I);
- MatrixBuilder Builder(IB);
- Value *M = Builder.CreateMatrixMultiply(
- BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
- setShapeInfo(M, {C, R});
- Instruction *NewInst = Builder.CreateMatrixTranspose(
- M, C->getZExtValue(), R->getZExtValue());
- ReplaceAllUsesWith(I, NewInst);
- if (I.use_empty())
- I.eraseFromParent();
- if (A->use_empty())
- cast<Instruction>(A)->eraseFromParent();
- if (A != B && B->use_empty())
- cast<Instruction>(B)->eraseFromParent();
- }
+ liftTranspose(I);
}
}
}
@@ -832,10 +956,10 @@ public:
if (!isMinimal()) {
optimizeTransposes();
- LLVM_DEBUG({
+ if (PrintAfterTransposeOpt) {
dbgs() << "Dump after matrix transpose optimization:\n";
- Func.dump();
- });
+ Func.print(dbgs());
+ }
}
bool Changed = false;
@@ -1199,8 +1323,8 @@ public:
bool IsScalarMatrixTransposed, FastMathFlags FMF) {
const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedSize() /
- Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+ .getFixedValue() /
+ Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1U);
unsigned R = Result.getNumRows();
unsigned C = Result.getNumColumns();
@@ -1378,8 +1502,8 @@ public:
const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
- .getFixedSize() /
- EltType->getPrimitiveSizeInBits().getFixedSize(),
+ .getFixedValue() /
+ EltType->getPrimitiveSizeInBits().getFixedValue(),
1U);
// Cost model for tiling
@@ -2160,7 +2284,7 @@ public:
// the inlinedAt chain. If the function does not have a DISubprogram, we
// only map them to the containing function.
MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
- for (auto &KV : Inst2Matrix) {
+ for (const auto &KV : Inst2Matrix) {
if (Func.getSubprogram()) {
auto *I = cast<Instruction>(KV.first);
DILocation *Context = I->getDebugLoc();