diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 206 |
1 files changed, 170 insertions, 36 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index dfdfd3e9da84..178c8eaf2502 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -235,8 +235,8 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, Type *MidTy = CI1->getDestTy(); Type *DstTy = CI2->getDestTy(); - Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); + Instruction::CastOps firstOp = CI1->getOpcode(); + Instruction::CastOps secondOp = CI2->getOpcode(); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -346,29 +346,50 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, } break; } - case Instruction::Shl: + case Instruction::Shl: { // If we are truncating the result of this SHL, and if it's a shift of a // constant amount, we can always perform a SHL in a smaller type. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (CI->getLimitedValue(BitWidth) < BitWidth) + if (Amt->getLimitedValue(BitWidth) < BitWidth) return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; - case Instruction::LShr: + } + case Instruction::LShr: { // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(0), APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && - CI->getLimitedValue(BitWidth) < BitWidth) { + Amt->getLimitedValue(BitWidth) < BitWidth) { return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; + } + case Instruction::AShr: { + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + // TODO: It is enough to check that the bits we would be shifting in are + // similar to sign bit of the truncate type. + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + if (Amt->getLimitedValue(BitWidth) < BitWidth && + OrigBitWidth - BitWidth < + IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); + } + break; + } case Instruction::Trunc: // trunc(trunc(x)) -> trunc(x) return true; @@ -443,24 +464,130 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } -/// Try to narrow the width of bitwise logic instructions with constants. -Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { +/// Rotate left/right may occur in a wider type than necessary because of type +/// promotion rules. Try to narrow all of the component instructions. +Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { + assert((isa<VectorType>(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) + Value *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // The shift amounts must add up to the narrow bit width. + Value *ShAmt; + bool SubIsOnLHS; + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (match(ShAmt0, + m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { + ShAmt = ShAmt1; + SubIsOnLHS = true; + } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), + m_Specific(ShAmt0))))) { + ShAmt = ShAmt0; + SubIsOnLHS = false; + } else { + return nullptr; + } + + // The shifted value must have high zeros in the wide type. Typically, this + // will be a zext, but it could also be the result of an 'and' or 'shift'. + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // Narrow it down to eliminate the zext/trunc: + // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); + + // Mask both shift amounts to ensure there's no UB from oversized shifts. + Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); + Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); + + // Truncate the original value and use narrow ops. + Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; + Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; + Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); + Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); + return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); +} + +/// Try to narrow the width of math or bitwise logic instructions by pulling a +/// truncate ahead of binary operators. +/// TODO: Transforms for truncated shifts should be moved into here. +Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; - BinaryOperator *LogicOp; - Constant *C; - if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) || - !LogicOp->isBitwiseLogicOp() || - !match(LogicOp->getOperand(1), m_Constant(C))) + BinaryOperator *BinOp; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp)))) return nullptr; - // trunc (logic X, C) --> logic (trunc X, C') - Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); - Value *NarrowOp0 = Builder.CreateTrunc(LogicOp->getOperand(0), DestTy); - return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); + Value *BinOp0 = BinOp->getOperand(0); + Value *BinOp1 = BinOp->getOperand(1); + switch (BinOp->getOpcode()) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: { + Constant *C; + if (match(BinOp0, m_Constant(C))) { + // trunc (binop C, X) --> binop (trunc C', X) + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX); + } + if (match(BinOp1, m_Constant(C))) { + // trunc (binop X, C) --> binop (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC); + } + Value *X; + if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop (ext X), Y) --> binop X, (trunc Y) + Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1); + } + if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop Y, (ext X)) --> binop (trunc Y), X + Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X); + } + break; + } + + default: break; + } + + if (Instruction *NarrowOr = narrowRotate(Trunc)) + return NarrowOr; + + return nullptr; } /// Try to narrow the width of a splat shuffle. This could be generalized to any @@ -616,7 +743,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - if (Instruction *I = shrinkBitwiseLogic(CI)) + if (Instruction *I = narrowBinOp(CI)) return I; if (Instruction *I = shrinkSplatShuffle(CI, Builder)) @@ -655,13 +782,13 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); + const APInt *Op1CV; + if (match(ICI->getOperand(1), m_APInt(Op1CV))) { // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV.isNullValue()) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); @@ -687,7 +814,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV.isNullValue() || Op1CV.isPowerOf2()) && + if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && // This only works for EQ and NE ICI->isEquality()) { // If Op1C some other power of two, convert: @@ -698,12 +825,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV.isNullValue() && (Op1CV != KnownZeroMask)) { + if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), - isNE); - Res = ConstantExpr::getZExt(Res, CI.getType()); + Constant *Res = ConstantInt::get(CI.getType(), isNE); return replaceInstUsesWith(CI, Res); } @@ -716,7 +841,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, In->getName() + ".lobit"); } - if (!Op1CV.isNullValue() == isNE) { // Toggle the low bit. + if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One); } @@ -833,17 +958,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, unsigned VSize = V->getType()->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(1), APInt::getHighBitsSet(VSize, BitsToClear), - 0, CxtI)) + 0, CxtI)) { + // If this is an And instruction and all of the BitsToClear are + // known to be zero we can reset BitsToClear. + if (Opc == Instruction::And) + BitsToClear = 0; return true; + } } // Otherwise, we don't know how to analyze this BitsToClear case yet. return false; - case Instruction::Shl: + case Instruction::Shl: { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); @@ -851,10 +982,12 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return true; } return false; - case Instruction::LShr: + } + case Instruction::LShr: { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); @@ -864,6 +997,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } // Cannot promote variable LSHR. return false; + } case Instruction::Select: if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || |