summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp506
1 files changed, 352 insertions, 154 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4eebe8255998..6f26f7f5cd19 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -12,12 +12,36 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
-#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
-#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include <cassert>
+#include <utility>
+
using namespace llvm;
using namespace PatternMatch;
@@ -69,6 +93,111 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder,
return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B);
}
+/// If one of the constants is zero (we know they can't both be) and we have an
+/// icmp instruction with zero, and we have an 'and' with the non-constant value
+/// and a power of two we can turn the select into a shift on the result of the
+/// 'and'.
+/// This folds:
+/// select (icmp eq (and X, C1)), C2, C3
+/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2.
+/// To something like:
+/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3
+/// Or:
+/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3
+/// With some variations depending if C3 is larger than C2, or the shift
+/// isn't needed, or the bit widths don't match.
+static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC,
+ APInt TrueVal, APInt FalseVal,
+ InstCombiner::BuilderTy &Builder) {
+ assert(SelType->isIntOrIntVectorTy() && "Not an integer select?");
+
+ // If this is a vector select, we need a vector compare.
+ if (SelType->isVectorTy() != IC->getType()->isVectorTy())
+ return nullptr;
+
+ Value *V;
+ APInt AndMask;
+ bool CreateAnd = false;
+ ICmpInst::Predicate Pred = IC->getPredicate();
+ if (ICmpInst::isEquality(Pred)) {
+ if (!match(IC->getOperand(1), m_Zero()))
+ return nullptr;
+
+ V = IC->getOperand(0);
+
+ const APInt *AndRHS;
+ if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
+ return nullptr;
+
+ AndMask = *AndRHS;
+ } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1),
+ Pred, V, AndMask)) {
+ assert(ICmpInst::isEquality(Pred) && "Not equality test?");
+
+ if (!AndMask.isPowerOf2())
+ return nullptr;
+
+ CreateAnd = true;
+ } else {
+ return nullptr;
+ }
+
+ // If both select arms are non-zero see if we have a select of the form
+ // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic
+ // for 'x ? 2^n : 0' and fix the thing up at the end.
+ APInt Offset(TrueVal.getBitWidth(), 0);
+ if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) {
+ if ((TrueVal - FalseVal).isPowerOf2())
+ Offset = FalseVal;
+ else if ((FalseVal - TrueVal).isPowerOf2())
+ Offset = TrueVal;
+ else
+ return nullptr;
+
+ // Adjust TrueVal and FalseVal to the offset.
+ TrueVal -= Offset;
+ FalseVal -= Offset;
+ }
+
+ // Make sure one of the select arms is a power of 2.
+ if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())
+ return nullptr;
+
+ // Determine which shift is needed to transform result of the 'and' into the
+ // desired result.
+ const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal;
+ unsigned ValZeros = ValC.logBase2();
+ unsigned AndZeros = AndMask.logBase2();
+
+ if (CreateAnd) {
+ // Insert the AND instruction on the input to the truncate.
+ V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
+ }
+
+ // If types don't match we can still convert the select by introducing a zext
+ // or a trunc of the 'and'.
+ if (ValZeros > AndZeros) {
+ V = Builder.CreateZExtOrTrunc(V, SelType);
+ V = Builder.CreateShl(V, ValZeros - AndZeros);
+ } else if (ValZeros < AndZeros) {
+ V = Builder.CreateLShr(V, AndZeros - ValZeros);
+ V = Builder.CreateZExtOrTrunc(V, SelType);
+ } else
+ V = Builder.CreateZExtOrTrunc(V, SelType);
+
+ // Okay, now we know that everything is set up, we just don't know whether we
+ // have a icmp_ne or icmp_eq and whether the true or false val is the zero.
+ bool ShouldNotVal = !TrueVal.isNullValue();
+ ShouldNotVal ^= Pred == ICmpInst::ICMP_NE;
+ if (ShouldNotVal)
+ V = Builder.CreateXor(V, ValC);
+
+ // Apply an offset if needed.
+ if (!Offset.isNullValue())
+ V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset));
+ return V;
+}
+
/// We want to turn code that looks like this:
/// %C = or %A, %B
/// %D = select %cond, %C, %A
@@ -79,8 +208,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder,
/// Assuming that the specified instruction is an operand to the select, return
/// a bitmask indicating which operands of this instruction are foldable if they
/// equal the other incoming value of the select.
-///
-static unsigned getSelectFoldableOperands(Instruction *I) {
+static unsigned getSelectFoldableOperands(BinaryOperator *I) {
switch (I->getOpcode()) {
case Instruction::Add:
case Instruction::Mul:
@@ -100,7 +228,7 @@ static unsigned getSelectFoldableOperands(Instruction *I) {
/// For the same transformation as the previous function, return the identity
/// constant that goes into the select.
-static Constant *getSelectFoldableConstant(Instruction *I) {
+static APInt getSelectFoldableConstant(BinaryOperator *I) {
switch (I->getOpcode()) {
default: llvm_unreachable("This cannot happen!");
case Instruction::Add:
@@ -110,11 +238,11 @@ static Constant *getSelectFoldableConstant(Instruction *I) {
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
- return Constant::getNullValue(I->getType());
+ return APInt::getNullValue(I->getType()->getScalarSizeInBits());
case Instruction::And:
- return Constant::getAllOnesValue(I->getType());
+ return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits());
case Instruction::Mul:
- return ConstantInt::get(I->getType(), 1);
+ return APInt(I->getType()->getScalarSizeInBits(), 1);
}
}
@@ -157,7 +285,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
if (TI->getOpcode() != Instruction::BitCast &&
(!TI->hasOneUse() || !FI->hasOneUse()))
return nullptr;
-
} else if (!TI->hasOneUse() || !FI->hasOneUse()) {
// TODO: The one-use restrictions for a scalar select could be eased if
// the fold of a select in visitLoadInst() was enhanced to match a pattern
@@ -218,17 +345,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
return BinaryOperator::Create(BO->getOpcode(), Op0, Op1);
}
-static bool isSelect01(Constant *C1, Constant *C2) {
- ConstantInt *C1I = dyn_cast<ConstantInt>(C1);
- if (!C1I)
- return false;
- ConstantInt *C2I = dyn_cast<ConstantInt>(C2);
- if (!C2I)
- return false;
- if (!C1I->isZero() && !C2I->isZero()) // One side must be zero.
+static bool isSelect01(const APInt &C1I, const APInt &C2I) {
+ if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero.
return false;
- return C1I->isOne() || C1I->isMinusOne() ||
- C2I->isOne() || C2I->isMinusOne();
+ return C1I.isOneValue() || C1I.isAllOnesValue() ||
+ C2I.isOneValue() || C2I.isAllOnesValue();
}
/// Try to fold the select into one of the operands to allow further
@@ -237,9 +358,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
Value *FalseVal) {
// See the comment above GetSelectFoldableOperands for a description of the
// transformation we are doing here.
- if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) {
- if (TVI->hasOneUse() && TVI->getNumOperands() == 2 &&
- !isa<Constant>(FalseVal)) {
+ if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) {
+ if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) {
if (unsigned SFO = getSelectFoldableOperands(TVI)) {
unsigned OpToFold = 0;
if ((SFO & 1) && FalseVal == TVI->getOperand(0)) {
@@ -249,17 +369,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- Constant *C = getSelectFoldableConstant(TVI);
+ APInt CI = getSelectFoldableConstant(TVI);
Value *OOp = TVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
- if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
+ const APInt *OOpC;
+ bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
+ if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
+ Value *C = ConstantInt::get(OOp->getType(), CI);
Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C);
NewSel->takeName(TVI);
- BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI);
- BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(),
+ BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(),
FalseVal, NewSel);
- BO->copyIRFlags(TVI_BO);
+ BO->copyIRFlags(TVI);
return BO;
}
}
@@ -267,9 +389,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
}
- if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) {
- if (FVI->hasOneUse() && FVI->getNumOperands() == 2 &&
- !isa<Constant>(TrueVal)) {
+ if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) {
+ if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) {
if (unsigned SFO = getSelectFoldableOperands(FVI)) {
unsigned OpToFold = 0;
if ((SFO & 1) && TrueVal == FVI->getOperand(0)) {
@@ -279,17 +400,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- Constant *C = getSelectFoldableConstant(FVI);
+ APInt CI = getSelectFoldableConstant(FVI);
Value *OOp = FVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
- if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
+ const APInt *OOpC;
+ bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
+ if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
+ Value *C = ConstantInt::get(OOp->getType(), CI);
Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp);
NewSel->takeName(FVI);
- BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI);
- BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(),
+ BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(),
TrueVal, NewSel);
- BO->copyIRFlags(FVI_BO);
+ BO->copyIRFlags(FVI);
return BO;
}
}
@@ -313,11 +436,13 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
/// 1. The icmp predicate is inverted
/// 2. The select operands are reversed
/// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
+static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
Value *FalseVal,
InstCombiner::BuilderTy &Builder) {
- const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition());
- if (!IC || !SI.getType()->isIntegerTy())
+ // Only handle integer compares. Also, if this is a vector select, we need a
+ // vector compare.
+ if (!TrueVal->getType()->isIntOrIntVectorTy() ||
+ TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
return nullptr;
Value *CmpLHS = IC->getOperand(0);
@@ -371,8 +496,8 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);
bool NeedShift = C1Log != C2Log;
- bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() !=
- V->getType()->getIntegerBitWidth();
+ bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() !=
+ V->getType()->getScalarSizeInBits();
// Make sure we don't create more instructions than we save.
Value *Or = OrOnFalseVal ? FalseVal : TrueVal;
@@ -447,8 +572,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
IntrinsicInst *II = cast<IntrinsicInst>(Count);
// Explicitly clear the 'undef_on_zero' flag.
IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone());
- Type *Ty = NewI->getArgOperand(1)->getType();
- NewI->setArgOperand(1, Constant::getNullValue(Ty));
+ NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext()));
Builder.Insert(NewI);
return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType());
}
@@ -597,6 +721,9 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+
if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder))
return NewSel;
@@ -605,40 +732,52 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
- Value *TrueVal = SI.getTrueValue();
- Value *FalseVal = SI.getFalseValue();
// Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1
// and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1
// FIXME: Type and constness constraints could be lifted, but we have to
// watch code size carefully. We should consider xor instead of
// sub/add when we decide to do that.
- if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) {
- if (TrueVal->getType() == Ty) {
- if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) {
- ConstantInt *C1 = nullptr, *C2 = nullptr;
- if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) {
- C1 = dyn_cast<ConstantInt>(TrueVal);
- C2 = dyn_cast<ConstantInt>(FalseVal);
- } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) {
- C1 = dyn_cast<ConstantInt>(FalseVal);
- C2 = dyn_cast<ConstantInt>(TrueVal);
- }
- if (C1 && C2) {
+ // TODO: Merge this with foldSelectICmpAnd somehow.
+ if (CmpLHS->getType()->isIntOrIntVectorTy() &&
+ CmpLHS->getType() == TrueVal->getType()) {
+ const APInt *C1, *C2;
+ if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) {
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *X;
+ APInt Mask;
+ if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) {
+ if (Mask.isSignMask()) {
+ assert(X == CmpLHS && "Expected to use the compare input directly");
+ assert(ICmpInst::isEquality(Pred) && "Expected equality predicate");
+
+ if (Pred == ICmpInst::ICMP_NE)
+ std::swap(C1, C2);
+
// This shift results in either -1 or 0.
- Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1);
+ Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1);
// Check if we can express the operation with a single or.
- if (C2->isMinusOne())
- return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1));
+ if (C2->isAllOnesValue())
+ return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1));
- Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue());
- return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1));
+ Value *And = Builder.CreateAnd(AShr, *C2 - *C1);
+ return replaceInstUsesWith(SI, Builder.CreateAdd(And,
+ ConstantInt::get(And->getType(), *C1)));
}
}
}
}
+ {
+ const APInt *TrueValC, *FalseValC;
+ if (match(TrueVal, m_APInt(TrueValC)) &&
+ match(FalseVal, m_APInt(FalseValC)))
+ if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC,
+ *FalseValC, Builder))
+ return replaceInstUsesWith(SI, V);
+ }
+
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
@@ -703,7 +842,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
}
}
- if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder))
+ if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
@@ -722,7 +861,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
/// Z = select X, Y, 0
///
/// because Y is not live in BB1/BB2.
-///
static bool canSelectOperandBeMappingIntoPredBlock(const Value *V,
const SelectInst &SI) {
// If the value is a non-instruction value like a constant or argument, it
@@ -864,78 +1002,6 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
return nullptr;
}
-/// If one of the constants is zero (we know they can't both be) and we have an
-/// icmp instruction with zero, and we have an 'and' with the non-constant value
-/// and a power of two we can turn the select into a shift on the result of the
-/// 'and'.
-static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal,
- APInt FalseVal,
- InstCombiner::BuilderTy &Builder) {
- const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition());
- if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy())
- return nullptr;
-
- if (!match(IC->getOperand(1), m_Zero()))
- return nullptr;
-
- ConstantInt *AndRHS;
- Value *LHS = IC->getOperand(0);
- if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS))))
- return nullptr;
-
- // If both select arms are non-zero see if we have a select of the form
- // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic
- // for 'x ? 2^n : 0' and fix the thing up at the end.
- APInt Offset(TrueVal.getBitWidth(), 0);
- if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) {
- if ((TrueVal - FalseVal).isPowerOf2())
- Offset = FalseVal;
- else if ((FalseVal - TrueVal).isPowerOf2())
- Offset = TrueVal;
- else
- return nullptr;
-
- // Adjust TrueVal and FalseVal to the offset.
- TrueVal -= Offset;
- FalseVal -= Offset;
- }
-
- // Make sure the mask in the 'and' and one of the select arms is a power of 2.
- if (!AndRHS->getValue().isPowerOf2() ||
- (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()))
- return nullptr;
-
- // Determine which shift is needed to transform result of the 'and' into the
- // desired result.
- const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal;
- unsigned ValZeros = ValC.logBase2();
- unsigned AndZeros = AndRHS->getValue().logBase2();
-
- // If types don't match we can still convert the select by introducing a zext
- // or a trunc of the 'and'. The trunc case requires that all of the truncated
- // bits are zero, we can figure that out by looking at the 'and' mask.
- if (AndZeros >= ValC.getBitWidth())
- return nullptr;
-
- Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType());
- if (ValZeros > AndZeros)
- V = Builder.CreateShl(V, ValZeros - AndZeros);
- else if (ValZeros < AndZeros)
- V = Builder.CreateLShr(V, AndZeros - ValZeros);
-
- // Okay, now we know that everything is set up, we just don't know whether we
- // have a icmp_ne or icmp_eq and whether the true or false val is the zero.
- bool ShouldNotVal = !TrueVal.isNullValue();
- ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE;
- if (ShouldNotVal)
- V = Builder.CreateXor(V, ValC);
-
- // Apply an offset if needed.
- if (!Offset.isNullValue())
- V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset));
- return V;
-}
-
/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))).
/// This is even legal for FP.
static Instruction *foldAddSubSelect(SelectInst &SI,
@@ -1151,12 +1217,100 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType());
}
+/// Try to eliminate select instructions that test the returned flag of cmpxchg
+/// instructions.
+///
+/// If a select instruction tests the returned flag of a cmpxchg instruction and
+/// selects between the returned value of the cmpxchg instruction its compare
+/// operand, the result of the select will always be equal to its false value.
+/// For example:
+///
+/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+/// %1 = extractvalue { i64, i1 } %0, 1
+/// %2 = extractvalue { i64, i1 } %0, 0
+/// %3 = select i1 %1, i64 %compare, i64 %2
+/// ret i64 %3
+///
+/// The returned value of the cmpxchg instruction (%2) is the original value
+/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2
+/// must have been equal to %compare. Thus, the result of the select is always
+/// equal to %2, and the code can be simplified to:
+///
+/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+/// %1 = extractvalue { i64, i1 } %0, 0
+/// ret i64 %1
+///
+static Instruction *foldSelectCmpXchg(SelectInst &SI) {
+ // A helper that determines if V is an extractvalue instruction whose
+ // aggregate operand is a cmpxchg instruction and whose single index is equal
+ // to I. If such conditions are true, the helper returns the cmpxchg
+ // instruction; otherwise, a nullptr is returned.
+ auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * {
+ auto *Extract = dyn_cast<ExtractValueInst>(V);
+ if (!Extract)
+ return nullptr;
+ if (Extract->getIndices()[0] != I)
+ return nullptr;
+ return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand());
+ };
+
+ // If the select has a single user, and this user is a select instruction that
+ // we can simplify, skip the cmpxchg simplification for now.
+ if (SI.hasOneUse())
+ if (auto *Select = dyn_cast<SelectInst>(SI.user_back()))
+ if (Select->getCondition() == SI.getCondition())
+ if (Select->getFalseValue() == SI.getTrueValue() ||
+ Select->getTrueValue() == SI.getFalseValue())
+ return nullptr;
+
+ // Ensure the select condition is the returned flag of a cmpxchg instruction.
+ auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1);
+ if (!CmpXchg)
+ return nullptr;
+
+ // Check the true value case: The true value of the select is the returned
+ // value of the same cmpxchg used by the condition, and the false value is the
+ // cmpxchg instruction's compare operand.
+ if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0))
+ if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) {
+ SI.setTrueValue(SI.getFalseValue());
+ return &SI;
+ }
+
+ // Check the false value case: The false value of the select is the returned
+ // value of the same cmpxchg used by the condition, and the true value is the
+ // cmpxchg instruction's compare operand.
+ if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0))
+ if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) {
+ SI.setTrueValue(SI.getFalseValue());
+ return &SI;
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
Type *SelType = SI.getType();
+ // FIXME: Remove this workaround when freeze related patches are done.
+ // For select with undef operand which feeds into an equality comparison,
+ // don't simplify it so loop unswitch can know the equality comparison
+ // may have an undef operand. This is a workaround for PR31652 caused by
+ // descrepancy about branch on undef between LoopUnswitch and GVN.
+ if (isa<UndefValue>(TrueVal) || isa<UndefValue>(FalseVal)) {
+ if (llvm::any_of(SI.users(), [&](User *U) {
+ ICmpInst *CI = dyn_cast<ICmpInst>(U);
+ if (CI && CI->isEquality())
+ return true;
+ return false;
+ })) {
+ return nullptr;
+ }
+ }
+
if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal,
SQ.getWithInstruction(&SI)))
return replaceInstUsesWith(SI, V);
@@ -1246,12 +1400,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
}
- if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal))
- if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal))
- if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(),
- FalseValC->getValue(), Builder))
- return replaceInstUsesWith(SI, V);
-
// See if we are selecting two values based on a comparison of the two values.
if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) {
if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) {
@@ -1373,9 +1521,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
auto SPF = SPR.Flavor;
if (SelectPatternResult::isMinOrMax(SPF)) {
- // Canonicalize so that type casts are outside select patterns.
- if (LHS->getType()->getPrimitiveSizeInBits() !=
- SelType->getPrimitiveSizeInBits()) {
+ // Canonicalize so that
+ // - type casts are outside select patterns.
+ // - float clamp is transformed to min/max pattern
+
+ bool IsCastNeeded = LHS->getType() != SelType;
+ Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
+ Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
+ if (IsCastNeeded ||
+ (LHS->getType()->isFPOrFPVectorTy() &&
+ ((CmpLHS != LHS && CmpLHS != RHS) ||
+ (CmpRHS != LHS && CmpRHS != RHS)))) {
CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered);
Value *Cmp;
@@ -1388,10 +1544,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
Cmp = Builder.CreateFCmp(Pred, LHS, RHS);
}
- Value *NewSI = Builder.CreateCast(
- CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI),
- SelType);
- return replaceInstUsesWith(SI, NewSI);
+ Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI);
+ if (!IsCastNeeded)
+ return replaceInstUsesWith(SI, NewSI);
+
+ Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType);
+ return replaceInstUsesWith(SI, NewCast);
}
}
@@ -1485,6 +1643,46 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
}
+ // Try to simplify a binop sandwiched between 2 selects with the same
+ // condition.
+ // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z)
+ BinaryOperator *TrueBO;
+ if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) {
+ if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) {
+ if (TrueBOSI->getCondition() == CondVal) {
+ TrueBO->setOperand(0, TrueBOSI->getTrueValue());
+ Worklist.Add(TrueBO);
+ return &SI;
+ }
+ }
+ if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) {
+ if (TrueBOSI->getCondition() == CondVal) {
+ TrueBO->setOperand(1, TrueBOSI->getTrueValue());
+ Worklist.Add(TrueBO);
+ return &SI;
+ }
+ }
+ }
+
+ // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W))
+ BinaryOperator *FalseBO;
+ if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) {
+ if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) {
+ if (FalseBOSI->getCondition() == CondVal) {
+ FalseBO->setOperand(0, FalseBOSI->getFalseValue());
+ Worklist.Add(FalseBO);
+ return &SI;
+ }
+ }
+ if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) {
+ if (FalseBOSI->getCondition() == CondVal) {
+ FalseBO->setOperand(1, FalseBOSI->getFalseValue());
+ Worklist.Add(FalseBO);
+ return &SI;
+ }
+ }
+ }
+
if (BinaryOperator::isNot(CondVal)) {
SI.setOperand(0, BinaryOperator::getNotArgument(CondVal));
SI.setOperand(1, FalseVal);
@@ -1501,10 +1699,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return replaceInstUsesWith(SI, V);
return &SI;
}
-
- if (isa<ConstantAggregateZero>(CondVal)) {
- return replaceInstUsesWith(SI, FalseVal);
- }
}
// See if we can determine the result of this select based on a dominating
@@ -1515,9 +1709,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (PBI && PBI->isConditional() &&
PBI->getSuccessor(0) != PBI->getSuccessor(1) &&
(PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) {
- bool CondIsFalse = PBI->getSuccessor(1) == Parent;
+ bool CondIsTrue = PBI->getSuccessor(0) == Parent;
Optional<bool> Implication = isImpliedCondition(
- PBI->getCondition(), SI.getCondition(), DL, CondIsFalse);
+ PBI->getCondition(), SI.getCondition(), DL, CondIsTrue);
if (Implication) {
Value *V = *Implication ? TrueVal : FalseVal;
return replaceInstUsesWith(SI, V);
@@ -1542,5 +1736,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder))
return BitCastSel;
+ // Simplify selects that test the returned flag of cmpxchg instructions.
+ if (Instruction *Select = foldSelectCmpXchg(SI))
+ return Select;
+
return nullptr;
}