summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp231
1 files changed, 188 insertions, 43 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index cc0a9127f8b18..d3c718a919c0a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -143,8 +143,7 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,
// the XOR is to toggle the bit. If it is clear, then the ADD has
// no effect.
if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop
- TheAnd.setOperand(0, X);
- return &TheAnd;
+ return replaceOperand(TheAnd, 0, X);
} else {
// Pull the XOR out of the AND.
Value *NewAnd = Builder.CreateAnd(X, AndRHS);
@@ -858,8 +857,10 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS,
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
- bool JoinedByAnd,
- Instruction &CxtI) {
+ BinaryOperator &Logic) {
+ bool JoinedByAnd = Logic.getOpcode() == Instruction::And;
+ assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) &&
+ "Wrong opcode");
ICmpInst::Predicate Pred = LHS->getPredicate();
if (Pred != RHS->getPredicate())
return nullptr;
@@ -883,8 +884,8 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
std::swap(A, B);
if (A == C &&
- isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) &&
- isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) {
+ isKnownToBeAPowerOfTwo(B, false, 0, &Logic) &&
+ isKnownToBeAPowerOfTwo(D, false, 0, &Logic)) {
Value *Mask = Builder.CreateOr(B, D);
Value *Masked = Builder.CreateAnd(A, Mask);
auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
@@ -1072,9 +1073,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) &&
match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) &&
(ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) {
- if (UnsignedICmp->getOperand(0) != ZeroCmpOp)
- UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
-
auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) {
if (!IsKnownNonZero(NonZero))
std::swap(NonZero, Other);
@@ -1111,8 +1109,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) ||
!ICmpInst::isUnsigned(UnsignedPred))
return nullptr;
- if (UnsignedICmp->getOperand(0) != Base)
- UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
// Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset
// (no overflow and not null)
@@ -1141,14 +1137,59 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
return nullptr;
}
+/// Reduce logic-of-compares with equality to a constant by substituting a
+/// common operand with the constant. Callers are expected to call this with
+/// Cmp0/Cmp1 switched to handle logic op commutativity.
+static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
+ BinaryOperator &Logic,
+ InstCombiner::BuilderTy &Builder,
+ const SimplifyQuery &Q) {
+ bool IsAnd = Logic.getOpcode() == Instruction::And;
+ assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op");
+
+ // Match an equality compare with a non-poison constant as Cmp0.
+ ICmpInst::Predicate Pred0;
+ Value *X;
+ Constant *C;
+ if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) ||
+ !isGuaranteedNotToBeUndefOrPoison(C))
+ return nullptr;
+ if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) ||
+ (!IsAnd && Pred0 != ICmpInst::ICMP_NE))
+ return nullptr;
+
+ // The other compare must include a common operand (X). Canonicalize the
+ // common operand as operand 1 (Pred1 is swapped if the common operand was
+ // operand 0).
+ Value *Y;
+ ICmpInst::Predicate Pred1;
+ if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X))))
+ return nullptr;
+
+ // Replace variable with constant value equivalence to remove a variable use:
+ // (X == C) && (Y Pred1 X) --> (X == C) && (Y Pred1 C)
+ // (X != C) || (Y Pred1 X) --> (X != C) || (Y Pred1 C)
+ // Can think of the 'or' substitution with the 'and' bool equivalent:
+ // A || B --> A || (!A && B)
+ Value *SubstituteCmp = SimplifyICmpInst(Pred1, Y, C, Q);
+ if (!SubstituteCmp) {
+ // If we need to create a new instruction, require that the old compare can
+ // be removed.
+ if (!Cmp1->hasOneUse())
+ return nullptr;
+ SubstituteCmp = Builder.CreateICmp(Pred1, Y, C);
+ }
+ return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp);
+}
+
/// Fold (icmp)&(icmp) if possible.
Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- Instruction &CxtI) {
- const SimplifyQuery Q = SQ.getWithInstruction(&CxtI);
+ BinaryOperator &And) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&And);
// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
// if K1 and K2 are a one-bit mask.
- if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI))
+ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, And))
return V;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
@@ -1171,6 +1212,11 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder))
return V;
+ if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, And, Builder, Q))
+ return V;
+ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q))
+ return V;
+
// E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n
if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false))
return V;
@@ -1182,7 +1228,7 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder))
return V;
- if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder))
+ if (Value *V = foldSignedTruncationCheck(LHS, RHS, And, Builder))
return V;
if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder))
@@ -1658,7 +1704,7 @@ static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) {
if (C->getType()->isVectorTy()) {
// Check each element of a constant vector.
- unsigned NumElts = C->getType()->getVectorNumElements();
+ unsigned NumElts = cast<VectorType>(C->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
@@ -1802,7 +1848,17 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return BinaryOperator::Create(BinOp, NewLHS, Y);
}
}
-
+ const APInt *ShiftC;
+ if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) {
+ unsigned Width = I.getType()->getScalarSizeInBits();
+ if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) {
+ // We are clearing high bits that were potentially set by sext+ashr:
+ // and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC
+ Value *Sext = Builder.CreateSExt(X, I.getType());
+ Constant *ShAmtC = ConstantInt::get(I.getType(), ShiftC->zext(Width));
+ return BinaryOperator::CreateLShr(Sext, ShAmtC);
+ }
+ }
}
if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) {
@@ -2020,7 +2076,7 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) {
LastInst->removeFromParent();
for (auto *Inst : Insts)
- Worklist.Add(Inst);
+ Worklist.push(Inst);
return LastInst;
}
@@ -2086,9 +2142,62 @@ static Instruction *matchRotate(Instruction &Or) {
return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt });
}
+/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
+static Instruction *matchOrConcat(Instruction &Or,
+ InstCombiner::BuilderTy &Builder) {
+ assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'");
+ Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1);
+ Type *Ty = Or.getType();
+
+ unsigned Width = Ty->getScalarSizeInBits();
+ if ((Width & 1) != 0)
+ return nullptr;
+ unsigned HalfWidth = Width / 2;
+
+ // Canonicalize zext (lower half) to LHS.
+ if (!isa<ZExtInst>(Op0))
+ std::swap(Op0, Op1);
+
+ // Find lower/upper half.
+ Value *LowerSrc, *ShlVal, *UpperSrc;
+ const APInt *C;
+ if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) ||
+ !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) ||
+ !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc)))))
+ return nullptr;
+ if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() ||
+ LowerSrc->getType()->getScalarSizeInBits() != HalfWidth)
+ return nullptr;
+
+ auto ConcatIntrinsicCalls = [&](Intrinsic::ID id, Value *Lo, Value *Hi) {
+ Value *NewLower = Builder.CreateZExt(Lo, Ty);
+ Value *NewUpper = Builder.CreateZExt(Hi, Ty);
+ NewUpper = Builder.CreateShl(NewUpper, HalfWidth);
+ Value *BinOp = Builder.CreateOr(NewLower, NewUpper);
+ Function *F = Intrinsic::getDeclaration(Or.getModule(), id, Ty);
+ return Builder.CreateCall(F, BinOp);
+ };
+
+ // BSWAP: Push the concat down, swapping the lower/upper sources.
+ // concat(bswap(x),bswap(y)) -> bswap(concat(x,y))
+ Value *LowerBSwap, *UpperBSwap;
+ if (match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) &&
+ match(UpperSrc, m_BSwap(m_Value(UpperBSwap))))
+ return ConcatIntrinsicCalls(Intrinsic::bswap, UpperBSwap, LowerBSwap);
+
+ // BITREVERSE: Push the concat down, swapping the lower/upper sources.
+ // concat(bitreverse(x),bitreverse(y)) -> bitreverse(concat(x,y))
+ Value *LowerBRev, *UpperBRev;
+ if (match(LowerSrc, m_BitReverse(m_Value(LowerBRev))) &&
+ match(UpperSrc, m_BitReverse(m_Value(UpperBRev))))
+ return ConcatIntrinsicCalls(Intrinsic::bitreverse, UpperBRev, LowerBRev);
+
+ return nullptr;
+}
+
/// If all elements of two constant vectors are 0/-1 and inverses, return true.
static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
- unsigned NumElts = C1->getType()->getVectorNumElements();
+ unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *EltC1 = C1->getAggregateElement(i);
Constant *EltC2 = C2->getAggregateElement(i);
@@ -2185,12 +2294,12 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B,
/// Fold (icmp)|(icmp) if possible.
Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- Instruction &CxtI) {
- const SimplifyQuery Q = SQ.getWithInstruction(&CxtI);
+ BinaryOperator &Or) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Or);
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// if K1 and K2 are a one-bit mask.
- if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI))
+ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, Or))
return V;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
@@ -2299,6 +2408,11 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A);
}
+ if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q))
+ return V;
+ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q))
+ return V;
+
// E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n
if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true))
return V;
@@ -2481,6 +2595,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (Instruction *Rotate = matchRotate(I))
return Rotate;
+ if (Instruction *Concat = matchOrConcat(I, Builder))
+ return replaceInstUsesWith(I, Concat);
+
Value *X, *Y;
const APInt *CV;
if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
@@ -2729,6 +2846,32 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
return V;
+ CmpInst::Predicate Pred;
+ Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv;
+ // Check if the OR weakens the overflow condition for umul.with.overflow by
+ // treating any non-zero result as overflow. In that case, we overflow if both
+ // umul.with.overflow operands are != 0, as in that case the result can only
+ // be 0, iff the multiplication overflows.
+ if (match(&I,
+ m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)),
+ m_Value(Ov)),
+ m_CombineAnd(m_ICmp(Pred,
+ m_CombineAnd(m_ExtractValue<0>(
+ m_Deferred(UMulWithOv)),
+ m_Value(Mul)),
+ m_ZeroInt()),
+ m_Value(MulIsNotZero)))) &&
+ (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse())) &&
+ Pred == CmpInst::ICMP_NE) {
+ Value *A, *B;
+ if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>(
+ m_Value(A), m_Value(B)))) {
+ Value *NotNullA = Builder.CreateIsNotNull(A);
+ Value *NotNullB = Builder.CreateIsNotNull(B);
+ return BinaryOperator::CreateAnd(NotNullA, NotNullB);
+ }
+ }
+
return nullptr;
}
@@ -2748,33 +2891,24 @@ static Instruction *foldXorToXor(BinaryOperator &I,
// (A | B) ^ (A & B) -> A ^ B
// (A | B) ^ (B & A) -> A ^ B
if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)),
- m_c_Or(m_Deferred(A), m_Deferred(B))))) {
- I.setOperand(0, A);
- I.setOperand(1, B);
- return &I;
- }
+ m_c_Or(m_Deferred(A), m_Deferred(B)))))
+ return BinaryOperator::CreateXor(A, B);
// (A | ~B) ^ (~A | B) -> A ^ B
// (~B | A) ^ (~A | B) -> A ^ B
// (~A | B) ^ (A | ~B) -> A ^ B
// (B | ~A) ^ (A | ~B) -> A ^ B
if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))),
- m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) {
- I.setOperand(0, A);
- I.setOperand(1, B);
- return &I;
- }
+ m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B)))))
+ return BinaryOperator::CreateXor(A, B);
// (A & ~B) ^ (~A & B) -> A ^ B
// (~B & A) ^ (~A & B) -> A ^ B
// (~A & B) ^ (A & ~B) -> A ^ B
// (B & ~A) ^ (A & ~B) -> A ^ B
if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))),
- m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) {
- I.setOperand(0, A);
- I.setOperand(1, B);
- return &I;
- }
+ m_c_And(m_Not(m_Deferred(A)), m_Deferred(B)))))
+ return BinaryOperator::CreateXor(A, B);
// For the remaining cases we need to get rid of one of the operands.
if (!Op0->hasOneUse() && !Op1->hasOneUse())
@@ -2878,6 +3012,7 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator()));
Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not");
// Replace all uses of Y (excluding the one in NotY!) with NotY.
+ Worklist.pushUsersToWorkList(*Y);
Y->replaceUsesWithIf(NotY,
[NotY](Use &U) { return U.getUser() != NotY; });
}
@@ -2924,6 +3059,9 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
Constant *C;
if (D->hasOneUse() && match(M, m_Constant(C))) {
+ // Propagating undef is unsafe. Clamp undef elements to -1.
+ Type *EltTy = C->getType()->getScalarType();
+ C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy));
// Unfold.
Value *LHS = Builder.CreateAnd(X, C);
Value *NotC = Builder.CreateNot(C);
@@ -3058,13 +3196,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits)
Constant *C;
if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) &&
- match(C, m_Negative()))
+ match(C, m_Negative())) {
+ // We matched a negative constant, so propagating undef is unsafe.
+ // Clamp undef elements to -1.
+ Type *EltTy = C->getType()->getScalarType();
+ C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy));
return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y);
+ }
// ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits)
if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) &&
- match(C, m_NonNegative()))
+ match(C, m_NonNegative())) {
+ // We matched a non-negative constant, so propagating undef is unsafe.
+ // Clamp undef elements to 0.
+ Type *EltTy = C->getType()->getScalarType();
+ C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy));
return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y);
+ }
// ~(X + C) --> -(C + 1) - X
if (match(Op0, m_Add(m_Value(X), m_Constant(C))))
@@ -3114,10 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (match(Op0, m_Or(m_Value(X), m_APInt(C))) &&
MaskedValueIsZero(X, *C, 0, &I)) {
Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC);
- Worklist.Add(cast<Instruction>(Op0));
- I.setOperand(0, X);
- I.setOperand(1, NewC);
- return &I;
+ return BinaryOperator::CreateXor(X, NewC);
}
}
}