diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
| commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
| tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | |
| parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 336 |
1 files changed, 230 insertions, 106 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index f1e1359255bd..17594b98c5bc 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -46,6 +46,8 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/MatrixUtils.h" +#include <cmath> + using namespace llvm; using namespace PatternMatch; @@ -80,6 +82,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); +static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", + cl::init(false)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -88,6 +93,39 @@ static DISubprogram *getSubprogram(DIScope *Scope) { return cast<DILocalScope>(Scope)->getSubprogram(); } +/// Erase \p V from \p BB and move \II forward to avoid invalidating +/// iterators. +static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, + BasicBlock &BB) { + auto *Inst = cast<Instruction>(V); + // Still used, don't erase. + if (!Inst->use_empty()) + return; + if (II != BB.rend() && Inst == &*II) + ++II; + Inst->eraseFromParent(); +} + +/// Return true if V is a splat of a value (which is used when multiplying a +/// matrix with a scalar). +static bool isSplat(Value *V) { + if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) + return SV->isZeroEltSplat(); + return false; +} + +/// Match any mul operation (fp or integer). +template <typename LTy, typename RTy> +auto m_AnyMul(const LTy &L, const RTy &R) { + return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); +} + +/// Match any add operation (fp or integer). +template <typename LTy, typename RTy> +auto m_AnyAdd(const LTy &L, const RTy &R) { + return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); +} + namespace { // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute @@ -384,6 +422,9 @@ class LowerMatrixIntrinsics { return NumColumns; return NumRows; } + + /// Returns the transposed shape. + ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } }; /// Maps instructions to their shape information. The shape information @@ -437,10 +478,10 @@ public: /// Return the estimated number of vector ops required for an operation on /// \p VT * N. unsigned getNumOps(Type *ST, unsigned N) { - return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / + return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / double(TTI.getRegisterBitWidth( TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize())); + .getFixedValue())); } /// Return the set of vectors that a matrix value is lowered to. @@ -684,115 +725,198 @@ public: return NewWorkList; } - /// Try moving transposes in order to fold them away or into multiplies. - void optimizeTransposes() { - auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { - // We need to remove Old from the ShapeMap otherwise RAUW will replace it - // with New. We should only add New it it supportsShapeInfo so we insert - // it conditionally instead. - auto S = ShapeMap.find(&Old); - if (S != ShapeMap.end()) { - ShapeMap.erase(S); - if (supportsShapeInfo(New)) - ShapeMap.insert({New, S->second}); - } - Old.replaceAllUsesWith(New); - }; + /// (Op0 op Op1)^T -> Op0^T op Op1^T + /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use + /// them on both sides of \p Operation. + Instruction *distributeTransposes( + Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, + MatrixBuilder &Builder, + function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> + Operation) { + Value *T0 = Builder.CreateMatrixTranspose( + Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); + // We are being run after shape prop, add shape for newly created + // instructions so that we lower them later. + setShapeInfo(T0, Shape0.t()); + Value *T1 = Builder.CreateMatrixTranspose( + Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); + setShapeInfo(T1, Shape1.t()); + return Operation(T0, Shape0.t(), T1, Shape1.t()); + } - // First sink all transposes inside matmuls, hoping that we end up with NN, - // NT or TN variants. - for (BasicBlock &BB : reverse(Func)) { - for (auto II = BB.rbegin(); II != BB.rend();) { - Instruction &I = *II; - // We may remove II. By default continue on the next/prev instruction. - ++II; - // If we were to erase II, move again. - auto EraseFromParent = [&II, &BB](Value *V) { - auto *Inst = cast<Instruction>(V); - if (Inst->use_empty()) { - if (II != BB.rend() && Inst == &*II) { - ++II; - } - Inst->eraseFromParent(); - } - }; + void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { + // We need to remove Old from the ShapeMap otherwise RAUW will replace it + // with New. We should only add New it it supportsShapeInfo so we insert + // it conditionally instead. + auto S = ShapeMap.find(&Old); + if (S != ShapeMap.end()) { + ShapeMap.erase(S); + if (supportsShapeInfo(New)) + ShapeMap.insert({New, S->second}); + } + Old.replaceAllUsesWith(New); + } - // If we're creating a new instruction, continue from there. - Instruction *NewInst = nullptr; + /// Sink a top-level transpose inside matmuls and adds. + /// This creates and erases instructions as needed, and returns the newly + /// created instruction while updating the iterator to avoid invalidation. If + /// this returns nullptr, no new instruction was created. + Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { + BasicBlock &BB = *I.getParent(); + IRBuilder<> IB(&I); + MatrixBuilder Builder(IB); - IRBuilder<> IB(&I); - MatrixBuilder Builder(IB); + Value *TA, *TAMA, *TAMB; + ConstantInt *R, *K, *C; + if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) + return nullptr; - Value *TA, *TAMA, *TAMB; - ConstantInt *R, *K, *C; - if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { + // Transpose of a transpose is a nop + Value *TATA; + if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { + updateShapeAndReplaceAllUsesWith(I, TATA); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return nullptr; + } - // Transpose of a transpose is a nop - Value *TATA; - if (match(TA, - m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { - ReplaceAllUsesWith(I, TATA); - EraseFromParent(&I); - EraseFromParent(TA); - } + // k^T -> k + if (isSplat(TA)) { + updateShapeAndReplaceAllUsesWith(I, TA); + eraseFromParentAndMove(&I, II, BB); + return nullptr; + } - // (A * B)^t -> B^t * A^t - // RxK KxC CxK KxR - else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), - m_ConstantInt(K), m_ConstantInt(C)))) { - Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), - C->getZExtValue(), - TAMB->getName() + "_t"); - // We are being run after shape prop, add shape for newly created - // instructions so that we lower them later. - setShapeInfo(T0, {C, K}); - Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), - K->getZExtValue(), - TAMA->getName() + "_t"); - setShapeInfo(T1, {K, R}); - NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), - K->getZExtValue(), - R->getZExtValue(), "mmul"); - ReplaceAllUsesWith(I, NewInst); - EraseFromParent(&I); - EraseFromParent(TA); - } - } + // (A * B)^t -> B^t * A^t + // RxK KxC CxK KxR + if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C)))) { + auto NewInst = distributeTransposes( + TAMB, {K, C}, TAMA, {R, K}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, + Shape0.NumColumns, + Shape1.NumColumns, "mmul"); + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + // Same as above, but with a mul, which occurs when multiplied + // with a scalar. + // (A * k)^t -> A^t * k + // R x C RxC + if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && + (isSplat(TAMA) || isSplat(TAMB))) { + IRBuilder<> LocalBuilder(&I); + // We know that the transposed operand is of shape RxC. + // An when multiplied with a scalar, the shape is preserved. + auto NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") + : LocalBuilder.CreateMul(T0, T1, "mmul"); + auto *Result = cast<Instruction>(Mul); + setShapeInfo(Result, Shape0); + return Result; + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + // (A + B)^t -> A^t + B^t + // RxC RxC CxR CxR + if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { + IRBuilder<> LocalBuilder(&I); + auto NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + auto *FAdd = + cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd")); + setShapeInfo(FAdd, Shape0); + return FAdd; + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + return nullptr; + } + + void liftTranspose(Instruction &I) { + // Erase dead Instructions after lifting transposes from binops. + auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { + if (T.use_empty()) + T.eraseFromParent(); + if (A->use_empty()) + cast<Instruction>(A)->eraseFromParent(); + if (A != B && B->use_empty()) + cast<Instruction>(B)->eraseFromParent(); + }; + + Value *A, *B, *AT, *BT; + ConstantInt *R, *K, *C; + // A^t * B ^t -> (B * A)^t + if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(A), m_Value(B), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C))) && + match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && + match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { + IRBuilder<> IB(&I); + MatrixBuilder Builder(IB); + Value *M = Builder.CreateMatrixMultiply( + BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); + setShapeInfo(M, {C, R}); + Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), + R->getZExtValue()); + updateShapeAndReplaceAllUsesWith(I, NewInst); + CleanupBinOp(I, A, B); + } + // A^t + B ^t -> (A + B)^t + else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && + match(A, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && + match(B, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { + IRBuilder<> Builder(&I); + Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); + setShapeInfo(Add, {C, R}); + MatrixBuilder MBuilder(Builder); + Instruction *NewInst = MBuilder.CreateMatrixTranspose( + Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); + updateShapeAndReplaceAllUsesWith(I, NewInst); + CleanupBinOp(I, A, B); + } + } - // If we replaced I with a new instruction, continue from there. - if (NewInst) + /// Try moving transposes in order to fold them away or into multiplies. + void optimizeTransposes() { + // First sink all transposes inside matmuls and adds, hoping that we end up + // with NN, NT or TN variants. + for (BasicBlock &BB : reverse(Func)) { + for (auto II = BB.rbegin(); II != BB.rend();) { + Instruction &I = *II; + // We may remove II. By default continue on the next/prev instruction. + ++II; + if (Instruction *NewInst = sinkTranspose(I, II)) II = std::next(BasicBlock::reverse_iterator(NewInst)); } } - // If we have a TT matmul, lift the transpose. We may be able to fold into - // consuming multiply. + // If we have a TT matmul or a TT add, lift the transpose. We may be able + // to fold into consuming multiply or add. for (BasicBlock &BB : Func) { for (Instruction &I : llvm::make_early_inc_range(BB)) { - Value *A, *B, *AT, *BT; - ConstantInt *R, *K, *C; - // A^t * B ^t -> (B * A)^t - if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(A), m_Value(B), m_ConstantInt(R), - m_ConstantInt(K), m_ConstantInt(C))) && - match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && - match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { - IRBuilder<> IB(&I); - MatrixBuilder Builder(IB); - Value *M = Builder.CreateMatrixMultiply( - BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); - setShapeInfo(M, {C, R}); - Instruction *NewInst = Builder.CreateMatrixTranspose( - M, C->getZExtValue(), R->getZExtValue()); - ReplaceAllUsesWith(I, NewInst); - if (I.use_empty()) - I.eraseFromParent(); - if (A->use_empty()) - cast<Instruction>(A)->eraseFromParent(); - if (A != B && B->use_empty()) - cast<Instruction>(B)->eraseFromParent(); - } + liftTranspose(I); } } } @@ -832,10 +956,10 @@ public: if (!isMinimal()) { optimizeTransposes(); - LLVM_DEBUG({ + if (PrintAfterTransposeOpt) { dbgs() << "Dump after matrix transpose optimization:\n"; - Func.dump(); - }); + Func.print(dbgs()); + } } bool Changed = false; @@ -1199,8 +1323,8 @@ public: bool IsScalarMatrixTransposed, FastMathFlags FMF) { const unsigned VF = std::max<unsigned>( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize() / - Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), + .getFixedValue() / + Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1U); unsigned R = Result.getNumRows(); unsigned C = Result.getNumColumns(); @@ -1378,8 +1502,8 @@ public: const unsigned VF = std::max<unsigned>( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize() / - EltType->getPrimitiveSizeInBits().getFixedSize(), + .getFixedValue() / + EltType->getPrimitiveSizeInBits().getFixedValue(), 1U); // Cost model for tiling @@ -2160,7 +2284,7 @@ public: // the inlinedAt chain. If the function does not have a DISubprogram, we // only map them to the containing function. MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; - for (auto &KV : Inst2Matrix) { + for (const auto &KV : Inst2Matrix) { if (Func.getSubprogram()) { auto *I = cast<Instruction>(KV.first); DILocation *Context = I->getDebugLoc(); |
