diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 63 |
1 files changed, 51 insertions, 12 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index b9674d85634dc..33951e66497a1 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -303,7 +303,7 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: -/// (or (shl (and X, C1), C3), y) +/// (or (shl (and X, C1), C3), Y) /// iff: /// C1 and C2 are both powers of 2 /// where: @@ -317,19 +317,44 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy *Builder) { const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) + if (!IC || !SI.getType()->isIntegerTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); Value *CmpRHS = IC->getOperand(1); - if (!match(CmpRHS, m_Zero())) - return nullptr; + Value *V; + unsigned C1Log; + bool IsEqualZero; + bool NeedAnd = false; + if (IC->isEquality()) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + const APInt *C1; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) + return nullptr; + + V = CmpLHS; + C1Log = C1->logBase2(); + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; + } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || + IC->getPredicate() == ICmpInst::ICMP_SGT) { + // We also need to recognize (icmp slt (trunc (X)), 0) and + // (icmp sgt (trunc (X)), -1). + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; + if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || + (!IsEqualZero && !match(CmpRHS, m_Zero()))) + return nullptr; + + if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + return nullptr; - Value *X; - const APInt *C1; - if (!match(CmpLHS, m_And(m_Value(X), m_Power2(C1)))) + C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + NeedAnd = true; + } else { return nullptr; + } const APInt *C2; bool OrOnTrueVal = false; @@ -340,11 +365,27 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, if (!OrOnFalseVal && !OrOnTrueVal) return nullptr; - Value *V = CmpLHS; Value *Y = OrOnFalseVal ? TrueVal : FalseVal; - unsigned C1Log = C1->logBase2(); unsigned C2Log = C2->logBase2(); + + bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); + bool NeedShift = C1Log != C2Log; + bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != + V->getType()->getIntegerBitWidth(); + + // Make sure we don't create more instructions than we save. + Value *Or = OrOnFalseVal ? FalseVal : TrueVal; + if ((NeedShift + NeedXor + NeedZExtTrunc) > + (IC->hasOneUse() + Or->hasOneUse())) + return nullptr; + + if (NeedAnd) { + // Insert the AND instruction on the input to the truncate. + APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log); + V = Builder->CreateAnd(V, ConstantInt::get(V->getType(), C1)); + } + if (C2Log > C1Log) { V = Builder->CreateZExtOrTrunc(V, Y->getType()); V = Builder->CreateShl(V, C2Log - C1Log); @@ -354,9 +395,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, } else V = Builder->CreateZExtOrTrunc(V, Y->getType()); - ICmpInst::Predicate Pred = IC->getPredicate(); - if ((Pred == ICmpInst::ICMP_NE && OrOnFalseVal) || - (Pred == ICmpInst::ICMP_EQ && OrOnTrueVal)) + if (NeedXor) V = Builder->CreateXor(V, *C2); return Builder->CreateOr(V, Y); |