diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2024-07-27 23:34:35 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2024-10-23 18:26:01 +0000 |
commit | 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583 (patch) | |
tree | 6cf5ab1f05330c6773b1f3f64799d56a9c7a1faa /contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | |
parent | 6b9f7133aba44189d9625c352bc2c2a59baf18ef (diff) | |
parent | ac9a064cb179f3425b310fa2847f8764ac970a4d (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 342 |
1 files changed, 220 insertions, 122 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 72b9db1e73d7..6a681fd93397 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -192,6 +193,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, return VecStart; } +namespace { +struct ShapeInfo { + unsigned NumRows; + unsigned NumColumns; + + bool IsColumnMajor; + + ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) + : NumRows(NumRows), NumColumns(NumColumns), + IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} + + ShapeInfo(Value *NumRows, Value *NumColumns) + : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), + cast<ConstantInt>(NumColumns)->getZExtValue()) {} + + bool operator==(const ShapeInfo &other) { + return NumRows == other.NumRows && NumColumns == other.NumColumns; + } + bool operator!=(const ShapeInfo &other) { return !(*this == other); } + + /// Returns true if shape-information is defined, meaning both dimensions + /// are != 0. + operator bool() const { + assert(NumRows == 0 || NumColumns != 0); + return NumRows != 0; + } + + unsigned getStride() const { + if (IsColumnMajor) + return NumRows; + return NumColumns; + } + + unsigned getNumVectors() const { + if (IsColumnMajor) + return NumColumns; + return NumRows; + } + + /// Returns the transposed shape. + ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } +}; +} // namespace + +static bool isUniformShape(Value *V) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return true; + + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: // Scalar multiply. + case Instruction::FNeg: + case Instruction::Add: + case Instruction::Mul: + case Instruction::Sub: + return true; + default: + return false; + } +} + +/// Return the ShapeInfo for the result of \p I, it it can be determined. +static std::optional<ShapeInfo> +computeShapeInfoForInst(Instruction *I, + const ValueMap<Value *, ShapeInfo> &ShapeMap) { + Value *M; + Value *N; + Value *K; + if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) + return ShapeInfo(M, K); + if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), + m_Value(N)))) { + // Flip dimensions. + return ShapeInfo(N, M); + } + if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( + m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), + m_Value(N)))) + return ShapeInfo(N, M); + if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) + return ShapeInfo(M, N); + Value *MatrixA; + if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { + auto OpShape = ShapeMap.find(MatrixA); + if (OpShape != ShapeMap.end()) + return OpShape->second; + } + + if (isUniformShape(I)) { + // Find the first operand that has a known shape and use that. + for (auto &Op : I->operands()) { + auto OpShape = ShapeMap.find(Op.get()); + if (OpShape != ShapeMap.end()) + return OpShape->second; + } + } + return std::nullopt; +} + /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. /// /// Currently, the lowering for each matrix intrinsic is done as follows: @@ -383,48 +487,6 @@ class LowerMatrixIntrinsics { } }; - struct ShapeInfo { - unsigned NumRows; - unsigned NumColumns; - - bool IsColumnMajor; - - ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) - : NumRows(NumRows), NumColumns(NumColumns), - IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} - - ShapeInfo(Value *NumRows, Value *NumColumns) - : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), - cast<ConstantInt>(NumColumns)->getZExtValue()) {} - - bool operator==(const ShapeInfo &other) { - return NumRows == other.NumRows && NumColumns == other.NumColumns; - } - bool operator!=(const ShapeInfo &other) { return !(*this == other); } - - /// Returns true if shape-information is defined, meaning both dimensions - /// are != 0. - operator bool() const { - assert(NumRows == 0 || NumColumns != 0); - return NumRows != 0; - } - - unsigned getStride() const { - if (IsColumnMajor) - return NumRows; - return NumColumns; - } - - unsigned getNumVectors() const { - if (IsColumnMajor) - return NumColumns; - return NumRows; - } - - /// Returns the transposed shape. - ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } - }; - /// Maps instructions to their shape information. The shape information /// describes the shape to be used while lowering. This matches the shape of /// the result value of the instruction, with the only exceptions being store @@ -459,7 +521,7 @@ public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, OptimizationRemarkEmitter *ORE) - : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), + : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT), LI(LI), ORE(ORE) {} unsigned getNumOps(Type *VT) { @@ -554,25 +616,6 @@ public: return true; } - bool isUniformShape(Value *V) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I) - return true; - - switch (I->getOpcode()) { - case Instruction::FAdd: - case Instruction::FSub: - case Instruction::FMul: // Scalar multiply. - case Instruction::FNeg: - case Instruction::Add: - case Instruction::Mul: - case Instruction::Sub: - return true; - default: - return false; - } - } - /// Returns true if shape information can be used for \p V. The supported /// instructions must match the instructions that can be lowered by this pass. bool supportsShapeInfo(Value *V) { @@ -610,43 +653,8 @@ public: // New entry, set the value and insert operands bool Propagate = false; - - Value *MatrixA; - Value *MatrixB; - Value *M; - Value *N; - Value *K; - if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(MatrixA), m_Value(MatrixB), m_Value(M), - m_Value(N), m_Value(K)))) { - Propagate = setShapeInfo(Inst, {M, K}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( - m_Value(MatrixA), m_Value(M), m_Value(N)))) { - // Flip dimensions. - Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( - m_Value(MatrixA), m_Value(), m_Value(), - m_Value(), m_Value(M), m_Value(N)))) { - Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( - m_Value(), m_Value(), m_Value(), m_Value(M), - m_Value(N)))) { - Propagate = setShapeInfo(Inst, {M, N}); - } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { - auto OpShape = ShapeMap.find(MatrixA); - if (OpShape != ShapeMap.end()) - setShapeInfo(Inst, OpShape->second); - continue; - } else if (isUniformShape(Inst)) { - // Find the first operand that has a known shape and use that. - for (auto &Op : Inst->operands()) { - auto OpShape = ShapeMap.find(Op.get()); - if (OpShape != ShapeMap.end()) { - Propagate |= setShapeInfo(Inst, OpShape->second); - break; - } - } - } + if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) + Propagate = setShapeInfo(Inst, *SI); if (Propagate) { NewWorkList.push_back(Inst); @@ -891,20 +899,28 @@ public: updateShapeAndReplaceAllUsesWith(I, NewInst); CleanupBinOp(I, A, B); } - // A^t + B ^t -> (A + B)^t + // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If + // the shape of the second transpose is different, there's a shape conflict + // which gets resolved by picking the shape of the first operand. else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && match(A, m_Intrinsic<Intrinsic::matrix_transpose>( m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && match(B, m_Intrinsic<Intrinsic::matrix_transpose>( - m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { + m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { IRBuilder<> Builder(&I); - Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); - setShapeInfo(Add, {C, R}); + auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); + setShapeInfo(Add, {R, C}); MatrixBuilder MBuilder(Builder); Instruction *NewInst = MBuilder.CreateMatrixTranspose( - Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); + Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); updateShapeAndReplaceAllUsesWith(I, NewInst); + assert(computeShapeInfoForInst(NewInst, ShapeMap) == + computeShapeInfoForInst(&I, ShapeMap) && + "Shape of new instruction doesn't match original shape."); CleanupBinOp(I, A, B); + assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) == + ShapeMap[Add] && + "Shape of updated addition doesn't match cached shape."); } } @@ -975,12 +991,15 @@ public: bool Changed = false; SmallVector<CallInst *, 16> MaybeFusableInsts; SmallVector<Instruction *, 16> MatrixInsts; + SmallVector<IntrinsicInst *, 16> LifetimeEnds; // First, collect all instructions with shape information and candidates for // fusion (currently only matrix multiplies). ReversePostOrderTraversal<Function *> RPOT(&Func); for (auto *BB : RPOT) for (Instruction &I : *BB) { + if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) + LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); if (ShapeMap.find(&I) == ShapeMap.end()) continue; if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) @@ -995,7 +1014,7 @@ public: // Third, try to fuse candidates. for (CallInst *CI : MaybeFusableInsts) - LowerMatrixMultiplyFused(CI, FusedInsts); + LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); Changed = !FusedInsts.empty(); @@ -1332,8 +1351,8 @@ public: if (!IsIntVec && !FMF.allowReassoc()) return; - auto CanBeFlattened = [this](Value *Op) { - if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + auto CanBeFlattened = [](Value *Op) { + if (match(Op, m_BinOp())) return true; return match( Op, m_OneUse(m_CombineOr( @@ -1346,6 +1365,9 @@ public: // the returned cost is < 0, the argument is cheaper to use in the // dot-product lowering. auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { + if (ShapeMap.find(Op) == ShapeMap.end()) + return InstructionCost::getInvalid(); + if (!isa<Instruction>(Op)) return InstructionCost(0); @@ -1356,7 +1378,7 @@ public: InstructionCost EmbedCost(0); // Roughly estimate the cost for embedding the columns into a vector. for (unsigned I = 1; I < N; ++I) - EmbedCost -= + EmbedCost += TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1378,7 +1400,7 @@ public: // vector. InstructionCost EmbedCost(0); for (unsigned I = 1; I < N; ++I) - EmbedCost += + EmbedCost -= TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1391,7 +1413,29 @@ public: return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); }; - auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); + + // Iterate over LHS and operations feeding LHS and check if it is profitable + // to flatten the visited ops. For each op, we compute the difference + // between the flattened and matrix versions. + SmallPtrSet<Value *, 4> Seen; + SmallVector<Value *> WorkList; + SmallVector<Value *> ToFlatten; + WorkList.push_back(LHS); + InstructionCost LHSCost(0); + while (!WorkList.empty()) { + Value *Op = WorkList.pop_back_val(); + if (!Seen.insert(Op).second) + continue; + + InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); + if (OpCost + LHSCost >= LHSCost) + continue; + + LHSCost += OpCost; + ToFlatten.push_back(Op); + if (auto *I = dyn_cast<Instruction>(Op)) + WorkList.append(I->op_begin(), I->op_end()); + } // We compare the costs of a vector.reduce.add to sequential add. int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; @@ -1412,16 +1456,16 @@ public: FusedInsts.insert(MatMul); IRBuilder<> Builder(MatMul); auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, - this](Value *Op) -> Value * { + this](Value *Op) { // Matmul must be the only user of loads because we don't use LowerLoad // for row vectors (LowerLoad results in scalar loads and shufflevectors // instead of single vector load). if (!CanBeFlattened(Op)) - return Op; + return; if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { ShapeMap[Op] = ShapeMap[Op].t(); - return Op; + return; } FusedInsts.insert(cast<Instruction>(Op)); @@ -1432,16 +1476,19 @@ public: auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); cast<Instruction>(Op)->eraseFromParent(); - return NewLoad; + return; } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( m_Value(Arg)))) { ToRemove.push_back(cast<Instruction>(Op)); - return Arg; + Op->replaceAllUsesWith(Arg); + return; } - - return Op; }; - LHS = FlattenArg(LHS); + + for (auto *V : ToFlatten) + FlattenArg(V); + + LHS = MatMul->getArgOperand(0); // Insert mul/fmul and llvm.vector.reduce.fadd Value *Mul = @@ -1594,7 +1641,7 @@ public: IRBuilder<> Builder(MatMul); Check0->getTerminator()->eraseFromParent(); Builder.SetInsertPoint(Check0); - Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); + Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); Value *StoreBegin = Builder.CreatePtrToInt( const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); Value *StoreEnd = Builder.CreateAdd( @@ -1813,8 +1860,10 @@ public: /// /// Call finalizeLowering on lowered instructions. Instructions that are /// completely eliminated by fusion are added to \p FusedInsts. - void LowerMatrixMultiplyFused(CallInst *MatMul, - SmallPtrSetImpl<Instruction *> &FusedInsts) { + void + LowerMatrixMultiplyFused(CallInst *MatMul, + SmallPtrSetImpl<Instruction *> &FusedInsts, + SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { if (!FuseMatrix || !DT) return; @@ -1903,6 +1952,55 @@ public: for (Instruction *I : ToHoist) I->moveBefore(MatMul); + // Deal with lifetime.end calls that might be between Load0/Load1 and the + // store. To avoid introducing loads to dead objects (i.e. after the + // lifetime has been termined by @llvm.lifetime.end), either sink them + // after the store if in the same block, or remove the lifetime.end marker + // otherwise. This might pessimize further optimizations, by extending the + // lifetime of the object until the function returns, but should be + // conservatively correct. + MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); + MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); + BasicBlock *StoreParent = Store->getParent(); + bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && + LoadOp1->getParent() == StoreParent; + for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { + IntrinsicInst *End = LifetimeEnds[Idx]; + auto Inc = make_scope_exit([&Idx]() { Idx++; }); + // If the lifetime.end is guaranteed to be before the loads or after the + // store, it won't interfere with fusion. + if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) + continue; + if (DT->dominates(Store, End)) + continue; + // If all fusable ops are in the same block and the lifetime.end is in a + // different block, it won't interfere with fusion. + if (FusableOpsInSameBlock && End->getParent() != StoreParent) + continue; + + // If the loads don't alias the lifetime.end, it won't interfere with + // fusion. + MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); + if (!EndLoc.Ptr) + continue; + if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) + continue; + + // If both lifetime.end and the store are in the same block, extend the + // lifetime until after the store, so the new lifetime covers the loads + // we introduce later. + if (End->getParent() == StoreParent) { + End->moveAfter(Store); + continue; + } + + // Otherwise remove the conflicting lifetime.end marker. + ToRemove.push_back(End); + std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); + LifetimeEnds.pop_back(); + Inc.release(); + } + emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); return; } @@ -2364,7 +2462,7 @@ public: RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, OptimizationRemarkEmitter &ORE, Function &Func) : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), - DL(Func.getParent()->getDataLayout()) {} + DL(Func.getDataLayout()) {} /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are /// instructions in Inst2Matrix returning void or without any users in |