aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp295
1 files changed, 192 insertions, 103 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 17594b98c5bc..f46ea6a20afa 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled(
cl::desc("Allow the use of FMAs if available and profitable. This may "
"result in different results, due to less rounding error."));
+static cl::opt<bool>
+ VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
+ cl::desc("Enable/disable matrix shape verification."),
+ cl::init(false));
+
enum class MatrixLayoutTy { ColumnMajor, RowMajor };
static cl::opt<MatrixLayoutTy> MatrixLayout(
@@ -267,7 +272,7 @@ class LowerMatrixIntrinsics {
unsigned D = isColumnMajor() ? NumColumns : NumRows;
for (unsigned J = 0; J < D; ++J)
- addVector(UndefValue::get(FixedVectorType::get(
+ addVector(PoisonValue::get(FixedVectorType::get(
EltTy, isColumnMajor() ? NumRows : NumColumns)));
}
@@ -535,6 +540,15 @@ public:
auto SIter = ShapeMap.find(V);
if (SIter != ShapeMap.end()) {
+ if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
+ SIter->second.NumColumns != Shape.NumColumns)) {
+ errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
+ << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
+ << Shape.NumColumns << ") for " << *V << "\n";
+ report_fatal_error(
+ "Matrix shape verification failed, compilation aborted!");
+ }
+
LLVM_DEBUG(dbgs() << " not overriding existing shape: "
<< SIter->second.NumRows << " "
<< SIter->second.NumColumns << " for " << *V << "\n");
@@ -838,10 +852,13 @@ public:
auto NewInst = distributeTransposes(
TAMA, {R, C}, TAMB, {R, C}, Builder,
[&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- auto *FAdd =
- cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
- setShapeInfo(FAdd, Shape0);
- return FAdd;
+ bool IsFP = I.getType()->isFPOrFPVectorTy();
+ auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
+ : LocalBuilder.CreateAdd(T0, T1, "madd");
+
+ auto *Result = cast<Instruction>(Add);
+ setShapeInfo(Result, Shape0);
+ return Result;
});
updateShapeAndReplaceAllUsesWith(I, NewInst);
eraseFromParentAndMove(&I, II, BB);
@@ -978,13 +995,18 @@ public:
MatrixInsts.push_back(&I);
}
- // Second, try to fuse candidates.
+ // Second, try to lower any dot products
SmallPtrSet<Instruction *, 16> FusedInsts;
for (CallInst *CI : MaybeFusableInsts)
+ lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
+
+ // Third, try to fuse candidates.
+ for (CallInst *CI : MaybeFusableInsts)
LowerMatrixMultiplyFused(CI, FusedInsts);
+
Changed = !FusedInsts.empty();
- // Third, lower remaining instructions with shape information.
+ // Fourth, lower remaining instructions with shape information.
for (Instruction *Inst : MatrixInsts) {
if (FusedInsts.count(Inst))
continue;
@@ -1311,6 +1333,165 @@ public:
}
}
+ /// Special case for MatMul lowering. Prevents scalar loads of row-major
+ /// vectors Lowers to vector reduction add instead of sequential add if
+ /// reassocation is enabled.
+ void lowerDotProduct(CallInst *MatMul,
+ SmallPtrSet<Instruction *, 16> &FusedInsts,
+ FastMathFlags FMF) {
+ if (FusedInsts.contains(MatMul) ||
+ MatrixLayout != MatrixLayoutTy::ColumnMajor)
+ return;
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
+ return;
+
+ Value *LHS = MatMul->getArgOperand(0);
+ Value *RHS = MatMul->getArgOperand(1);
+
+ Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+ bool IsIntVec = ElementType->isIntegerTy();
+
+ // Floating point reductions require reassocation.
+ if (!IsIntVec && !FMF.allowReassoc())
+ return;
+
+ auto CanBeFlattened = [this](Value *Op) {
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+ return true;
+ return match(
+ Op, m_OneUse(m_CombineOr(
+ m_Load(m_Value()),
+ m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
+ m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_SpecificInt(1))))));
+ };
+ // Returns the cost benefit of using \p Op with the dot product lowering. If
+ // the returned cost is < 0, the argument is cheaper to use in the
+ // dot-product lowering.
+ auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+ if (!isa<Instruction>(Op))
+ return InstructionCost(0);
+
+ FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
+ Type *EltTy = VecTy->getElementType();
+
+ if (!CanBeFlattened(Op)) {
+ InstructionCost EmbedCost(0);
+ // Roughly estimate the cost for embedding the columns into a vector.
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost -=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
+ }
+
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
+ InstructionCost OriginalCost =
+ TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(),
+ EltTy) *
+ N;
+ InstructionCost NewCost = TTI.getArithmeticInstrCost(
+ cast<Instruction>(Op)->getOpcode(), VecTy);
+ return NewCost - OriginalCost;
+ }
+
+ if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
+ // The transpose can be skipped for the dot product lowering, roughly
+ // estimate the savings as the cost of embedding the columns in a
+ // vector.
+ InstructionCost EmbedCost(0);
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost +=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
+ }
+
+ // Costs for loads.
+ if (N == 1)
+ return InstructionCost(0);
+
+ return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
+ N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
+ };
+ auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+ // We compare the costs of a vector.reduce.add to sequential add.
+ int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
+ int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
+ InstructionCost ReductionCost =
+ TTI.getArithmeticReductionCost(
+ AddOpCode, cast<VectorType>(LHS->getType()),
+ IsIntVec ? std::nullopt : std::optional(FMF)) +
+ TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
+ InstructionCost SequentialAddCost =
+ TTI.getArithmeticInstrCost(AddOpCode, ElementType) *
+ (LShape.NumColumns - 1) +
+ TTI.getArithmeticInstrCost(MulOpCode, ElementType) *
+ (LShape.NumColumns);
+ if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
+ return;
+
+ FusedInsts.insert(MatMul);
+ IRBuilder<> Builder(MatMul);
+ auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
+ this](Value *Op) -> Value * {
+ // Matmul must be the only user of loads because we don't use LowerLoad
+ // for row vectors (LowerLoad results in scalar loads and shufflevectors
+ // instead of single vector load).
+ if (!CanBeFlattened(Op))
+ return Op;
+
+ if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
+ ShapeMap[Op] = ShapeMap[Op].t();
+ return Op;
+ }
+
+ FusedInsts.insert(cast<Instruction>(Op));
+ // If vector uses the builtin load, lower to a LoadInst
+ Value *Arg;
+ if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(Arg)))) {
+ auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
+ Op->replaceAllUsesWith(NewLoad);
+ cast<Instruction>(Op)->eraseFromParent();
+ return NewLoad;
+ } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(Arg)))) {
+ ToRemove.push_back(cast<Instruction>(Op));
+ return Arg;
+ }
+
+ return Op;
+ };
+ LHS = FlattenArg(LHS);
+
+ // Insert mul/fmul and llvm.vector.reduce.fadd
+ Value *Mul =
+ IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS);
+
+ Value *Result;
+ if (IsIntVec)
+ Result = Builder.CreateAddReduce(Mul);
+ else {
+ Result = Builder.CreateFAddReduce(
+ ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
+ 0.0),
+ Mul);
+ cast<Instruction>(Result)->setFastMathFlags(FMF);
+ }
+
+ // pack scalar back into a matrix and then replace matmul inst
+ Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
+ Result, uint64_t(0));
+ MatMul->replaceAllUsesWith(Result);
+ FusedInsts.insert(MatMul);
+ ToRemove.push_back(MatMul);
+ }
+
/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
@@ -1469,15 +1650,14 @@ public:
auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
AllocaInst *Alloca =
Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
- Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo());
- Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(),
+ Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(),
Load->getAlign(), LoadLoc.Size.getValue());
Builder.SetInsertPoint(Fusion, Fusion->begin());
PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
PHI->addIncoming(Load->getPointerOperand(), Check0);
PHI->addIncoming(Load->getPointerOperand(), Check1);
- PHI->addIncoming(BC, Copy);
+ PHI->addIncoming(Alloca, Copy);
// Adjust DT.
DTUpdates.push_back({DT->Insert, Check0, Check1});
@@ -2397,99 +2577,8 @@ void LowerMatrixIntrinsicsPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
- OS << "<";
+ OS << '<';
if (Minimal)
OS << "minimal";
- OS << ">";
-}
-
-namespace {
-
-class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
- bool C = LMT.Visit();
- return C;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- }
-};
-} // namespace
-
-static const char pass_name[] = "Lower the matrix intrinsics";
-char LowerMatrixIntrinsicsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-
-Pass *llvm::createLowerMatrixIntrinsicsPass() {
- return new LowerMatrixIntrinsicsLegacyPass();
-}
-
-namespace {
-
-/// A lightweight version of the matrix lowering pass that only requires TTI.
-/// Advanced features that require DT, AA or ORE like tiling are disabled. This
-/// is used to lower matrix intrinsics if the main lowering pass is not run, for
-/// example with -O0.
-class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
- bool C = LMT.Visit();
- return C;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
- }
-};
-} // namespace
-
-static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
-char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal,
- false, false)
-INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
- "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
- false)
-
-Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
- return new LowerMatrixIntrinsicsMinimalLegacyPass();
+ OS << '>';
}