aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2024-07-27 23:34:35 +0000
committerDimitry Andric <dim@FreeBSD.org>2024-10-23 18:26:01 +0000
commit0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583 (patch)
tree6cf5ab1f05330c6773b1f3f64799d56a9c7a1faa /contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
parent6b9f7133aba44189d9625c352bc2c2a59baf18ef (diff)
parentac9a064cb179f3425b310fa2847f8764ac970a4d (diff)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp342
1 files changed, 220 insertions, 122 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 72b9db1e73d7..6a681fd93397 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,6 +19,7 @@
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
@@ -192,6 +193,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
return VecStart;
}
+namespace {
+struct ShapeInfo {
+ unsigned NumRows;
+ unsigned NumColumns;
+
+ bool IsColumnMajor;
+
+ ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
+ : NumRows(NumRows), NumColumns(NumColumns),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+
+ ShapeInfo(Value *NumRows, Value *NumColumns)
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+
+ bool operator==(const ShapeInfo &other) {
+ return NumRows == other.NumRows && NumColumns == other.NumColumns;
+ }
+ bool operator!=(const ShapeInfo &other) { return !(*this == other); }
+
+ /// Returns true if shape-information is defined, meaning both dimensions
+ /// are != 0.
+ operator bool() const {
+ assert(NumRows == 0 || NumColumns != 0);
+ return NumRows != 0;
+ }
+
+ unsigned getStride() const {
+ if (IsColumnMajor)
+ return NumRows;
+ return NumColumns;
+ }
+
+ unsigned getNumVectors() const {
+ if (IsColumnMajor)
+ return NumColumns;
+ return NumRows;
+ }
+
+ /// Returns the transposed shape.
+ ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
+};
+} // namespace
+
+static bool isUniformShape(Value *V) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return true;
+
+ switch (I->getOpcode()) {
+ case Instruction::FAdd:
+ case Instruction::FSub:
+ case Instruction::FMul: // Scalar multiply.
+ case Instruction::FNeg:
+ case Instruction::Add:
+ case Instruction::Mul:
+ case Instruction::Sub:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// Return the ShapeInfo for the result of \p I, it it can be determined.
+static std::optional<ShapeInfo>
+computeShapeInfoForInst(Instruction *I,
+ const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+ Value *M;
+ Value *N;
+ Value *K;
+ if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
+ return ShapeInfo(M, K);
+ if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
+ m_Value(N)))) {
+ // Flip dimensions.
+ return ShapeInfo(N, M);
+ }
+ if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
+ m_Value(N))))
+ return ShapeInfo(N, M);
+ if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
+ return ShapeInfo(M, N);
+ Value *MatrixA;
+ if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
+ auto OpShape = ShapeMap.find(MatrixA);
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ }
+
+ if (isUniformShape(I)) {
+ // Find the first operand that has a known shape and use that.
+ for (auto &Op : I->operands()) {
+ auto OpShape = ShapeMap.find(Op.get());
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ }
+ }
+ return std::nullopt;
+}
+
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
///
/// Currently, the lowering for each matrix intrinsic is done as follows:
@@ -383,48 +487,6 @@ class LowerMatrixIntrinsics {
}
};
- struct ShapeInfo {
- unsigned NumRows;
- unsigned NumColumns;
-
- bool IsColumnMajor;
-
- ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
-
- ShapeInfo(Value *NumRows, Value *NumColumns)
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
-
- bool operator==(const ShapeInfo &other) {
- return NumRows == other.NumRows && NumColumns == other.NumColumns;
- }
- bool operator!=(const ShapeInfo &other) { return !(*this == other); }
-
- /// Returns true if shape-information is defined, meaning both dimensions
- /// are != 0.
- operator bool() const {
- assert(NumRows == 0 || NumColumns != 0);
- return NumRows != 0;
- }
-
- unsigned getStride() const {
- if (IsColumnMajor)
- return NumRows;
- return NumColumns;
- }
-
- unsigned getNumVectors() const {
- if (IsColumnMajor)
- return NumColumns;
- return NumRows;
- }
-
- /// Returns the transposed shape.
- ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
- };
-
/// Maps instructions to their shape information. The shape information
/// describes the shape to be used while lowering. This matches the shape of
/// the result value of the instruction, with the only exceptions being store
@@ -459,7 +521,7 @@ public:
LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
OptimizationRemarkEmitter *ORE)
- : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
+ : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT),
LI(LI), ORE(ORE) {}
unsigned getNumOps(Type *VT) {
@@ -554,25 +616,6 @@ public:
return true;
}
- bool isUniformShape(Value *V) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return true;
-
- switch (I->getOpcode()) {
- case Instruction::FAdd:
- case Instruction::FSub:
- case Instruction::FMul: // Scalar multiply.
- case Instruction::FNeg:
- case Instruction::Add:
- case Instruction::Mul:
- case Instruction::Sub:
- return true;
- default:
- return false;
- }
- }
-
/// Returns true if shape information can be used for \p V. The supported
/// instructions must match the instructions that can be lowered by this pass.
bool supportsShapeInfo(Value *V) {
@@ -610,43 +653,8 @@ public:
// New entry, set the value and insert operands
bool Propagate = false;
-
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- Propagate = setShapeInfo(Inst, {M, K});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(), m_Value(M), m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(), m_Value(), m_Value(), m_Value(M),
- m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {M, N});
- } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
- auto OpShape = ShapeMap.find(MatrixA);
- if (OpShape != ShapeMap.end())
- setShapeInfo(Inst, OpShape->second);
- continue;
- } else if (isUniformShape(Inst)) {
- // Find the first operand that has a known shape and use that.
- for (auto &Op : Inst->operands()) {
- auto OpShape = ShapeMap.find(Op.get());
- if (OpShape != ShapeMap.end()) {
- Propagate |= setShapeInfo(Inst, OpShape->second);
- break;
- }
- }
- }
+ if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
+ Propagate = setShapeInfo(Inst, *SI);
if (Propagate) {
NewWorkList.push_back(Inst);
@@ -891,20 +899,28 @@ public:
updateShapeAndReplaceAllUsesWith(I, NewInst);
CleanupBinOp(I, A, B);
}
- // A^t + B ^t -> (A + B)^t
+ // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
+ // the shape of the second transpose is different, there's a shape conflict
+ // which gets resolved by picking the shape of the first operand.
else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+ m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
IRBuilder<> Builder(&I);
- Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
- setShapeInfo(Add, {C, R});
+ auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+ setShapeInfo(Add, {R, C});
MatrixBuilder MBuilder(Builder);
Instruction *NewInst = MBuilder.CreateMatrixTranspose(
- Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+ Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
updateShapeAndReplaceAllUsesWith(I, NewInst);
+ assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
+ computeShapeInfoForInst(&I, ShapeMap) &&
+ "Shape of new instruction doesn't match original shape.");
CleanupBinOp(I, A, B);
+ assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
+ ShapeMap[Add] &&
+ "Shape of updated addition doesn't match cached shape.");
}
}
@@ -975,12 +991,15 @@ public:
bool Changed = false;
SmallVector<CallInst *, 16> MaybeFusableInsts;
SmallVector<Instruction *, 16> MatrixInsts;
+ SmallVector<IntrinsicInst *, 16> LifetimeEnds;
// First, collect all instructions with shape information and candidates for
// fusion (currently only matrix multiplies).
ReversePostOrderTraversal<Function *> RPOT(&Func);
for (auto *BB : RPOT)
for (Instruction &I : *BB) {
+ if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
+ LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
if (ShapeMap.find(&I) == ShapeMap.end())
continue;
if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
@@ -995,7 +1014,7 @@ public:
// Third, try to fuse candidates.
for (CallInst *CI : MaybeFusableInsts)
- LowerMatrixMultiplyFused(CI, FusedInsts);
+ LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
Changed = !FusedInsts.empty();
@@ -1332,8 +1351,8 @@ public:
if (!IsIntVec && !FMF.allowReassoc())
return;
- auto CanBeFlattened = [this](Value *Op) {
- if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+ auto CanBeFlattened = [](Value *Op) {
+ if (match(Op, m_BinOp()))
return true;
return match(
Op, m_OneUse(m_CombineOr(
@@ -1346,6 +1365,9 @@ public:
// the returned cost is < 0, the argument is cheaper to use in the
// dot-product lowering.
auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+ if (ShapeMap.find(Op) == ShapeMap.end())
+ return InstructionCost::getInvalid();
+
if (!isa<Instruction>(Op))
return InstructionCost(0);
@@ -1356,7 +1378,7 @@ public:
InstructionCost EmbedCost(0);
// Roughly estimate the cost for embedding the columns into a vector.
for (unsigned I = 1; I < N; ++I)
- EmbedCost -=
+ EmbedCost +=
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
std::nullopt, TTI::TCK_RecipThroughput);
return EmbedCost;
@@ -1378,7 +1400,7 @@ public:
// vector.
InstructionCost EmbedCost(0);
for (unsigned I = 1; I < N; ++I)
- EmbedCost +=
+ EmbedCost -=
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
std::nullopt, TTI::TCK_RecipThroughput);
return EmbedCost;
@@ -1391,7 +1413,29 @@ public:
return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
};
- auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+ // Iterate over LHS and operations feeding LHS and check if it is profitable
+ // to flatten the visited ops. For each op, we compute the difference
+ // between the flattened and matrix versions.
+ SmallPtrSet<Value *, 4> Seen;
+ SmallVector<Value *> WorkList;
+ SmallVector<Value *> ToFlatten;
+ WorkList.push_back(LHS);
+ InstructionCost LHSCost(0);
+ while (!WorkList.empty()) {
+ Value *Op = WorkList.pop_back_val();
+ if (!Seen.insert(Op).second)
+ continue;
+
+ InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
+ if (OpCost + LHSCost >= LHSCost)
+ continue;
+
+ LHSCost += OpCost;
+ ToFlatten.push_back(Op);
+ if (auto *I = dyn_cast<Instruction>(Op))
+ WorkList.append(I->op_begin(), I->op_end());
+ }
// We compare the costs of a vector.reduce.add to sequential add.
int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
@@ -1412,16 +1456,16 @@ public:
FusedInsts.insert(MatMul);
IRBuilder<> Builder(MatMul);
auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
- this](Value *Op) -> Value * {
+ this](Value *Op) {
// Matmul must be the only user of loads because we don't use LowerLoad
// for row vectors (LowerLoad results in scalar loads and shufflevectors
// instead of single vector load).
if (!CanBeFlattened(Op))
- return Op;
+ return;
if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
ShapeMap[Op] = ShapeMap[Op].t();
- return Op;
+ return;
}
FusedInsts.insert(cast<Instruction>(Op));
@@ -1432,16 +1476,19 @@ public:
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
cast<Instruction>(Op)->eraseFromParent();
- return NewLoad;
+ return;
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
ToRemove.push_back(cast<Instruction>(Op));
- return Arg;
+ Op->replaceAllUsesWith(Arg);
+ return;
}
-
- return Op;
};
- LHS = FlattenArg(LHS);
+
+ for (auto *V : ToFlatten)
+ FlattenArg(V);
+
+ LHS = MatMul->getArgOperand(0);
// Insert mul/fmul and llvm.vector.reduce.fadd
Value *Mul =
@@ -1594,7 +1641,7 @@ public:
IRBuilder<> Builder(MatMul);
Check0->getTerminator()->eraseFromParent();
Builder.SetInsertPoint(Check0);
- Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
+ Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout());
Value *StoreBegin = Builder.CreatePtrToInt(
const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
Value *StoreEnd = Builder.CreateAdd(
@@ -1813,8 +1860,10 @@ public:
///
/// Call finalizeLowering on lowered instructions. Instructions that are
/// completely eliminated by fusion are added to \p FusedInsts.
- void LowerMatrixMultiplyFused(CallInst *MatMul,
- SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ void
+ LowerMatrixMultiplyFused(CallInst *MatMul,
+ SmallPtrSetImpl<Instruction *> &FusedInsts,
+ SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
if (!FuseMatrix || !DT)
return;
@@ -1903,6 +1952,55 @@ public:
for (Instruction *I : ToHoist)
I->moveBefore(MatMul);
+ // Deal with lifetime.end calls that might be between Load0/Load1 and the
+ // store. To avoid introducing loads to dead objects (i.e. after the
+ // lifetime has been termined by @llvm.lifetime.end), either sink them
+ // after the store if in the same block, or remove the lifetime.end marker
+ // otherwise. This might pessimize further optimizations, by extending the
+ // lifetime of the object until the function returns, but should be
+ // conservatively correct.
+ MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
+ MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
+ BasicBlock *StoreParent = Store->getParent();
+ bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
+ LoadOp1->getParent() == StoreParent;
+ for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
+ IntrinsicInst *End = LifetimeEnds[Idx];
+ auto Inc = make_scope_exit([&Idx]() { Idx++; });
+ // If the lifetime.end is guaranteed to be before the loads or after the
+ // store, it won't interfere with fusion.
+ if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
+ continue;
+ if (DT->dominates(Store, End))
+ continue;
+ // If all fusable ops are in the same block and the lifetime.end is in a
+ // different block, it won't interfere with fusion.
+ if (FusableOpsInSameBlock && End->getParent() != StoreParent)
+ continue;
+
+ // If the loads don't alias the lifetime.end, it won't interfere with
+ // fusion.
+ MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
+ if (!EndLoc.Ptr)
+ continue;
+ if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
+ continue;
+
+ // If both lifetime.end and the store are in the same block, extend the
+ // lifetime until after the store, so the new lifetime covers the loads
+ // we introduce later.
+ if (End->getParent() == StoreParent) {
+ End->moveAfter(Store);
+ continue;
+ }
+
+ // Otherwise remove the conflicting lifetime.end marker.
+ ToRemove.push_back(End);
+ std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
+ LifetimeEnds.pop_back();
+ Inc.release();
+ }
+
emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
return;
}
@@ -2364,7 +2462,7 @@ public:
RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
OptimizationRemarkEmitter &ORE, Function &Func)
: Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
- DL(Func.getParent()->getDataLayout()) {}
+ DL(Func.getDataLayout()) {}
/// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
/// instructions in Inst2Matrix returning void or without any users in