summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp165
1 files changed, 108 insertions, 57 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 04877bec94ec..ca87477c5d81 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -333,7 +333,7 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
SrcTy->getNumElements() == DestTy->getNumElements() &&
SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) {
Value *CastX = Builder.CreateCast(CI.getOpcode(), X, DestTy);
- return new ShuffleVectorInst(CastX, UndefValue::get(DestTy), Mask);
+ return new ShuffleVectorInst(CastX, Mask);
}
}
@@ -701,10 +701,10 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
is_splat(Shuf->getShuffleMask()) &&
Shuf->getType() == Shuf->getOperand(0)->getType()) {
- // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask
- Constant *NarrowUndef = UndefValue::get(Trunc.getType());
+ // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
+ // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask
Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType());
- return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask());
+ return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask());
}
return nullptr;
@@ -961,14 +961,25 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
return BinaryOperator::CreateAdd(NarrowCtlz, WidthDiff);
}
}
+
+ if (match(Src, m_VScale(DL))) {
+ if (Trunc.getFunction() &&
+ Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
+ unsigned MaxVScale = Trunc.getFunction()
+ ->getFnAttribute(Attribute::VScaleRange)
+ .getVScaleRangeArgs()
+ .second;
+ if (MaxVScale > 0 && Log2_32(MaxVScale) < DestWidth) {
+ Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
+ return replaceInstUsesWith(Trunc, VScale);
+ }
+ }
+ }
+
return nullptr;
}
-/// Transform (zext icmp) to bitwise / integer operations in order to
-/// eliminate it. If DoTransform is false, just test whether the given
-/// (zext icmp) can be transformed.
-Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
- bool DoTransform) {
+Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) {
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
@@ -977,10 +988,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
// zext (x <s 0) to i32 --> x>>u31 true if signbit set.
// zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear.
- if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
- (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
- if (!DoTransform) return Cmp;
-
+ if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) ||
+ (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnes())) {
Value *In = Cmp->getOperand(0);
Value *Sh = ConstantInt::get(In->getType(),
In->getType()->getScalarSizeInBits() - 1);
@@ -1004,7 +1013,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
// zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set.
// zext (X != 1) to i32 --> X^1 iff X has only the low bit set.
// zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
- if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) &&
+ if ((Op1CV->isZero() || Op1CV->isPowerOf2()) &&
// This only works for EQ and NE
Cmp->isEquality()) {
// If Op1C some other power of two, convert:
@@ -1012,10 +1021,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
APInt KnownZeroMask(~Known.Zero);
if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
- if (!DoTransform) return Cmp;
-
bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE;
- if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) {
+ if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) {
// (X&4) == 2 --> false
// (X&4) != 2 --> true
Constant *Res = ConstantInt::get(Zext.getType(), isNE);
@@ -1031,7 +1038,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
In->getName() + ".lobit");
}
- if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit.
+ if (!Op1CV->isZero() == isNE) { // Toggle the low bit.
Constant *One = ConstantInt::get(In->getType(), 1);
In = Builder.CreateXor(In, One);
}
@@ -1053,9 +1060,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) &&
match(Cmp->getOperand(0),
m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) {
- if (!DoTransform)
- return Cmp;
-
if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
X = Builder.CreateNot(X);
Value *Lshr = Builder.CreateLShr(X, ShAmt);
@@ -1077,8 +1081,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
APInt KnownBits = KnownLHS.Zero | KnownLHS.One;
APInt UnknownBit = ~KnownBits;
if (UnknownBit.countPopulation() == 1) {
- if (!DoTransform) return Cmp;
-
Value *Result = Builder.CreateXor(LHS, RHS);
// Mask off any bits that are set and won't be shifted away.
@@ -1316,51 +1318,37 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src))
return transformZExtICmp(Cmp, CI);
- BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src);
- if (SrcI && SrcI->getOpcode() == Instruction::Or) {
- // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) if at least one
- // of the (zext icmp) can be eliminated. If so, immediately perform the
- // according elimination.
- ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0));
- ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1));
- if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() &&
- LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType() &&
- (transformZExtICmp(LHS, CI, false) ||
- transformZExtICmp(RHS, CI, false))) {
- // zext (or icmp, icmp) -> or (zext icmp), (zext icmp)
- Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName());
- Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName());
- Value *Or = Builder.CreateOr(LCast, RCast, CI.getName());
- if (auto *OrInst = dyn_cast<Instruction>(Or))
- Builder.SetInsertPoint(OrInst);
-
- // Perform the elimination.
- if (auto *LZExt = dyn_cast<ZExtInst>(LCast))
- transformZExtICmp(LHS, *LZExt);
- if (auto *RZExt = dyn_cast<ZExtInst>(RCast))
- transformZExtICmp(RHS, *RZExt);
-
- return replaceInstUsesWith(CI, Or);
- }
- }
-
// zext(trunc(X) & C) -> (X & zext(C)).
Constant *C;
Value *X;
- if (SrcI &&
- match(SrcI, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) &&
+ if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) &&
X->getType() == CI.getType())
return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType()));
// zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)).
Value *And;
- if (SrcI && match(SrcI, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) &&
+ if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) &&
match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) &&
X->getType() == CI.getType()) {
Constant *ZC = ConstantExpr::getZExt(C, CI.getType());
return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC);
}
+ if (match(Src, m_VScale(DL))) {
+ if (CI.getFunction() &&
+ CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
+ unsigned MaxVScale = CI.getFunction()
+ ->getFnAttribute(Attribute::VScaleRange)
+ .getVScaleRangeArgs()
+ .second;
+ unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
+ if (MaxVScale > 0 && Log2_32(MaxVScale) < TypeWidth) {
+ Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
+ return replaceInstUsesWith(CI, VScale);
+ }
+ }
+ }
+
return nullptr;
}
@@ -1605,6 +1593,32 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
return BinaryOperator::CreateAShr(A, NewShAmt);
}
+ // Splatting a bit of constant-index across a value:
+ // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1
+ // TODO: If the dest type is different, use a cast (adjust use check).
+ if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)),
+ m_SpecificInt(SrcBitSize - 1)))) &&
+ X->getType() == DestTy) {
+ Constant *ShlAmtC = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
+ Constant *AshrAmtC = ConstantInt::get(DestTy, DestBitSize - 1);
+ Value *Shl = Builder.CreateShl(X, ShlAmtC);
+ return BinaryOperator::CreateAShr(Shl, AshrAmtC);
+ }
+
+ if (match(Src, m_VScale(DL))) {
+ if (CI.getFunction() &&
+ CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
+ unsigned MaxVScale = CI.getFunction()
+ ->getFnAttribute(Attribute::VScaleRange)
+ .getVScaleRangeArgs()
+ .second;
+ if (MaxVScale > 0 && Log2_32(MaxVScale) < (SrcBitSize - 1)) {
+ Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
+ return replaceInstUsesWith(CI, VScale);
+ }
+ }
+ }
+
return nullptr;
}
@@ -2060,6 +2074,19 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
}
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) {
+ // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use.
+ // While this can increase the number of instructions it doesn't actually
+ // increase the overall complexity since the arithmetic is just part of
+ // the GEP otherwise.
+ if (GEP->hasOneUse() &&
+ isa<ConstantPointerNull>(GEP->getPointerOperand())) {
+ return replaceInstUsesWith(CI,
+ Builder.CreateIntCast(EmitGEPOffset(GEP), Ty,
+ /*isSigned=*/false));
+ }
+ }
+
Value *Vec, *Scalar, *Index;
if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)),
m_Value(Scalar), m_Value(Index)))) &&
@@ -2133,9 +2160,9 @@ optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy,
if (SrcElts > DestElts) {
// If we're shrinking the number of elements (rewriting an integer
// truncate), just shuffle in the elements corresponding to the least
- // significant bits from the input and use undef as the second shuffle
+ // significant bits from the input and use poison as the second shuffle
// input.
- V2 = UndefValue::get(SrcTy);
+ V2 = PoisonValue::get(SrcTy);
// Make sure the shuffle mask selects the "least significant bits" by
// keeping elements from back of the src vector for big endian, and from the
// front for little endian.
@@ -2528,7 +2555,7 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI,
// As long as the user is another old PHI node, then even if we don't
// rewrite it, the PHI web we're considering won't have any users
// outside itself, so it'll be dead.
- if (OldPhiNodes.count(PHI) == 0)
+ if (!OldPhiNodes.contains(PHI))
return nullptr;
} else {
return nullptr;
@@ -2736,6 +2763,30 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
if (auto *InsElt = dyn_cast<InsertElementInst>(Src))
return new BitCastInst(InsElt->getOperand(1), DestTy);
}
+
+ // Convert an artificial vector insert into more analyzable bitwise logic.
+ unsigned BitWidth = DestTy->getScalarSizeInBits();
+ Value *X, *Y;
+ uint64_t IndexC;
+ if (match(Src, m_OneUse(m_InsertElt(m_OneUse(m_BitCast(m_Value(X))),
+ m_Value(Y), m_ConstantInt(IndexC)))) &&
+ DestTy->isIntegerTy() && X->getType() == DestTy &&
+ isDesirableIntType(BitWidth)) {
+ // Adjust for big endian - the LSBs are at the high index.
+ if (DL.isBigEndian())
+ IndexC = SrcVTy->getNumElements() - 1 - IndexC;
+
+ // We only handle (endian-normalized) insert to index 0. Any other insert
+ // would require a left-shift, so that is an extra instruction.
+ if (IndexC == 0) {
+ // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y)
+ unsigned EltWidth = Y->getType()->getScalarSizeInBits();
+ APInt MaskC = APInt::getHighBitsSet(BitWidth, BitWidth - EltWidth);
+ Value *AndX = Builder.CreateAnd(X, MaskC);
+ Value *ZextY = Builder.CreateZExt(Y, DestTy);
+ return BinaryOperator::CreateOr(AndX, ZextY);
+ }
+ }
}
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) {