diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 506 |
1 files changed, 352 insertions, 154 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4eebe8255998..6f26f7f5cd19 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,12 +12,36 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <utility> + using namespace llvm; using namespace PatternMatch; @@ -69,6 +93,111 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } +/// If one of the constants is zero (we know they can't both be) and we have an +/// icmp instruction with zero, and we have an 'and' with the non-constant value +/// and a power of two we can turn the select into a shift on the result of the +/// 'and'. +/// This folds: +/// select (icmp eq (and X, C1)), C2, C3 +/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// To something like: +/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// Or: +/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 +/// With some variations depending if C3 is larger than C2, or the shift +/// isn't needed, or the bit widths don't match. +static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, + APInt TrueVal, APInt FalseVal, + InstCombiner::BuilderTy &Builder) { + assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + + // If this is a vector select, we need a vector compare. + if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + return nullptr; + + Value *V; + APInt AndMask; + bool CreateAnd = false; + ICmpInst::Predicate Pred = IC->getPredicate(); + if (ICmpInst::isEquality(Pred)) { + if (!match(IC->getOperand(1), m_Zero())) + return nullptr; + + V = IC->getOperand(0); + + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + Pred, V, AndMask)) { + assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + + if (!AndMask.isPowerOf2()) + return nullptr; + + CreateAnd = true; + } else { + return nullptr; + } + + // If both select arms are non-zero see if we have a select of the form + // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic + // for 'x ? 2^n : 0' and fix the thing up at the end. + APInt Offset(TrueVal.getBitWidth(), 0); + if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { + if ((TrueVal - FalseVal).isPowerOf2()) + Offset = FalseVal; + else if ((FalseVal - TrueVal).isPowerOf2()) + Offset = TrueVal; + else + return nullptr; + + // Adjust TrueVal and FalseVal to the offset. + TrueVal -= Offset; + FalseVal -= Offset; + } + + // Make sure one of the select arms is a power of 2. + if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + return nullptr; + + // Determine which shift is needed to transform result of the 'and' into the + // desired result. + const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + unsigned ValZeros = ValC.logBase2(); + unsigned AndZeros = AndMask.logBase2(); + + if (CreateAnd) { + // Insert the AND instruction on the input to the truncate. + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); + } + + // If types don't match we can still convert the select by introducing a zext + // or a trunc of the 'and'. + if (ValZeros > AndZeros) { + V = Builder.CreateZExtOrTrunc(V, SelType); + V = Builder.CreateShl(V, ValZeros - AndZeros); + } else if (ValZeros < AndZeros) { + V = Builder.CreateLShr(V, AndZeros - ValZeros); + V = Builder.CreateZExtOrTrunc(V, SelType); + } else + V = Builder.CreateZExtOrTrunc(V, SelType); + + // Okay, now we know that everything is set up, we just don't know whether we + // have a icmp_ne or icmp_eq and whether the true or false val is the zero. + bool ShouldNotVal = !TrueVal.isNullValue(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + if (ShouldNotVal) + V = Builder.CreateXor(V, ValC); + + // Apply an offset if needed. + if (!Offset.isNullValue()) + V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); + return V; +} + /// We want to turn code that looks like this: /// %C = or %A, %B /// %D = select %cond, %C, %A @@ -79,8 +208,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, /// Assuming that the specified instruction is an operand to the select, return /// a bitmask indicating which operands of this instruction are foldable if they /// equal the other incoming value of the select. -/// -static unsigned getSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(BinaryOperator *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -100,7 +228,7 @@ static unsigned getSelectFoldableOperands(Instruction *I) { /// For the same transformation as the previous function, return the identity /// constant that goes into the select. -static Constant *getSelectFoldableConstant(Instruction *I) { +static APInt getSelectFoldableConstant(BinaryOperator *I) { switch (I->getOpcode()) { default: llvm_unreachable("This cannot happen!"); case Instruction::Add: @@ -110,11 +238,11 @@ static Constant *getSelectFoldableConstant(Instruction *I) { case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: - return Constant::getNullValue(I->getType()); + return APInt::getNullValue(I->getType()->getScalarSizeInBits()); case Instruction::And: - return Constant::getAllOnesValue(I->getType()); + return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits()); case Instruction::Mul: - return ConstantInt::get(I->getType(), 1); + return APInt(I->getType()->getScalarSizeInBits(), 1); } } @@ -157,7 +285,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, if (TI->getOpcode() != Instruction::BitCast && (!TI->hasOneUse() || !FI->hasOneUse())) return nullptr; - } else if (!TI->hasOneUse() || !FI->hasOneUse()) { // TODO: The one-use restrictions for a scalar select could be eased if // the fold of a select in visitLoadInst() was enhanced to match a pattern @@ -218,17 +345,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); } -static bool isSelect01(Constant *C1, Constant *C2) { - ConstantInt *C1I = dyn_cast<ConstantInt>(C1); - if (!C1I) - return false; - ConstantInt *C2I = dyn_cast<ConstantInt>(C2); - if (!C2I) - return false; - if (!C1I->isZero() && !C2I->isZero()) // One side must be zero. +static bool isSelect01(const APInt &C1I, const APInt &C2I) { + if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero. return false; - return C1I->isOne() || C1I->isMinusOne() || - C2I->isOne() || C2I->isMinusOne(); + return C1I.isOneValue() || C1I.isAllOnesValue() || + C2I.isOneValue() || C2I.isAllOnesValue(); } /// Try to fold the select into one of the operands to allow further @@ -237,9 +358,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. - if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { - if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && - !isa<Constant>(FalseVal)) { + if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { + if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { if (unsigned SFO = getSelectFoldableOperands(TVI)) { unsigned OpToFold = 0; if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { @@ -249,17 +369,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(TVI); + APInt CI = getSelectFoldableConstant(TVI); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); - BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); - BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI_BO); + BO->copyIRFlags(TVI); return BO; } } @@ -267,9 +389,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } } - if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { - if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && - !isa<Constant>(TrueVal)) { + if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) { + if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) { if (unsigned SFO = getSelectFoldableOperands(FVI)) { unsigned OpToFold = 0; if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { @@ -279,17 +400,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(FVI); + APInt CI = getSelectFoldableConstant(FVI); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); - BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); - BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), TrueVal, NewSel); - BO->copyIRFlags(FVI_BO); + BO->copyIRFlags(FVI); return BO; } } @@ -313,11 +436,13 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, +static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !SI.getType()->isIntegerTy()) + // Only handle integer compares. Also, if this is a vector select, we need a + // vector compare. + if (!TrueVal->getType()->isIntOrIntVectorTy() || + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); @@ -371,8 +496,8 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; - bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != - V->getType()->getIntegerBitWidth(); + bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != + V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. Value *Or = OrOnFalseVal ? FalseVal : TrueVal; @@ -447,8 +572,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, IntrinsicInst *II = cast<IntrinsicInst>(Count); // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); - Type *Ty = NewI->getArgOperand(1)->getType(); - NewI->setArgOperand(1, Constant::getNullValue(Ty)); + NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); Builder.Insert(NewI); return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } @@ -597,6 +721,9 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -605,40 +732,52 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 // FIXME: Type and constness constraints could be lifted, but we have to // watch code size carefully. We should consider xor instead of // sub/add when we decide to do that. - if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) { - if (TrueVal->getType() == Ty) { - if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) { - ConstantInt *C1 = nullptr, *C2 = nullptr; - if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) { - C1 = dyn_cast<ConstantInt>(TrueVal); - C2 = dyn_cast<ConstantInt>(FalseVal); - } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) { - C1 = dyn_cast<ConstantInt>(FalseVal); - C2 = dyn_cast<ConstantInt>(TrueVal); - } - if (C1 && C2) { + // TODO: Merge this with foldSelectICmpAnd somehow. + if (CmpLHS->getType()->isIntOrIntVectorTy() && + CmpLHS->getType() == TrueVal->getType()) { + const APInt *C1, *C2; + if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *X; + APInt Mask; + if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { + if (Mask.isSignMask()) { + assert(X == CmpLHS && "Expected to use the compare input directly"); + assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); + + if (Pred == ICmpInst::ICMP_NE) + std::swap(C1, C2); + // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1); + Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); // Check if we can express the operation with a single or. - if (C2->isMinusOne()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1)); + if (C2->isAllOnesValue()) + return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue()); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1)); + Value *And = Builder.CreateAnd(AShr, *C2 - *C1); + return replaceInstUsesWith(SI, Builder.CreateAdd(And, + ConstantInt::get(And->getType(), *C1))); } } } } + { + const APInt *TrueValC, *FalseValC; + if (match(TrueVal, m_APInt(TrueValC)) && + match(FalseVal, m_APInt(FalseValC))) + if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, + *FalseValC, Builder)) + return replaceInstUsesWith(SI, V); + } + // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { @@ -703,7 +842,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } - if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) @@ -722,7 +861,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, /// Z = select X, Y, 0 /// /// because Y is not live in BB1/BB2. -/// static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, const SelectInst &SI) { // If the value is a non-instruction value like a constant or argument, it @@ -864,78 +1002,6 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, return nullptr; } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. -static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal, - APInt FalseVal, - InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) - return nullptr; - - if (!match(IC->getOperand(1), m_Zero())) - return nullptr; - - ConstantInt *AndRHS; - Value *LHS = IC->getOperand(0); - if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS)))) - return nullptr; - - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else - return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; - } - - // Make sure the mask in the 'and' and one of the select arms is a power of 2. - if (!AndRHS->getValue().isPowerOf2() || - (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())) - return nullptr; - - // Determine which shift is needed to transform result of the 'and' into the - // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; - unsigned ValZeros = ValC.logBase2(); - unsigned AndZeros = AndRHS->getValue().logBase2(); - - // If types don't match we can still convert the select by introducing a zext - // or a trunc of the 'and'. The trunc case requires that all of the truncated - // bits are zero, we can figure that out by looking at the 'and' mask. - if (AndZeros >= ValC.getBitWidth()) - return nullptr; - - Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType()); - if (ValZeros > AndZeros) - V = Builder.CreateShl(V, ValZeros - AndZeros); - else if (ValZeros < AndZeros) - V = Builder.CreateLShr(V, AndZeros - ValZeros); - - // Okay, now we know that everything is set up, we just don't know whether we - // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); - ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; - if (ShouldNotVal) - V = Builder.CreateXor(V, ValC); - - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); - return V; -} - /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). /// This is even legal for FP. static Instruction *foldAddSubSelect(SelectInst &SI, @@ -1151,12 +1217,100 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); } +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 1 +/// %2 = extractvalue { i64, i1 } %0, 0 +/// %3 = select i1 %1, i64 %compare, i64 %2 +/// ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 0 +/// ret i64 %1 +/// +static Instruction *foldSelectCmpXchg(SelectInst &SI) { + // A helper that determines if V is an extractvalue instruction whose + // aggregate operand is a cmpxchg instruction and whose single index is equal + // to I. If such conditions are true, the helper returns the cmpxchg + // instruction; otherwise, a nullptr is returned. + auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { + auto *Extract = dyn_cast<ExtractValueInst>(V); + if (!Extract) + return nullptr; + if (Extract->getIndices()[0] != I) + return nullptr; + return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); + }; + + // If the select has a single user, and this user is a select instruction that + // we can simplify, skip the cmpxchg simplification for now. + if (SI.hasOneUse()) + if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) + if (Select->getCondition() == SI.getCondition()) + if (Select->getFalseValue() == SI.getTrueValue() || + Select->getTrueValue() == SI.getFalseValue()) + return nullptr; + + // Ensure the select condition is the returned flag of a cmpxchg instruction. + auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); + if (!CmpXchg) + return nullptr; + + // Check the true value case: The true value of the select is the returned + // value of the same cmpxchg used by the condition, and the false value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + // Check the false value case: The false value of the select is the returned + // value of the same cmpxchg used by the condition, and the true value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); + // FIXME: Remove this workaround when freeze related patches are done. + // For select with undef operand which feeds into an equality comparison, + // don't simplify it so loop unswitch can know the equality comparison + // may have an undef operand. This is a workaround for PR31652 caused by + // descrepancy about branch on undef between LoopUnswitch and GVN. + if (isa<UndefValue>(TrueVal) || isa<UndefValue>(FalseVal)) { + if (llvm::any_of(SI.users(), [&](User *U) { + ICmpInst *CI = dyn_cast<ICmpInst>(U); + if (CI && CI->isEquality()) + return true; + return false; + })) { + return nullptr; + } + } + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); @@ -1246,12 +1400,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) - if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(), - FalseValC->getValue(), Builder)) - return replaceInstUsesWith(SI, V); - // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { @@ -1373,9 +1521,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto SPF = SPR.Flavor; if (SelectPatternResult::isMinOrMax(SPF)) { - // Canonicalize so that type casts are outside select patterns. - if (LHS->getType()->getPrimitiveSizeInBits() != - SelType->getPrimitiveSizeInBits()) { + // Canonicalize so that + // - type casts are outside select patterns. + // - float clamp is transformed to min/max pattern + + bool IsCastNeeded = LHS->getType() != SelType; + Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); + Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); + if (IsCastNeeded || + (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1388,10 +1544,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Cmp = Builder.CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder.CreateCast( - CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), - SelType); - return replaceInstUsesWith(SI, NewSI); + Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); + if (!IsCastNeeded) + return replaceInstUsesWith(SI, NewSI); + + Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); + return replaceInstUsesWith(SI, NewCast); } } @@ -1485,6 +1643,46 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + // Try to simplify a binop sandwiched between 2 selects with the same + // condition. + // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) + BinaryOperator *TrueBO; + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(0, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(1, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + } + + // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) + BinaryOperator *FalseBO; + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(0, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(1, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + } + if (BinaryOperator::isNot(CondVal)) { SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); SI.setOperand(1, FalseVal); @@ -1501,10 +1699,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return replaceInstUsesWith(SI, V); return &SI; } - - if (isa<ConstantAggregateZero>(CondVal)) { - return replaceInstUsesWith(SI, FalseVal); - } } // See if we can determine the result of this select based on a dominating @@ -1515,9 +1709,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (PBI && PBI->isConditional() && PBI->getSuccessor(0) != PBI->getSuccessor(1) && (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { - bool CondIsFalse = PBI->getSuccessor(1) == Parent; + bool CondIsTrue = PBI->getSuccessor(0) == Parent; Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + PBI->getCondition(), SI.getCondition(), DL, CondIsTrue); if (Implication) { Value *V = *Implication ? TrueVal : FalseVal; return replaceInstUsesWith(SI, V); @@ -1542,5 +1736,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) return BitCastSel; + // Simplify selects that test the returned flag of cmpxchg instructions. + if (Instruction *Select = foldSelectCmpXchg(SI)) + return Select; + return nullptr; } |