aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp59
1 files changed, 49 insertions, 10 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 1539fa9a3269..3b7fe7fa2266 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -357,9 +357,9 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
// Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
- APInt UndefElts(DemandedElts.getBitWidth(), 0);
- if (Value *V =
- SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts))
+ APInt PoisonElts(DemandedElts.getBitWidth(), 0);
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
+ PoisonElts))
return replaceOperand(II, 0, V);
return nullptr;
@@ -439,12 +439,12 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
// Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
- APInt UndefElts(DemandedElts.getBitWidth(), 0);
- if (Value *V =
- SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts))
+ APInt PoisonElts(DemandedElts.getBitWidth(), 0);
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts,
+ PoisonElts))
return replaceOperand(II, 0, V);
- if (Value *V =
- SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, UndefElts))
+ if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts,
+ PoisonElts))
return replaceOperand(II, 1, V);
return nullptr;
@@ -1526,9 +1526,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// support.
if (auto *IIFVTy = dyn_cast<FixedVectorType>(II->getType())) {
auto VWidth = IIFVTy->getNumElements();
- APInt UndefElts(VWidth, 0);
+ APInt PoisonElts(VWidth, 0);
APInt AllOnesEltMask(APInt::getAllOnes(VWidth));
- if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) {
+ if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, PoisonElts)) {
if (V != II)
return replaceInstUsesWith(*II, V);
return II;
@@ -1539,6 +1539,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
return I;
+ if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
+ return I;
+
if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
}
@@ -1793,6 +1796,23 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *NewMinMax = factorizeMinMaxTree(II))
return NewMinMax;
+ // Try to fold minmax with constant RHS based on range information
+ const APInt *RHSC;
+ if (match(I1, m_APIntAllowUndef(RHSC))) {
+ ICmpInst::Predicate Pred =
+ ICmpInst::getNonStrictPredicate(MinMaxIntrinsic::getPredicate(IID));
+ bool IsSigned = MinMaxIntrinsic::isSigned(IID);
+ ConstantRange LHS_CR = computeConstantRangeIncludingKnownBits(
+ I0, IsSigned, SQ.getWithInstruction(II));
+ if (!LHS_CR.isFullSet()) {
+ if (LHS_CR.icmp(Pred, *RHSC))
+ return replaceInstUsesWith(*II, I0);
+ if (LHS_CR.icmp(ICmpInst::getSwappedPredicate(Pred), *RHSC))
+ return replaceInstUsesWith(*II,
+ ConstantInt::get(II->getType(), *RHSC));
+ }
+ }
+
break;
}
case Intrinsic::bitreverse: {
@@ -4237,3 +4257,22 @@ InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
return nullptr;
}
+
+Instruction *
+InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
+ assert(II.isCommutative() && "Instruction should be commutative");
+
+ PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
+ PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
+
+ if (!LHS || !RHS)
+ return nullptr;
+
+ if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
+ replaceOperand(II, 0, P->first);
+ replaceOperand(II, 1, P->second);
+ return &II;
+ }
+
+ return nullptr;
+}