diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 104 |
1 files changed, 82 insertions, 22 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 64294838644f..fbff5dd4a8cd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -138,24 +138,6 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( return Ret; } -// Try to replace `undef` constants in C with Replacement. -static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) { - if (C && match(C, m_Undef())) - return Replacement; - - if (auto *CV = dyn_cast<ConstantVector>(C)) { - llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands()); - for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) { - Constant *EltC = CV->getOperand(i); - NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC; - } - return ConstantVector::get(NewOps); - } - - // Don't know how to deal with this constant. - return C; -} - // If we have some pattern that leaves only some low bits set, and then performs // left-shift of those bits, if none of the bits that are left after the final // shift are modified by the mask, we can omit the mask. @@ -180,10 +162,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, "The input must be 'shl'!"); Value *Masked, *ShiftShAmt; - match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt))); + match(OuterShift, + m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt)))); + + // *If* there is a truncation between an outer shift and a possibly-mask, + // then said truncation *must* be one-use, else we can't perform the fold. + Value *Trunc; + if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) && + !Trunc->hasOneUse()) + return nullptr; Type *NarrowestTy = OuterShift->getType(); Type *WidestTy = Masked->getType(); + bool HadTrunc = WidestTy != NarrowestTy; + // The mask must be computed in a type twice as wide to ensure // that no bits are lost if the sum-of-shifts is wider than the base type. Type *ExtendedTy = WidestTy->getExtendedType(); @@ -204,6 +196,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, Constant *NewMask; if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (MaskShAmt->getType() != ShiftShAmt->getType()) + return nullptr; + // Can we simplify (MaskShAmt+ShiftShAmt) ? auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); @@ -216,7 +216,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // completely unknown. Replace the the `undef` shift amounts with final // shift bitwidth to ensure that the value remains undef when creating the // subsequent shift op. - SumOfShAmts = replaceUndefsWith( + SumOfShAmts = Constant::replaceUndefsWith( SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), ExtendedTy->getScalarSizeInBits())); auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); @@ -228,6 +228,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (MaskShAmt->getType() != ShiftShAmt->getType()) + return nullptr; + // Can we simplify (ShiftShAmt-MaskShAmt) ? auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); @@ -241,7 +249,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // bitwidth of innermost shift to ensure that the value remains undef when // creating the subsequent shift op. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); - ShAmtsDiff = replaceUndefsWith( + ShAmtsDiff = Constant::replaceUndefsWith( ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -WidestTyBitWidth)); auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( @@ -272,10 +280,15 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return nullptr; } + // If we need to apply truncation, let's do it first, since we can. + // We have already ensured that the old truncation will go away. + if (HadTrunc) + X = Builder.CreateTrunc(X, NarrowestTy); + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. + // We didn't change the Type of this outermost shift, so we can just do it. auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, OuterShift->getOperand(1)); - if (!NeedMask) return NewShift; @@ -283,6 +296,50 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } +/// If we have a shift-by-constant of a bitwise logic op that itself has a +/// shift-by-constant operand with identical opcode, we may be able to convert +/// that into 2 independent shifts followed by the logic op. This eliminates a +/// a use of an intermediate value (reduces dependency chain). +static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isShift() && "Expected a shift as input"); + auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + return nullptr; + + const APInt *C0, *C1; + if (!match(I.getOperand(1), m_APInt(C1))) + return nullptr; + + Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + Type *Ty = I.getType(); + + // Find a matching one-use shift by constant. The fold is not valid if the sum + // of the shift values equals or exceeds bitwidth. + // TODO: Remove the one-use check if the other logic operand (Y) is constant. + Value *X, *Y; + auto matchFirstShift = [&](Value *V) { + return !isa<ConstantExpr>(V) && + match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) && + cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode && + (*C0 + *C1).ult(Ty->getScalarSizeInBits()); + }; + + // Logic ops are commutative, so check each operand for a match. + if (matchFirstShift(LogicInst->getOperand(0))) + Y = LogicInst->getOperand(1); + else if (matchFirstShift(LogicInst->getOperand(1))) + Y = LogicInst->getOperand(0); + else + return nullptr; + + // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1); + Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -335,6 +392,9 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return &I; } + if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + return Logic; + return nullptr; } |