diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 | 
| commit | 044eb2f6afba375a914ac9d8024f8f5142bb912e (patch) | |
| tree | 1475247dc9f9fe5be155ebd4c9069c75aadf8c20 /lib/Transforms/InstCombine/InstCombineSelect.cpp | |
| parent | eb70dddbd77e120e5d490bd8fbe7ff3f8fa81c6b (diff) | |
Notes
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 506 | 
1 files changed, 352 insertions, 154 deletions
| diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4eebe8255998..6f26f7f5cd19 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,12 +12,36 @@  //===----------------------------------------------------------------------===//  #include "InstCombineInternal.h" -#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CmpInstAnalysis.h"  #include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h"  #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h"  #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <utility> +  using namespace llvm;  using namespace PatternMatch; @@ -69,6 +93,111 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder,    return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B);  } +/// If one of the constants is zero (we know they can't both be) and we have an +/// icmp instruction with zero, and we have an 'and' with the non-constant value +/// and a power of two we can turn the select into a shift on the result of the +/// 'and'. +/// This folds: +///  select (icmp eq (and X, C1)), C2, C3 +///    iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// To something like: +///  (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// Or: +///  (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 +/// With some variations depending if C3 is larger than C2, or the shift +/// isn't needed, or the bit widths don't match. +static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, +                                APInt TrueVal, APInt FalseVal, +                                InstCombiner::BuilderTy &Builder) { +  assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + +  // If this is a vector select, we need a vector compare. +  if (SelType->isVectorTy() != IC->getType()->isVectorTy()) +    return nullptr; + +  Value *V; +  APInt AndMask; +  bool CreateAnd = false; +  ICmpInst::Predicate Pred = IC->getPredicate(); +  if (ICmpInst::isEquality(Pred)) { +    if (!match(IC->getOperand(1), m_Zero())) +      return nullptr; + +    V = IC->getOperand(0); + +    const APInt *AndRHS; +    if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) +      return nullptr; + +    AndMask = *AndRHS; +  } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), +                                  Pred, V, AndMask)) { +    assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + +    if (!AndMask.isPowerOf2()) +      return nullptr; + +    CreateAnd = true; +  } else { +    return nullptr; +  } + +  // If both select arms are non-zero see if we have a select of the form +  // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic +  // for 'x ? 2^n : 0' and fix the thing up at the end. +  APInt Offset(TrueVal.getBitWidth(), 0); +  if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { +    if ((TrueVal - FalseVal).isPowerOf2()) +      Offset = FalseVal; +    else if ((FalseVal - TrueVal).isPowerOf2()) +      Offset = TrueVal; +    else +      return nullptr; + +    // Adjust TrueVal and FalseVal to the offset. +    TrueVal -= Offset; +    FalseVal -= Offset; +  } + +  // Make sure one of the select arms is a power of 2. +  if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) +    return nullptr; + +  // Determine which shift is needed to transform result of the 'and' into the +  // desired result. +  const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; +  unsigned ValZeros = ValC.logBase2(); +  unsigned AndZeros = AndMask.logBase2(); + +  if (CreateAnd) { +    // Insert the AND instruction on the input to the truncate. +    V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); +  } + +  // If types don't match we can still convert the select by introducing a zext +  // or a trunc of the 'and'. +  if (ValZeros > AndZeros) { +    V = Builder.CreateZExtOrTrunc(V, SelType); +    V = Builder.CreateShl(V, ValZeros - AndZeros); +  } else if (ValZeros < AndZeros) { +    V = Builder.CreateLShr(V, AndZeros - ValZeros); +    V = Builder.CreateZExtOrTrunc(V, SelType); +  } else +    V = Builder.CreateZExtOrTrunc(V, SelType); + +  // Okay, now we know that everything is set up, we just don't know whether we +  // have a icmp_ne or icmp_eq and whether the true or false val is the zero. +  bool ShouldNotVal = !TrueVal.isNullValue(); +  ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; +  if (ShouldNotVal) +    V = Builder.CreateXor(V, ValC); + +  // Apply an offset if needed. +  if (!Offset.isNullValue()) +    V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); +  return V; +} +  /// We want to turn code that looks like this:  ///   %C = or %A, %B  ///   %D = select %cond, %C, %A @@ -79,8 +208,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder,  /// Assuming that the specified instruction is an operand to the select, return  /// a bitmask indicating which operands of this instruction are foldable if they  /// equal the other incoming value of the select. -/// -static unsigned getSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(BinaryOperator *I) {    switch (I->getOpcode()) {    case Instruction::Add:    case Instruction::Mul: @@ -100,7 +228,7 @@ static unsigned getSelectFoldableOperands(Instruction *I) {  /// For the same transformation as the previous function, return the identity  /// constant that goes into the select. -static Constant *getSelectFoldableConstant(Instruction *I) { +static APInt getSelectFoldableConstant(BinaryOperator *I) {    switch (I->getOpcode()) {    default: llvm_unreachable("This cannot happen!");    case Instruction::Add: @@ -110,11 +238,11 @@ static Constant *getSelectFoldableConstant(Instruction *I) {    case Instruction::Shl:    case Instruction::LShr:    case Instruction::AShr: -    return Constant::getNullValue(I->getType()); +    return APInt::getNullValue(I->getType()->getScalarSizeInBits());    case Instruction::And: -    return Constant::getAllOnesValue(I->getType()); +    return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits());    case Instruction::Mul: -    return ConstantInt::get(I->getType(), 1); +    return APInt(I->getType()->getScalarSizeInBits(), 1);    }  } @@ -157,7 +285,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,        if (TI->getOpcode() != Instruction::BitCast &&            (!TI->hasOneUse() || !FI->hasOneUse()))          return nullptr; -      } else if (!TI->hasOneUse() || !FI->hasOneUse()) {        // TODO: The one-use restrictions for a scalar select could be eased if        // the fold of a select in visitLoadInst() was enhanced to match a pattern @@ -218,17 +345,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,    return BinaryOperator::Create(BO->getOpcode(), Op0, Op1);  } -static bool isSelect01(Constant *C1, Constant *C2) { -  ConstantInt *C1I = dyn_cast<ConstantInt>(C1); -  if (!C1I) -    return false; -  ConstantInt *C2I = dyn_cast<ConstantInt>(C2); -  if (!C2I) -    return false; -  if (!C1I->isZero() && !C2I->isZero()) // One side must be zero. +static bool isSelect01(const APInt &C1I, const APInt &C2I) { +  if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero.      return false; -  return C1I->isOne() || C1I->isMinusOne() || -         C2I->isOne() || C2I->isMinusOne(); +  return C1I.isOneValue() || C1I.isAllOnesValue() || +         C2I.isOneValue() || C2I.isAllOnesValue();  }  /// Try to fold the select into one of the operands to allow further @@ -237,9 +358,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,                                              Value *FalseVal) {    // See the comment above GetSelectFoldableOperands for a description of the    // transformation we are doing here. -  if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { -    if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && -        !isa<Constant>(FalseVal)) { +  if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { +    if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) {        if (unsigned SFO = getSelectFoldableOperands(TVI)) {          unsigned OpToFold = 0;          if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { @@ -249,17 +369,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,          }          if (OpToFold) { -          Constant *C = getSelectFoldableConstant(TVI); +          APInt CI = getSelectFoldableConstant(TVI);            Value *OOp = TVI->getOperand(2-OpToFold);            // Avoid creating select between 2 constants unless it's selecting            // between 0, 1 and -1. -          if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { +          const APInt *OOpC; +          bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); +          if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { +            Value *C = ConstantInt::get(OOp->getType(), CI);              Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C);              NewSel->takeName(TVI); -            BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); -            BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), +            BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(),                                                          FalseVal, NewSel); -            BO->copyIRFlags(TVI_BO); +            BO->copyIRFlags(TVI);              return BO;            }          } @@ -267,9 +389,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,      }    } -  if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { -    if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && -        !isa<Constant>(TrueVal)) { +  if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) { +    if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) {        if (unsigned SFO = getSelectFoldableOperands(FVI)) {          unsigned OpToFold = 0;          if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { @@ -279,17 +400,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,          }          if (OpToFold) { -          Constant *C = getSelectFoldableConstant(FVI); +          APInt CI = getSelectFoldableConstant(FVI);            Value *OOp = FVI->getOperand(2-OpToFold);            // Avoid creating select between 2 constants unless it's selecting            // between 0, 1 and -1. -          if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { +          const APInt *OOpC; +          bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); +          if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { +            Value *C = ConstantInt::get(OOp->getType(), CI);              Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp);              NewSel->takeName(FVI); -            BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); -            BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), +            BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(),                                                          TrueVal, NewSel); -            BO->copyIRFlags(FVI_BO); +            BO->copyIRFlags(FVI);              return BO;            }          } @@ -313,11 +436,13 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,  /// 1. The icmp predicate is inverted  /// 2. The select operands are reversed  /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, +static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,                                    Value *FalseVal,                                    InstCombiner::BuilderTy &Builder) { -  const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); -  if (!IC || !SI.getType()->isIntegerTy()) +  // Only handle integer compares. Also, if this is a vector select, we need a +  // vector compare. +  if (!TrueVal->getType()->isIntOrIntVectorTy() || +      TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())      return nullptr;    Value *CmpLHS = IC->getOperand(0); @@ -371,8 +496,8 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,    bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);    bool NeedShift = C1Log != C2Log; -  bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != -                       V->getType()->getIntegerBitWidth(); +  bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != +                       V->getType()->getScalarSizeInBits();    // Make sure we don't create more instructions than we save.    Value *Or = OrOnFalseVal ? FalseVal : TrueVal; @@ -447,8 +572,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,      IntrinsicInst *II = cast<IntrinsicInst>(Count);      // Explicitly clear the 'undef_on_zero' flag.      IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); -    Type *Ty = NewI->getArgOperand(1)->getType(); -    NewI->setArgOperand(1, Constant::getNullValue(Ty)); +    NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext()));      Builder.Insert(NewI);      return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType());    } @@ -597,6 +721,9 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,  /// Visit a SelectInst that has an ICmpInst as its first operand.  Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,                                                    ICmpInst *ICI) { +  Value *TrueVal = SI.getTrueValue(); +  Value *FalseVal = SI.getFalseValue(); +    if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder))      return NewSel; @@ -605,40 +732,52 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,    ICmpInst::Predicate Pred = ICI->getPredicate();    Value *CmpLHS = ICI->getOperand(0);    Value *CmpRHS = ICI->getOperand(1); -  Value *TrueVal = SI.getTrueValue(); -  Value *FalseVal = SI.getFalseValue();    // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1    // and       (X <s  0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1    // FIXME: Type and constness constraints could be lifted, but we have to    //        watch code size carefully. We should consider xor instead of    //        sub/add when we decide to do that. -  if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) { -    if (TrueVal->getType() == Ty) { -      if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) { -        ConstantInt *C1 = nullptr, *C2 = nullptr; -        if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) { -          C1 = dyn_cast<ConstantInt>(TrueVal); -          C2 = dyn_cast<ConstantInt>(FalseVal); -        } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) { -          C1 = dyn_cast<ConstantInt>(FalseVal); -          C2 = dyn_cast<ConstantInt>(TrueVal); -        } -        if (C1 && C2) { +  // TODO: Merge this with foldSelectICmpAnd somehow. +  if (CmpLHS->getType()->isIntOrIntVectorTy() && +      CmpLHS->getType() == TrueVal->getType()) { +    const APInt *C1, *C2; +    if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { +      ICmpInst::Predicate Pred = ICI->getPredicate(); +      Value *X; +      APInt Mask; +      if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { +        if (Mask.isSignMask()) { +          assert(X == CmpLHS && "Expected to use the compare input directly"); +          assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); + +          if (Pred == ICmpInst::ICMP_NE) +            std::swap(C1, C2); +            // This shift results in either -1 or 0. -          Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1); +          Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1);            // Check if we can express the operation with a single or. -          if (C2->isMinusOne()) -            return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1)); +          if (C2->isAllOnesValue()) +            return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); -          Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue()); -          return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1)); +          Value *And = Builder.CreateAnd(AShr, *C2 - *C1); +          return replaceInstUsesWith(SI, Builder.CreateAdd(And, +                                        ConstantInt::get(And->getType(), *C1)));          }        }      }    } +  { +    const APInt *TrueValC, *FalseValC; +    if (match(TrueVal, m_APInt(TrueValC)) && +        match(FalseVal, m_APInt(FalseValC))) +      if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, +                                       *FalseValC, Builder)) +        return replaceInstUsesWith(SI, V); +  } +    // NOTE: if we wanted to, this is where to detect integer MIN/MAX    if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { @@ -703,7 +842,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,      }    } -  if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) +  if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))      return replaceInstUsesWith(SI, V);    if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) @@ -722,7 +861,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,  ///   Z = select X, Y, 0  ///  /// because Y is not live in BB1/BB2. -///  static bool canSelectOperandBeMappingIntoPredBlock(const Value *V,                                                     const SelectInst &SI) {    // If the value is a non-instruction value like a constant or argument, it @@ -864,78 +1002,6 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,    return nullptr;  } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. -static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal, -                                APInt FalseVal, -                                InstCombiner::BuilderTy &Builder) { -  const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); -  if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) -    return nullptr; - -  if (!match(IC->getOperand(1), m_Zero())) -    return nullptr; - -  ConstantInt *AndRHS; -  Value *LHS = IC->getOperand(0); -  if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS)))) -    return nullptr; - -  // If both select arms are non-zero see if we have a select of the form -  // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic -  // for 'x ? 2^n : 0' and fix the thing up at the end. -  APInt Offset(TrueVal.getBitWidth(), 0); -  if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { -    if ((TrueVal - FalseVal).isPowerOf2()) -      Offset = FalseVal; -    else if ((FalseVal - TrueVal).isPowerOf2()) -      Offset = TrueVal; -    else -      return nullptr; - -    // Adjust TrueVal and FalseVal to the offset. -    TrueVal -= Offset; -    FalseVal -= Offset; -  } - -  // Make sure the mask in the 'and' and one of the select arms is a power of 2. -  if (!AndRHS->getValue().isPowerOf2() || -      (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())) -    return nullptr; - -  // Determine which shift is needed to transform result of the 'and' into the -  // desired result. -  const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; -  unsigned ValZeros = ValC.logBase2(); -  unsigned AndZeros = AndRHS->getValue().logBase2(); - -  // If types don't match we can still convert the select by introducing a zext -  // or a trunc of the 'and'. The trunc case requires that all of the truncated -  // bits are zero, we can figure that out by looking at the 'and' mask. -  if (AndZeros >= ValC.getBitWidth()) -    return nullptr; - -  Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType()); -  if (ValZeros > AndZeros) -    V = Builder.CreateShl(V, ValZeros - AndZeros); -  else if (ValZeros < AndZeros) -    V = Builder.CreateLShr(V, AndZeros - ValZeros); - -  // Okay, now we know that everything is set up, we just don't know whether we -  // have a icmp_ne or icmp_eq and whether the true or false val is the zero. -  bool ShouldNotVal = !TrueVal.isNullValue(); -  ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; -  if (ShouldNotVal) -    V = Builder.CreateXor(V, ValC); - -  // Apply an offset if needed. -  if (!Offset.isNullValue()) -    V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); -  return V; -} -  /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))).  /// This is even legal for FP.  static Instruction *foldAddSubSelect(SelectInst &SI, @@ -1151,12 +1217,100 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,    return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType());  } +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +///   %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +///   %1 = extractvalue { i64, i1 } %0, 1 +///   %2 = extractvalue { i64, i1 } %0, 0 +///   %3 = select i1 %1, i64 %compare, i64 %2 +///   ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +///   %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +///   %1 = extractvalue { i64, i1 } %0, 0 +///   ret i64 %1 +/// +static Instruction *foldSelectCmpXchg(SelectInst &SI) { +  // A helper that determines if V is an extractvalue instruction whose +  // aggregate operand is a cmpxchg instruction and whose single index is equal +  // to I. If such conditions are true, the helper returns the cmpxchg +  // instruction; otherwise, a nullptr is returned. +  auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { +    auto *Extract = dyn_cast<ExtractValueInst>(V); +    if (!Extract) +      return nullptr; +    if (Extract->getIndices()[0] != I) +      return nullptr; +    return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); +  }; + +  // If the select has a single user, and this user is a select instruction that +  // we can simplify, skip the cmpxchg simplification for now. +  if (SI.hasOneUse()) +    if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) +      if (Select->getCondition() == SI.getCondition()) +        if (Select->getFalseValue() == SI.getTrueValue() || +            Select->getTrueValue() == SI.getFalseValue()) +          return nullptr; + +  // Ensure the select condition is the returned flag of a cmpxchg instruction. +  auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); +  if (!CmpXchg) +    return nullptr; + +  // Check the true value case: The true value of the select is the returned +  // value of the same cmpxchg used by the condition, and the false value is the +  // cmpxchg instruction's compare operand. +  if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) +    if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { +      SI.setTrueValue(SI.getFalseValue()); +      return &SI; +    } + +  // Check the false value case: The false value of the select is the returned +  // value of the same cmpxchg used by the condition, and the true value is the +  // cmpxchg instruction's compare operand. +  if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) +    if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { +      SI.setTrueValue(SI.getFalseValue()); +      return &SI; +    } + +  return nullptr; +} +  Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {    Value *CondVal = SI.getCondition();    Value *TrueVal = SI.getTrueValue();    Value *FalseVal = SI.getFalseValue();    Type *SelType = SI.getType(); +  // FIXME: Remove this workaround when freeze related patches are done. +  // For select with undef operand which feeds into an equality comparison, +  // don't simplify it so loop unswitch can know the equality comparison +  // may have an undef operand. This is a workaround for PR31652 caused by +  // descrepancy about branch on undef between LoopUnswitch and GVN. +  if (isa<UndefValue>(TrueVal) || isa<UndefValue>(FalseVal)) { +    if (llvm::any_of(SI.users(), [&](User *U) { +          ICmpInst *CI = dyn_cast<ICmpInst>(U); +          if (CI && CI->isEquality()) +            return true; +          return false; +        })) { +      return nullptr; +    } +  } +    if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal,                                      SQ.getWithInstruction(&SI)))      return replaceInstUsesWith(SI, V); @@ -1246,12 +1400,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {      }    } -  if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) -    if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) -      if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(), -                                       FalseValC->getValue(), Builder)) -        return replaceInstUsesWith(SI, V); -    // See if we are selecting two values based on a comparison of the two values.    if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) {      if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { @@ -1373,9 +1521,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {      auto SPF = SPR.Flavor;      if (SelectPatternResult::isMinOrMax(SPF)) { -      // Canonicalize so that type casts are outside select patterns. -      if (LHS->getType()->getPrimitiveSizeInBits() != -          SelType->getPrimitiveSizeInBits()) { +      // Canonicalize so that +      // - type casts are outside select patterns. +      // - float clamp is transformed to min/max pattern + +      bool IsCastNeeded = LHS->getType() != SelType; +      Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); +      Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); +      if (IsCastNeeded || +          (LHS->getType()->isFPOrFPVectorTy() && +           ((CmpLHS != LHS && CmpLHS != RHS) || +            (CmpRHS != LHS && CmpRHS != RHS)))) {          CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered);          Value *Cmp; @@ -1388,10 +1544,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {            Cmp = Builder.CreateFCmp(Pred, LHS, RHS);          } -        Value *NewSI = Builder.CreateCast( -            CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), -            SelType); -        return replaceInstUsesWith(SI, NewSI); +        Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); +        if (!IsCastNeeded) +          return replaceInstUsesWith(SI, NewSI); + +        Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); +        return replaceInstUsesWith(SI, NewCast);        }      } @@ -1485,6 +1643,46 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {      }    } +  // Try to simplify a binop sandwiched between 2 selects with the same +  // condition. +  // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) +  BinaryOperator *TrueBO; +  if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { +    if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { +      if (TrueBOSI->getCondition() == CondVal) { +        TrueBO->setOperand(0, TrueBOSI->getTrueValue()); +        Worklist.Add(TrueBO); +        return &SI; +      } +    } +    if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { +      if (TrueBOSI->getCondition() == CondVal) { +        TrueBO->setOperand(1, TrueBOSI->getTrueValue()); +        Worklist.Add(TrueBO); +        return &SI; +      } +    } +  } + +  // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) +  BinaryOperator *FalseBO; +  if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { +    if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { +      if (FalseBOSI->getCondition() == CondVal) { +        FalseBO->setOperand(0, FalseBOSI->getFalseValue()); +        Worklist.Add(FalseBO); +        return &SI; +      } +    } +    if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { +      if (FalseBOSI->getCondition() == CondVal) { +        FalseBO->setOperand(1, FalseBOSI->getFalseValue()); +        Worklist.Add(FalseBO); +        return &SI; +      } +    } +  } +    if (BinaryOperator::isNot(CondVal)) {      SI.setOperand(0, BinaryOperator::getNotArgument(CondVal));      SI.setOperand(1, FalseVal); @@ -1501,10 +1699,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {          return replaceInstUsesWith(SI, V);        return &SI;      } - -    if (isa<ConstantAggregateZero>(CondVal)) { -      return replaceInstUsesWith(SI, FalseVal); -    }    }    // See if we can determine the result of this select based on a dominating @@ -1515,9 +1709,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {      if (PBI && PBI->isConditional() &&          PBI->getSuccessor(0) != PBI->getSuccessor(1) &&          (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { -      bool CondIsFalse = PBI->getSuccessor(1) == Parent; +      bool CondIsTrue = PBI->getSuccessor(0) == Parent;        Optional<bool> Implication = isImpliedCondition( -        PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); +          PBI->getCondition(), SI.getCondition(), DL, CondIsTrue);        if (Implication) {          Value *V = *Implication ? TrueVal : FalseVal;          return replaceInstUsesWith(SI, V); @@ -1542,5 +1736,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {    if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder))      return BitCastSel; +  // Simplify selects that test the returned flag of cmpxchg instructions. +  if (Instruction *Select = foldSelectCmpXchg(SI)) +    return Select; +    return nullptr;  } | 
