diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 126 |
1 files changed, 79 insertions, 47 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index fd59c3a7c0c3..1201ac196ec0 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -492,12 +492,19 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { } /// Rotate left/right may occur in a wider type than necessary because of type -/// promotion rules. Try to narrow all of the component instructions. +/// promotion rules. Try to narrow the inputs and convert to funnel shift. Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { assert((isa<VectorType>(Trunc.getSrcTy()) || shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && "Don't narrow to an illegal scalar type"); + // Bail out on strange types. It is possible to handle some of these patterns + // even with non-power-of-2 sizes, but it is not a likely scenario. + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (!isPowerOf2_32(NarrowWidth)) + return nullptr; + // First, find an or'd pair of opposite shifts with the same shifted operand: // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) Value *Or0, *Or1; @@ -514,22 +521,38 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { if (ShiftOpcode0 == ShiftOpcode1) return nullptr; - // The shift amounts must add up to the narrow bit width. - Value *ShAmt; - bool SubIsOnLHS; - Type *DestTy = Trunc.getType(); - unsigned NarrowWidth = DestTy->getScalarSizeInBits(); - if (match(ShAmt0, - m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { - ShAmt = ShAmt1; - SubIsOnLHS = true; - } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), - m_Specific(ShAmt0))))) { - ShAmt = ShAmt0; - SubIsOnLHS = false; - } else { + // Match the shift amount operands for a rotate pattern. This always matches + // a subtraction on the R operand. + auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + // The shift amounts may add up to the narrow bit width: + // (shl ShVal, L) | (lshr ShVal, Width - L) + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) + return L; + + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Same as above, but the shift amount may be extended after masking: + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return X; + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); + bool SubIsOnLHS = false; + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); + SubIsOnLHS = true; } + if (!ShAmt) + return nullptr; // The shifted value must have high zeros in the wide type. Typically, this // will be a zext, but it could also be the result of an 'and' or 'shift'. @@ -540,23 +563,15 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { // We have an unnecessarily wide rotate! // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) - // Narrow it down to eliminate the zext/trunc: - // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + // Narrow the inputs and convert to funnel shift intrinsic: + // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); - Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); - - // Mask both shift amounts to ensure there's no UB from oversized shifts. - Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); - Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); - Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); - - // Truncate the original value and use narrow ops. Value *X = Builder.CreateTrunc(ShVal, DestTy); - Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; - Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; - Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); - Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); - return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); + bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) || + (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); + return IntrinsicInst::Create(F, { X, X, NarrowShAmt }); } /// Try to narrow the width of math or bitwise logic instructions by pulling a @@ -706,12 +721,35 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (SimplifyDemandedInstructionBits(CI)) return &CI; - // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. if (DestTy->getScalarSizeInBits() == 1) { - Constant *One = ConstantInt::get(SrcTy, 1); - Src = Builder.CreateAnd(Src, One); Value *Zero = Constant::getNullValue(Src->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, Src, Zero); + if (DestTy->isIntegerTy()) { + // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). + // TODO: We canonicalize to more instructions here because we are probably + // lacking equivalent analysis for trunc relative to icmp. There may also + // be codegen concerns. If those trunc limitations were removed, we could + // remove this transform. + Value *And = Builder.CreateAnd(Src, ConstantInt::get(SrcTy, 1)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + // For vectors, we do not canonicalize all truncs to icmp, so optimize + // patterns that would be covered within visitICmpInst. + Value *X; + const APInt *C; + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 + APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C); + Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)), + m_Deferred(X))))) { + // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 + APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1; + Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } } // FIXME: Maybe combine the next two transforms to handle the no cast case @@ -1061,12 +1099,9 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); - // Attempt to extend the entire input expression tree to the destination - // type. Only do this if the dest type is a simple type, don't convert the - // expression tree to something weird like i93 unless the source is also - // strange. + // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; - if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && + if (shouldChangeType(SrcTy, DestTy) && canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { assert(BitsToClear <= SrcTy->getScalarSizeInBits() && "Can't clear more bits than in SrcTy"); @@ -1343,12 +1378,8 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { return replaceInstUsesWith(CI, ZExt); } - // Attempt to extend the entire input expression tree to the destination - // type. Only do this if the dest type is a simple type, don't convert the - // expression tree to something weird like i93 unless the source is also - // strange. - if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateSExtd(Src, DestTy)) { + // Try to extend the entire expression tree to the wide destination type. + if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1589,8 +1620,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) - if (BinaryOperator::isFNeg(OpI)) { - Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + Value *X; + if (match(OpI, m_FNeg(m_Value(X)))) { + Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); return BinaryOperator::CreateFNegFMF(InnerTrunc, OpI); } } |