summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp121
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp17
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp175
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp249
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp70
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h9
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp50
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp14
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp12
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp181
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp104
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp99
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp95
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp201
14 files changed, 1088 insertions, 309 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 8bc34825f8a7..ec976a971e3c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -890,6 +890,10 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
if (match(Op0, m_ZExt(m_Value(X))) &&
X->getType()->getScalarSizeInBits() == 1)
return SelectInst::Create(X, AddOne(Op1C), Op1);
+ // sext(bool) + C -> bool ? C - 1 : C
+ if (match(Op0, m_SExt(m_Value(X))) &&
+ X->getType()->getScalarSizeInBits() == 1)
+ return SelectInst::Create(X, SubOne(Op1C), Op1);
// ~X + C --> (C-1) - X
if (match(Op0, m_Not(m_Value(X))))
@@ -1288,12 +1292,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateSub(RHS, A);
}
- // Canonicalize sext to zext for better value tracking potential.
- // add A, sext(B) --> sub A, zext(B)
- if (match(&I, m_c_Add(m_Value(A), m_OneUse(m_SExt(m_Value(B))))) &&
- B->getType()->isIntOrIntVectorTy(1))
- return BinaryOperator::CreateSub(A, Builder.CreateZExt(B, Ty));
-
// A + -B --> A - B
if (match(RHS, m_Neg(m_Value(B))))
return BinaryOperator::CreateSub(LHS, B);
@@ -1587,7 +1585,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
- Type *Ty) {
+ Type *Ty, bool IsNUW) {
// If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
// this.
bool Swapped = false;
@@ -1655,6 +1653,15 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
// Emit the offset of the GEP and an intptr_t.
Value *Result = EmitGEPOffset(GEP1);
+ // If this is a single inbounds GEP and the original sub was nuw,
+ // then the final multiplication is also nuw. We match an extra add zero
+ // here, because that's what EmitGEPOffset() generates.
+ Instruction *I;
+ if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() &&
+ match(Result, m_Add(m_Instruction(I), m_Zero())) &&
+ I->getOpcode() == Instruction::Mul)
+ I->setHasNoUnsignedWrap();
+
// If we had a constant expression GEP on the other side offsetting the
// pointer, subtract it from the offset we have.
if (GEP2) {
@@ -1881,6 +1888,74 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Y, Builder.CreateNot(Op1, Op1->getName() + ".not"));
}
+ {
+ // (sub (and Op1, (neg X)), Op1) --> neg (and Op1, (add X, -1))
+ Value *X;
+ if (match(Op0, m_OneUse(m_c_And(m_Specific(Op1),
+ m_OneUse(m_Neg(m_Value(X))))))) {
+ return BinaryOperator::CreateNeg(Builder.CreateAnd(
+ Op1, Builder.CreateAdd(X, Constant::getAllOnesValue(I.getType()))));
+ }
+ }
+
+ {
+ // (sub (and Op1, C), Op1) --> neg (and Op1, ~C)
+ Constant *C;
+ if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_Constant(C))))) {
+ return BinaryOperator::CreateNeg(
+ Builder.CreateAnd(Op1, Builder.CreateNot(C)));
+ }
+ }
+
+ {
+ // If we have a subtraction between some value and a select between
+ // said value and something else, sink subtraction into select hands, i.e.:
+ // sub (select %Cond, %TrueVal, %FalseVal), %Op1
+ // ->
+ // select %Cond, (sub %TrueVal, %Op1), (sub %FalseVal, %Op1)
+ // or
+ // sub %Op0, (select %Cond, %TrueVal, %FalseVal)
+ // ->
+ // select %Cond, (sub %Op0, %TrueVal), (sub %Op0, %FalseVal)
+ // This will result in select between new subtraction and 0.
+ auto SinkSubIntoSelect =
+ [Ty = I.getType()](Value *Select, Value *OtherHandOfSub,
+ auto SubBuilder) -> Instruction * {
+ Value *Cond, *TrueVal, *FalseVal;
+ if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal),
+ m_Value(FalseVal)))))
+ return nullptr;
+ if (OtherHandOfSub != TrueVal && OtherHandOfSub != FalseVal)
+ return nullptr;
+ // While it is really tempting to just create two subtractions and let
+ // InstCombine fold one of those to 0, it isn't possible to do so
+ // because of worklist visitation order. So ugly it is.
+ bool OtherHandOfSubIsTrueVal = OtherHandOfSub == TrueVal;
+ Value *NewSub = SubBuilder(OtherHandOfSubIsTrueVal ? FalseVal : TrueVal);
+ Constant *Zero = Constant::getNullValue(Ty);
+ SelectInst *NewSel =
+ SelectInst::Create(Cond, OtherHandOfSubIsTrueVal ? Zero : NewSub,
+ OtherHandOfSubIsTrueVal ? NewSub : Zero);
+ // Preserve prof metadata if any.
+ NewSel->copyMetadata(cast<Instruction>(*Select));
+ return NewSel;
+ };
+ if (Instruction *NewSel = SinkSubIntoSelect(
+ /*Select=*/Op0, /*OtherHandOfSub=*/Op1,
+ [Builder = &Builder, Op1](Value *OtherHandOfSelect) {
+ return Builder->CreateSub(OtherHandOfSelect,
+ /*OtherHandOfSub=*/Op1);
+ }))
+ return NewSel;
+ if (Instruction *NewSel = SinkSubIntoSelect(
+ /*Select=*/Op1, /*OtherHandOfSub=*/Op0,
+ [Builder = &Builder, Op0](Value *OtherHandOfSelect) {
+ return Builder->CreateSub(/*OtherHandOfSub=*/Op0,
+ OtherHandOfSelect);
+ }))
+ return NewSel;
+ }
+
if (Op1->hasOneUse()) {
Value *X = nullptr, *Y = nullptr, *Z = nullptr;
Constant *C = nullptr;
@@ -1896,14 +1971,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Builder.CreateNot(Y, Y->getName() + ".not"));
// 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow.
- // TODO: This could be extended to match arbitrary vector constants.
- const APInt *DivC;
- if (match(Op0, m_Zero()) && match(Op1, m_SDiv(m_Value(X), m_APInt(DivC))) &&
- !DivC->isMinSignedValue() && *DivC != 1) {
- Constant *NegDivC = ConstantInt::get(I.getType(), -(*DivC));
- Instruction *BO = BinaryOperator::CreateSDiv(X, NegDivC);
- BO->setIsExact(cast<BinaryOperator>(Op1)->isExact());
- return BO;
+ if (match(Op0, m_Zero())) {
+ Constant *Op11C;
+ if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) &&
+ !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() &&
+ Op11C->isNotOneValue()) {
+ Instruction *BO =
+ BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C));
+ BO->setIsExact(cast<BinaryOperator>(Op1)->isExact());
+ return BO;
+ }
}
// 0 - (X << Y) -> (-X << Y) when X is freely negatable.
@@ -1921,6 +1998,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Add->setHasNoSignedWrap(I.hasNoSignedWrap());
return Add;
}
+ // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y)
+ // 'nuw' is dropped in favor of the canonical form.
+ if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) {
+ Value *Sext = Builder.CreateSExt(Y, I.getType());
+ BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext);
+ Add->setHasNoSignedWrap(I.hasNoSignedWrap());
+ return Add;
+ }
// X - A*-B -> X + A*B
// X - -A*B -> X + A*B
@@ -1975,13 +2060,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Value *LHSOp, *RHSOp;
if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
match(Op1, m_PtrToInt(m_Value(RHSOp))))
- if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+ if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
+ I.hasNoUnsignedWrap()))
return replaceInstUsesWith(I, Res);
// trunc(p)-trunc(q) -> trunc(p-q)
if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
- if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+ if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
+ /* IsNUW */ false))
return replaceInstUsesWith(I, Res);
// Canonicalize a shifty way to code absolute value to the common pattern.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 4a30b60ca931..cc0a9127f8b1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3279,6 +3279,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
NotLHS, NotRHS);
}
}
+
+ // Pull 'not' into operands of select if both operands are one-use compares.
+ // Inverting the predicates eliminates the 'not' operation.
+ // Example:
+ // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) -->
+ // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?)
+ // TODO: Canonicalize by hoisting 'not' into an arm of the select if only
+ // 1 select operand is a cmp?
+ if (auto *Sel = dyn_cast<SelectInst>(Op0)) {
+ auto *CmpT = dyn_cast<CmpInst>(Sel->getTrueValue());
+ auto *CmpF = dyn_cast<CmpInst>(Sel->getFalseValue());
+ if (CmpT && CmpF && CmpT->hasOneUse() && CmpF->hasOneUse()) {
+ CmpT->setPredicate(CmpT->getInversePredicate());
+ CmpF->setPredicate(CmpF->getInversePredicate());
+ return replaceInstUsesWith(I, Sel);
+ }
+ }
}
if (Instruction *NewXor = sinkNotIntoXor(I, Builder))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c650d242cd50..f463c5fa1138 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -40,6 +40,12 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/IntrinsicsARM.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
+#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/IntrinsicsPowerPC.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PatternMatch.h"
@@ -2279,6 +2285,35 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::copysign: {
+ if (SignBitMustBeZero(II->getArgOperand(1), &TLI)) {
+ // If we know that the sign argument is positive, reduce to FABS:
+ // copysign X, Pos --> fabs X
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
+ II->getArgOperand(0), II);
+ return replaceInstUsesWith(*II, Fabs);
+ }
+ // TODO: There should be a ValueTracking sibling like SignBitMustBeOne.
+ const APFloat *C;
+ if (match(II->getArgOperand(1), m_APFloat(C)) && C->isNegative()) {
+ // If we know that the sign argument is negative, reduce to FNABS:
+ // copysign X, Neg --> fneg (fabs X)
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
+ II->getArgOperand(0), II);
+ return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II));
+ }
+
+ // Propagate sign argument through nested calls:
+ // copysign X, (copysign ?, SignArg) --> copysign X, SignArg
+ Value *SignArg;
+ if (match(II->getArgOperand(1),
+ m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg)))) {
+ II->setArgOperand(1, SignArg);
+ return II;
+ }
+
+ break;
+ }
case Intrinsic::fabs: {
Value *Cond;
Constant *LHS, *RHS;
@@ -2452,6 +2487,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// TODO should we convert this to an AND if the RHS is constant?
}
break;
+ case Intrinsic::x86_bmi_pext_32:
+ case Intrinsic::x86_bmi_pext_64:
+ if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
+ if (MaskC->isNullValue())
+ return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
+ if (MaskC->isAllOnesValue())
+ return replaceInstUsesWith(CI, II->getArgOperand(0));
+
+ if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
+ uint64_t Src = SrcC->getZExtValue();
+ uint64_t Mask = MaskC->getZExtValue();
+ uint64_t Result = 0;
+ uint64_t BitToSet = 1;
+
+ while (Mask) {
+ // Isolate lowest set bit.
+ uint64_t BitToTest = Mask & -Mask;
+ if (BitToTest & Src)
+ Result |= BitToSet;
+
+ BitToSet <<= 1;
+ // Clear lowest set bit.
+ Mask &= Mask - 1;
+ }
+
+ return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
+ }
+ }
+ break;
+ case Intrinsic::x86_bmi_pdep_32:
+ case Intrinsic::x86_bmi_pdep_64:
+ if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
+ if (MaskC->isNullValue())
+ return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
+ if (MaskC->isAllOnesValue())
+ return replaceInstUsesWith(CI, II->getArgOperand(0));
+
+ if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
+ uint64_t Src = SrcC->getZExtValue();
+ uint64_t Mask = MaskC->getZExtValue();
+ uint64_t Result = 0;
+ uint64_t BitToTest = 1;
+
+ while (Mask) {
+ // Isolate lowest set bit.
+ uint64_t BitToSet = Mask & -Mask;
+ if (BitToTest & Src)
+ Result |= BitToSet;
+
+ BitToTest <<= 1;
+ // Clear lowest set bit;
+ Mask &= Mask - 1;
+ }
+
+ return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
+ }
+ }
+ break;
case Intrinsic::x86_vcvtph2ps_128:
case Intrinsic::x86_vcvtph2ps_256: {
@@ -3308,6 +3401,60 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::arm_mve_pred_i2v: {
+ Value *Arg = II->getArgOperand(0);
+ Value *ArgArg;
+ if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg))) &&
+ II->getType() == ArgArg->getType())
+ return replaceInstUsesWith(*II, ArgArg);
+ Constant *XorMask;
+ if (match(Arg,
+ m_Xor(m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg)),
+ m_Constant(XorMask))) &&
+ II->getType() == ArgArg->getType()) {
+ if (auto *CI = dyn_cast<ConstantInt>(XorMask)) {
+ if (CI->getValue().trunc(16).isAllOnesValue()) {
+ auto TrueVector = Builder.CreateVectorSplat(
+ II->getType()->getVectorNumElements(), Builder.getTrue());
+ return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector);
+ }
+ }
+ }
+ KnownBits ScalarKnown(32);
+ if (SimplifyDemandedBits(II, 0, APInt::getLowBitsSet(32, 16),
+ ScalarKnown, 0))
+ return II;
+ break;
+ }
+ case Intrinsic::arm_mve_pred_v2i: {
+ Value *Arg = II->getArgOperand(0);
+ Value *ArgArg;
+ if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_i2v>(m_Value(ArgArg))))
+ return replaceInstUsesWith(*II, ArgArg);
+ if (!II->getMetadata(LLVMContext::MD_range)) {
+ Type *IntTy32 = Type::getInt32Ty(II->getContext());
+ Metadata *M[] = {
+ ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0)),
+ ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0xFFFF))
+ };
+ II->setMetadata(LLVMContext::MD_range, MDNode::get(II->getContext(), M));
+ return II;
+ }
+ break;
+ }
+ case Intrinsic::arm_mve_vadc:
+ case Intrinsic::arm_mve_vadc_predicated: {
+ unsigned CarryOp =
+ (II->getIntrinsicID() == Intrinsic::arm_mve_vadc_predicated) ? 3 : 2;
+ assert(II->getArgOperand(CarryOp)->getType()->getScalarSizeInBits() == 32 &&
+ "Bad type for intrinsic!");
+
+ KnownBits CarryKnown(32);
+ if (SimplifyDemandedBits(II, CarryOp, APInt::getOneBitSet(32, 29),
+ CarryKnown))
+ return II;
+ break;
+ }
case Intrinsic::amdgcn_rcp: {
Value *Src = II->getArgOperand(0);
@@ -3317,7 +3464,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
const APFloat &ArgVal = C->getValueAPF();
- APFloat Val(ArgVal.getSemantics(), 1.0);
+ APFloat Val(ArgVal.getSemantics(), 1);
APFloat::opStatus Status = Val.divide(ArgVal,
APFloat::rmNearestTiesToEven);
// Only do this if it was exact and therefore not dependent on the
@@ -3872,7 +4019,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return eraseInstFromFunction(CI);
// Bail if we cross over an intrinsic with side effects, such as
- // llvm.stacksave, llvm.read_register, or llvm.setjmp.
+ // llvm.stacksave, or llvm.read_register.
if (II2->mayHaveSideEffects()) {
CannotRemove = true;
break;
@@ -4019,12 +4166,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Is this guard followed by another guard? We scan forward over a small
// fixed window of instructions to handle common cases with conditions
// computed between guards.
- Instruction *NextInst = II->getNextNode();
+ Instruction *NextInst = II->getNextNonDebugInstruction();
for (unsigned i = 0; i < GuardWideningWindow; i++) {
// Note: Using context-free form to avoid compile time blow up
if (!isSafeToSpeculativelyExecute(NextInst))
break;
- NextInst = NextInst->getNextNode();
+ NextInst = NextInst->getNextNonDebugInstruction();
}
Value *NextCond = nullptr;
if (match(NextInst,
@@ -4032,18 +4179,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
Value *CurrCond = II->getArgOperand(0);
// Remove a guard that it is immediately preceded by an identical guard.
- if (CurrCond == NextCond)
- return eraseInstFromFunction(*NextInst);
-
// Otherwise canonicalize guard(a); guard(b) -> guard(a & b).
- Instruction* MoveI = II->getNextNode();
- while (MoveI != NextInst) {
- auto *Temp = MoveI;
- MoveI = MoveI->getNextNode();
- Temp->moveBefore(II);
+ if (CurrCond != NextCond) {
+ Instruction *MoveI = II->getNextNonDebugInstruction();
+ while (MoveI != NextInst) {
+ auto *Temp = MoveI;
+ MoveI = MoveI->getNextNonDebugInstruction();
+ Temp->moveBefore(II);
+ }
+ II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond));
}
- II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond));
- return eraseInstFromFunction(*NextInst);
+ eraseInstFromFunction(*NextInst);
+ return II;
}
break;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 65aaef28d87a..71b7f279e5fa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -18,6 +18,7 @@
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
+#include <numeric>
using namespace llvm;
using namespace PatternMatch;
@@ -843,33 +844,33 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
return nullptr;
}
-Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
+Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
bool DoTransform) {
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
const APInt *Op1CV;
- if (match(ICI->getOperand(1), m_APInt(Op1CV))) {
+ if (match(Cmp->getOperand(1), m_APInt(Op1CV))) {
// zext (x <s 0) to i32 --> x>>u31 true if signbit set.
// zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear.
- if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
- (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
- if (!DoTransform) return ICI;
+ if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
+ (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
+ if (!DoTransform) return Cmp;
- Value *In = ICI->getOperand(0);
+ Value *In = Cmp->getOperand(0);
Value *Sh = ConstantInt::get(In->getType(),
In->getType()->getScalarSizeInBits() - 1);
In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit");
- if (In->getType() != CI.getType())
- In = Builder.CreateIntCast(In, CI.getType(), false /*ZExt*/);
+ if (In->getType() != Zext.getType())
+ In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/);
- if (ICI->getPredicate() == ICmpInst::ICMP_SGT) {
+ if (Cmp->getPredicate() == ICmpInst::ICMP_SGT) {
Constant *One = ConstantInt::get(In->getType(), 1);
In = Builder.CreateXor(In, One, In->getName() + ".not");
}
- return replaceInstUsesWith(CI, In);
+ return replaceInstUsesWith(Zext, In);
}
// zext (X == 0) to i32 --> X^1 iff X has only the low bit set.
@@ -882,24 +883,24 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
// zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) &&
// This only works for EQ and NE
- ICI->isEquality()) {
+ Cmp->isEquality()) {
// If Op1C some other power of two, convert:
- KnownBits Known = computeKnownBits(ICI->getOperand(0), 0, &CI);
+ KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
APInt KnownZeroMask(~Known.Zero);
if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
- if (!DoTransform) return ICI;
+ if (!DoTransform) return Cmp;
- bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE;
+ bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE;
if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) {
// (X&4) == 2 --> false
// (X&4) != 2 --> true
- Constant *Res = ConstantInt::get(CI.getType(), isNE);
- return replaceInstUsesWith(CI, Res);
+ Constant *Res = ConstantInt::get(Zext.getType(), isNE);
+ return replaceInstUsesWith(Zext, Res);
}
uint32_t ShAmt = KnownZeroMask.logBase2();
- Value *In = ICI->getOperand(0);
+ Value *In = Cmp->getOperand(0);
if (ShAmt) {
// Perform a logical shr by shiftamt.
// Insert the shift to put the result in the low bit.
@@ -912,11 +913,11 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
In = Builder.CreateXor(In, One);
}
- if (CI.getType() == In->getType())
- return replaceInstUsesWith(CI, In);
+ if (Zext.getType() == In->getType())
+ return replaceInstUsesWith(Zext, In);
- Value *IntCast = Builder.CreateIntCast(In, CI.getType(), false);
- return replaceInstUsesWith(CI, IntCast);
+ Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false);
+ return replaceInstUsesWith(Zext, IntCast);
}
}
}
@@ -924,19 +925,19 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
// icmp ne A, B is equal to xor A, B when A and B only really have one bit.
// It is also profitable to transform icmp eq into not(xor(A, B)) because that
// may lead to additional simplifications.
- if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) {
- if (IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) {
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
+ if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(0)->getType()) {
+ if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) {
+ Value *LHS = Cmp->getOperand(0);
+ Value *RHS = Cmp->getOperand(1);
- KnownBits KnownLHS = computeKnownBits(LHS, 0, &CI);
- KnownBits KnownRHS = computeKnownBits(RHS, 0, &CI);
+ KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext);
+ KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext);
if (KnownLHS.Zero == KnownRHS.Zero && KnownLHS.One == KnownRHS.One) {
APInt KnownBits = KnownLHS.Zero | KnownLHS.One;
APInt UnknownBit = ~KnownBits;
if (UnknownBit.countPopulation() == 1) {
- if (!DoTransform) return ICI;
+ if (!DoTransform) return Cmp;
Value *Result = Builder.CreateXor(LHS, RHS);
@@ -949,10 +950,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
Result = Builder.CreateLShr(
Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros()));
- if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
+ if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1));
- Result->takeName(ICI);
- return replaceInstUsesWith(CI, Result);
+ Result->takeName(Cmp);
+ return replaceInstUsesWith(Zext, Result);
}
}
}
@@ -1172,8 +1173,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
}
}
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
- return transformZExtICmp(ICI, CI);
+ if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src))
+ return transformZExtICmp(Cmp, CI);
BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src);
if (SrcI && SrcI->getOpcode() == Instruction::Or) {
@@ -1188,7 +1189,9 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
// zext (or icmp, icmp) -> or (zext icmp), (zext icmp)
Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName());
Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName());
- BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast);
+ Value *Or = Builder.CreateOr(LCast, RCast, CI.getName());
+ if (auto *OrInst = dyn_cast<Instruction>(Or))
+ Builder.SetInsertPoint(OrInst);
// Perform the elimination.
if (auto *LZExt = dyn_cast<ZExtInst>(LCast))
@@ -1196,7 +1199,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
if (auto *RZExt = dyn_cast<ZExtInst>(RCast))
transformZExtICmp(RHS, *RZExt);
- return Or;
+ return replaceInstUsesWith(CI, Or);
}
}
@@ -1621,6 +1624,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
Value *X;
Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0));
if (Op && Op->hasOneUse()) {
+ // FIXME: The FMF should propagate from the fptrunc, not the source op.
+ IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ if (isa<FPMathOperator>(Op))
+ Builder.setFastMathFlags(Op->getFastMathFlags());
+
if (match(Op, m_FNeg(m_Value(X)))) {
Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty);
@@ -1630,6 +1638,24 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
return BinaryOperator::CreateFNegFMF(InnerTrunc, Op);
return UnaryOperator::CreateFNegFMF(InnerTrunc, Op);
}
+
+ // If we are truncating a select that has an extended operand, we can
+ // narrow the other operand and do the select as a narrow op.
+ Value *Cond, *X, *Y;
+ if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) &&
+ X->getType() == Ty) {
+ // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y)
+ Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
+ Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op);
+ return replaceInstUsesWith(FPT, Sel);
+ }
+ if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) &&
+ X->getType() == Ty) {
+ // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X
+ Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
+ Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op);
+ return replaceInstUsesWith(FPT, Sel);
+ }
}
if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) {
@@ -1808,7 +1834,7 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
Type *Ty = CI.getType();
unsigned AS = CI.getPointerAddressSpace();
- if (Ty->getScalarSizeInBits() == DL.getIndexSizeInBits(AS))
+ if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS))
return commonPointerCastTransforms(CI);
Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS);
@@ -1820,12 +1846,24 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
}
/// This input value (which is known to have vector type) is being zero extended
-/// or truncated to the specified vector type.
+/// or truncated to the specified vector type. Since the zext/trunc is done
+/// using an integer type, we have a (bitcast(cast(bitcast))) pattern,
+/// endianness will impact which end of the vector that is extended or
+/// truncated.
+///
+/// A vector is always stored with index 0 at the lowest address, which
+/// corresponds to the most significant bits for a big endian stored integer and
+/// the least significant bits for little endian. A trunc/zext of an integer
+/// impacts the big end of the integer. Thus, we need to add/remove elements at
+/// the front of the vector for big endian targets, and the back of the vector
+/// for little endian targets.
+///
/// Try to replace it with a shuffle (and vector/vector bitcast) if possible.
///
/// The source and destination vector types may have different element types.
-static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy,
- InstCombiner &IC) {
+static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal,
+ VectorType *DestTy,
+ InstCombiner &IC) {
// We can only do this optimization if the output is a multiple of the input
// element size, or the input is a multiple of the output element size.
// Convert the input type to have the same element type as the output.
@@ -1844,31 +1882,53 @@ static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy,
InVal = IC.Builder.CreateBitCast(InVal, SrcTy);
}
+ bool IsBigEndian = IC.getDataLayout().isBigEndian();
+ unsigned SrcElts = SrcTy->getNumElements();
+ unsigned DestElts = DestTy->getNumElements();
+
+ assert(SrcElts != DestElts && "Element counts should be different.");
+
// Now that the element types match, get the shuffle mask and RHS of the
// shuffle to use, which depends on whether we're increasing or decreasing the
// size of the input.
- SmallVector<uint32_t, 16> ShuffleMask;
+ SmallVector<uint32_t, 16> ShuffleMaskStorage;
+ ArrayRef<uint32_t> ShuffleMask;
Value *V2;
- if (SrcTy->getNumElements() > DestTy->getNumElements()) {
- // If we're shrinking the number of elements, just shuffle in the low
- // elements from the input and use undef as the second shuffle input.
- V2 = UndefValue::get(SrcTy);
- for (unsigned i = 0, e = DestTy->getNumElements(); i != e; ++i)
- ShuffleMask.push_back(i);
+ // Produce an identify shuffle mask for the src vector.
+ ShuffleMaskStorage.resize(SrcElts);
+ std::iota(ShuffleMaskStorage.begin(), ShuffleMaskStorage.end(), 0);
+ if (SrcElts > DestElts) {
+ // If we're shrinking the number of elements (rewriting an integer
+ // truncate), just shuffle in the elements corresponding to the least
+ // significant bits from the input and use undef as the second shuffle
+ // input.
+ V2 = UndefValue::get(SrcTy);
+ // Make sure the shuffle mask selects the "least significant bits" by
+ // keeping elements from back of the src vector for big endian, and from the
+ // front for little endian.
+ ShuffleMask = ShuffleMaskStorage;
+ if (IsBigEndian)
+ ShuffleMask = ShuffleMask.take_back(DestElts);
+ else
+ ShuffleMask = ShuffleMask.take_front(DestElts);
} else {
- // If we're increasing the number of elements, shuffle in all of the
- // elements from InVal and fill the rest of the result elements with zeros
- // from a constant zero.
+ // If we're increasing the number of elements (rewriting an integer zext),
+ // shuffle in all of the elements from InVal. Fill the rest of the result
+ // elements with zeros from a constant zero.
V2 = Constant::getNullValue(SrcTy);
- unsigned SrcElts = SrcTy->getNumElements();
- for (unsigned i = 0, e = SrcElts; i != e; ++i)
- ShuffleMask.push_back(i);
-
- // The excess elements reference the first element of the zero input.
- for (unsigned i = 0, e = DestTy->getNumElements()-SrcElts; i != e; ++i)
- ShuffleMask.push_back(SrcElts);
+ // Use first elt from V2 when indicating zero in the shuffle mask.
+ uint32_t NullElt = SrcElts;
+ // Extend with null values in the "most significant bits" by adding elements
+ // in front of the src vector for big endian, and at the back for little
+ // endian.
+ unsigned DeltaElts = DestElts - SrcElts;
+ if (IsBigEndian)
+ ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt);
+ else
+ ShuffleMaskStorage.append(DeltaElts, NullElt);
+ ShuffleMask = ShuffleMaskStorage;
}
return new ShuffleVectorInst(InVal, V2,
@@ -2217,6 +2277,31 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
}
}
+ // Check that each user of each old PHI node is something that we can
+ // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
+ for (auto *OldPN : OldPhiNodes) {
+ for (User *V : OldPN->users()) {
+ if (auto *SI = dyn_cast<StoreInst>(V)) {
+ if (!SI->isSimple() || SI->getOperand(0) != OldPN)
+ return nullptr;
+ } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
+ // Verify it's a B->A cast.
+ Type *TyB = BCI->getOperand(0)->getType();
+ Type *TyA = BCI->getType();
+ if (TyA != DestTy || TyB != SrcTy)
+ return nullptr;
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ // As long as the user is another old PHI node, then even if we don't
+ // rewrite it, the PHI web we're considering won't have any users
+ // outside itself, so it'll be dead.
+ if (OldPhiNodes.count(PHI) == 0)
+ return nullptr;
+ } else {
+ return nullptr;
+ }
+ }
+ }
+
// For each old PHI node, create a corresponding new PHI node with a type A.
SmallDenseMap<PHINode *, PHINode *> NewPNodes;
for (auto *OldPN : OldPhiNodes) {
@@ -2234,9 +2319,14 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
if (auto *C = dyn_cast<Constant>(V)) {
NewV = ConstantExpr::getBitCast(C, DestTy);
} else if (auto *LI = dyn_cast<LoadInst>(V)) {
- Builder.SetInsertPoint(LI->getNextNode());
- NewV = Builder.CreateBitCast(LI, DestTy);
- Worklist.Add(LI);
+ // Explicitly perform load combine to make sure no opposing transform
+ // can remove the bitcast in the meantime and trigger an infinite loop.
+ Builder.SetInsertPoint(LI);
+ NewV = combineLoadToNewType(*LI, DestTy);
+ // Remove the old load and its use in the old phi, which itself becomes
+ // dead once the whole transform finishes.
+ replaceInstUsesWith(*LI, UndefValue::get(LI->getType()));
+ eraseInstFromFunction(*LI);
} else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
NewV = BCI->getOperand(0);
} else if (auto *PrevPN = dyn_cast<PHINode>(V)) {
@@ -2259,26 +2349,33 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
Instruction *RetVal = nullptr;
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
- for (User *V : OldPN->users()) {
+ for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) {
+ User *V = *It;
+ // We may remove this user, advance to avoid iterator invalidation.
+ ++It;
if (auto *SI = dyn_cast<StoreInst>(V)) {
- if (SI->isSimple() && SI->getOperand(0) == OldPN) {
- Builder.SetInsertPoint(SI);
- auto *NewBC =
- cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
- SI->setOperand(0, NewBC);
- Worklist.Add(SI);
- assert(hasStoreUsersOnly(*NewBC));
- }
+ assert(SI->isSimple() && SI->getOperand(0) == OldPN);
+ Builder.SetInsertPoint(SI);
+ auto *NewBC =
+ cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
+ SI->setOperand(0, NewBC);
+ Worklist.Add(SI);
+ assert(hasStoreUsersOnly(*NewBC));
}
else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
- // Verify it's a B->A cast.
Type *TyB = BCI->getOperand(0)->getType();
Type *TyA = BCI->getType();
- if (TyA == DestTy && TyB == SrcTy) {
- Instruction *I = replaceInstUsesWith(*BCI, NewPN);
- if (BCI == &CI)
- RetVal = I;
- }
+ assert(TyA == DestTy && TyB == SrcTy);
+ (void) TyA;
+ (void) TyB;
+ Instruction *I = replaceInstUsesWith(*BCI, NewPN);
+ if (BCI == &CI)
+ RetVal = I;
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ assert(OldPhiNodes.count(PHI) > 0);
+ (void) PHI;
+ } else {
+ llvm_unreachable("all uses should be handled");
}
}
}
@@ -2374,8 +2471,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
CastInst *SrcCast = cast<CastInst>(Src);
if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0)))
if (isa<VectorType>(BCIn->getOperand(0)->getType()))
- if (Instruction *I = optimizeVectorResize(BCIn->getOperand(0),
- cast<VectorType>(DestTy), *this))
+ if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts(
+ BCIn->getOperand(0), cast<VectorType>(DestTy), *this))
return I;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a9f64feb600c..f38dc436722d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2566,9 +2566,6 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
Type *Ty = Add->getType();
CmpInst::Predicate Pred = Cmp.getPredicate();
- if (!Add->hasOneUse())
- return nullptr;
-
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
// are canonicalized to SGT/SLT/UGT/ULT.
@@ -2602,6 +2599,9 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower));
}
+ if (!Add->hasOneUse())
+ return nullptr;
+
// X+C <u C2 -> (X & -C2) == C
// iff C & (C2-1) == 0
// C2 is a power of 2
@@ -3364,6 +3364,23 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
llvm_unreachable("All possible folds are handled.");
}
+ // The mask value may be a vector constant that has undefined elements. But it
+ // may not be safe to propagate those undefs into the new compare, so replace
+ // those elements by copying an existing, defined, and safe scalar constant.
+ Type *OpTy = M->getType();
+ auto *VecC = dyn_cast<Constant>(M);
+ if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) {
+ Constant *SafeReplacementConstant = nullptr;
+ for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) {
+ if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
+ SafeReplacementConstant = VecC->getAggregateElement(i);
+ break;
+ }
+ }
+ assert(SafeReplacementConstant && "Failed to find undef replacement");
+ M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
+ }
+
return Builder.CreateICmp(DstPred, X, M);
}
@@ -4930,7 +4947,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
// Get scalar or pointer size.
unsigned BitWidth = Ty->isIntOrIntVectorTy()
? Ty->getScalarSizeInBits()
- : DL.getIndexTypeSizeInBits(Ty->getScalarType());
+ : DL.getPointerTypeSizeInBits(Ty->getScalarType());
if (!BitWidth)
return nullptr;
@@ -5167,6 +5184,7 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
};
+ Constant *SafeReplacementConstant = nullptr;
if (auto *CI = dyn_cast<ConstantInt>(C)) {
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
@@ -5186,12 +5204,23 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
auto *CI = dyn_cast<ConstantInt>(Elt);
if (!CI || !ConstantIsOk(CI))
return llvm::None;
+
+ if (!SafeReplacementConstant)
+ SafeReplacementConstant = CI;
}
} else {
// ConstantExpr?
return llvm::None;
}
+ // It may not be safe to change a compare predicate in the presence of
+ // undefined elements, so replace those elements with the first safe constant
+ // that we found.
+ if (C->containsUndefElement()) {
+ assert(SafeReplacementConstant && "Replacement constant not set");
+ C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
+ }
+
CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
// Increment or decrement the constant.
@@ -5374,6 +5403,36 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
return nullptr;
}
+// extract(uadd.with.overflow(A, B), 0) ult A
+// -> extract(uadd.with.overflow(A, B), 1)
+static Instruction *foldICmpOfUAddOv(ICmpInst &I) {
+ CmpInst::Predicate Pred = I.getPredicate();
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+
+ Value *UAddOv;
+ Value *A, *B;
+ auto UAddOvResultPat = m_ExtractValue<0>(
+ m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B)));
+ if (match(Op0, UAddOvResultPat) &&
+ ((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) ||
+ (Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) &&
+ (match(A, m_One()) || match(B, m_One()))) ||
+ (Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) &&
+ (match(A, m_AllOnes()) || match(B, m_AllOnes())))))
+ // extract(uadd.with.overflow(A, B), 0) < A
+ // extract(uadd.with.overflow(A, 1), 0) == 0
+ // extract(uadd.with.overflow(A, -1), 0) != -1
+ UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand();
+ else if (match(Op1, UAddOvResultPat) &&
+ Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B))
+ // A > extract(uadd.with.overflow(A, B), 0)
+ UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand();
+ else
+ return nullptr;
+
+ return ExtractValueInst::Create(UAddOv, 1);
+}
+
Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
bool Changed = false;
const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -5562,6 +5621,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpEquality(I))
return Res;
+ if (Instruction *Res = foldICmpOfUAddOv(I))
+ return Res;
+
// The 'cmpxchg' instruction returns an aggregate containing the old value and
// an i1 which indicates whether or not we successfully did the swap.
//
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 1dbc06d92e7a..1a746cb87abb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -369,7 +369,8 @@ public:
Instruction *visitFNeg(UnaryOperator &I);
Instruction *visitAdd(BinaryOperator &I);
Instruction *visitFAdd(BinaryOperator &I);
- Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty);
+ Value *OptimizePointerDifference(
+ Value *LHS, Value *RHS, Type *Ty, bool isNUW);
Instruction *visitSub(BinaryOperator &I);
Instruction *visitFSub(BinaryOperator &I);
Instruction *visitMul(BinaryOperator &I);
@@ -446,6 +447,7 @@ public:
Instruction *visitLandingPadInst(LandingPadInst &LI);
Instruction *visitVAStartInst(VAStartInst &I);
Instruction *visitVACopyInst(VACopyInst &I);
+ Instruction *visitFreeze(FreezeInst &I);
/// Specify what to return for unhandled instructions.
Instruction *visitInstruction(Instruction &I) { return nullptr; }
@@ -465,6 +467,9 @@ public:
/// \return true if successful.
bool replacePointer(Instruction &I, Value *V);
+ LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy,
+ const Twine &Suffix = "");
+
private:
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
bool shouldChangeType(Type *From, Type *To) const;
@@ -705,7 +710,7 @@ public:
Instruction *eraseInstFromFunction(Instruction &I) {
LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n');
assert(I.use_empty() && "Cannot erase instruction that is used!");
- salvageDebugInfo(I);
+ salvageDebugInfoOrMarkUndef(I);
// Make sure that we reprocess all operands now that we reduced their
// use counts.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 3a0e05832fcb..ebf9d24eecc4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -449,8 +449,8 @@ static bool isSupportedAtomicType(Type *Ty) {
///
/// Note that this will create all of the instructions with whatever insert
/// point the \c InstCombiner currently is using.
-static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy,
- const Twine &Suffix = "") {
+LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy,
+ const Twine &Suffix) {
assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) &&
"can't fold an atomic load to requested type");
@@ -460,10 +460,17 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT
if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) &&
NewPtr->getType()->getPointerElementType() == NewTy &&
NewPtr->getType()->getPointerAddressSpace() == AS))
- NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS));
+ NewPtr = Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS));
- LoadInst *NewLoad = IC.Builder.CreateAlignedLoad(
- NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix);
+ unsigned Align = LI.getAlignment();
+ if (!Align)
+ // If old load did not have an explicit alignment specified,
+ // manually preserve the implied (ABI) alignment of the load.
+ // Else we may inadvertently incorrectly over-promise alignment.
+ Align = getDataLayout().getABITypeAlignment(LI.getType());
+
+ LoadInst *NewLoad = Builder.CreateAlignedLoad(
+ NewTy, NewPtr, Align, LI.isVolatile(), LI.getName() + Suffix);
NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
copyMetadataForLoad(*NewLoad, LI);
return NewLoad;
@@ -526,7 +533,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value
/// Returns true if instruction represent minmax pattern like:
/// select ((cmp load V1, load V2), V1, V2).
-static bool isMinMaxWithLoads(Value *V) {
+static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) {
assert(V->getType()->isPointerTy() && "Expected pointer type.");
// Ignore possible ty* to ixx* bitcast.
V = peekThroughBitcast(V);
@@ -540,6 +547,7 @@ static bool isMinMaxWithLoads(Value *V) {
if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)),
m_Value(LHS), m_Value(RHS))))
return false;
+ LoadTy = L1->getType();
return (match(L1, m_Load(m_Specific(LHS))) &&
match(L2, m_Load(m_Specific(RHS)))) ||
(match(L1, m_Load(m_Specific(RHS))) &&
@@ -585,20 +593,22 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
// size is a legal integer type.
// Do not perform canonicalization if minmax pattern is found (to avoid
// infinite loop).
+ Type *Dummy;
if (!Ty->isIntegerTy() && Ty->isSized() &&
+ !(Ty->isVectorTy() && Ty->getVectorIsScalable()) &&
DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) &&
DL.typeSizeEqualsStoreSize(Ty) &&
!DL.isNonIntegralPointerType(Ty) &&
!isMinMaxWithLoads(
- peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) {
+ peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true),
+ Dummy)) {
if (all_of(LI.users(), [&LI](User *U) {
auto *SI = dyn_cast<StoreInst>(U);
return SI && SI->getPointerOperand() != &LI &&
!SI->getPointerOperand()->isSwiftError();
})) {
- LoadInst *NewLoad = combineLoadToNewType(
- IC, LI,
- Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty)));
+ LoadInst *NewLoad = IC.combineLoadToNewType(
+ LI, Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty)));
// Replace all the stores with stores of the newly loaded value.
for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) {
auto *SI = cast<StoreInst>(*UI++);
@@ -620,7 +630,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
if (CI->isNoopCast(DL))
if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) {
- LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy());
+ LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy());
CI->replaceAllUsesWith(NewLoad);
IC.eraseInstFromFunction(*CI);
return &LI;
@@ -648,8 +658,8 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
// If the struct only have one element, we unpack.
auto NumElements = ST->getNumElements();
if (NumElements == 1) {
- LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U),
- ".unpack");
+ LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U),
+ ".unpack");
AAMDNodes AAMD;
LI.getAAMetadata(AAMD);
NewLoad->setAAMetadata(AAMD);
@@ -698,7 +708,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
auto *ET = AT->getElementType();
auto NumElements = AT->getNumElements();
if (NumElements == 1) {
- LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack");
+ LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack");
AAMDNodes AAMD;
LI.getAAMetadata(AAMD);
NewLoad->setAAMetadata(AAMD);
@@ -1322,7 +1332,14 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
auto *LI = cast<LoadInst>(SI.getValueOperand());
if (!LI->getType()->isIntegerTy())
return false;
- if (!isMinMaxWithLoads(LoadAddr))
+ Type *CmpLoadTy;
+ if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy))
+ return false;
+
+ // Make sure we're not changing the size of the load/store.
+ const auto &DL = IC.getDataLayout();
+ if (DL.getTypeStoreSizeInBits(LI->getType()) !=
+ DL.getTypeStoreSizeInBits(CmpLoadTy))
return false;
if (!all_of(LI->users(), [LI, LoadAddr](User *U) {
@@ -1334,8 +1351,7 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
return false;
IC.Builder.SetInsertPoint(LI);
- LoadInst *NewLI = combineLoadToNewType(
- IC, *LI, LoadAddr->getType()->getPointerElementType());
+ LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy);
// Replace all the stores with stores of the newly loaded value.
for (auto *UI : LI->users()) {
auto *USI = cast<StoreInst>(UI);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 0b9128a9f5a1..2774e46151fa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1239,6 +1239,14 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
Value *YZ = Builder.CreateFMulFMF(Y, Op0, &I);
return BinaryOperator::CreateFDivFMF(YZ, X, &I);
}
+ // Z / (1.0 / Y) => (Y * Z)
+ //
+ // This is a special case of Z / (X / Y) => (Y * Z) / X, with X = 1.0. The
+ // m_OneUse check is avoided because even in the case of the multiple uses
+ // for 1.0/Y, the number of instructions remain the same and a division is
+ // replaced by a multiplication.
+ if (match(Op1, m_FDiv(m_SpecificFP(1.0), m_Value(Y))))
+ return BinaryOperator::CreateFMulFMF(Y, Op0, &I);
}
if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) {
@@ -1368,8 +1376,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
}
// 1 urem X -> zext(X != 1)
- if (match(Op0, m_One()))
- return CastInst::CreateZExtOrBitCast(Builder.CreateICmpNE(Op1, Op0), Ty);
+ if (match(Op0, m_One())) {
+ Value *Cmp = Builder.CreateICmpNE(Op1, ConstantInt::get(Ty, 1));
+ return CastInst::CreateZExtOrBitCast(Cmp, Ty);
+ }
// X urem C -> X < C ? X : X - C, where C >= signbit.
if (match(Op1, m_Negative())) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index e0376b7582f3..74e015a4f1d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -14,9 +14,10 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -180,13 +181,14 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) {
"Not enough available ptr typed incoming values");
PHINode *MatchingPtrPHI = nullptr;
unsigned NumPhis = 0;
- for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI());
- II != EI; II++, NumPhis++) {
+ for (auto II = BB->begin(); II != BB->end(); II++, NumPhis++) {
// FIXME: consider handling this in AggressiveInstCombine
+ PHINode *PtrPHI = dyn_cast<PHINode>(II);
+ if (!PtrPHI)
+ break;
if (NumPhis > MaxNumPhis)
return nullptr;
- PHINode *PtrPHI = dyn_cast<PHINode>(II);
- if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType())
+ if (PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType())
continue;
MatchingPtrPHI = PtrPHI;
for (unsigned i = 0; i != PtrPHI->getNumIncomingValues(); ++i) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 9fc871e49b30..05a624fde86b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -704,16 +704,24 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&
"Unexpected isUnsigned predicate!");
- // Account for swapped form of subtraction: ((a > b) ? b - a : 0).
+ // Ensure the sub is of the form:
+ // (a > b) ? a - b : 0 -> usub.sat(a, b)
+ // (a > b) ? b - a : 0 -> -usub.sat(a, b)
+ // Checking for both a-b and a+(-b) as a constant.
bool IsNegative = false;
- if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))))
+ const APInt *C;
+ if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) ||
+ (match(A, m_APInt(C)) &&
+ match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C)))))
IsNegative = true;
- else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))))
+ else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) &&
+ !(match(B, m_APInt(C)) &&
+ match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C)))))
return nullptr;
- // If sub is used anywhere else, we wouldn't be able to eliminate it
- // afterwards.
- if (!TrueVal->hasOneUse())
+ // If we are adding a negate and the sub and icmp are used anywhere else, we
+ // would end up with more instructions.
+ if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse())
return nullptr;
// (a > b) ? a - b : 0 -> usub.sat(a, b)
@@ -781,6 +789,13 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return Builder.CreateBinaryIntrinsic(
Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
}
+ // The overflow may be detected via the add wrapping round.
+ if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
+ match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) {
+ // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y)
+ // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
+ return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y);
+ }
return nullptr;
}
@@ -1725,6 +1740,128 @@ static Instruction *foldAddSubSelect(SelectInst &SI,
return nullptr;
}
+/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
+/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y
+/// Along with a number of patterns similar to:
+/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+static Instruction *
+foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
+ WithOverflowInst *II;
+ if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) ||
+ !match(FalseVal, m_ExtractValue<0>(m_Specific(II))))
+ return nullptr;
+
+ Value *X = II->getLHS();
+ Value *Y = II->getRHS();
+
+ auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) {
+ Type *Ty = Limit->getType();
+
+ ICmpInst::Predicate Pred;
+ Value *TrueVal, *FalseVal, *Op;
+ const APInt *C;
+ if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)),
+ m_Value(TrueVal), m_Value(FalseVal))))
+ return false;
+
+ auto IsZeroOrOne = [](const APInt &C) {
+ return C.isNullValue() || C.isOneValue();
+ };
+ auto IsMinMax = [&](Value *Min, Value *Max) {
+ APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
+ APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits());
+ return match(Min, m_SpecificInt(MinVal)) &&
+ match(Max, m_SpecificInt(MaxVal));
+ };
+
+ if (Op != X && Op != Y)
+ return false;
+
+ if (IsAdd) {
+ // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ } else {
+ // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
+ IsMinMax(FalseVal, TrueVal))
+ return true;
+ // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
+ IsMinMax(TrueVal, FalseVal))
+ return true;
+ }
+
+ return false;
+ };
+
+ Intrinsic::ID NewIntrinsicID;
+ if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow &&
+ match(TrueVal, m_AllOnes()))
+ // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
+ NewIntrinsicID = Intrinsic::uadd_sat;
+ else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow &&
+ match(TrueVal, m_Zero()))
+ // X - Y overflows ? 0 : X - Y -> usub_sat X, Y
+ NewIntrinsicID = Intrinsic::usub_sat;
+ else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow &&
+ IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true))
+ // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
+ NewIntrinsicID = Intrinsic::sadd_sat;
+ else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow &&
+ IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false))
+ // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
+ NewIntrinsicID = Intrinsic::ssub_sat;
+ else
+ return nullptr;
+
+ Function *F =
+ Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType());
+ return CallInst::Create(F, {X, Y});
+}
+
Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
Constant *C;
if (!match(Sel.getTrueValue(), m_Constant(C)) &&
@@ -2296,7 +2433,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
// See if we are selecting two values based on a comparison of the two values.
if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) {
- if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) {
+ Value *Cmp0 = FCI->getOperand(0), *Cmp1 = FCI->getOperand(1);
+ if ((Cmp0 == TrueVal && Cmp1 == FalseVal) ||
+ (Cmp0 == FalseVal && Cmp1 == TrueVal)) {
// Canonicalize to use ordered comparisons by swapping the select
// operands.
//
@@ -2305,30 +2444,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
FCmpInst::Predicate InvPred = FCI->getInversePredicate();
IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ // FIXME: The FMF should propagate from the select, not the fcmp.
Builder.setFastMathFlags(FCI->getFastMathFlags());
- Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal,
- FCI->getName() + ".inv");
-
- return SelectInst::Create(NewCond, FalseVal, TrueVal,
- SI.getName() + ".p");
- }
-
- // NOTE: if we wanted to, this is where to detect MIN/MAX
- } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){
- // Canonicalize to use ordered comparisons by swapping the select
- // operands.
- //
- // e.g.
- // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y
- if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
- FCmpInst::Predicate InvPred = FCI->getInversePredicate();
- IRBuilder<>::FastMathFlagGuard FMFG(Builder);
- Builder.setFastMathFlags(FCI->getFastMathFlags());
- Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal,
+ Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1,
FCI->getName() + ".inv");
-
- return SelectInst::Create(NewCond, FalseVal, TrueVal,
- SI.getName() + ".p");
+ Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal);
+ return replaceInstUsesWith(SI, NewSel);
}
// NOTE: if we wanted to, this is where to detect MIN/MAX
@@ -2391,6 +2512,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *Add = foldAddSubSelect(SI, Builder))
return Add;
+ if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
+ return Add;
// Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
auto *TI = dyn_cast<Instruction>(TrueVal);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 64294838644f..fbff5dd4a8cd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -138,24 +138,6 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
return Ret;
}
-// Try to replace `undef` constants in C with Replacement.
-static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) {
- if (C && match(C, m_Undef()))
- return Replacement;
-
- if (auto *CV = dyn_cast<ConstantVector>(C)) {
- llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands());
- for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) {
- Constant *EltC = CV->getOperand(i);
- NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC;
- }
- return ConstantVector::get(NewOps);
- }
-
- // Don't know how to deal with this constant.
- return C;
-}
-
// If we have some pattern that leaves only some low bits set, and then performs
// left-shift of those bits, if none of the bits that are left after the final
// shift are modified by the mask, we can omit the mask.
@@ -180,10 +162,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
"The input must be 'shl'!");
Value *Masked, *ShiftShAmt;
- match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt)));
+ match(OuterShift,
+ m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt))));
+
+ // *If* there is a truncation between an outer shift and a possibly-mask,
+ // then said truncation *must* be one-use, else we can't perform the fold.
+ Value *Trunc;
+ if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) &&
+ !Trunc->hasOneUse())
+ return nullptr;
Type *NarrowestTy = OuterShift->getType();
Type *WidestTy = Masked->getType();
+ bool HadTrunc = WidestTy != NarrowestTy;
+
// The mask must be computed in a type twice as wide to ensure
// that no bits are lost if the sum-of-shifts is wider than the base type.
Type *ExtendedTy = WidestTy->getExtendedType();
@@ -204,6 +196,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
Constant *NewMask;
if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
+ // Peek through an optional zext of the shift amount.
+ match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
+
+ // We have two shift amounts from two different shifts. The types of those
+ // shift amounts may not match. If that's the case let's bailout now.
+ if (MaskShAmt->getType() != ShiftShAmt->getType())
+ return nullptr;
+
// Can we simplify (MaskShAmt+ShiftShAmt) ?
auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst(
MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
@@ -216,7 +216,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
// completely unknown. Replace the the `undef` shift amounts with final
// shift bitwidth to ensure that the value remains undef when creating the
// subsequent shift op.
- SumOfShAmts = replaceUndefsWith(
+ SumOfShAmts = Constant::replaceUndefsWith(
SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
ExtendedTy->getScalarSizeInBits()));
auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
@@ -228,6 +228,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
} else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
m_Deferred(MaskShAmt)))) {
+ // Peek through an optional zext of the shift amount.
+ match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
+
+ // We have two shift amounts from two different shifts. The types of those
+ // shift amounts may not match. If that's the case let's bailout now.
+ if (MaskShAmt->getType() != ShiftShAmt->getType())
+ return nullptr;
+
// Can we simplify (ShiftShAmt-MaskShAmt) ?
auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst(
ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
@@ -241,7 +249,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
// bitwidth of innermost shift to ensure that the value remains undef when
// creating the subsequent shift op.
unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
- ShAmtsDiff = replaceUndefsWith(
+ ShAmtsDiff = Constant::replaceUndefsWith(
ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
-WidestTyBitWidth));
auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
@@ -272,10 +280,15 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
return nullptr;
}
+ // If we need to apply truncation, let's do it first, since we can.
+ // We have already ensured that the old truncation will go away.
+ if (HadTrunc)
+ X = Builder.CreateTrunc(X, NarrowestTy);
+
// No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
+ // We didn't change the Type of this outermost shift, so we can just do it.
auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
OuterShift->getOperand(1));
-
if (!NeedMask)
return NewShift;
@@ -283,6 +296,50 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
}
+/// If we have a shift-by-constant of a bitwise logic op that itself has a
+/// shift-by-constant operand with identical opcode, we may be able to convert
+/// that into 2 independent shifts followed by the logic op. This eliminates a
+/// a use of an intermediate value (reduces dependency chain).
+static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ assert(I.isShift() && "Expected a shift as input");
+ auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0));
+ if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse())
+ return nullptr;
+
+ const APInt *C0, *C1;
+ if (!match(I.getOperand(1), m_APInt(C1)))
+ return nullptr;
+
+ Instruction::BinaryOps ShiftOpcode = I.getOpcode();
+ Type *Ty = I.getType();
+
+ // Find a matching one-use shift by constant. The fold is not valid if the sum
+ // of the shift values equals or exceeds bitwidth.
+ // TODO: Remove the one-use check if the other logic operand (Y) is constant.
+ Value *X, *Y;
+ auto matchFirstShift = [&](Value *V) {
+ return !isa<ConstantExpr>(V) &&
+ match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) &&
+ cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode &&
+ (*C0 + *C1).ult(Ty->getScalarSizeInBits());
+ };
+
+ // Logic ops are commutative, so check each operand for a match.
+ if (matchFirstShift(LogicInst->getOperand(0)))
+ Y = LogicInst->getOperand(1);
+ else if (matchFirstShift(LogicInst->getOperand(1)))
+ Y = LogicInst->getOperand(0);
+ else
+ return nullptr;
+
+ // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
+ Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1);
+ Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
+ Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1));
+ return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
+}
+
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
@@ -335,6 +392,9 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
return &I;
}
+ if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder))
+ return Logic;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index d30ab8001897..47ce83974c8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -14,6 +14,8 @@
#include "InstCombineInternal.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
@@ -348,8 +350,36 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
// If the operands are constants, see if we can simplify them.
- if (ShrinkDemandedConstant(I, 1, DemandedMask) ||
- ShrinkDemandedConstant(I, 2, DemandedMask))
+ // This is similar to ShrinkDemandedConstant, but for a select we want to
+ // try to keep the selected constants the same as icmp value constants, if
+ // we can. This helps not break apart (or helps put back together)
+ // canonical patterns like min and max.
+ auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
+ APInt DemandedMask) {
+ const APInt *SelC;
+ if (!match(I->getOperand(OpNo), m_APInt(SelC)))
+ return false;
+
+ // Get the constant out of the ICmp, if there is one.
+ const APInt *CmpC;
+ ICmpInst::Predicate Pred;
+ if (!match(I->getOperand(0), m_c_ICmp(Pred, m_APInt(CmpC), m_Value())) ||
+ CmpC->getBitWidth() != SelC->getBitWidth())
+ return ShrinkDemandedConstant(I, OpNo, DemandedMask);
+
+ // If the constant is already the same as the ICmp, leave it as-is.
+ if (*CmpC == *SelC)
+ return false;
+ // If the constants are not already the same, but can be with the demand
+ // mask, use the constant value from the ICmp.
+ if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
+ I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC));
+ return true;
+ }
+ return ShrinkDemandedConstant(I, OpNo, DemandedMask);
+ };
+ if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
+ CanonicalizeSelectConstant(I, 2, DemandedMask))
return I;
// Only known if known in both the LHS and RHS.
@@ -1247,30 +1277,57 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
break;
}
case Instruction::ShuffleVector: {
- ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
- unsigned LHSVWidth =
- Shuffle->getOperand(0)->getType()->getVectorNumElements();
- APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0);
+ auto *Shuffle = cast<ShuffleVectorInst>(I);
+ assert(Shuffle->getOperand(0)->getType() ==
+ Shuffle->getOperand(1)->getType() &&
+ "Expected shuffle operands to have same type");
+ unsigned OpWidth =
+ Shuffle->getOperand(0)->getType()->getVectorNumElements();
+ APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
for (unsigned i = 0; i < VWidth; i++) {
if (DemandedElts[i]) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (MaskVal != -1u) {
- assert(MaskVal < LHSVWidth * 2 &&
+ assert(MaskVal < OpWidth * 2 &&
"shufflevector mask index out of range!");
- if (MaskVal < LHSVWidth)
+ if (MaskVal < OpWidth)
LeftDemanded.setBit(MaskVal);
else
- RightDemanded.setBit(MaskVal - LHSVWidth);
+ RightDemanded.setBit(MaskVal - OpWidth);
}
}
}
- APInt LHSUndefElts(LHSVWidth, 0);
+ APInt LHSUndefElts(OpWidth, 0);
simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
- APInt RHSUndefElts(LHSVWidth, 0);
+ APInt RHSUndefElts(OpWidth, 0);
simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts);
+ // If this shuffle does not change the vector length and the elements
+ // demanded by this shuffle are an identity mask, then this shuffle is
+ // unnecessary.
+ //
+ // We are assuming canonical form for the mask, so the source vector is
+ // operand 0 and operand 1 is not used.
+ //
+ // Note that if an element is demanded and this shuffle mask is undefined
+ // for that element, then the shuffle is not considered an identity
+ // operation. The shuffle prevents poison from the operand vector from
+ // leaking to the result by replacing poison with an undefined value.
+ if (VWidth == OpWidth) {
+ bool IsIdentityShuffle = true;
+ for (unsigned i = 0; i < VWidth; i++) {
+ unsigned MaskVal = Shuffle->getMaskValue(i);
+ if (DemandedElts[i] && i != MaskVal) {
+ IsIdentityShuffle = false;
+ break;
+ }
+ }
+ if (IsIdentityShuffle)
+ return Shuffle->getOperand(0);
+ }
+
bool NewUndefElts = false;
unsigned LHSIdx = -1u, LHSValIdx = -1u;
unsigned RHSIdx = -1u, RHSValIdx = -1u;
@@ -1283,23 +1340,23 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
} else if (!DemandedElts[i]) {
NewUndefElts = true;
UndefElts.setBit(i);
- } else if (MaskVal < LHSVWidth) {
+ } else if (MaskVal < OpWidth) {
if (LHSUndefElts[MaskVal]) {
NewUndefElts = true;
UndefElts.setBit(i);
} else {
- LHSIdx = LHSIdx == -1u ? i : LHSVWidth;
- LHSValIdx = LHSValIdx == -1u ? MaskVal : LHSVWidth;
+ LHSIdx = LHSIdx == -1u ? i : OpWidth;
+ LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
LHSUniform = LHSUniform && (MaskVal == i);
}
} else {
- if (RHSUndefElts[MaskVal - LHSVWidth]) {
+ if (RHSUndefElts[MaskVal - OpWidth]) {
NewUndefElts = true;
UndefElts.setBit(i);
} else {
- RHSIdx = RHSIdx == -1u ? i : LHSVWidth;
- RHSValIdx = RHSValIdx == -1u ? MaskVal - LHSVWidth : LHSVWidth;
- RHSUniform = RHSUniform && (MaskVal - LHSVWidth == i);
+ RHSIdx = RHSIdx == -1u ? i : OpWidth;
+ RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
+ RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
}
}
}
@@ -1308,20 +1365,20 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
// this constant vector to single insertelement instruction.
// shufflevector V, C, <v1, v2, .., ci, .., vm> ->
// insertelement V, C[ci], ci-n
- if (LHSVWidth == Shuffle->getType()->getNumElements()) {
+ if (OpWidth == Shuffle->getType()->getNumElements()) {
Value *Op = nullptr;
Constant *Value = nullptr;
unsigned Idx = -1u;
// Find constant vector with the single element in shuffle (LHS or RHS).
- if (LHSIdx < LHSVWidth && RHSUniform) {
+ if (LHSIdx < OpWidth && RHSUniform) {
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
Op = Shuffle->getOperand(1);
Value = CV->getOperand(LHSValIdx);
Idx = LHSIdx;
}
}
- if (RHSIdx < LHSVWidth && LHSUniform) {
+ if (RHSIdx < OpWidth && LHSUniform) {
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
Op = Shuffle->getOperand(0);
Value = CV->getOperand(RHSValIdx);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 9c890748e5ab..f604c9dc32ca 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1390,20 +1390,6 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
llvm_unreachable("failed to reorder elements of vector instruction!");
}
-static void recognizeIdentityMask(const SmallVectorImpl<int> &Mask,
- bool &isLHSID, bool &isRHSID) {
- isLHSID = isRHSID = true;
-
- for (unsigned i = 0, e = Mask.size(); i != e; ++i) {
- if (Mask[i] < 0) continue; // Ignore undef values.
- // Is this an identity shuffle of the LHS value?
- isLHSID &= (Mask[i] == (int)i);
-
- // Is this an identity shuffle of the RHS value?
- isRHSID &= (Mask[i]-e == i);
- }
-}
-
// Returns true if the shuffle is extracting a contiguous range of values from
// LHS, for example:
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
@@ -1560,9 +1546,11 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
if (!Shuf.isSelect())
return nullptr;
- // Canonicalize to choose from operand 0 first.
+ // Canonicalize to choose from operand 0 first unless operand 1 is undefined.
+ // Commuting undef to operand 0 conflicts with another canonicalization.
unsigned NumElts = Shuf.getType()->getVectorNumElements();
- if (Shuf.getMaskValue(0) >= (int)NumElts) {
+ if (!isa<UndefValue>(Shuf.getOperand(1)) &&
+ Shuf.getMaskValue(0) >= (int)NumElts) {
// TODO: Can we assert that both operands of a shuffle-select are not undef
// (otherwise, it would have been folded by instsimplify?
Shuf.commute();
@@ -1753,7 +1741,8 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask));
}
-/// Try to replace a shuffle with an insertelement.
+/// Try to replace a shuffle with an insertelement or try to replace a shuffle
+/// operand with the operand of an insertelement.
static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) {
Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1);
SmallVector<int, 16> Mask = Shuf.getShuffleMask();
@@ -1765,6 +1754,31 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) {
if (NumElts != (int)(V0->getType()->getVectorNumElements()))
return nullptr;
+ // This is a specialization of a fold in SimplifyDemandedVectorElts. We may
+ // not be able to handle it there if the insertelement has >1 use.
+ // If the shuffle has an insertelement operand but does not choose the
+ // inserted scalar element from that value, then we can replace that shuffle
+ // operand with the source vector of the insertelement.
+ Value *X;
+ uint64_t IdxC;
+ if (match(V0, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) {
+ // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask
+ if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) {
+ Shuf.setOperand(0, X);
+ return &Shuf;
+ }
+ }
+ if (match(V1, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) {
+ // Offset the index constant by the vector width because we are checking for
+ // accesses to the 2nd vector input of the shuffle.
+ IdxC += NumElts;
+ // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask
+ if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) {
+ Shuf.setOperand(1, X);
+ return &Shuf;
+ }
+ }
+
// shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC'
auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) {
// We need an insertelement with a constant index.
@@ -1891,29 +1905,21 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI)))
return replaceInstUsesWith(SVI, V);
- // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask')
- // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask').
+ // shuffle x, x, mask --> shuffle x, undef, mask'
unsigned VWidth = SVI.getType()->getVectorNumElements();
unsigned LHSWidth = LHS->getType()->getVectorNumElements();
SmallVector<int, 16> Mask = SVI.getShuffleMask();
Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
- if (LHS == RHS || isa<UndefValue>(LHS)) {
+ if (LHS == RHS) {
+ assert(!isa<UndefValue>(RHS) && "Shuffle with 2 undef ops not simplified?");
// Remap any references to RHS to use LHS.
SmallVector<Constant*, 16> Elts;
- for (unsigned i = 0, e = LHSWidth; i != VWidth; ++i) {
- if (Mask[i] < 0) {
- Elts.push_back(UndefValue::get(Int32Ty));
- continue;
- }
-
- if ((Mask[i] >= (int)e && isa<UndefValue>(RHS)) ||
- (Mask[i] < (int)e && isa<UndefValue>(LHS))) {
- Mask[i] = -1; // Turn into undef.
+ for (unsigned i = 0; i != VWidth; ++i) {
+ // Propagate undef elements or force mask to LHS.
+ if (Mask[i] < 0)
Elts.push_back(UndefValue::get(Int32Ty));
- } else {
- Mask[i] = Mask[i] % e; // Force to LHS.
- Elts.push_back(ConstantInt::get(Int32Ty, Mask[i]));
- }
+ else
+ Elts.push_back(ConstantInt::get(Int32Ty, Mask[i] % LHSWidth));
}
SVI.setOperand(0, SVI.getOperand(1));
SVI.setOperand(1, UndefValue::get(RHS->getType()));
@@ -1921,6 +1927,12 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
return &SVI;
}
+ // shuffle undef, x, mask --> shuffle x, undef, mask'
+ if (isa<UndefValue>(LHS)) {
+ SVI.commute();
+ return &SVI;
+ }
+
if (Instruction *I = canonicalizeInsertSplat(SVI, Builder))
return I;
@@ -1948,16 +1960,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (Instruction *I = foldIdentityPaddedShuffles(SVI))
return I;
- if (VWidth == LHSWidth) {
- // Analyze the shuffle, are the LHS or RHS and identity shuffles?
- bool isLHSID, isRHSID;
- recognizeIdentityMask(Mask, isLHSID, isRHSID);
-
- // Eliminate identity shuffles.
- if (isLHSID) return replaceInstUsesWith(SVI, LHS);
- if (isRHSID) return replaceInstUsesWith(SVI, RHS);
- }
-
if (isa<UndefValue>(RHS) && canEvaluateShuffled(LHS, Mask)) {
Value *V = evaluateInDifferentElementOrder(LHS, Mask);
return replaceInstUsesWith(SVI, V);
@@ -2235,12 +2237,5 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
return new ShuffleVectorInst(newLHS, newRHS, ConstantVector::get(Elts));
}
- // If the result mask is an identity, replace uses of this instruction with
- // corresponding argument.
- bool isLHSID, isRHSID;
- recognizeIdentityMask(newMask, isLHSID, isRHSID);
- if (isLHSID && VWidth == LHSOp0Width) return replaceInstUsesWith(SVI, newLHS);
- if (isRHSID && VWidth == RHSOp0Width) return replaceInstUsesWith(SVI, newRHS);
-
return MadeChange ? &SVI : nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index ecb486c544e0..801c09a317a7 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -86,6 +86,7 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/ValueHandle.h"
+#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CBindingWrapping.h"
#include "llvm/Support/Casting.h"
@@ -121,6 +122,9 @@ STATISTIC(NumReassoc , "Number of reassociations");
DEBUG_COUNTER(VisitCounter, "instcombine-visit",
"Controls which instructions are visited");
+static constexpr unsigned InstCombineDefaultMaxIterations = 1000;
+static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000;
+
static cl::opt<bool>
EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"),
cl::init(true));
@@ -129,6 +133,17 @@ static cl::opt<bool>
EnableExpensiveCombines("expensive-combines",
cl::desc("Enable expensive instruction combines"));
+static cl::opt<unsigned> LimitMaxIterations(
+ "instcombine-max-iterations",
+ cl::desc("Limit the maximum number of instruction combining iterations"),
+ cl::init(InstCombineDefaultMaxIterations));
+
+static cl::opt<unsigned> InfiniteLoopDetectionThreshold(
+ "instcombine-infinite-loop-threshold",
+ cl::desc("Number of instruction combining iterations considered an "
+ "infinite loop"),
+ cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden);
+
static cl::opt<unsigned>
MaxArraySize("instcombine-maxarray-size", cl::init(1024),
cl::desc("Maximum array size considered when doing a combine"));
@@ -759,35 +774,52 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Value *LHS, Value *RHS) {
- Instruction::BinaryOps Opcode = I.getOpcode();
- // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op
- // c, e)))
- Value *A, *B, *C, *D, *E;
- Value *SI = nullptr;
- if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
- match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) {
- bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse();
-
- FastMathFlags FMF;
- BuilderTy::FastMathFlagGuard Guard(Builder);
- if (isa<FPMathOperator>(&I)) {
- FMF = I.getFastMathFlags();
- Builder.setFastMathFlags(FMF);
- }
+ Value *A, *B, *C, *D, *E, *F;
+ bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C)));
+ bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F)));
+ if (!LHSIsSelect && !RHSIsSelect)
+ return nullptr;
- Value *V1 = SimplifyBinOp(Opcode, C, E, FMF, SQ.getWithInstruction(&I));
- Value *V2 = SimplifyBinOp(Opcode, B, D, FMF, SQ.getWithInstruction(&I));
- if (V1 && V2)
- SI = Builder.CreateSelect(A, V2, V1);
- else if (V2 && SelectsHaveOneUse)
- SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E));
- else if (V1 && SelectsHaveOneUse)
- SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1);
+ FastMathFlags FMF;
+ BuilderTy::FastMathFlagGuard Guard(Builder);
+ if (isa<FPMathOperator>(&I)) {
+ FMF = I.getFastMathFlags();
+ Builder.setFastMathFlags(FMF);
+ }
- if (SI)
- SI->takeName(&I);
+ Instruction::BinaryOps Opcode = I.getOpcode();
+ SimplifyQuery Q = SQ.getWithInstruction(&I);
+
+ Value *Cond, *True = nullptr, *False = nullptr;
+ if (LHSIsSelect && RHSIsSelect && A == D) {
+ // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
+ Cond = A;
+ True = SimplifyBinOp(Opcode, B, E, FMF, Q);
+ False = SimplifyBinOp(Opcode, C, F, FMF, Q);
+
+ if (LHS->hasOneUse() && RHS->hasOneUse()) {
+ if (False && !True)
+ True = Builder.CreateBinOp(Opcode, B, E);
+ else if (True && !False)
+ False = Builder.CreateBinOp(Opcode, C, F);
+ }
+ } else if (LHSIsSelect && LHS->hasOneUse()) {
+ // (A ? B : C) op Y -> A ? (B op Y) : (C op Y)
+ Cond = A;
+ True = SimplifyBinOp(Opcode, B, RHS, FMF, Q);
+ False = SimplifyBinOp(Opcode, C, RHS, FMF, Q);
+ } else if (RHSIsSelect && RHS->hasOneUse()) {
+ // X op (D ? E : F) -> D ? (X op E) : (X op F)
+ Cond = D;
+ True = SimplifyBinOp(Opcode, LHS, E, FMF, Q);
+ False = SimplifyBinOp(Opcode, LHS, F, FMF, Q);
}
+ if (!True || !False)
+ return nullptr;
+
+ Value *SI = Builder.CreateSelect(Cond, True, False);
+ SI->takeName(&I);
return SI;
}
@@ -1526,11 +1558,13 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
// If this is a widening shuffle, we must be able to extend with undef
// elements. If the original binop does not produce an undef in the high
// lanes, then this transform is not safe.
+ // Similarly for undef lanes due to the shuffle mask, we can only
+ // transform binops that preserve undef.
// TODO: We could shuffle those non-undef constant values into the
// result by using a constant vector (rather than an undef vector)
// as operand 1 of the new binop, but that might be too aggressive
// for target-independent shuffle creation.
- if (I >= SrcVecNumElts) {
+ if (I >= SrcVecNumElts || ShMask[I] < 0) {
Constant *MaybeUndef =
ConstOp1 ? ConstantExpr::get(Opcode, UndefScalar, CElt)
: ConstantExpr::get(Opcode, CElt, UndefScalar);
@@ -1615,6 +1649,15 @@ Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) {
return CastInst::Create(CastOpc, NarrowBO, BO.getType());
}
+static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) {
+ // At least one GEP must be inbounds.
+ if (!GEP1.isInBounds() && !GEP2.isInBounds())
+ return false;
+
+ return (GEP1.isInBounds() || GEP1.hasAllZeroIndices()) &&
+ (GEP2.isInBounds() || GEP2.hasAllZeroIndices());
+}
+
Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
Type *GEPType = GEP.getType();
@@ -1724,8 +1767,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// The first two arguments can vary for any GEP, the rest have to be
// static for struct slots
- if (J > 1 && CurTy->isStructTy())
- return nullptr;
+ if (J > 1) {
+ assert(CurTy && "No current type?");
+ if (CurTy->isStructTy())
+ return nullptr;
+ }
DI = J;
} else {
@@ -1885,6 +1931,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// Update the GEP in place if possible.
if (Src->getNumOperands() == 2) {
+ GEP.setIsInBounds(isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)));
GEP.setOperand(0, Src->getOperand(0));
GEP.setOperand(1, Sum);
return &GEP;
@@ -1901,7 +1948,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
}
if (!Indices.empty())
- return GEP.isInBounds() && Src->isInBounds()
+ return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))
? GetElementPtrInst::CreateInBounds(
Src->getSourceElementType(), Src->getOperand(0), Indices,
GEP.getName())
@@ -2154,15 +2201,17 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// of a bitcasted pointer to vector or array of the same dimensions:
// gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z
// gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z
- auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy) {
+ auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy,
+ const DataLayout &DL) {
return ArrTy->getArrayElementType() == VecTy->getVectorElementType() &&
- ArrTy->getArrayNumElements() == VecTy->getVectorNumElements();
+ ArrTy->getArrayNumElements() == VecTy->getVectorNumElements() &&
+ DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy);
};
if (GEP.getNumOperands() == 3 &&
((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() &&
- areMatchingArrayAndVecTypes(GEPEltType, SrcEltType)) ||
+ areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) ||
(GEPEltType->isVectorTy() && SrcEltType->isArrayTy() &&
- areMatchingArrayAndVecTypes(SrcEltType, GEPEltType)))) {
+ areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) {
// Create a new GEP here, as using `setOperand()` followed by
// `setSourceElementType()` won't actually update the type of the
@@ -2401,12 +2450,13 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) {
replaceInstUsesWith(*C,
ConstantInt::get(Type::getInt1Ty(C->getContext()),
C->isFalseWhenEqual()));
- } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I) ||
- isa<AddrSpaceCastInst>(I)) {
- replaceInstUsesWith(*I, UndefValue::get(I->getType()));
} else if (auto *SI = dyn_cast<StoreInst>(I)) {
for (auto *DII : DIIs)
ConvertDebugDeclareToDebugValue(DII, SI, *DIB);
+ } else {
+ // Casts, GEP, or anything else: we're about to delete this instruction,
+ // so it can not have any valid uses.
+ replaceInstUsesWith(*I, UndefValue::get(I->getType()));
}
eraseInstFromFunction(*I);
}
@@ -3111,6 +3161,15 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) {
return nullptr;
}
+Instruction *InstCombiner::visitFreeze(FreezeInst &I) {
+ Value *Op0 = I.getOperand(0);
+
+ if (Value *V = SimplifyFreezeInst(Op0, SQ.getWithInstruction(&I)))
+ return replaceInstUsesWith(I, V);
+
+ return nullptr;
+}
+
/// Try to move the specified instruction from its current block into the
/// beginning of DestBlock, which can only happen if it's safe to move the
/// instruction past all of the instructions between it and the end of its
@@ -3322,10 +3381,6 @@ bool InstCombiner::run() {
// Move the name to the new instruction first.
Result->takeName(I);
- // Push the new instruction and any users onto the worklist.
- Worklist.AddUsersToWorkList(*Result);
- Worklist.Add(Result);
-
// Insert the new instruction into the basic block...
BasicBlock *InstParent = I->getParent();
BasicBlock::iterator InsertPos = I->getIterator();
@@ -3337,6 +3392,10 @@ bool InstCombiner::run() {
InstParent->getInstList().insert(InsertPos, Result);
+ // Push the new instruction and any users onto the worklist.
+ Worklist.AddUsersToWorkList(*Result);
+ Worklist.Add(Result);
+
eraseInstFromFunction(*I);
} else {
LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n'
@@ -3392,8 +3451,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
if (isInstructionTriviallyDead(Inst, TLI)) {
++NumDeadInst;
LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n');
- if (!salvageDebugInfo(*Inst))
- replaceDbgUsesWithUndef(Inst);
+ salvageDebugInfoOrMarkUndef(*Inst);
Inst->eraseFromParent();
MadeIRChange = true;
continue;
@@ -3507,10 +3565,11 @@ static bool combineInstructionsOverFunction(
Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA,
AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT,
OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
- ProfileSummaryInfo *PSI, bool ExpensiveCombines = true,
- LoopInfo *LI = nullptr) {
+ ProfileSummaryInfo *PSI, bool ExpensiveCombines, unsigned MaxIterations,
+ LoopInfo *LI) {
auto &DL = F.getParent()->getDataLayout();
ExpensiveCombines |= EnableExpensiveCombines;
+ MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue());
/// Builder - This is an IRBuilder that automatically inserts new
/// instructions into the worklist when they are created.
@@ -3529,9 +3588,23 @@ static bool combineInstructionsOverFunction(
MadeIRChange = LowerDbgDeclare(F);
// Iterate while there is work to do.
- int Iteration = 0;
+ unsigned Iteration = 0;
while (true) {
++Iteration;
+
+ if (Iteration > InfiniteLoopDetectionThreshold) {
+ report_fatal_error(
+ "Instruction Combining seems stuck in an infinite loop after " +
+ Twine(InfiniteLoopDetectionThreshold) + " iterations.");
+ }
+
+ if (Iteration > MaxIterations) {
+ LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations
+ << " on " << F.getName()
+ << " reached; stopping before reaching a fixpoint\n");
+ break;
+ }
+
LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on "
<< F.getName() << "\n");
@@ -3543,11 +3616,19 @@ static bool combineInstructionsOverFunction(
if (!IC.run())
break;
+
+ MadeIRChange = true;
}
- return MadeIRChange || Iteration > 1;
+ return MadeIRChange;
}
+InstCombinePass::InstCombinePass(bool ExpensiveCombines)
+ : ExpensiveCombines(ExpensiveCombines), MaxIterations(LimitMaxIterations) {}
+
+InstCombinePass::InstCombinePass(bool ExpensiveCombines, unsigned MaxIterations)
+ : ExpensiveCombines(ExpensiveCombines), MaxIterations(MaxIterations) {}
+
PreservedAnalyses InstCombinePass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &AC = AM.getResult<AssumptionAnalysis>(F);
@@ -3565,8 +3646,9 @@ PreservedAnalyses InstCombinePass::run(Function &F,
auto *BFI = (PSI && PSI->hasProfileSummary()) ?
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
- if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE,
- BFI, PSI, ExpensiveCombines, LI))
+ if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI,
+ PSI, ExpensiveCombines, MaxIterations,
+ LI))
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
@@ -3615,12 +3697,26 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
&getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
nullptr;
- return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE,
- BFI, PSI, ExpensiveCombines, LI);
+ return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI,
+ PSI, ExpensiveCombines, MaxIterations,
+ LI);
}
char InstructionCombiningPass::ID = 0;
+InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines)
+ : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines),
+ MaxIterations(InstCombineDefaultMaxIterations) {
+ initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry());
+}
+
+InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines,
+ unsigned MaxIterations)
+ : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines),
+ MaxIterations(MaxIterations) {
+ initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry());
+}
+
INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine",
"Combine redundant instructions", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
@@ -3647,6 +3743,11 @@ FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) {
return new InstructionCombiningPass(ExpensiveCombines);
}
+FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines,
+ unsigned MaxIterations) {
+ return new InstructionCombiningPass(ExpensiveCombines, MaxIterations);
+}
+
void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createInstructionCombiningPass());
}