diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp | 135 | 
1 files changed, 102 insertions, 33 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 351fc3b0174f..7f2018b3a199 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {          getComplexity(I.getOperand(1)))        Changed = !I.swapOperands(); +    if (I.isCommutative()) { +      if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) { +        replaceOperand(I, 0, Pair->first); +        replaceOperand(I, 1, Pair->second); +        Changed = true; +      } +    } +      BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));      BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1)); @@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {    return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);  } -std::optional<std::pair<Value *, Value *>> -InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) { +static std::optional<std::pair<Value *, Value *>> +matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {    if (LHS->getParent() != RHS->getParent())      return std::nullopt; @@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {    return std::optional(std::pair(L0, R0));  } -Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, -                                                        Value *Op0, -                                                        Value *Op1) { -  assert(I.isCommutative() && "Instruction should be commutative"); - -  PHINode *LHS = dyn_cast<PHINode>(Op0); -  PHINode *RHS = dyn_cast<PHINode>(Op1); - -  if (!LHS || !RHS) -    return nullptr; - -  if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) { -    Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second); -    if (auto *BO = dyn_cast<BinaryOperator>(BI)) -      BO->copyIRFlags(&I); -    return BI; +std::optional<std::pair<Value *, Value *>> +InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) { +  Instruction *LHSInst = dyn_cast<Instruction>(LHS); +  Instruction *RHSInst = dyn_cast<Instruction>(RHS); +  if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode()) +    return std::nullopt; +  switch (LHSInst->getOpcode()) { +  case Instruction::PHI: +    return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS)); +  case Instruction::Select: { +    Value *Cond = LHSInst->getOperand(0); +    Value *TrueVal = LHSInst->getOperand(1); +    Value *FalseVal = LHSInst->getOperand(2); +    if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) && +        FalseVal == RHSInst->getOperand(1)) +      return std::pair(TrueVal, FalseVal); +    return std::nullopt; +  } +  case Instruction::Call: { +    // Match min(a, b) and max(a, b) +    MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst); +    MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst); +    if (LHSMinMax && RHSMinMax && +        LHSMinMax->getPredicate() == +            ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) && +        ((LHSMinMax->getLHS() == RHSMinMax->getLHS() && +          LHSMinMax->getRHS() == RHSMinMax->getRHS()) || +         (LHSMinMax->getLHS() == RHSMinMax->getRHS() && +          LHSMinMax->getRHS() == RHSMinMax->getLHS()))) +      return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS()); +    return std::nullopt; +  } +  default: +    return std::nullopt;    } - -  return nullptr;  }  Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, @@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,    };    if (LHSIsSelect && RHSIsSelect && A == D) { -    // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y) -    if (I.isCommutative() && B == F && C == E) { -      Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E); -      if (auto *BO = dyn_cast<BinaryOperator>(BI)) -        BO->copyIRFlags(&I); -      return BI; -    } -      // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)      Cond = A;      True = simplifyBinOp(Opcode, B, E, FMF, Q); @@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {        BO.getParent() != Phi1->getParent())      return nullptr; -  if (BO.isCommutative()) { -    if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1)) -      return replaceInstUsesWith(BO, V); -  } -    // Fold if there is at least one specific constant value in phi0 or phi1's    // incoming values that comes from the same block and this specific constant    // value can be used to do optimization for specific binary operator. @@ -3197,6 +3208,64 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {      return replaceOperand(SI, 0, Op0);    } +  ConstantInt *SubLHS; +  if (match(Cond, m_Sub(m_ConstantInt(SubLHS), m_Value(Op0)))) { +    // Change 'switch (1-X) case 1:' into 'switch (X) case 0'. +    for (auto Case : SI.cases()) { +      Constant *NewCase = ConstantExpr::getSub(SubLHS, Case.getCaseValue()); +      assert(isa<ConstantInt>(NewCase) && +             "Result of expression should be constant"); +      Case.setValue(cast<ConstantInt>(NewCase)); +    } +    return replaceOperand(SI, 0, Op0); +  } + +  uint64_t ShiftAmt; +  if (match(Cond, m_Shl(m_Value(Op0), m_ConstantInt(ShiftAmt))) && +      ShiftAmt < Op0->getType()->getScalarSizeInBits() && +      all_of(SI.cases(), [&](const auto &Case) { +        return Case.getCaseValue()->getValue().countr_zero() >= ShiftAmt; +      })) { +    // Change 'switch (X << 2) case 4:' into 'switch (X) case 1:'. +    OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(Cond); +    if (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap() || +        Shl->hasOneUse()) { +      Value *NewCond = Op0; +      if (!Shl->hasNoUnsignedWrap() && !Shl->hasNoSignedWrap()) { +        // If the shift may wrap, we need to mask off the shifted bits. +        unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); +        NewCond = Builder.CreateAnd( +            Op0, APInt::getLowBitsSet(BitWidth, BitWidth - ShiftAmt)); +      } +      for (auto Case : SI.cases()) { +        const APInt &CaseVal = Case.getCaseValue()->getValue(); +        APInt ShiftedCase = Shl->hasNoSignedWrap() ? CaseVal.ashr(ShiftAmt) +                                                   : CaseVal.lshr(ShiftAmt); +        Case.setValue(ConstantInt::get(SI.getContext(), ShiftedCase)); +      } +      return replaceOperand(SI, 0, NewCond); +    } +  } + +  // Fold switch(zext/sext(X)) into switch(X) if possible. +  if (match(Cond, m_ZExtOrSExt(m_Value(Op0)))) { +    bool IsZExt = isa<ZExtInst>(Cond); +    Type *SrcTy = Op0->getType(); +    unsigned NewWidth = SrcTy->getScalarSizeInBits(); + +    if (all_of(SI.cases(), [&](const auto &Case) { +          const APInt &CaseVal = Case.getCaseValue()->getValue(); +          return IsZExt ? CaseVal.isIntN(NewWidth) +                        : CaseVal.isSignedIntN(NewWidth); +        })) { +      for (auto &Case : SI.cases()) { +        APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); +        Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); +      } +      return replaceOperand(SI, 0, Op0); +    } +  } +    KnownBits Known = computeKnownBits(Cond, 0, &SI);    unsigned LeadingKnownZeros = Known.countMinLeadingZeros();    unsigned LeadingKnownOnes = Known.countMinLeadingOnes();  | 
