diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-24 15:03:44 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2022-07-24 15:03:44 +0000 |
| commit | 4b4fe385e49bd883fd183b5f21c1ea486c722e61 (patch) | |
| tree | c3d8fdb355c9c73e57723718c22103aaf7d15aa6 /llvm/lib/Transforms/InstCombine | |
| parent | 1f917f69ff07f09b6dbb670971f57f8efe718b84 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
7 files changed, 170 insertions, 69 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 535a7736454c..4a459ec6c550 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1966,12 +1966,14 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); } - // If there's no chance any bit will need to borrow from an adjacent bit: - // sub C, X --> xor X, C const APInt *Op0C; - if (match(Op0, m_APInt(Op0C)) && - (~computeKnownBits(Op1, 0, &I).Zero).isSubsetOf(*Op0C)) - return BinaryOperator::CreateXor(Op1, Op0); + if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { + // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known + // zero. + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateXor(Op1, Op0); + } { Value *Y; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index a8f2cd79830a..8253c575bc37 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2664,8 +2664,8 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // Inverted form (example): // (icmp slt (X | Y), 0) & (icmp sgt (X & Y), -1) -> (icmp slt (X ^ Y), 0) bool TrueIfSignedL, TrueIfSignedR; - if (InstCombiner::isSignBitCheck(PredL, *LHSC, TrueIfSignedL) && - InstCombiner::isSignBitCheck(PredR, *RHSC, TrueIfSignedR) && + if (isSignBitCheck(PredL, *LHSC, TrueIfSignedL) && + isSignBitCheck(PredR, *RHSC, TrueIfSignedR) && (RHS->hasOneUse() || LHS->hasOneUse())) { Value *X, *Y; if (IsAnd) { @@ -3202,25 +3202,38 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, // TODO: This can be generalized to compares of non-signbits using // decomposeBitTestICmp(). It could be enhanced more by using (something like) // foldLogOpOfMaskedICmps(). - if ((LHS->hasOneUse() || RHS->hasOneUse()) && + const APInt *LC, *RC; + if (match(LHS1, m_APInt(LC)) && match(RHS1, m_APInt(RC)) && LHS0->getType() == RHS0->getType() && - LHS0->getType()->isIntOrIntVectorTy()) { + LHS0->getType()->isIntOrIntVectorTy() && + (LHS->hasOneUse() || RHS->hasOneUse())) { + // Convert xor of signbit tests to signbit test of xor'd values: // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0 // (X < 0) ^ (Y < 0) --> (X ^ Y) < 0 - if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && - PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes())) || - (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && - PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) - return Builder.CreateIsNeg(Builder.CreateXor(LHS0, RHS0)); - // (X > -1) ^ (Y < 0) --> (X ^ Y) > -1 // (X < 0) ^ (Y > -1) --> (X ^ Y) > -1 - if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && - PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero())) || - (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && - PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) - return Builder.CreateIsNotNeg(Builder.CreateXor(LHS0, RHS0)); + bool TrueIfSignedL, TrueIfSignedR; + if (isSignBitCheck(PredL, *LC, TrueIfSignedL) && + isSignBitCheck(PredR, *RC, TrueIfSignedR)) { + Value *XorLR = Builder.CreateXor(LHS0, RHS0); + return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR) : + Builder.CreateIsNotNeg(XorLR); + } + // (X > C) ^ (X < C + 2) --> X != C + 1 + // (X < C + 2) ^ (X > C) --> X != C + 1 + // Considering the correctness of this pattern, we should avoid that C is + // non-negative and C + 2 is negative, although it will be matched by other + // patterns. + const APInt *C1, *C2; + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_APInt(C1)) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_APInt(C2))) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_APInt(C2)) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_APInt(C1)))) + if (LHS0 == RHS0 && *C1 + 2 == *C2 && + (C1->isNegative() || C2->isNonNegative())) + return Builder.CreateICmpNE(LHS0, + ConstantInt::get(LHS0->getType(), *C1 + 1)); } // Instead of trying to imitate the folds for and/or, decompose this 'xor' diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index edfdf70c2b97..bc01d2ef7fe2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1140,8 +1140,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); - if (isFreeCall(&CI, &TLI)) - return visitFree(CI); + if (Value *FreedOp = getFreedOperand(&CI, &TLI)) + return visitFree(CI, FreedOp); // If the caller function (i.e. us, the function that contains this CallInst) // is nounwind, mark the call as nounwind, even if the callee isn't. @@ -1539,8 +1539,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Type *Ty = II->getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); Constant *ShAmtC; - if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC)) && - !ShAmtC->containsConstantExpression()) { + if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC))) { // Canonicalize a shift amount constant operand to modulo the bit-width. Constant *WidthC = ConstantInt::get(Ty, BitWidth); Constant *ModuloC = @@ -2885,21 +2884,21 @@ bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, // of the respective allocator declaration with generic attributes. bool Changed = false; - if (isAllocationFn(&Call, TLI)) { - uint64_t Size; - ObjectSizeOpts Opts; - if (getObjectSize(&Call, Size, DL, TLI, Opts) && Size > 0) { - // TODO: We really should just emit deref_or_null here and then - // let the generic inference code combine that with nonnull. - if (Call.hasRetAttr(Attribute::NonNull)) { - Changed = !Call.hasRetAttr(Attribute::Dereferenceable); - Call.addRetAttr( - Attribute::getWithDereferenceableBytes(Call.getContext(), Size)); - } else { - Changed = !Call.hasRetAttr(Attribute::DereferenceableOrNull); - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Size)); - } + if (!Call.getType()->isPointerTy()) + return Changed; + + Optional<APInt> Size = getAllocSize(&Call, TLI); + if (Size && *Size != 0) { + // TODO: We really should just emit deref_or_null here and then + // let the generic inference code combine that with nonnull. + if (Call.hasRetAttr(Attribute::NonNull)) { + Changed = !Call.hasRetAttr(Attribute::Dereferenceable); + Call.addRetAttr(Attribute::getWithDereferenceableBytes( + Call.getContext(), Size->getLimitedValue())); + } else { + Changed = !Call.hasRetAttr(Attribute::DereferenceableOrNull); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size->getLimitedValue())); } } @@ -3079,8 +3078,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); } - if (isAllocationFn(&Call, &TLI) && - isAllocRemovable(&cast<CallBase>(Call), &TLI)) + if (isRemovableAlloc(&Call, &TLI)) return visitAllocSite(Call); // Handle intrinsics which can be used in both call and invoke context. @@ -3242,15 +3240,16 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // the call because there is no place to put the cast instruction (without // breaking the critical edge). Bail out in this case. if (!Caller->use_empty()) { - if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) - for (User *U : II->users()) + BasicBlock *PhisNotSupportedBlock = nullptr; + if (auto *II = dyn_cast<InvokeInst>(Caller)) + PhisNotSupportedBlock = II->getNormalDest(); + if (auto *CB = dyn_cast<CallBrInst>(Caller)) + PhisNotSupportedBlock = CB->getDefaultDest(); + if (PhisNotSupportedBlock) + for (User *U : Caller->users()) if (PHINode *PN = dyn_cast<PHINode>(U)) - if (PN->getParent() == II->getNormalDest() || - PN->getParent() == II->getUnwindDest()) + if (PN->getParent() == PhisNotSupportedBlock) return false; - // FIXME: Be conservative for callbr to avoid a quadratic search. - if (isa<CallBrInst>(Caller)) - return false; } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 9f6d36b85522..158d2e8289e0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2002,9 +2002,12 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, Constant::getNullValue(Mul->getType())); } + if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap())) + return nullptr; + // If the multiply does not wrap, try to divide the compare constant by the // multiplication factor. - if (Cmp.isEquality() && !MulC->isZero()) { + if (Cmp.isEquality()) { // (mul nsw X, MulC) == C --> X == C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); @@ -2017,7 +2020,40 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, } } - return nullptr; + Constant *NewC = nullptr; + + // FIXME: Add assert that Pred is not equal to ICMP_SGE, ICMP_SLE, + // ICMP_UGE, ICMP_ULE. + + if (Mul->hasNoSignedWrap()) { + if (MulC->isNegative()) { + // MININT / -1 --> overflow. + if (C.isMinSignedValue() && MulC->isAllOnes()) + return nullptr; + Pred = ICmpInst::getSwappedPredicate(Pred); + } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) + NewC = ConstantInt::get( + Mul->getType(), + APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); + if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) + NewC = ConstantInt::get( + Mul->getType(), + APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); + } + + if (Mul->hasNoUnsignedWrap()) { + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) + NewC = ConstantInt::get( + Mul->getType(), + APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) + NewC = ConstantInt::get( + Mul->getType(), + APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + } + + return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr; } /// Fold icmp (shl 1, Y), C. @@ -2235,13 +2271,22 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, bool IsAShr = Shr->getOpcode() == Instruction::AShr; const APInt *ShiftValC; - if (match(Shr->getOperand(0), m_APInt(ShiftValC))) { + if (match(X, m_APInt(ShiftValC))) { if (Cmp.isEquality()) return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC); + // (ShiftValC >> Y) >s -1 --> Y != 0 with ShiftValC < 0 + // (ShiftValC >> Y) <s 0 --> Y == 0 with ShiftValC < 0 + bool TrueIfSigned; + if (!IsAShr && ShiftValC->isNegative() && + isSignBitCheck(Pred, C, TrueIfSigned)) + return new ICmpInst(TrueIfSigned ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE, + Shr->getOperand(1), + ConstantInt::getNullValue(X->getType())); + // If the shifted constant is a power-of-2, test the shift amount directly: - // (ShiftValC >> X) >u C --> X <u (LZ(C) - LZ(ShiftValC)) - // (ShiftValC >> X) <u C --> X >=u (LZ(C-1) - LZ(ShiftValC)) + // (ShiftValC >> Y) >u C --> X <u (LZ(C) - LZ(ShiftValC)) + // (ShiftValC >> Y) <u C --> X >=u (LZ(C-1) - LZ(ShiftValC)) if (!IsAShr && ShiftValC->isPowerOf2() && (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_ULT)) { bool IsUGT = Pred == CmpInst::ICMP_UGT; @@ -2972,7 +3017,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { const APInt *C; bool TrueIfSigned; if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && - InstCombiner::isSignBitCheck(Pred, *C, TrueIfSigned)) { + isSignBitCheck(Pred, *C, TrueIfSigned)) { if (match(BCSrcOp, m_FPExt(m_Value(X))) || match(BCSrcOp, m_FPTrunc(m_Value(X)))) { // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 271154bb3f5a..827b25533513 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -152,7 +152,7 @@ public: Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); Instruction *visitAllocaInst(AllocaInst &AI); Instruction *visitAllocSite(Instruction &FI); - Instruction *visitFree(CallInst &FI); + Instruction *visitFree(CallInst &FI, Value *FreedOp); Instruction *visitLoadInst(LoadInst &LI); Instruction *visitStoreInst(StoreInst &SI); Instruction *visitAtomicRMWInst(AtomicRMWInst &SI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index f4e2d1239f0f..13c98b935adf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -566,6 +566,13 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, return false; return true; } + case Instruction::Mul: { + const APInt *MulConst; + // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`) + return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) && + MulConst->isNegatedPowerOf2() && + MulConst->countTrailingZeros() == NumBits; + } } } @@ -680,6 +687,17 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, isLeftShift, IC, DL)); return PN; } + case Instruction::Mul: { + assert(!isLeftShift && "Unexpected shift direction!"); + auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0)); + IC.InsertNewInstWith(Neg, *I); + unsigned TypeWidth = I->getType()->getScalarSizeInBits(); + APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits); + auto *And = BinaryOperator::CreateAnd(Neg, + ConstantInt::get(I->getType(), Mask)); + And->takeName(I); + return IC.InsertNewInstWith(And, *I); + } } } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 75520a0c8d5f..71c763de43b4 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -994,6 +994,24 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { return SelectInst::Create(X, TVal, FVal); } +static Constant *constantFoldOperationIntoSelectOperand( + Instruction &I, SelectInst *SI, Value *SO) { + auto *ConstSO = dyn_cast<Constant>(SO); + if (!ConstSO) + return nullptr; + + SmallVector<Constant *> ConstOps; + for (Value *Op : I.operands()) { + if (Op == SI) + ConstOps.push_back(ConstSO); + else if (auto *C = dyn_cast<Constant>(Op)) + ConstOps.push_back(C); + else + llvm_unreachable("Operands should be select or constant"); + } + return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); +} + static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, InstCombiner::BuilderTy &Builder) { if (auto *Cast = dyn_cast<CastInst>(&I)) @@ -1101,8 +1119,17 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, } } - Value *NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); - Value *NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); + // Make sure that one of the select arms constant folds successfully. + Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); + Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + if (!NewTV && !NewFV) + return nullptr; + + // Create an instruction for the arm that did not fold. + if (!NewTV) + NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + if (!NewFV) + NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } @@ -2774,13 +2801,14 @@ static bool isAllocSiteRemovable(Instruction *AI, continue; } - if (isFreeCall(I, &TLI) && getAllocationFamily(I, &TLI) == Family) { + if (getFreedOperand(cast<CallBase>(I), &TLI) == PI && + getAllocationFamily(I, &TLI) == Family) { assert(Family); Users.emplace_back(I); continue; } - if (isReallocLikeFn(I, &TLI) && + if (getReallocatedOperand(cast<CallBase>(I), &TLI) == PI && getAllocationFamily(I, &TLI) == Family) { assert(Family); Users.emplace_back(I); @@ -2805,7 +2833,7 @@ static bool isAllocSiteRemovable(Instruction *AI, } Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { - assert(isa<AllocaInst>(MI) || isAllocRemovable(&cast<CallBase>(MI), &TLI)); + assert(isa<AllocaInst>(MI) || isRemovableAlloc(&cast<CallBase>(MI), &TLI)); // If we have a malloc call which is only used in any amount of comparisons to // null and free calls, delete the calls and replace the comparisons with true @@ -3007,9 +3035,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, return &FI; } -Instruction *InstCombinerImpl::visitFree(CallInst &FI) { - Value *Op = FI.getArgOperand(0); - +Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { // free undef -> unreachable. if (isa<UndefValue>(Op)) { // Leave a marker since we can't modify the CFG here. @@ -3024,12 +3050,10 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI) { // If we had free(realloc(...)) with no intervening uses, then eliminate the // realloc() entirely. - if (CallInst *CI = dyn_cast<CallInst>(Op)) { - if (CI->hasOneUse() && isReallocLikeFn(CI, &TLI)) { - return eraseInstFromFunction( - *replaceInstUsesWith(*CI, CI->getOperand(0))); - } - } + CallInst *CI = dyn_cast<CallInst>(Op); + if (CI && CI->hasOneUse()) + if (Value *ReallocatedOp = getReallocatedOperand(CI, &TLI)) + return eraseInstFromFunction(*replaceInstUsesWith(*CI, ReallocatedOp)); // If we optimize for code size, try to move the call to free before the null // test so that simplify cfg can remove the empty block and dead code |
