summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineCasts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineCasts.cpp206
1 files changed, 170 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index dfdfd3e9da84..178c8eaf2502 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -235,8 +235,8 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1,
Type *MidTy = CI1->getDestTy();
Type *DstTy = CI2->getDestTy();
- Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode());
- Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode());
+ Instruction::CastOps firstOp = CI1->getOpcode();
+ Instruction::CastOps secondOp = CI2->getOpcode();
Type *SrcIntPtrTy =
SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr;
Type *MidIntPtrTy =
@@ -346,29 +346,50 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
}
break;
}
- case Instruction::Shl:
+ case Instruction::Shl: {
// If we are truncating the result of this SHL, and if it's a shift of a
// constant amount, we can always perform a SHL in a smaller type.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ const APInt *Amt;
+ if (match(I->getOperand(1), m_APInt(Amt))) {
uint32_t BitWidth = Ty->getScalarSizeInBits();
- if (CI->getLimitedValue(BitWidth) < BitWidth)
+ if (Amt->getLimitedValue(BitWidth) < BitWidth)
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
}
break;
- case Instruction::LShr:
+ }
+ case Instruction::LShr: {
// If this is a truncate of a logical shr, we can truncate it to a smaller
// lshr iff we know that the bits we would otherwise be shifting in are
// already zeros.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ const APInt *Amt;
+ if (match(I->getOperand(1), m_APInt(Amt))) {
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
uint32_t BitWidth = Ty->getScalarSizeInBits();
if (IC.MaskedValueIsZero(I->getOperand(0),
APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) &&
- CI->getLimitedValue(BitWidth) < BitWidth) {
+ Amt->getLimitedValue(BitWidth) < BitWidth) {
return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
}
}
break;
+ }
+ case Instruction::AShr: {
+ // If this is a truncate of an arithmetic shr, we can truncate it to a
+ // smaller ashr iff we know that all the bits from the sign bit of the
+ // original type and the sign bit of the truncate type are similar.
+ // TODO: It is enough to check that the bits we would be shifting in are
+ // similar to sign bit of the truncate type.
+ const APInt *Amt;
+ if (match(I->getOperand(1), m_APInt(Amt))) {
+ uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+ uint32_t BitWidth = Ty->getScalarSizeInBits();
+ if (Amt->getLimitedValue(BitWidth) < BitWidth &&
+ OrigBitWidth - BitWidth <
+ IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
+ }
+ break;
+ }
case Instruction::Trunc:
// trunc(trunc(x)) -> trunc(x)
return true;
@@ -443,24 +464,130 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) {
return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
}
-/// Try to narrow the width of bitwise logic instructions with constants.
-Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {
+/// Rotate left/right may occur in a wider type than necessary because of type
+/// promotion rules. Try to narrow all of the component instructions.
+Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
+ assert((isa<VectorType>(Trunc.getSrcTy()) ||
+ shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
+ "Don't narrow to an illegal scalar type");
+
+ // First, find an or'd pair of opposite shifts with the same shifted operand:
+ // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1))
+ Value *Or0, *Or1;
+ if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+ return nullptr;
+
+ Value *ShVal, *ShAmt0, *ShAmt1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
+ return nullptr;
+
+ auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
+ auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
+ if (ShiftOpcode0 == ShiftOpcode1)
+ return nullptr;
+
+ // The shift amounts must add up to the narrow bit width.
+ Value *ShAmt;
+ bool SubIsOnLHS;
+ Type *DestTy = Trunc.getType();
+ unsigned NarrowWidth = DestTy->getScalarSizeInBits();
+ if (match(ShAmt0,
+ m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) {
+ ShAmt = ShAmt1;
+ SubIsOnLHS = true;
+ } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth),
+ m_Specific(ShAmt0))))) {
+ ShAmt = ShAmt0;
+ SubIsOnLHS = false;
+ } else {
+ return nullptr;
+ }
+
+ // The shifted value must have high zeros in the wide type. Typically, this
+ // will be a zext, but it could also be the result of an 'and' or 'shift'.
+ unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
+ APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
+ if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc))
+ return nullptr;
+
+ // We have an unnecessarily wide rotate!
+ // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt))
+ // Narrow it down to eliminate the zext/trunc:
+ // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1')
+ Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
+ Value *NegShAmt = Builder.CreateNeg(NarrowShAmt);
+
+ // Mask both shift amounts to ensure there's no UB from oversized shifts.
+ Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1);
+ Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC);
+ Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC);
+
+ // Truncate the original value and use narrow ops.
+ Value *X = Builder.CreateTrunc(ShVal, DestTy);
+ Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt;
+ Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt;
+ Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0);
+ Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1);
+ return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1);
+}
+
+/// Try to narrow the width of math or bitwise logic instructions by pulling a
+/// truncate ahead of binary operators.
+/// TODO: Transforms for truncated shifts should be moved into here.
+Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
Type *SrcTy = Trunc.getSrcTy();
Type *DestTy = Trunc.getType();
- if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
+ if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
return nullptr;
- BinaryOperator *LogicOp;
- Constant *C;
- if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) ||
- !LogicOp->isBitwiseLogicOp() ||
- !match(LogicOp->getOperand(1), m_Constant(C)))
+ BinaryOperator *BinOp;
+ if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp))))
return nullptr;
- // trunc (logic X, C) --> logic (trunc X, C')
- Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy);
- Value *NarrowOp0 = Builder.CreateTrunc(LogicOp->getOperand(0), DestTy);
- return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC);
+ Value *BinOp0 = BinOp->getOperand(0);
+ Value *BinOp1 = BinOp->getOperand(1);
+ switch (BinOp->getOpcode()) {
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Mul: {
+ Constant *C;
+ if (match(BinOp0, m_Constant(C))) {
+ // trunc (binop C, X) --> binop (trunc C', X)
+ Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy);
+ Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy);
+ return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX);
+ }
+ if (match(BinOp1, m_Constant(C))) {
+ // trunc (binop X, C) --> binop (trunc X, C')
+ Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy);
+ Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy);
+ return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC);
+ }
+ Value *X;
+ if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) {
+ // trunc (binop (ext X), Y) --> binop X, (trunc Y)
+ Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy);
+ return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1);
+ }
+ if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) {
+ // trunc (binop Y, (ext X)) --> binop (trunc Y), X
+ Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy);
+ return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X);
+ }
+ break;
+ }
+
+ default: break;
+ }
+
+ if (Instruction *NarrowOr = narrowRotate(Trunc))
+ return NarrowOr;
+
+ return nullptr;
}
/// Try to narrow the width of a splat shuffle. This could be generalized to any
@@ -616,7 +743,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
}
}
- if (Instruction *I = shrinkBitwiseLogic(CI))
+ if (Instruction *I = narrowBinOp(CI))
return I;
if (Instruction *I = shrinkSplatShuffle(CI, Builder))
@@ -655,13 +782,13 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) {
- const APInt &Op1CV = Op1C->getValue();
+ const APInt *Op1CV;
+ if (match(ICI->getOperand(1), m_APInt(Op1CV))) {
// zext (x <s 0) to i32 --> x>>u31 true if signbit set.
// zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear.
- if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV.isNullValue()) ||
- (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) {
+ if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
+ (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
if (!DoTransform) return ICI;
Value *In = ICI->getOperand(0);
@@ -687,7 +814,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
// zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set.
// zext (X != 1) to i32 --> X^1 iff X has only the low bit set.
// zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
- if ((Op1CV.isNullValue() || Op1CV.isPowerOf2()) &&
+ if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) &&
// This only works for EQ and NE
ICI->isEquality()) {
// If Op1C some other power of two, convert:
@@ -698,12 +825,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
if (!DoTransform) return ICI;
bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE;
- if (!Op1CV.isNullValue() && (Op1CV != KnownZeroMask)) {
+ if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) {
// (X&4) == 2 --> false
// (X&4) != 2 --> true
- Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()),
- isNE);
- Res = ConstantExpr::getZExt(Res, CI.getType());
+ Constant *Res = ConstantInt::get(CI.getType(), isNE);
return replaceInstUsesWith(CI, Res);
}
@@ -716,7 +841,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
In->getName() + ".lobit");
}
- if (!Op1CV.isNullValue() == isNE) { // Toggle the low bit.
+ if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit.
Constant *One = ConstantInt::get(In->getType(), 1);
In = Builder.CreateXor(In, One);
}
@@ -833,17 +958,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
unsigned VSize = V->getType()->getScalarSizeInBits();
if (IC.MaskedValueIsZero(I->getOperand(1),
APInt::getHighBitsSet(VSize, BitsToClear),
- 0, CxtI))
+ 0, CxtI)) {
+ // If this is an And instruction and all of the BitsToClear are
+ // known to be zero we can reset BitsToClear.
+ if (Opc == Instruction::And)
+ BitsToClear = 0;
return true;
+ }
}
// Otherwise, we don't know how to analyze this BitsToClear case yet.
return false;
- case Instruction::Shl:
+ case Instruction::Shl: {
// We can promote shl(x, cst) if we can promote x. Since shl overwrites the
// upper bits we can reduce BitsToClear by the shift amount.
- if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ const APInt *Amt;
+ if (match(I->getOperand(1), m_APInt(Amt))) {
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
uint64_t ShiftAmt = Amt->getZExtValue();
@@ -851,10 +982,12 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
return true;
}
return false;
- case Instruction::LShr:
+ }
+ case Instruction::LShr: {
// We can promote lshr(x, cst) if we can promote x. This requires the
// ultimate 'and' to clear out the high zero bits we're clearing out though.
- if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ const APInt *Amt;
+ if (match(I->getOperand(1), m_APInt(Amt))) {
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
BitsToClear += Amt->getZExtValue();
@@ -864,6 +997,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
}
// Cannot promote variable LSHR.
return false;
+ }
case Instruction::Select:
if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) ||
!canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) ||