diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 484 |
1 files changed, 293 insertions, 191 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 71b7f279e5fa5..3639edb5df4d1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -85,16 +85,16 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI) { PointerType *PTy = cast<PointerType>(CI.getType()); - BuilderTy AllocaBuilder(Builder); - AllocaBuilder.SetInsertPoint(&AI); + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&AI); // Get the type really allocated and the type casted to. Type *AllocElTy = AI.getAllocatedType(); Type *CastElTy = PTy->getElementType(); if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; - unsigned AllocElTyAlign = DL.getABITypeAlignment(AllocElTy); - unsigned CastElTyAlign = DL.getABITypeAlignment(CastElTy); + Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); + Align CastElTyAlign = DL.getABITypeAlign(CastElTy); if (CastElTyAlign < AllocElTyAlign) return nullptr; // If the allocation has multiple uses, only promote it if we are strictly @@ -131,17 +131,17 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, } else { Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale); // Insert before the alloca, not before the cast. - Amt = AllocaBuilder.CreateMul(Amt, NumElements); + Amt = Builder.CreateMul(Amt, NumElements); } if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { Value *Off = ConstantInt::get(AI.getArraySize()->getType(), Offset, true); - Amt = AllocaBuilder.CreateAdd(Amt, Off); + Amt = Builder.CreateAdd(Amt, Off); } - AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt); - New->setAlignment(MaybeAlign(AI.getAlignment())); + AllocaInst *New = Builder.CreateAlloca(CastElTy, Amt); + New->setAlignment(AI.getAlign()); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); @@ -151,8 +151,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, if (!AI.hasOneUse()) { // New is the allocation instruction, pointer typed. AI is the original // allocation instruction, also pointer typed. Thus, cast to use is BitCast. - Value *NewCast = AllocaBuilder.CreateBitCast(New, AI.getType(), "tmpcast"); + Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast"); replaceInstUsesWith(AI, NewCast); + eraseInstFromFunction(AI); } return replaceInstUsesWith(CI, New); } @@ -164,9 +165,7 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, if (Constant *C = dyn_cast<Constant>(V)) { C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); // If we got a constantexpr back, try to simplify it with DL info. - if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) - C = FoldedC; - return C; + return ConstantFoldConstant(C, DL, &TLI); } // Otherwise, it must be an instruction. @@ -276,16 +275,20 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { } if (auto *Sel = dyn_cast<SelectInst>(Src)) { - // We are casting a select. Try to fold the cast into the select, but only - // if the select does not have a compare instruction with matching operand - // types. Creating a select with operands that are different sizes than its + // We are casting a select. Try to fold the cast into the select if the + // select does not have a compare instruction with matching operand types + // or the select is likely better done in a narrow type. + // Creating a select with operands that are different sizes than its // condition may inhibit other folds and lead to worse codegen. auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition()); - if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType()) + if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() || + (CI.getOpcode() == Instruction::Trunc && + shouldChangeType(CI.getSrcTy(), CI.getType()))) { if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { replaceAllDbgUsesWith(*Sel, *NV, CI, DT); return NV; } + } } // If we are casting a PHI, then fold the cast into the PHI. @@ -293,7 +296,7 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { // Don't do this if it would create a PHI node with an illegal type from a // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || - shouldChangeType(CI.getType(), Src->getType())) + shouldChangeType(CI.getSrcTy(), CI.getType())) if (Instruction *NV = foldOpIntoPhi(CI, PN)) return NV; } @@ -374,29 +377,31 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, break; } case Instruction::Shl: { - // If we are truncating the result of this SHL, and if it's a shift of a - // constant amount, we can always perform a SHL in a smaller type. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + // If we are truncating the result of this SHL, and if it's a shift of an + // inrange amount, we can always perform a SHL in a smaller type. + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + if (AmtKnownBits.getMaxValue().ult(BitWidth)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::LShr: { // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth && - IC.MaskedValueIsZero(I->getOperand(0), - APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + // TODO: It is enough to check that the bits we would be shifting in are + // zero - use AmtKnownBits.getMaxValue(). + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -406,15 +411,15 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, // original type and the sign bit of the truncate type are similar. // TODO: It is enough to check that the bits we would be shifting in are // similar to sign bit of the truncate type. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth && - OrigBitWidth - BitWidth < - IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + unsigned ShiftedBits = OrigBitWidth - BitWidth; + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::Trunc: @@ -480,7 +485,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { // bitcast it to a vector type that we can extract from. unsigned NumVecElts = VecWidth / DestWidth; if (VecType->getElementType() != DestType) { - VecType = VectorType::get(DestType, NumVecElts); + VecType = FixedVectorType::get(DestType, NumVecElts); VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc"); } @@ -639,12 +644,12 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) && - Shuf->getMask()->getSplatValue() && + is_splat(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask Constant *NarrowUndef = UndefValue::get(Trunc.getType()); Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); - return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask()); + return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask()); } return nullptr; @@ -682,29 +687,51 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } -Instruction *InstCombiner::visitTrunc(TruncInst &CI) { - if (Instruction *Result = commonCastTransforms(CI)) +Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { + if (Instruction *Result = commonCastTransforms(Trunc)) return Result; - Value *Src = CI.getOperand(0); - Type *DestTy = CI.getType(), *SrcTy = Src->getType(); + Value *Src = Trunc.getOperand(0); + Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + ConstantInt *Cst; // Attempt to truncate the entire input expression tree to the destination // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateTruncated(Src, DestTy, *this, &CI)) { + canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid cast: " - << CI << '\n'); + << Trunc << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); - return replaceInstUsesWith(CI, Res); + return replaceInstUsesWith(Trunc, Res); + } + + // For integer types, check if we can shorten the entire input expression to + // DestWidth * 2, which won't allow removing the truncate, but reducing the + // width may enable further optimizations, e.g. allowing for larger + // vectorization factors. + if (auto *DestITy = dyn_cast<IntegerType>(DestTy)) { + if (DestWidth * 2 < SrcWidth) { + auto *NewDestTy = DestITy->getExtendedType(); + if (shouldChangeType(SrcTy, NewDestTy) && + canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) { + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to reduce the width of operand of" + << Trunc << '\n'); + Value *Res = EvaluateInDifferentType(Src, NewDestTy, false); + return new TruncInst(Res, DestTy); + } + } } // Test if the trunc is the user of a select which is part of a @@ -712,17 +739,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // Even simplifying demanded bits can break the canonical form of a // min/max. Value *LHS, *RHS; - if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) - if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) + if (SelectInst *Sel = dyn_cast<SelectInst>(Src)) + if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; // See if we can simplify any instructions used by the input whose sole // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; + if (SimplifyDemandedInstructionBits(Trunc)) + return &Trunc; - if (DestTy->getScalarSizeInBits() == 1) { - Value *Zero = Constant::getNullValue(Src->getType()); + if (DestWidth == 1) { + Value *Zero = Constant::getNullValue(SrcTy); if (DestTy->isIntegerTy()) { // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). // TODO: We canonicalize to more instructions here because we are probably @@ -736,18 +763,21 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // For vectors, we do not canonicalize all truncs to icmp, so optimize // patterns that would be covered within visitICmpInst. Value *X; - const APInt *C; - if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + Constant *C; + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) { // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C); - Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } - if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)), + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)), m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1; - Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + MaskC = ConstantExpr::getOr(MaskC, One); + Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } } @@ -756,7 +786,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // more efficiently. Support vector types. Cleanup code by using m_OneUse. // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion. - Value *A = nullptr; ConstantInt *Cst = nullptr; + Value *A = nullptr; if (Src->hasOneUse() && match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) { // We have three types to worry about here, the type of A, the source of @@ -768,7 +798,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If the shift amount is larger than the size of A, then the result is // known to be zero because all the input bits got shifted out. if (Cst->getZExtValue() >= ASize) - return replaceInstUsesWith(CI, Constant::getNullValue(DestTy)); + return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy)); // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's @@ -778,45 +808,37 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return CastInst::CreateIntegerCast(Shift, DestTy, false); } - // FIXME: We should canonicalize to zext/trunc and remove this transform. - // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type - // conversion. - // It works because bits coming from sign extension have the same value as - // the sign bit of the original value; performing ashr instead of lshr - // generates bits of the same value as the sign bit. - if (Src->hasOneUse() && - match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { - Value *SExt = cast<Instruction>(Src)->getOperand(0); - const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits(); - const unsigned ASize = A->getType()->getPrimitiveSizeInBits(); - const unsigned CISize = CI.getType()->getPrimitiveSizeInBits(); - const unsigned MaxAmt = SExtSize - std::max(CISize, ASize); - unsigned ShiftAmt = Cst->getZExtValue(); - - // This optimization can be only performed when zero bits generated by - // the original lshr aren't pulled into the value after truncation, so we - // can only shift by values no larger than the number of extension bits. - // FIXME: Instead of bailing when the shift is too large, use and to clear - // the extra bits. - if (ShiftAmt <= MaxAmt) { - if (CISize == ASize) - return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(), - std::min(ShiftAmt, ASize - 1))); - if (SExt->hasOneUse()) { - Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1)); - Shift->takeName(Src); - return CastInst::CreateIntegerCast(Shift, CI.getType(), true); + const APInt *C; + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) { + unsigned AWidth = A->getType()->getScalarSizeInBits(); + unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); + + // If the shift is small enough, all zero bits created by the shift are + // removed by the trunc. + if (C->getZExtValue() <= MaxShiftAmt) { + // trunc (lshr (sext A), C) --> ashr A, C + if (A->getType() == DestTy) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); + return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); + } + // The types are mismatched, so create a cast after shifting: + // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) + if (Src->hasOneUse()) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1); + Value *Shift = Builder.CreateAShr(A, ShAmt); + return CastInst::CreateIntegerCast(Shift, DestTy, true); } } + // TODO: Mask high bits with 'and'. } - if (Instruction *I = narrowBinOp(CI)) + if (Instruction *I = narrowBinOp(Trunc)) return I; - if (Instruction *I = shrinkSplatShuffle(CI, Builder)) + if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) return I; - if (Instruction *I = shrinkInsertElt(CI, Builder)) + if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && @@ -827,20 +849,48 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { !match(A, m_Shr(m_Value(), m_Constant()))) { // Skip shifts of shift by constants. It undoes a combine in // FoldShiftByConstant and is the extend in reg pattern. - const unsigned DestSize = DestTy->getScalarSizeInBits(); - if (Cst->getValue().ult(DestSize)) { + if (Cst->getValue().ult(DestWidth)) { Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); return BinaryOperator::Create( Instruction::Shl, NewTrunc, - ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize))); + ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth))); } } } - if (Instruction *I = foldVecTruncToExtElt(CI, *this)) + if (Instruction *I = foldVecTruncToExtElt(Trunc, *this)) return I; + // Whenever an element is extracted from a vector, and then truncated, + // canonicalize by converting it to a bitcast followed by an + // extractelement. + // + // Example (little endian): + // trunc (extractelement <4 x i64> %X, 0) to i32 + // ---> + // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 + Value *VecOp; + if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { + auto *VecOpTy = cast<VectorType>(VecOp->getType()); + unsigned VecNumElts = VecOpTy->getNumElements(); + + // A badly fit destination size would result in an invalid cast. + if (SrcWidth % DestWidth == 0) { + uint64_t TruncRatio = SrcWidth / DestWidth; + uint64_t BitCastNumElts = VecNumElts * TruncRatio; + uint64_t VecOpIdx = Cst->getZExtValue(); + uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 + : VecOpIdx * TruncRatio; + assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && + "overflow 32-bits"); + + auto *BitCastTo = FixedVectorType::get(DestTy, BitCastNumElts); + Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); + return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); + } + } + return nullptr; } @@ -1431,16 +1481,17 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // %d = ashr i32 %a, 30 Value *A = nullptr; // TODO: Eventually this could be subsumed by EvaluateInDifferentType. - ConstantInt *BA = nullptr, *CA = nullptr; - if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_ConstantInt(BA)), - m_ConstantInt(CA))) && + Constant *BA = nullptr, *CA = nullptr; + if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), + m_Constant(CA))) && BA == CA && A->getType() == CI.getType()) { unsigned MidSize = Src->getType()->getScalarSizeInBits(); unsigned SrcDstSize = CI.getType()->getScalarSizeInBits(); - unsigned ShAmt = CA->getZExtValue()+SrcDstSize-MidSize; - Constant *ShAmtV = ConstantInt::get(CI.getType(), ShAmt); - A = Builder.CreateShl(A, ShAmtV, CI.getName()); - return BinaryOperator::CreateAShr(A, ShAmtV); + Constant *SizeDiff = ConstantInt::get(CA->getType(), SrcDstSize - MidSize); + Constant *ShAmt = ConstantExpr::getAdd(CA, SizeDiff); + Constant *ShAmtExt = ConstantExpr::getSExt(ShAmt, CI.getType()); + A = Builder.CreateShl(A, ShAmtExt, CI.getName()); + return BinaryOperator::CreateAShr(A, ShAmtExt); } return nullptr; @@ -1478,12 +1529,13 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // TODO: Make these support undef elements. static Type *shrinkFPConstantVector(Value *V) { auto *CV = dyn_cast<Constant>(V); - if (!CV || !CV->getType()->isVectorTy()) + auto *CVVTy = dyn_cast<VectorType>(V->getType()); + if (!CV || !CVVTy) return nullptr; Type *MinType = nullptr; - unsigned NumElts = CV->getType()->getVectorNumElements(); + unsigned NumElts = CVVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) @@ -1500,7 +1552,7 @@ static Type *shrinkFPConstantVector(Value *V) { } // Make a vector type from the minimal type. - return VectorType::get(MinType, NumElts); + return FixedVectorType::get(MinType, NumElts); } /// Find the minimum FP type we can safely truncate to. @@ -1522,6 +1574,48 @@ static Type *getMinimumFPType(Value *V) { return V->getType(); } +/// Return true if the cast from integer to FP can be proven to be exact for all +/// possible inputs (the conversion does not lose any precision). +static bool isKnownExactCastIntToFP(CastInst &I) { + CastInst::CastOps Opcode = I.getOpcode(); + assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && + "Unexpected cast"); + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *FPTy = I.getType(); + bool IsSigned = Opcode == Instruction::SIToFP; + int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; + + // Easy case - if the source integer type has less bits than the FP mantissa, + // then the cast must be exact. + int DestNumSigBits = FPTy->getFPMantissaWidth(); + if (SrcSize <= DestNumSigBits) + return true; + + // Cast from FP to integer and back to FP is independent of the intermediate + // integer width because of poison on overflow. + Value *F; + if (match(Src, m_FPToSI(m_Value(F))) || match(Src, m_FPToUI(m_Value(F)))) { + // If this is uitofp (fptosi F), the source needs an extra bit to avoid + // potential rounding of negative FP input values. + int SrcNumSigBits = F->getType()->getFPMantissaWidth(); + if (!IsSigned && match(Src, m_FPToSI(m_Value()))) + SrcNumSigBits++; + + // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal + // significant bits than the destination (and make sure neither type is + // weird -- ppc_fp128). + if (SrcNumSigBits > 0 && DestNumSigBits > 0 && + SrcNumSigBits <= DestNumSigBits) + return true; + } + + // TODO: + // Try harder to find if the source integer type has less significant bits. + // For example, compute number of sign bits or compute low bit mask. + return false; +} + Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (Instruction *I = commonCastTransforms(FPT)) return I; @@ -1632,10 +1726,6 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (match(Op, m_FNeg(m_Value(X)))) { Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); - // FIXME: Once we're sure that unary FNeg optimizations are on par with - // binary FNeg, this should always return a unary operator. - if (isa<BinaryOperator>(Op)) - return BinaryOperator::CreateFNegFMF(InnerTrunc, Op); return UnaryOperator::CreateFNegFMF(InnerTrunc, Op); } @@ -1667,6 +1757,7 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { case Intrinsic::nearbyint: case Intrinsic::rint: case Intrinsic::round: + case Intrinsic::roundeven: case Intrinsic::trunc: { Value *Src = II->getArgOperand(0); if (!Src->hasOneUse()) @@ -1699,74 +1790,83 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (Instruction *I = shrinkInsertElt(FPT, Builder)) return I; + Value *Src = FPT.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + return nullptr; } -Instruction *InstCombiner::visitFPExt(CastInst &CI) { - return commonCastTransforms(CI); +Instruction *InstCombiner::visitFPExt(CastInst &FPExt) { + // If the source operand is a cast from integer to FP and known exact, then + // cast the integer operand directly to the destination type. + Type *Ty = FPExt.getType(); + Value *Src = FPExt.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + + return commonCastTransforms(FPExt); } -// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) -// This is safe if the intermediate type has enough bits in its mantissa to -// accurately represent all values of X. For example, this won't work with -// i64 -> float -> i64. -Instruction *InstCombiner::FoldItoFPtoI(Instruction &FI) { +/// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) +/// This is safe if the intermediate type has enough bits in its mantissa to +/// accurately represent all values of X. For example, this won't work with +/// i64 -> float -> i64. +Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) { if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0))) return nullptr; - Instruction *OpI = cast<Instruction>(FI.getOperand(0)); - Value *SrcI = OpI->getOperand(0); - Type *FITy = FI.getType(); - Type *OpITy = OpI->getType(); - Type *SrcTy = SrcI->getType(); - bool IsInputSigned = isa<SIToFPInst>(OpI); + auto *OpI = cast<CastInst>(FI.getOperand(0)); + Value *X = OpI->getOperand(0); + Type *XType = X->getType(); + Type *DestType = FI.getType(); bool IsOutputSigned = isa<FPToSIInst>(FI); - // We can safely assume the conversion won't overflow the output range, - // because (for example) (uint8_t)18293.f is undefined behavior. - // Since we can assume the conversion won't overflow, our decision as to // whether the input will fit in the float should depend on the minimum // of the input range and output range. // This means this is also safe for a signed input and unsigned output, since // a negative input would lead to undefined behavior. - int InputSize = (int)SrcTy->getScalarSizeInBits() - IsInputSigned; - int OutputSize = (int)FITy->getScalarSizeInBits() - IsOutputSigned; - int ActualSize = std::min(InputSize, OutputSize); - - if (ActualSize <= OpITy->getFPMantissaWidth()) { - if (FITy->getScalarSizeInBits() > SrcTy->getScalarSizeInBits()) { - if (IsInputSigned && IsOutputSigned) - return new SExtInst(SrcI, FITy); - return new ZExtInst(SrcI, FITy); - } - if (FITy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits()) - return new TruncInst(SrcI, FITy); - if (SrcTy == FITy) - return replaceInstUsesWith(FI, SrcI); - return new BitCastInst(SrcI, FITy); + if (!isKnownExactCastIntToFP(*OpI)) { + // The first cast may not round exactly based on the source integer width + // and FP width, but the overflow UB rules can still allow this to fold. + // If the destination type is narrow, that means the intermediate FP value + // must be large enough to hold the source value exactly. + // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. + int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned; + if (OutputSize > OpI->getType()->getFPMantissaWidth()) + return nullptr; } - return nullptr; + + if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { + bool IsInputSigned = isa<SIToFPInst>(OpI); + if (IsInputSigned && IsOutputSigned) + return new SExtInst(X, DestType); + return new ZExtInst(X, DestType); + } + if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits()) + return new TruncInst(X, DestType); + + assert(XType == DestType && "Unexpected types for int to FP to int casts"); + return replaceInstUsesWith(FI, X); } Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) { - Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); - if (!OpI) - return commonCastTransforms(FI); - - if (Instruction *I = FoldItoFPtoI(FI)) + if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); } Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) { - Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); - if (!OpI) - return commonCastTransforms(FI); - - if (Instruction *I = FoldItoFPtoI(FI)) + if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); @@ -1788,8 +1888,9 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { if (CI.getOperand(0)->getType()->getScalarSizeInBits() != DL.getPointerSizeInBits(AS)) { Type *Ty = DL.getIntPtrType(CI.getContext(), AS); - if (CI.getType()->isVectorTy()) // Handle vectors of pointers. - Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements()); + // Handle vectors of pointers. + if (auto *CIVTy = dyn_cast<VectorType>(CI.getType())) + Ty = VectorType::get(Ty, CIVTy->getElementCount()); Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); return new IntToPtrInst(P, CI.getType()); @@ -1817,9 +1918,7 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { // Changing the cast operand is usually not a good idea but it is safe // here because the pointer operand is being replaced with another // pointer operand so the opcode doesn't need to change. - Worklist.Add(GEP); - CI.setOperand(0, GEP->getOperand(0)); - return &CI; + return replaceOperand(CI, 0, GEP->getOperand(0)); } } @@ -1838,8 +1937,11 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { return commonPointerCastTransforms(CI); Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); - if (Ty->isVectorTy()) // Handle vectors of pointers. - PtrTy = VectorType::get(PtrTy, Ty->getVectorNumElements()); + if (auto *VTy = dyn_cast<VectorType>(Ty)) { + // Handle vectors of pointers. + // FIXME: what should happen for scalable vectors? + PtrTy = FixedVectorType::get(PtrTy, VTy->getNumElements()); + } Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy); return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); @@ -1878,7 +1980,8 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, DestTy->getElementType()->getPrimitiveSizeInBits()) return nullptr; - SrcTy = VectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); + SrcTy = + FixedVectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); InVal = IC.Builder.CreateBitCast(InVal, SrcTy); } @@ -1891,8 +1994,8 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, // Now that the element types match, get the shuffle mask and RHS of the // shuffle to use, which depends on whether we're increasing or decreasing the // size of the input. - SmallVector<uint32_t, 16> ShuffleMaskStorage; - ArrayRef<uint32_t> ShuffleMask; + SmallVector<int, 16> ShuffleMaskStorage; + ArrayRef<int> ShuffleMask; Value *V2; // Produce an identify shuffle mask for the src vector. @@ -1931,9 +2034,7 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, ShuffleMask = ShuffleMaskStorage; } - return new ShuffleVectorInst(InVal, V2, - ConstantDataVector::get(V2->getContext(), - ShuffleMask)); + return new ShuffleVectorInst(InVal, V2, ShuffleMask); } static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) { @@ -2106,7 +2207,7 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, return nullptr; unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); - auto *NewVecType = VectorType::get(DestType, NumElts); + auto *NewVecType = FixedVectorType::get(DestType, NumElts); auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), NewVecType, "bc"); return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); @@ -2151,7 +2252,7 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, if (match(BO->getOperand(1), m_Constant(C))) { // bitcast (logic X, C) --> logic (bitcast X, C') Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); - Value *CastedC = ConstantExpr::getBitCast(C, DestTy); + Value *CastedC = Builder.CreateBitCast(C, DestTy); return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC); } @@ -2169,10 +2270,10 @@ static Instruction *foldBitCastSelect(BitCastInst &BitCast, // A vector select must maintain the same number of elements in its operands. Type *CondTy = Cond->getType(); Type *DestTy = BitCast.getType(); - if (CondTy->isVectorTy()) { + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { if (!DestTy->isVectorTy()) return nullptr; - if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements()) + if (cast<VectorType>(DestTy)->getNumElements() != CondVTy->getNumElements()) return nullptr; } @@ -2359,7 +2460,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { auto *NewBC = cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); SI->setOperand(0, NewBC); - Worklist.Add(SI); + Worklist.push(SI); assert(hasStoreUsersOnly(*NewBC)); } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { @@ -2395,8 +2496,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { if (DestTy == Src->getType()) return replaceInstUsesWith(CI, Src); - if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) { + if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) { PointerType *SrcPTy = cast<PointerType>(SrcTy); + PointerType *DstPTy = cast<PointerType>(DestTy); Type *DstElTy = DstPTy->getElementType(); Type *SrcElTy = SrcPTy->getElementType(); @@ -2425,10 +2527,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. // This can enhance SROA and other transforms that want type-safe pointers. unsigned NumZeros = 0; - while (SrcElTy != DstElTy && - isa<CompositeType>(SrcElTy) && !SrcElTy->isPointerTy() && - SrcElTy->getNumContainedTypes() /* not "{}" */) { - SrcElTy = cast<CompositeType>(SrcElTy)->getTypeAtIndex(0U); + while (SrcElTy && SrcElTy != DstElTy) { + SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); ++NumZeros; } @@ -2455,12 +2555,12 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) { - if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) { + if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) { + // Beware: messing with this target-specific oddity may cause trouble. + if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType()); return InsertElementInst::Create(UndefValue::get(DestTy), Elem, Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); - // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast) } if (isa<IntegerType>(SrcTy)) { @@ -2484,7 +2584,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (VectorType *SrcVTy = dyn_cast<VectorType>(SrcTy)) { + if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(SrcTy)) { if (SrcVTy->getNumElements() == 1) { // If our destination is not a vector, then make this a straight // scalar-scalar cast. @@ -2508,10 +2608,11 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // a bitcast to a vector with the same # elts. Value *ShufOp0 = Shuf->getOperand(0); Value *ShufOp1 = Shuf->getOperand(1); - unsigned NumShufElts = Shuf->getType()->getVectorNumElements(); - unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements(); + unsigned NumShufElts = Shuf->getType()->getNumElements(); + unsigned NumSrcVecElts = + cast<VectorType>(ShufOp0->getType())->getNumElements(); if (Shuf->hasOneUse() && DestTy->isVectorTy() && - DestTy->getVectorNumElements() == NumShufElts && + cast<VectorType>(DestTy)->getNumElements() == NumShufElts && NumShufElts == NumSrcVecElts) { BitCastInst *Tmp; // If either of the operands is a cast from CI.getType(), then @@ -2525,7 +2626,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy); // Return a new shuffle vector. Use the same element ID's, as we // know the vector types match #elts. - return new ShuffleVectorInst(LHS, RHS, Shuf->getOperand(2)); + return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask()); } } @@ -2578,7 +2679,8 @@ Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace()); if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) { // Handle vectors of pointers. - MidTy = VectorType::get(MidTy, VT->getNumElements()); + // FIXME: what should happen for scalable vectors? + MidTy = FixedVectorType::get(MidTy, VT->getNumElements()); } Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); |