diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 231 |
1 files changed, 188 insertions, 43 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index cc0a9127f8b18..d3c718a919c0a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -143,8 +143,7 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, // the XOR is to toggle the bit. If it is clear, then the ADD has // no effect. if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop - TheAnd.setOperand(0, X); - return &TheAnd; + return replaceOperand(TheAnd, 0, X); } else { // Pull the XOR out of the AND. Value *NewAnd = Builder.CreateAnd(X, AndRHS); @@ -858,8 +857,10 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - bool JoinedByAnd, - Instruction &CxtI) { + BinaryOperator &Logic) { + bool JoinedByAnd = Logic.getOpcode() == Instruction::And; + assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) && + "Wrong opcode"); ICmpInst::Predicate Pred = LHS->getPredicate(); if (Pred != RHS->getPredicate()) return nullptr; @@ -883,8 +884,8 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, std::swap(A, B); if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && - isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { + isKnownToBeAPowerOfTwo(B, false, 0, &Logic) && + isKnownToBeAPowerOfTwo(D, false, 0, &Logic)) { Value *Mask = Builder.CreateOr(B, D); Value *Masked = Builder.CreateAnd(A, Mask); auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; @@ -1072,9 +1073,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { - if (UnsignedICmp->getOperand(0) != ZeroCmpOp) - UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); - auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { if (!IsKnownNonZero(NonZero)) std::swap(NonZero, Other); @@ -1111,8 +1109,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || !ICmpInst::isUnsigned(UnsignedPred)) return nullptr; - if (UnsignedICmp->getOperand(0) != Base) - UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset // (no overflow and not null) @@ -1141,14 +1137,59 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, return nullptr; } +/// Reduce logic-of-compares with equality to a constant by substituting a +/// common operand with the constant. Callers are expected to call this with +/// Cmp0/Cmp1 switched to handle logic op commutativity. +static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, + BinaryOperator &Logic, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { + bool IsAnd = Logic.getOpcode() == Instruction::And; + assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op"); + + // Match an equality compare with a non-poison constant as Cmp0. + ICmpInst::Predicate Pred0; + Value *X; + Constant *C; + if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) || + !isGuaranteedNotToBeUndefOrPoison(C)) + return nullptr; + if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) || + (!IsAnd && Pred0 != ICmpInst::ICMP_NE)) + return nullptr; + + // The other compare must include a common operand (X). Canonicalize the + // common operand as operand 1 (Pred1 is swapped if the common operand was + // operand 0). + Value *Y; + ICmpInst::Predicate Pred1; + if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X)))) + return nullptr; + + // Replace variable with constant value equivalence to remove a variable use: + // (X == C) && (Y Pred1 X) --> (X == C) && (Y Pred1 C) + // (X != C) || (Y Pred1 X) --> (X != C) || (Y Pred1 C) + // Can think of the 'or' substitution with the 'and' bool equivalent: + // A || B --> A || (!A && B) + Value *SubstituteCmp = SimplifyICmpInst(Pred1, Y, C, Q); + if (!SubstituteCmp) { + // If we need to create a new instruction, require that the old compare can + // be removed. + if (!Cmp1->hasOneUse()) + return nullptr; + SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); + } + return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction &CxtI) { - const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + BinaryOperator &And) { + const SimplifyQuery Q = SQ.getWithInstruction(&And); // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, And)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -1171,6 +1212,11 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, And, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q)) + return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) return V; @@ -1182,7 +1228,7 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) return V; - if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) + if (Value *V = foldSignedTruncationCheck(LHS, RHS, And, Builder)) return V; if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) @@ -1658,7 +1704,7 @@ static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { if (C->getType()->isVectorTy()) { // Check each element of a constant vector. - unsigned NumElts = C->getType()->getVectorNumElements(); + unsigned NumElts = cast<VectorType>(C->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -1802,7 +1848,17 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::Create(BinOp, NewLHS, Y); } } - + const APInt *ShiftC; + if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) { + unsigned Width = I.getType()->getScalarSizeInBits(); + if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) { + // We are clearing high bits that were potentially set by sext+ashr: + // and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC + Value *Sext = Builder.CreateSExt(X, I.getType()); + Constant *ShAmtC = ConstantInt::get(I.getType(), ShiftC->zext(Width)); + return BinaryOperator::CreateLShr(Sext, ShAmtC); + } + } } if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { @@ -2020,7 +2076,7 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { LastInst->removeFromParent(); for (auto *Inst : Insts) - Worklist.Add(Inst); + Worklist.push(Inst); return LastInst; } @@ -2086,9 +2142,62 @@ static Instruction *matchRotate(Instruction &Or) { return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); } +/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. +static Instruction *matchOrConcat(Instruction &Or, + InstCombiner::BuilderTy &Builder) { + assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); + Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); + Type *Ty = Or.getType(); + + unsigned Width = Ty->getScalarSizeInBits(); + if ((Width & 1) != 0) + return nullptr; + unsigned HalfWidth = Width / 2; + + // Canonicalize zext (lower half) to LHS. + if (!isa<ZExtInst>(Op0)) + std::swap(Op0, Op1); + + // Find lower/upper half. + Value *LowerSrc, *ShlVal, *UpperSrc; + const APInt *C; + if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) || + !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) || + !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc))))) + return nullptr; + if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() || + LowerSrc->getType()->getScalarSizeInBits() != HalfWidth) + return nullptr; + + auto ConcatIntrinsicCalls = [&](Intrinsic::ID id, Value *Lo, Value *Hi) { + Value *NewLower = Builder.CreateZExt(Lo, Ty); + Value *NewUpper = Builder.CreateZExt(Hi, Ty); + NewUpper = Builder.CreateShl(NewUpper, HalfWidth); + Value *BinOp = Builder.CreateOr(NewLower, NewUpper); + Function *F = Intrinsic::getDeclaration(Or.getModule(), id, Ty); + return Builder.CreateCall(F, BinOp); + }; + + // BSWAP: Push the concat down, swapping the lower/upper sources. + // concat(bswap(x),bswap(y)) -> bswap(concat(x,y)) + Value *LowerBSwap, *UpperBSwap; + if (match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) && + match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) + return ConcatIntrinsicCalls(Intrinsic::bswap, UpperBSwap, LowerBSwap); + + // BITREVERSE: Push the concat down, swapping the lower/upper sources. + // concat(bitreverse(x),bitreverse(y)) -> bitreverse(concat(x,y)) + Value *LowerBRev, *UpperBRev; + if (match(LowerSrc, m_BitReverse(m_Value(LowerBRev))) && + match(UpperSrc, m_BitReverse(m_Value(UpperBRev)))) + return ConcatIntrinsicCalls(Intrinsic::bitreverse, UpperBRev, LowerBRev); + + return nullptr; +} + /// If all elements of two constant vectors are 0/-1 and inverses, return true. static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { - unsigned NumElts = C1->getType()->getVectorNumElements(); + unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *EltC1 = C1->getAggregateElement(i); Constant *EltC2 = C2->getAggregateElement(i); @@ -2185,12 +2294,12 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction &CxtI) { - const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + BinaryOperator &Or) { + const SimplifyQuery Q = SQ.getWithInstruction(&Or); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, Or)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -2299,6 +2408,11 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) + return V; + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) return V; @@ -2481,6 +2595,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Instruction *Rotate = matchRotate(I)) return Rotate; + if (Instruction *Concat = matchOrConcat(I, Builder)) + return replaceInstUsesWith(I, Concat); + Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && @@ -2729,6 +2846,32 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; + CmpInst::Predicate Pred; + Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv; + // Check if the OR weakens the overflow condition for umul.with.overflow by + // treating any non-zero result as overflow. In that case, we overflow if both + // umul.with.overflow operands are != 0, as in that case the result can only + // be 0, iff the multiplication overflows. + if (match(&I, + m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)), + m_Value(Ov)), + m_CombineAnd(m_ICmp(Pred, + m_CombineAnd(m_ExtractValue<0>( + m_Deferred(UMulWithOv)), + m_Value(Mul)), + m_ZeroInt()), + m_Value(MulIsNotZero)))) && + (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse())) && + Pred == CmpInst::ICMP_NE) { + Value *A, *B; + if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>( + m_Value(A), m_Value(B)))) { + Value *NotNullA = Builder.CreateIsNotNull(A); + Value *NotNullB = Builder.CreateIsNotNull(B); + return BinaryOperator::CreateAnd(NotNullA, NotNullB); + } + } + return nullptr; } @@ -2748,33 +2891,24 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), - m_c_Or(m_Deferred(A), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_Or(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // (A | ~B) ^ (~A | B) -> A ^ B // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), - m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // (A & ~B) ^ (~A & B) -> A ^ B // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), - m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // For the remaining cases we need to get rid of one of the operands. if (!Op0->hasOneUse() && !Op1->hasOneUse()) @@ -2878,6 +3012,7 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator())); Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); // Replace all uses of Y (excluding the one in NotY!) with NotY. + Worklist.pushUsersToWorkList(*Y); Y->replaceUsesWithIf(NotY, [NotY](Use &U) { return U.getUser() != NotY; }); } @@ -2924,6 +3059,9 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, Constant *C; if (D->hasOneUse() && match(M, m_Constant(C))) { + // Propagating undef is unsafe. Clamp undef elements to -1. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); // Unfold. Value *LHS = Builder.CreateAnd(X, C); Value *NotC = Builder.CreateNot(C); @@ -3058,13 +3196,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) Constant *C; if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && - match(C, m_Negative())) + match(C, m_Negative())) { + // We matched a negative constant, so propagating undef is unsafe. + // Clamp undef elements to -1. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); + } // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && - match(C, m_NonNegative())) + match(C, m_NonNegative())) { + // We matched a non-negative constant, so propagating undef is unsafe. + // Clamp undef elements to 0. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy)); return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); + } // ~(X + C) --> -(C + 1) - X if (match(Op0, m_Add(m_Value(X), m_Constant(C)))) @@ -3114,10 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && MaskedValueIsZero(X, *C, 0, &I)) { Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC); - Worklist.Add(cast<Instruction>(Op0)); - I.setOperand(0, X); - I.setOperand(1, NewC); - return &I; + return BinaryOperator::CreateXor(X, NewC); } } } |