summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp370
1 files changed, 338 insertions, 32 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c821292400cd..64294838644f 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -25,50 +25,275 @@ using namespace PatternMatch;
// we should rewrite it as
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
// This is valid for any shift, but they must be identical.
-static Instruction *
-reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
- const SimplifyQuery &SQ) {
- // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1
- Value *X, *ShAmt1, *ShAmt0;
+//
+// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
+// pattern has any 2 right-shifts that sum to 1 less than original bit width.
+Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
+ BinaryOperator *Sh0, const SimplifyQuery &SQ,
+ bool AnalyzeForSignBitExtraction) {
+ // Look for a shift of some instruction, ignore zext of shift amount if any.
+ Instruction *Sh0Op0;
+ Value *ShAmt0;
+ if (!match(Sh0,
+ m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0)))))
+ return nullptr;
+
+ // If there is a truncation between the two shifts, we must make note of it
+ // and look through it. The truncation imposes additional constraints on the
+ // transform.
Instruction *Sh1;
- if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)),
- m_Instruction(Sh1)),
- m_Value(ShAmt0))))
+ Value *Trunc = nullptr;
+ match(Sh0Op0,
+ m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)),
+ m_Instruction(Sh1)));
+
+ // Inner shift: (x shiftopcode ShAmt1)
+ // Like with other shift, ignore zext of shift amount if any.
+ Value *X, *ShAmt1;
+ if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
+ return nullptr;
+
+ // We have two shift amounts from two different shifts. The types of those
+ // shift amounts may not match. If that's the case let's bailout now..
+ if (ShAmt0->getType() != ShAmt1->getType())
+ return nullptr;
+
+ // We are only looking for signbit extraction if we have two right shifts.
+ bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
+ match(Sh1, m_Shr(m_Value(), m_Value()));
+ // ... and if it's not two right-shifts, we know the answer already.
+ if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
return nullptr;
- // The shift opcodes must be identical.
+ // The shift opcodes must be identical, unless we are just checking whether
+ // this pattern can be interpreted as a sign-bit-extraction.
Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
- if (ShiftOpcode != Sh1->getOpcode())
+ bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
+ if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
return nullptr;
+
+ // If we saw truncation, we'll need to produce extra instruction,
+ // and for that one of the operands of the shift must be one-use,
+ // unless of course we don't actually plan to produce any instructions here.
+ if (Trunc && !AnalyzeForSignBitExtraction &&
+ !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
// Can we fold (ShAmt0+ShAmt1) ?
- Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1,
- SQ.getWithInstruction(Sh0));
+ auto *NewShAmt = dyn_cast_or_null<Constant>(
+ SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false,
+ SQ.getWithInstruction(Sh0)));
if (!NewShAmt)
return nullptr; // Did not simplify.
- // Is the new shift amount smaller than the bit width?
- // FIXME: could also rely on ConstantRange.
- unsigned BitWidth = X->getType()->getScalarSizeInBits();
+ unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits();
+ unsigned XBitWidth = X->getType()->getScalarSizeInBits();
+ // Is the new shift amount smaller than the bit width of inner/new shift?
if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
- APInt(BitWidth, BitWidth))))
- return nullptr;
+ APInt(NewShAmtBitWidth, XBitWidth))))
+ return nullptr; // FIXME: could perform constant-folding.
+
+ // If there was a truncation, and we have a right-shift, we can only fold if
+ // we are left with the original sign bit. Likewise, if we were just checking
+ // that this is a sighbit extraction, this is the place to check it.
+ // FIXME: zero shift amount is also legal here, but we can't *easily* check
+ // more than one predicate so it's not really worth it.
+ if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
+ // If it's not a sign bit extraction, then we're done.
+ if (!match(NewShAmt,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(NewShAmtBitWidth, XBitWidth - 1))))
+ return nullptr;
+ // If it is, and that was the question, return the base value.
+ if (AnalyzeForSignBitExtraction)
+ return X;
+ }
+
+ assert(IdenticalShOpcodes && "Should not get here with different shifts.");
+
// All good, we can do this fold.
+ NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
+
BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
- // If both of the original shifts had the same flag set, preserve the flag.
- if (ShiftOpcode == Instruction::BinaryOps::Shl) {
- NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
- Sh1->hasNoUnsignedWrap());
- NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
- Sh1->hasNoSignedWrap());
- } else {
- NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
+
+ // The flags can only be propagated if there wasn't a trunc.
+ if (!Trunc) {
+ // If the pattern did not involve trunc, and both of the original shifts
+ // had the same flag set, preserve the flag.
+ if (ShiftOpcode == Instruction::BinaryOps::Shl) {
+ NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
+ Sh1->hasNoUnsignedWrap());
+ NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
+ Sh1->hasNoSignedWrap());
+ } else {
+ NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
+ }
+ }
+
+ Instruction *Ret = NewShift;
+ if (Trunc) {
+ Builder.Insert(NewShift);
+ Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType());
+ }
+
+ return Ret;
+}
+
+// Try to replace `undef` constants in C with Replacement.
+static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) {
+ if (C && match(C, m_Undef()))
+ return Replacement;
+
+ if (auto *CV = dyn_cast<ConstantVector>(C)) {
+ llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands());
+ for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) {
+ Constant *EltC = CV->getOperand(i);
+ NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC;
+ }
+ return ConstantVector::get(NewOps);
+ }
+
+ // Don't know how to deal with this constant.
+ return C;
+}
+
+// If we have some pattern that leaves only some low bits set, and then performs
+// left-shift of those bits, if none of the bits that are left after the final
+// shift are modified by the mask, we can omit the mask.
+//
+// There are many variants to this pattern:
+// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
+// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
+// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt
+// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt
+// e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt
+// f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt
+// All these patterns can be simplified to just:
+// x << ShiftShAmt
+// iff:
+// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
+// c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
+static Instruction *
+dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
+ const SimplifyQuery &Q,
+ InstCombiner::BuilderTy &Builder) {
+ assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
+ "The input must be 'shl'!");
+
+ Value *Masked, *ShiftShAmt;
+ match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt)));
+
+ Type *NarrowestTy = OuterShift->getType();
+ Type *WidestTy = Masked->getType();
+ // The mask must be computed in a type twice as wide to ensure
+ // that no bits are lost if the sum-of-shifts is wider than the base type.
+ Type *ExtendedTy = WidestTy->getExtendedType();
+
+ Value *MaskShAmt;
+
+ // ((1 << MaskShAmt) - 1)
+ auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
+ // (~(-1 << maskNbits))
+ auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
+ // (-1 >> MaskShAmt)
+ auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
+ // ((-1 << MaskShAmt) >> MaskShAmt)
+ auto MaskD =
+ m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt));
+
+ Value *X;
+ Constant *NewMask;
+
+ if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
+ // Can we simplify (MaskShAmt+ShiftShAmt) ?
+ auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst(
+ MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
+ if (!SumOfShAmts)
+ return nullptr; // Did not simplify.
+ // In this pattern SumOfShAmts correlates with the number of low bits
+ // that shall remain in the root value (OuterShift).
+
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with final
+ // shift bitwidth to ensure that the value remains undef when creating the
+ // subsequent shift op.
+ SumOfShAmts = replaceUndefsWith(
+ SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
+ ExtendedTy->getScalarSizeInBits()));
+ auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+ // And compute the mask as usual: ~(-1 << (SumOfShAmts))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ auto *ExtendedInvertedMask =
+ ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
+ NewMask = ConstantExpr::getNot(ExtendedInvertedMask);
+ } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
+ match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
+ m_Deferred(MaskShAmt)))) {
+ // Can we simplify (ShiftShAmt-MaskShAmt) ?
+ auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst(
+ ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
+ if (!ShAmtsDiff)
+ return nullptr; // Did not simplify.
+ // In this pattern ShAmtsDiff correlates with the number of high bits that
+ // shall be unset in the root value (OuterShift).
+
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with negated
+ // bitwidth of innermost shift to ensure that the value remains undef when
+ // creating the subsequent shift op.
+ unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
+ ShAmtsDiff = replaceUndefsWith(
+ ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
+ -WidestTyBitWidth));
+ auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+ ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
+ WidestTyBitWidth,
+ /*isSigned=*/false),
+ ShAmtsDiff),
+ ExtendedTy);
+ // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ NewMask =
+ ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+ } else
+ return nullptr; // Don't know anything about this pattern.
+
+ NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy);
+
+ // Does this mask has any unset bits? If not then we can just not apply it.
+ bool NeedMask = !match(NewMask, m_AllOnes());
+
+ // If we need to apply a mask, there are several more restrictions we have.
+ if (NeedMask) {
+ // The old masking instruction must go away.
+ if (!Masked->hasOneUse())
+ return nullptr;
+ // The original "masking" instruction must not have been`ashr`.
+ if (match(Masked, m_AShr(m_Value(), m_Value())))
+ return nullptr;
}
- return NewShift;
+
+ // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
+ auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
+ OuterShift->getOperand(1));
+
+ if (!NeedMask)
+ return NewShift;
+
+ Builder.Insert(NewShift);
+ return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
}
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
+ // If the shift amount is a one-use `sext`, we can demote it to `zext`.
+ Value *Y;
+ if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
+ Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName());
+ return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
+ }
+
// See if we can fold away this shift.
if (SimplifyDemandedInstructionBits(I))
return &I;
@@ -83,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res;
- if (Instruction *NewShift =
- reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))
+ if (auto *NewShift = cast_or_null<Instruction>(
+ reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
return NewShift;
// (C1 shift (A add C2)) -> (C1 shift C2) shift A)
@@ -618,9 +843,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
Instruction *InstCombiner::visitShl(BinaryOperator &I) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+
if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
- I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
- SQ.getWithInstruction(&I)))
+ I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q))
return replaceInstUsesWith(I, V);
if (Instruction *X = foldVectorBinop(I))
@@ -629,6 +855,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Instruction *V = commonShiftTransforms(I))
return V;
+ if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
+ return V;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
@@ -636,12 +865,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
const APInt *ShAmtAPInt;
if (match(Op1, m_APInt(ShAmtAPInt))) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
- unsigned BitWidth = Ty->getScalarSizeInBits();
// shl (zext X), ShAmt --> zext (shl X, ShAmt)
// This is only valid if X would have zeros shifted out.
Value *X;
- if (match(Op0, m_ZExt(m_Value(X)))) {
+ if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmt < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
@@ -719,6 +947,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
// (X * C2) << C1 --> X * (C2 << C1)
if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
+
+ // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
+ if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+ auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1);
+ return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
+ }
}
// (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
@@ -859,6 +1093,75 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
return nullptr;
}
+Instruction *
+InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract(
+ BinaryOperator &OldAShr) {
+ assert(OldAShr.getOpcode() == Instruction::AShr &&
+ "Must be called with arithmetic right-shift instruction only.");
+
+ // Check that constant C is a splat of the element-wise bitwidth of V.
+ auto BitWidthSplat = [](Constant *C, Value *V) {
+ return match(
+ C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(C->getType()->getScalarSizeInBits(),
+ V->getType()->getScalarSizeInBits())));
+ };
+
+ // It should look like variable-length sign-extension on the outside:
+ // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
+ Value *NBits;
+ Instruction *MaybeTrunc;
+ Constant *C1, *C2;
+ if (!match(&OldAShr,
+ m_AShr(m_Shl(m_Instruction(MaybeTrunc),
+ m_ZExtOrSelf(m_Sub(m_Constant(C1),
+ m_ZExtOrSelf(m_Value(NBits))))),
+ m_ZExtOrSelf(m_Sub(m_Constant(C2),
+ m_ZExtOrSelf(m_Deferred(NBits)))))) ||
+ !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
+ return nullptr;
+
+ // There may or may not be a truncation after outer two shifts.
+ Instruction *HighBitExtract;
+ match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
+ bool HadTrunc = MaybeTrunc != HighBitExtract;
+
+ // And finally, the innermost part of the pattern must be a right-shift.
+ Value *X, *NumLowBitsToSkip;
+ if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
+ return nullptr;
+
+ // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
+ Constant *C0;
+ if (!match(NumLowBitsToSkip,
+ m_ZExtOrSelf(
+ m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
+ !BitWidthSplat(C0, HighBitExtract))
+ return nullptr;
+
+ // Since the NBits is identical for all shifts, if the outermost and
+ // innermost shifts are identical, then outermost shifts are redundant.
+ // If we had truncation, do keep it though.
+ if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
+ return replaceInstUsesWith(OldAShr, MaybeTrunc);
+
+ // Else, if there was a truncation, then we need to ensure that one
+ // instruction will go away.
+ if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
+ // Finally, bypass two innermost shifts, and perform the outermost shift on
+ // the operands of the innermost shift.
+ Instruction *NewAShr =
+ BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
+ NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
+ if (!HadTrunc)
+ return NewAShr;
+
+ Builder.Insert(NewAShr);
+ return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
+}
+
Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
@@ -933,6 +1236,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
}
}
+ if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
+ return R;
+
// See if we can turn a signed shr into an unsigned shr.
if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
return BinaryOperator::CreateLShr(Op0, Op1);