aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp135
1 files changed, 102 insertions, 33 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 351fc3b0174f..7f2018b3a199 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
getComplexity(I.getOperand(1)))
Changed = !I.swapOperands();
+ if (I.isCommutative()) {
+ if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
+ replaceOperand(I, 0, Pair->first);
+ replaceOperand(I, 1, Pair->second);
+ Changed = true;
+ }
+ }
+
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
@@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}
-std::optional<std::pair<Value *, Value *>>
-InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
+static std::optional<std::pair<Value *, Value *>>
+matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
if (LHS->getParent() != RHS->getParent())
return std::nullopt;
@@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
return std::optional(std::pair(L0, R0));
}
-Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
- Value *Op0,
- Value *Op1) {
- assert(I.isCommutative() && "Instruction should be commutative");
-
- PHINode *LHS = dyn_cast<PHINode>(Op0);
- PHINode *RHS = dyn_cast<PHINode>(Op1);
-
- if (!LHS || !RHS)
- return nullptr;
-
- if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
- Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
- if (auto *BO = dyn_cast<BinaryOperator>(BI))
- BO->copyIRFlags(&I);
- return BI;
+std::optional<std::pair<Value *, Value *>>
+InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
+ Instruction *LHSInst = dyn_cast<Instruction>(LHS);
+ Instruction *RHSInst = dyn_cast<Instruction>(RHS);
+ if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
+ return std::nullopt;
+ switch (LHSInst->getOpcode()) {
+ case Instruction::PHI:
+ return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
+ case Instruction::Select: {
+ Value *Cond = LHSInst->getOperand(0);
+ Value *TrueVal = LHSInst->getOperand(1);
+ Value *FalseVal = LHSInst->getOperand(2);
+ if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
+ FalseVal == RHSInst->getOperand(1))
+ return std::pair(TrueVal, FalseVal);
+ return std::nullopt;
+ }
+ case Instruction::Call: {
+ // Match min(a, b) and max(a, b)
+ MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
+ MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
+ if (LHSMinMax && RHSMinMax &&
+ LHSMinMax->getPredicate() ==
+ ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
+ ((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
+ LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
+ (LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
+ LHSMinMax->getRHS() == RHSMinMax->getLHS())))
+ return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
+ return std::nullopt;
+ }
+ default:
+ return std::nullopt;
}
-
- return nullptr;
}
Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
@@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
};
if (LHSIsSelect && RHSIsSelect && A == D) {
- // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
- if (I.isCommutative() && B == F && C == E) {
- Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
- if (auto *BO = dyn_cast<BinaryOperator>(BI))
- BO->copyIRFlags(&I);
- return BI;
- }
-
// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
Cond = A;
True = simplifyBinOp(Opcode, B, E, FMF, Q);
@@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;
- if (BO.isCommutative()) {
- if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
- return replaceInstUsesWith(BO, V);
- }
-
// Fold if there is at least one specific constant value in phi0 or phi1's
// incoming values that comes from the same block and this specific constant
// value can be used to do optimization for specific binary operator.
@@ -3197,6 +3208,64 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
return replaceOperand(SI, 0, Op0);
}
+ ConstantInt *SubLHS;
+ if (match(Cond, m_Sub(m_ConstantInt(SubLHS), m_Value(Op0)))) {
+ // Change 'switch (1-X) case 1:' into 'switch (X) case 0'.
+ for (auto Case : SI.cases()) {
+ Constant *NewCase = ConstantExpr::getSub(SubLHS, Case.getCaseValue());
+ assert(isa<ConstantInt>(NewCase) &&
+ "Result of expression should be constant");
+ Case.setValue(cast<ConstantInt>(NewCase));
+ }
+ return replaceOperand(SI, 0, Op0);
+ }
+
+ uint64_t ShiftAmt;
+ if (match(Cond, m_Shl(m_Value(Op0), m_ConstantInt(ShiftAmt))) &&
+ ShiftAmt < Op0->getType()->getScalarSizeInBits() &&
+ all_of(SI.cases(), [&](const auto &Case) {
+ return Case.getCaseValue()->getValue().countr_zero() >= ShiftAmt;
+ })) {
+ // Change 'switch (X << 2) case 4:' into 'switch (X) case 1:'.
+ OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(Cond);
+ if (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap() ||
+ Shl->hasOneUse()) {
+ Value *NewCond = Op0;
+ if (!Shl->hasNoUnsignedWrap() && !Shl->hasNoSignedWrap()) {
+ // If the shift may wrap, we need to mask off the shifted bits.
+ unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+ NewCond = Builder.CreateAnd(
+ Op0, APInt::getLowBitsSet(BitWidth, BitWidth - ShiftAmt));
+ }
+ for (auto Case : SI.cases()) {
+ const APInt &CaseVal = Case.getCaseValue()->getValue();
+ APInt ShiftedCase = Shl->hasNoSignedWrap() ? CaseVal.ashr(ShiftAmt)
+ : CaseVal.lshr(ShiftAmt);
+ Case.setValue(ConstantInt::get(SI.getContext(), ShiftedCase));
+ }
+ return replaceOperand(SI, 0, NewCond);
+ }
+ }
+
+ // Fold switch(zext/sext(X)) into switch(X) if possible.
+ if (match(Cond, m_ZExtOrSExt(m_Value(Op0)))) {
+ bool IsZExt = isa<ZExtInst>(Cond);
+ Type *SrcTy = Op0->getType();
+ unsigned NewWidth = SrcTy->getScalarSizeInBits();
+
+ if (all_of(SI.cases(), [&](const auto &Case) {
+ const APInt &CaseVal = Case.getCaseValue()->getValue();
+ return IsZExt ? CaseVal.isIntN(NewWidth)
+ : CaseVal.isSignedIntN(NewWidth);
+ })) {
+ for (auto &Case : SI.cases()) {
+ APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth);
+ Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
+ }
+ return replaceOperand(SI, 0, Op0);
+ }
+ }
+
KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();