diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp | 135 |
1 files changed, 102 insertions, 33 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 351fc3b0174f..7f2018b3a199 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { getComplexity(I.getOperand(1))) Changed = !I.swapOperands(); + if (I.isCommutative()) { + if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) { + replaceOperand(I, 0, Pair->first); + replaceOperand(I, 1, Pair->second); + Changed = true; + } + } + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0)); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1)); @@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) { return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); } -std::optional<std::pair<Value *, Value *>> -InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) { +static std::optional<std::pair<Value *, Value *>> +matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) { if (LHS->getParent() != RHS->getParent()) return std::nullopt; @@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) { return std::optional(std::pair(L0, R0)); } -Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, - Value *Op0, - Value *Op1) { - assert(I.isCommutative() && "Instruction should be commutative"); - - PHINode *LHS = dyn_cast<PHINode>(Op0); - PHINode *RHS = dyn_cast<PHINode>(Op1); - - if (!LHS || !RHS) - return nullptr; - - if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) { - Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second); - if (auto *BO = dyn_cast<BinaryOperator>(BI)) - BO->copyIRFlags(&I); - return BI; +std::optional<std::pair<Value *, Value *>> +InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) { + Instruction *LHSInst = dyn_cast<Instruction>(LHS); + Instruction *RHSInst = dyn_cast<Instruction>(RHS); + if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode()) + return std::nullopt; + switch (LHSInst->getOpcode()) { + case Instruction::PHI: + return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS)); + case Instruction::Select: { + Value *Cond = LHSInst->getOperand(0); + Value *TrueVal = LHSInst->getOperand(1); + Value *FalseVal = LHSInst->getOperand(2); + if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) && + FalseVal == RHSInst->getOperand(1)) + return std::pair(TrueVal, FalseVal); + return std::nullopt; + } + case Instruction::Call: { + // Match min(a, b) and max(a, b) + MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst); + MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst); + if (LHSMinMax && RHSMinMax && + LHSMinMax->getPredicate() == + ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) && + ((LHSMinMax->getLHS() == RHSMinMax->getLHS() && + LHSMinMax->getRHS() == RHSMinMax->getRHS()) || + (LHSMinMax->getLHS() == RHSMinMax->getRHS() && + LHSMinMax->getRHS() == RHSMinMax->getLHS()))) + return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS()); + return std::nullopt; + } + default: + return std::nullopt; } - - return nullptr; } Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, @@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, }; if (LHSIsSelect && RHSIsSelect && A == D) { - // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y) - if (I.isCommutative() && B == F && C == E) { - Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E); - if (auto *BO = dyn_cast<BinaryOperator>(BI)) - BO->copyIRFlags(&I); - return BI; - } - // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; True = simplifyBinOp(Opcode, B, E, FMF, Q); @@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { BO.getParent() != Phi1->getParent()) return nullptr; - if (BO.isCommutative()) { - if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1)) - return replaceInstUsesWith(BO, V); - } - // Fold if there is at least one specific constant value in phi0 or phi1's // incoming values that comes from the same block and this specific constant // value can be used to do optimization for specific binary operator. @@ -3197,6 +3208,64 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, Op0); } + ConstantInt *SubLHS; + if (match(Cond, m_Sub(m_ConstantInt(SubLHS), m_Value(Op0)))) { + // Change 'switch (1-X) case 1:' into 'switch (X) case 0'. + for (auto Case : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(SubLHS, Case.getCaseValue()); + assert(isa<ConstantInt>(NewCase) && + "Result of expression should be constant"); + Case.setValue(cast<ConstantInt>(NewCase)); + } + return replaceOperand(SI, 0, Op0); + } + + uint64_t ShiftAmt; + if (match(Cond, m_Shl(m_Value(Op0), m_ConstantInt(ShiftAmt))) && + ShiftAmt < Op0->getType()->getScalarSizeInBits() && + all_of(SI.cases(), [&](const auto &Case) { + return Case.getCaseValue()->getValue().countr_zero() >= ShiftAmt; + })) { + // Change 'switch (X << 2) case 4:' into 'switch (X) case 1:'. + OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(Cond); + if (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap() || + Shl->hasOneUse()) { + Value *NewCond = Op0; + if (!Shl->hasNoUnsignedWrap() && !Shl->hasNoSignedWrap()) { + // If the shift may wrap, we need to mask off the shifted bits. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + NewCond = Builder.CreateAnd( + Op0, APInt::getLowBitsSet(BitWidth, BitWidth - ShiftAmt)); + } + for (auto Case : SI.cases()) { + const APInt &CaseVal = Case.getCaseValue()->getValue(); + APInt ShiftedCase = Shl->hasNoSignedWrap() ? CaseVal.ashr(ShiftAmt) + : CaseVal.lshr(ShiftAmt); + Case.setValue(ConstantInt::get(SI.getContext(), ShiftedCase)); + } + return replaceOperand(SI, 0, NewCond); + } + } + + // Fold switch(zext/sext(X)) into switch(X) if possible. + if (match(Cond, m_ZExtOrSExt(m_Value(Op0)))) { + bool IsZExt = isa<ZExtInst>(Cond); + Type *SrcTy = Op0->getType(); + unsigned NewWidth = SrcTy->getScalarSizeInBits(); + + if (all_of(SI.cases(), [&](const auto &Case) { + const APInt &CaseVal = Case.getCaseValue()->getValue(); + return IsZExt ? CaseVal.isIntN(NewWidth) + : CaseVal.isSignedIntN(NewWidth); + })) { + for (auto &Case : SI.cases()) { + APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + } + return replaceOperand(SI, 0, Op0); + } + } + KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); |
