aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/VectorCombine.cpp290
1 files changed, 239 insertions, 51 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d18bcd34620c..57b11e9414ba 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -31,10 +31,12 @@
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Vectorize.h"
+#define DEBUG_TYPE "vector-combine"
+#include "llvm/Transforms/Utils/InstructionWorklist.h"
+
using namespace llvm;
using namespace llvm::PatternMatch;
-#define DEBUG_TYPE "vector-combine"
STATISTIC(NumVecLoad, "Number of vector loads formed");
STATISTIC(NumVecCmp, "Number of vector compares formed");
STATISTIC(NumVecBO, "Number of vector binops formed");
@@ -61,8 +63,10 @@ namespace {
class VectorCombine {
public:
VectorCombine(Function &F, const TargetTransformInfo &TTI,
- const DominatorTree &DT, AAResults &AA, AssumptionCache &AC)
- : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {}
+ const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
+ bool ScalarizationOnly)
+ : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC),
+ ScalarizationOnly(ScalarizationOnly) {}
bool run();
@@ -74,12 +78,18 @@ private:
AAResults &AA;
AssumptionCache ∾
+ /// If true only perform scalarization combines and do not introduce new
+ /// vector operations.
+ bool ScalarizationOnly;
+
+ InstructionWorklist Worklist;
+
bool vectorizeLoadInsert(Instruction &I);
ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
ExtractElementInst *Ext1,
unsigned PreferredExtractIndex) const;
bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
- unsigned Opcode,
+ const Instruction &I,
ExtractElementInst *&ConvertToShuffle,
unsigned PreferredExtractIndex);
void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
@@ -92,14 +102,27 @@ private:
bool foldExtractedCmps(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
+ bool foldShuffleOfBinops(Instruction &I);
+
+ void replaceValue(Value &Old, Value &New) {
+ Old.replaceAllUsesWith(&New);
+ New.takeName(&Old);
+ if (auto *NewI = dyn_cast<Instruction>(&New)) {
+ Worklist.pushUsersToWorkList(*NewI);
+ Worklist.pushValue(NewI);
+ }
+ Worklist.pushValue(&Old);
+ }
+
+ void eraseInstruction(Instruction &I) {
+ for (Value *Op : I.operands())
+ Worklist.pushValue(Op);
+ Worklist.remove(&I);
+ I.eraseFromParent();
+ }
};
} // namespace
-static void replaceValue(Value &Old, Value &New) {
- Old.replaceAllUsesWith(&New);
- New.takeName(&Old);
-}
-
bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
// Match insert into fixed vector of scalar value.
// TODO: Handle non-zero insert index.
@@ -284,12 +307,13 @@ ExtractElementInst *VectorCombine::getShuffleExtract(
/// \p ConvertToShuffle to that extract instruction.
bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
ExtractElementInst *Ext1,
- unsigned Opcode,
+ const Instruction &I,
ExtractElementInst *&ConvertToShuffle,
unsigned PreferredExtractIndex) {
assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
isa<ConstantInt>(Ext1->getOperand(1)) &&
"Expected constant extract indexes");
+ unsigned Opcode = I.getOpcode();
Type *ScalarTy = Ext0->getType();
auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
InstructionCost ScalarOpCost, VectorOpCost;
@@ -302,10 +326,11 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
} else {
assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
"Expected a compare");
- ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
- CmpInst::makeCmpResultType(ScalarTy));
- VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
- CmpInst::makeCmpResultType(VecTy));
+ CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
+ ScalarOpCost = TTI.getCmpSelInstrCost(
+ Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
+ VectorOpCost = TTI.getCmpSelInstrCost(
+ Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
}
// Get cost estimates for the extract elements. These costs will factor into
@@ -480,8 +505,7 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
ExtractElementInst *ExtractToChange;
- if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange,
- InsertIndex))
+ if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
return false;
if (ExtractToChange) {
@@ -501,6 +525,8 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
else
foldExtExtBinop(Ext0, Ext1, I);
+ Worklist.push(Ext0);
+ Worklist.push(Ext1);
return true;
}
@@ -623,8 +649,11 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
unsigned Opcode = I.getOpcode();
InstructionCost ScalarOpCost, VectorOpCost;
if (IsCmp) {
- ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
- VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
+ CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
+ ScalarOpCost = TTI.getCmpSelInstrCost(
+ Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
+ VectorOpCost = TTI.getCmpSelInstrCost(
+ Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
} else {
ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
@@ -724,7 +753,10 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
InstructionCost OldCost =
TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
- OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2;
+ OldCost +=
+ TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
+ CmpInst::makeCmpResultType(I0->getType()), Pred) *
+ 2;
OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
// The proposed vector pattern is:
@@ -733,7 +765,8 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
- InstructionCost NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType());
+ InstructionCost NewCost = TTI.getCmpSelInstrCost(
+ CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
ShufMask[CheapIndex] = ExpensiveIndex;
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
@@ -774,18 +807,98 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin,
});
}
+/// Helper class to indicate whether a vector index can be safely scalarized and
+/// if a freeze needs to be inserted.
+class ScalarizationResult {
+ enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
+
+ StatusTy Status;
+ Value *ToFreeze;
+
+ ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
+ : Status(Status), ToFreeze(ToFreeze) {}
+
+public:
+ ScalarizationResult(const ScalarizationResult &Other) = default;
+ ~ScalarizationResult() {
+ assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+ }
+
+ static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
+ static ScalarizationResult safe() { return {StatusTy::Safe}; }
+ static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
+ return {StatusTy::SafeWithFreeze, ToFreeze};
+ }
+
+ /// Returns true if the index can be scalarize without requiring a freeze.
+ bool isSafe() const { return Status == StatusTy::Safe; }
+ /// Returns true if the index cannot be scalarized.
+ bool isUnsafe() const { return Status == StatusTy::Unsafe; }
+ /// Returns true if the index can be scalarize, but requires inserting a
+ /// freeze.
+ bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
+
+ /// Reset the state of Unsafe and clear ToFreze if set.
+ void discard() {
+ ToFreeze = nullptr;
+ Status = StatusTy::Unsafe;
+ }
+
+ /// Freeze the ToFreeze and update the use in \p User to use it.
+ void freeze(IRBuilder<> &Builder, Instruction &UserI) {
+ assert(isSafeWithFreeze() &&
+ "should only be used when freezing is required");
+ assert(is_contained(ToFreeze->users(), &UserI) &&
+ "UserI must be a user of ToFreeze");
+ IRBuilder<>::InsertPointGuard Guard(Builder);
+ Builder.SetInsertPoint(cast<Instruction>(&UserI));
+ Value *Frozen =
+ Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
+ for (Use &U : make_early_inc_range((UserI.operands())))
+ if (U.get() == ToFreeze)
+ U.set(Frozen);
+
+ ToFreeze = nullptr;
+ }
+};
+
/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
/// Idx. \p Idx must access a valid vector element.
-static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx,
- Instruction *CtxI, AssumptionCache &AC) {
- if (auto *C = dyn_cast<ConstantInt>(Idx))
- return C->getValue().ult(VecTy->getNumElements());
+static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
+ Value *Idx, Instruction *CtxI,
+ AssumptionCache &AC,
+ const DominatorTree &DT) {
+ if (auto *C = dyn_cast<ConstantInt>(Idx)) {
+ if (C->getValue().ult(VecTy->getNumElements()))
+ return ScalarizationResult::safe();
+ return ScalarizationResult::unsafe();
+ }
- APInt Zero(Idx->getType()->getScalarSizeInBits(), 0);
- APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements());
+ unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
+ APInt Zero(IntWidth, 0);
+ APInt MaxElts(IntWidth, VecTy->getNumElements());
ConstantRange ValidIndices(Zero, MaxElts);
- ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0);
- return ValidIndices.contains(IdxRange);
+ ConstantRange IdxRange(IntWidth, true);
+
+ if (isGuaranteedNotToBePoison(Idx, &AC)) {
+ if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, &DT)))
+ return ScalarizationResult::safe();
+ return ScalarizationResult::unsafe();
+ }
+
+ // If the index may be poison, check if we can insert a freeze before the
+ // range of the index is restricted.
+ Value *IdxBase;
+ ConstantInt *CI;
+ if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
+ IdxRange = IdxRange.binaryAnd(CI->getValue());
+ } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
+ IdxRange = IdxRange.urem(CI->getValue());
+ }
+
+ if (ValidIndices.contains(IdxRange))
+ return ScalarizationResult::safeWithFreeze(IdxBase);
+ return ScalarizationResult::unsafe();
}
/// The memory operation on a vector of \p ScalarType had alignment of
@@ -833,12 +946,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
// modified between, vector type matches store size, and index is inbounds.
if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
!DL.typeSizeEqualsStoreSize(Load->getType()) ||
- !canScalarizeAccess(VecTy, Idx, Load, AC) ||
- SrcAddr != SI->getPointerOperand()->stripPointerCasts() ||
+ SrcAddr != SI->getPointerOperand()->stripPointerCasts())
+ return false;
+
+ auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
+ if (ScalarizableIdx.isUnsafe() ||
isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
MemoryLocation::get(SI), AA))
return false;
+ if (ScalarizableIdx.isSafeWithFreeze())
+ ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
Value *GEP = Builder.CreateInBoundsGEP(
SI->getValueOperand()->getType(), SI->getPointerOperand(),
{ConstantInt::get(Idx->getType(), 0), Idx});
@@ -849,8 +967,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
DL);
NSI->setAlignment(ScalarOpAlignment);
replaceValue(I, *NSI);
- // Need erasing the store manually.
- I.eraseFromParent();
+ eraseInstruction(I);
return true;
}
@@ -860,11 +977,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
Value *Ptr;
- Value *Idx;
- if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx))))
+ if (!match(&I, m_Load(m_Value(Ptr))))
return false;
- auto *LI = cast<LoadInst>(I.getOperand(0));
+ auto *LI = cast<LoadInst>(&I);
const DataLayout &DL = I.getModule()->getDataLayout();
if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType()))
return false;
@@ -909,8 +1025,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
else if (LastCheckedInst->comesBefore(UI))
LastCheckedInst = UI;
- if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC))
+ auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+ if (!ScalarIdx.isSafe()) {
+ // TODO: Freeze index if it is safe to do so.
+ ScalarIdx.discard();
return false;
+ }
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
OriginalCost +=
@@ -946,6 +1066,60 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}
+/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
+/// "binop (shuffle), (shuffle)".
+bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
+ auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
+ if (!VecTy)
+ return false;
+
+ BinaryOperator *B0, *B1;
+ ArrayRef<int> Mask;
+ if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
+ m_Mask(Mask))) ||
+ B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
+ return false;
+
+ // Try to replace a binop with a shuffle if the shuffle is not costly.
+ // The new shuffle will choose from a single, common operand, so it may be
+ // cheaper than the existing two-operand shuffle.
+ SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
+ Instruction::BinaryOps Opcode = B0->getOpcode();
+ InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
+ InstructionCost ShufCost = TTI.getShuffleCost(
+ TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
+ if (ShufCost > BinopCost)
+ return false;
+
+ // If we have something like "add X, Y" and "add Z, X", swap ops to match.
+ Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
+ Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
+ if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
+ std::swap(X, Y);
+
+ Value *Shuf0, *Shuf1;
+ if (X == Z) {
+ // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
+ Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
+ Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
+ } else if (Y == W) {
+ // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
+ Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
+ Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
+ } else {
+ return false;
+ }
+
+ Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
+ // Intersect flags from the old binops.
+ if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
+ NewInst->copyIRFlags(B0);
+ NewInst->andIRFlags(B1);
+ }
+ replaceValue(I, *NewBO);
+ return true;
+}
+
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
@@ -957,29 +1131,43 @@ bool VectorCombine::run() {
return false;
bool MadeChange = false;
+ auto FoldInst = [this, &MadeChange](Instruction &I) {
+ Builder.SetInsertPoint(&I);
+ if (!ScalarizationOnly) {
+ MadeChange |= vectorizeLoadInsert(I);
+ MadeChange |= foldExtractExtract(I);
+ MadeChange |= foldBitcastShuf(I);
+ MadeChange |= foldExtractedCmps(I);
+ MadeChange |= foldShuffleOfBinops(I);
+ }
+ MadeChange |= scalarizeBinopOrCmp(I);
+ MadeChange |= scalarizeLoadExtract(I);
+ MadeChange |= foldSingleElementStore(I);
+ };
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
if (!DT.isReachableFromEntry(&BB))
continue;
// Use early increment range so that we can erase instructions in loop.
for (Instruction &I : make_early_inc_range(BB)) {
- if (isa<DbgInfoIntrinsic>(I))
+ if (I.isDebugOrPseudoInst())
continue;
- Builder.SetInsertPoint(&I);
- MadeChange |= vectorizeLoadInsert(I);
- MadeChange |= foldExtractExtract(I);
- MadeChange |= foldBitcastShuf(I);
- MadeChange |= scalarizeBinopOrCmp(I);
- MadeChange |= foldExtractedCmps(I);
- MadeChange |= scalarizeLoadExtract(I);
- MadeChange |= foldSingleElementStore(I);
+ FoldInst(I);
}
}
- // We're done with transforms, so remove dead instructions.
- if (MadeChange)
- for (BasicBlock &BB : F)
- SimplifyInstructionsInBlock(&BB);
+ while (!Worklist.isEmpty()) {
+ Instruction *I = Worklist.removeOne();
+ if (!I)
+ continue;
+
+ if (isInstructionTriviallyDead(I)) {
+ eraseInstruction(*I);
+ continue;
+ }
+
+ FoldInst(*I);
+ }
return MadeChange;
}
@@ -1014,7 +1202,7 @@ public:
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- VectorCombine Combiner(F, TTI, DT, AA, AC);
+ VectorCombine Combiner(F, TTI, DT, AA, AC, false);
return Combiner.run();
}
};
@@ -1038,7 +1226,7 @@ PreservedAnalyses VectorCombinePass::run(Function &F,
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
AAResults &AA = FAM.getResult<AAManager>(F);
- VectorCombine Combiner(F, TTI, DT, AA, AC);
+ VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly);
if (!Combiner.run())
return PreservedAnalyses::all();
PreservedAnalyses PA;