diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
commit | eb11fae6d08f479c0799db45860a98af528fa6e7 (patch) | |
tree | 44d492a50c8c1a7eb8e2d17ea3360ec4d066f042 /lib/Transforms/InstCombine | |
parent | b8a2042aa938069e862750553db0e4d82d25822c (diff) |
Notes
Diffstat (limited to 'lib/Transforms/InstCombine')
16 files changed, 3571 insertions, 2325 deletions
diff --git a/lib/Transforms/InstCombine/CMakeLists.txt b/lib/Transforms/InstCombine/CMakeLists.txt index 5cbe804ce3ec..8a3a58e9ecc9 100644 --- a/lib/Transforms/InstCombine/CMakeLists.txt +++ b/lib/Transforms/InstCombine/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS InstCombineTables.td) +tablegen(LLVM InstCombineTables.inc -gen-searchable-tables) +add_public_tablegen_target(InstCombineTableGen) + add_llvm_library(LLVMInstCombine InstructionCombining.cpp InstCombineAddSub.cpp diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 688897644848..aa31e0d850dd 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -511,7 +511,8 @@ Value *FAddCombine::performFactorization(Instruction *I) { } Value *FAddCombine::simplify(Instruction *I) { - assert(I->isFast() && "Expected 'fast' instruction"); + assert(I->hasAllowReassoc() && I->hasNoSignedZeros() && + "Expected 'reassoc'+'nsz' instruction"); // Currently we are not able to handle vector type. if (I->getType()->isVectorTy()) @@ -855,48 +856,6 @@ Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { return createFMul(OpndVal, Coeff.getValue(Instr->getType())); } -/// \brief Return true if we can prove that: -/// (sub LHS, RHS) === (sub nsw LHS, RHS) -/// This basically requires proving that the add in the original type would not -/// overflow to change the sign bit or have a carry out. -/// TODO: Handle this for Vectors. -bool InstCombiner::willNotOverflowSignedSub(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // If LHS and RHS each have at least two sign bits, the subtraction - // cannot overflow. - if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && - ComputeNumSignBits(RHS, 0, &CxtI) > 1) - return true; - - KnownBits LHSKnown = computeKnownBits(LHS, 0, &CxtI); - - KnownBits RHSKnown = computeKnownBits(RHS, 0, &CxtI); - - // Subtraction of two 2's complement numbers having identical signs will - // never overflow. - if ((LHSKnown.isNegative() && RHSKnown.isNegative()) || - (LHSKnown.isNonNegative() && RHSKnown.isNonNegative())) - return true; - - // TODO: implement logic similar to checkRippleForAdd - return false; -} - -/// \brief Return true if we can prove that: -/// (sub LHS, RHS) === (sub nuw LHS, RHS) -bool InstCombiner::willNotOverflowUnsignedSub(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // If the LHS is negative and the RHS is non-negative, no unsigned wrap. - KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); - KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); - if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) - return true; - - return false; -} - // Checks if any operand is negative and we can convert add to sub. // This function checks for following negative patterns // ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) @@ -964,7 +923,7 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (!match(Op1, m_Constant(Op1C))) return nullptr; - if (Instruction *NV = foldOpWithConstantIntoOperand(Add)) + if (Instruction *NV = foldBinOpIntoSelectOrPhi(Add)) return NV; Value *X; @@ -1031,17 +990,148 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { return nullptr; } -Instruction *InstCombiner::visitAdd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); +// Matches multiplication expression Op * C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns true if such a +// match is found. +static bool MatchMul(Value *E, Value *&Op, APInt &C) { + const APInt *AI; + if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + return false; +} + +// Matches remainder expression Op % C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns the signedness of +// the remainder operation in IsSigned. Returns true if such a match is +// found. +static bool MatchRem(Value *E, Value *&Op, APInt &C, bool &IsSigned) { + const APInt *AI; + IsSigned = false; + if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) { + IsSigned = true; + C = *AI; + return true; + } + if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) { + C = *AI + 1; + return true; + } + return false; +} +// Matches division expression Op / C with the given signedness as indicated +// by IsSigned, where C is a constant. Returns the constant value in C and the +// other operand in Op. Returns true if such a match is found. +static bool MatchDiv(Value *E, Value *&Op, APInt &C, bool IsSigned) { + const APInt *AI; + if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (!IsSigned) { + if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + } + return false; +} + +// Returns whether C0 * C1 with the given signedness overflows. +static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) { + bool overflow; + if (IsSigned) + (void)C0.smul_ov(C1, overflow); + else + (void)C0.umul_ov(C1, overflow); + return overflow; +} + +// Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1) +// does not overflow. +Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (Value *V = - SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) + Value *X, *MulOpV; + APInt C0, MulOpC; + bool IsSigned; + // Match I = X % C0 + MulOpV * C0 + if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) || + (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) && + C0 == MulOpC) { + Value *RemOpV; + APInt C1; + bool Rem2IsSigned; + // Match MulOpC = RemOpV % C1 + if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) && + IsSigned == Rem2IsSigned) { + Value *DivOpV; + APInt DivOpC; + // Match RemOpV = X / C0 + if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV && + C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) { + Value *NewDivisor = + ConstantInt::get(X->getType()->getContext(), C0 * C1); + return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem") + : Builder.CreateURem(X, NewDivisor, "urem"); + } + } + } + + return nullptr; +} + +/// Fold +/// (1 << NBits) - 1 +/// Into: +/// ~(-(1 << NBits)) +/// Because a 'not' is better for bit-tracking analysis and other transforms +/// than an 'add'. The new shl is always nsw, and is nuw if old `and` was. +static Instruction *canonicalizeLowbitMask(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *NBits; + if (!match(&I, m_Add(m_OneUse(m_Shl(m_One(), m_Value(NBits))), m_AllOnes()))) + return nullptr; + + Constant *MinusOne = Constant::getAllOnesValue(NBits->getType()); + Value *NotMask = Builder.CreateShl(MinusOne, NBits, "notmask"); + // Be wary of constant folding. + if (auto *BOp = dyn_cast<BinaryOperator>(NotMask)) { + // Always NSW. But NUW propagates from `add`. + BOp->setHasNoSignedWrap(); + BOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } + + return BinaryOperator::CreateNot(NotMask, I.getName()); +} + +Instruction *InstCombiner::visitAdd(BinaryOperator &I) { + if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -1051,6 +1141,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // FIXME: This should be moved into the above helper function to allow these // transforms for general constant or constant splat vectors. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; @@ -1123,6 +1214,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = checkForNegativeOperand(I, Builder)) return replaceInstUsesWith(I, V); + // (A + 1) + ~B --> A - B + // ~B + (A + 1) --> A - B + if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B))))) + return BinaryOperator::CreateSub(A, B); + + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) + if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1253,26 +1352,15 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // (add (xor A, B) (and A, B)) --> (or A, B) - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - // (add (and A, B) (xor A, B)) --> (or A, B) - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } - // (add (and A, B) (or A, B)) --> (add A, B) - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -1281,6 +1369,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -1290,39 +1379,35 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { I.setHasNoUnsignedWrap(true); } + if (Instruction *V = canonicalizeLowbitMask(I, Builder)) + return V; + return Changed ? &I : nullptr; } Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), + if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (isa<Constant>(RHS)) - if (Instruction *FoldedFAdd = foldOpWithConstantIntoOperand(I)) - return FoldedFAdd; + if (SimplifyAssociativeOrCommutative(I)) + return &I; - // -A + B --> B - A - // -A + -B --> -(A + B) - if (Value *LHSV = dyn_castFNegVal(LHS)) { - Instruction *RI = BinaryOperator::CreateFSub(RHS, LHSV); - RI->copyFastMathFlags(&I); - return RI; - } + if (Instruction *X = foldShuffledBinop(I)) + return X; - // A + -B --> A - B - if (!isa<Constant>(RHS)) - if (Value *V = dyn_castFNegVal(RHS)) { - Instruction *RI = BinaryOperator::CreateFSub(LHS, V); - RI->copyFastMathFlags(&I); - return RI; - } + if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) + return FoldedFAdd; + + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *X; + // (-X) + Y --> Y - X + if (match(LHS, m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFSubFMF(RHS, X, &I); + // Y + (-X) --> Y - X + if (match(RHS, m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFSubFMF(LHS, X, &I); // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. @@ -1386,12 +1471,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) return replaceInstUsesWith(I, V); - if (I.isFast()) { + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } - return Changed ? &I : nullptr; + return nullptr; } /// Optimize pointer differences into the same array into a size. Consider: @@ -1481,21 +1566,20 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, } Instruction *InstCombiner::visitSub(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = - SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // If this is a 'B = x-(-A)', change to B = x+A. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); @@ -1519,12 +1603,28 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_AllOnes())) return BinaryOperator::CreateNot(Op1); + // (~X) - (~Y) --> Y - X + Value *X, *Y; + if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y)))) + return BinaryOperator::CreateSub(Y, X); + if (Constant *C = dyn_cast<Constant>(Op0)) { + bool IsNegate = match(C, m_ZeroInt()); Value *X; - // C - zext(bool) -> bool ? C - 1 : C - if (match(Op1, m_ZExt(m_Value(X))) && - X->getType()->getScalarSizeInBits() == 1) + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (zext bool) --> sext bool + // C - (zext bool) --> bool ? C - 1 : C + if (IsNegate) + return CastInst::CreateSExtOrBitCast(X, I.getType()); return SelectInst::Create(X, SubOne(C), C); + } + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (sext bool) --> zext bool + // C - (sext bool) --> bool ? C + 1 : C + if (IsNegate) + return CastInst::CreateZExtOrBitCast(X, I.getType()); + return SelectInst::Create(X, AddOne(C), C); + } // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) @@ -1544,16 +1644,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); - - // Fold (sub 0, (zext bool to B)) --> (sext bool to B) - if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateSExtOrBitCast(X, Op1->getType()); - - // Fold (sub 0, (sext bool to B)) --> (zext bool to B) - if (C->isNullValue() && match(Op1, m_SExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } const APInt *Op0C; @@ -1575,6 +1665,22 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); return BinaryOperator::CreateLShr(X, ShAmtOp); } + + if (Op1->hasOneUse()) { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // This is a negate of an ABS/NABS pattern. Just swap the operands + // of the select. + SelectInst *SI = cast<SelectInst>(Op1); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + SI->setTrueValue(FalseVal); + SI->setFalseValue(TrueVal); + // Don't swap prof metadata, we didn't change the branch behavior. + return replaceInstUsesWith(I, SI); + } + } } // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known @@ -1678,6 +1784,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) return replaceInstUsesWith(I, Res); + // Canonicalize a shifty way to code absolute value to the common pattern. + // There are 2 potential commuted variants. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the xor has exactly 1 use (otherwise, we might increase + // instructions). + Value *A; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { + // B = ashr i32 A, 31 ; smear the sign bit + // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) + // --> (A < 0) ? -A : A + Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + // Copy the nuw/nsw flags from the sub to the negate. + Value *Neg = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), + I.hasNoSignedWrap()); + return SelectInst::Create(Cmp, Neg, A); + } + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { Changed = true; @@ -1692,21 +1819,32 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } Instruction *InstCombiner::visitFSub(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), + if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X - if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { - // Subtraction from -0.0 is the canonical form of fneg. - Instruction *NewI = BinaryOperator::CreateFNeg(Op1); - NewI->copyFastMathFlags(&I); - return NewI; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) + return BinaryOperator::CreateFNegFMF(Op1, &I); + + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) + // Canonicalize to fadd to make analysis easier. + // This can also help codegen because fadd is commutative. + // Note that if this fsub was really an fneg, the fadd with -0.0 will get + // killed later. We still limit that particular transform with 'hasOneUse' + // because an fneg is assumed better/cheaper than a generic fsub. + Value *X, *Y; + if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { + if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); + return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + } } if (isa<Constant>(Op0)) @@ -1714,34 +1852,34 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Instruction *NV = FoldOpIntoSelect(I, SI)) return NV; - // If this is a 'B = x-(-A)', change to B = x+A, potentially looking - // through FP extensions/truncations along the way. - if (Value *V = dyn_castFNegVal(Op1)) { - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, V); - NewI->copyFastMathFlags(&I); - return NewI; - } - if (FPTruncInst *FPTI = dyn_cast<FPTruncInst>(Op1)) { - if (Value *V = dyn_castFNegVal(FPTI->getOperand(0))) { - Value *NewTrunc = Builder.CreateFPTrunc(V, I.getType()); - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewTrunc); - NewI->copyFastMathFlags(&I); - return NewI; - } - } else if (FPExtInst *FPEI = dyn_cast<FPExtInst>(Op1)) { - if (Value *V = dyn_castFNegVal(FPEI->getOperand(0))) { - Value *NewExt = Builder.CreateFPExt(V, I.getType()); - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewExt); - NewI->copyFastMathFlags(&I); - return NewI; - } + // X - C --> X + (-C) + // But don't transform constant expressions because there's an inverse fold + // for X + (-Y) --> X - Y. + Constant *C; + if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1)) + return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + + // X - (-Y) --> X + Y + if (match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + + // Similar to above, but look through a cast of the negated value: + // X - (fptrunc(-Y)) --> X + fptrunc(Y) + if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) { + Value *TruncY = Builder.CreateFPTrunc(Y, I.getType()); + return BinaryOperator::CreateFAddFMF(Op0, TruncY, &I); + } + // X - (fpext(-Y)) --> X + fpext(Y) + if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) { + Value *ExtY = Builder.CreateFPExt(Y, I.getType()); + return BinaryOperator::CreateFAddFMF(Op0, ExtY, &I); } // Handle specials cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); - if (I.isFast()) { + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2364202e5b69..372bc41f780e 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -14,10 +14,10 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -75,7 +75,7 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, return Builder.CreateFCmp(Pred, LHS, RHS); } -/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or +/// Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or /// BITWISE_OP(BSWAP(A), Constant) to BSWAP(BITWISE_OP(A, B)) /// \param I Binary operator to transform. /// \return Pointer to node that must replace the original binary operator, or @@ -305,17 +305,21 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre } /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). -/// Return the set of pattern classes (from MaskedICmpType) that both LHS and -/// RHS satisfy. -static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, - Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, - ICmpInst::Predicate &PredL, - ICmpInst::Predicate &PredR) { +/// Return the pattern classes (from MaskedICmpType) for the left hand side and +/// the right hand side as a pair. +/// LHS and RHS are the left hand side and the right hand side ICmps and PredL +/// and PredR are their predicates, respectively. +static +Optional<std::pair<unsigned, unsigned>> +getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { // vectors are not (yet?) supported. Don't support pointers either. if (!LHS->getOperand(0)->getType()->isIntegerTy() || !RHS->getOperand(0)->getType()->isIntegerTy()) - return 0; + return None; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -346,7 +350,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if LHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredL)) - return 0; + return None; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); @@ -360,7 +364,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, A = R12; D = R11; } else { - return 0; + return None; } E = R2; R1 = nullptr; @@ -388,7 +392,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if RHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredR)) - return 0; + return None; // Look for ANDs on the right side of the RHS icmp. if (!Ok) { @@ -408,11 +412,11 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R1; Ok = true; } else { - return 0; + return None; } } if (!Ok) - return 0; + return None; if (L11 == A) { B = L12; @@ -430,7 +434,174 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, unsigned LeftType = getMaskedICmpType(A, B, C, PredL); unsigned RightType = getMaskedICmpType(A, D, E, PredR); - return LeftType & RightType; + return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType)); +} + +/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros +/// and the right hand side is of type BMask_Mixed. For example, +/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). +static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + llvm::InstCombiner::BuilderTy &Builder) { + // We are given the canonical form: + // (icmp ne (A & B), 0) & (icmp eq (A & D), E). + // where D & E == E. + // + // If IsAnd is false, we get it in negated form: + // (icmp eq (A & B), 0) | (icmp ne (A & D), E) -> + // !((icmp ne (A & B), 0) & (icmp eq (A & D), E)). + // + // We currently handle the case of B, C, D, E are constant. + // + ConstantInt *BCst = dyn_cast<ConstantInt>(B); + if (!BCst) + return nullptr; + ConstantInt *CCst = dyn_cast<ConstantInt>(C); + if (!CCst) + return nullptr; + ConstantInt *DCst = dyn_cast<ConstantInt>(D); + if (!DCst) + return nullptr; + ConstantInt *ECst = dyn_cast<ConstantInt>(E); + if (!ECst) + return nullptr; + + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + + // Update E to the canonical form when D is a power of two and RHS is + // canonicalized as, + // (icmp ne (A & D), 0) -> (icmp eq (A & D), D) or + // (icmp ne (A & D), D) -> (icmp eq (A & D), 0). + if (PredR != NewCC) + ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + + // If B or D is zero, skip because if LHS or RHS can be trivially folded by + // other folding rules and this pattern won't apply any more. + if (BCst->getValue() == 0 || DCst->getValue() == 0) + return nullptr; + + // If B and D don't intersect, ie. (B & D) == 0, no folding because we can't + // deduce anything from it. + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 3), 1) -> no folding. + if ((BCst->getValue() & DCst->getValue()) == 0) + return nullptr; + + // If the following two conditions are met: + // + // 1. mask B covers only a single bit that's not covered by mask D, that is, + // (B & (B ^ D)) is a power of 2 (in other words, B minus the intersection of + // B and D has only one bit set) and, + // + // 2. RHS (and E) indicates that the rest of B's bits are zero (in other + // words, the intersection of B and D is zero), that is, ((B & D) & E) == 0 + // + // then that single bit in B must be one and thus the whole expression can be + // folded to + // (A & (B | D)) == (B & (B ^ D)) | E. + // + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 7), 1) -> (icmp eq (A & 15), 9) + // (icmp ne (A & 15), 0) & (icmp eq (A & 7), 0) -> (icmp eq (A & 15), 8) + if ((((BCst->getValue() & DCst->getValue()) & ECst->getValue()) == 0) && + (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())).isPowerOf2()) { + APInt BorD = BCst->getValue() | DCst->getValue(); + APInt BandBxorDorE = (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())) | + ECst->getValue(); + Value *NewMask = ConstantInt::get(BCst->getType(), BorD); + Value *NewMaskedValue = ConstantInt::get(BCst->getType(), BandBxorDorE); + Value *NewAnd = Builder.CreateAnd(A, NewMask); + return Builder.CreateICmp(NewCC, NewAnd, NewMaskedValue); + } + + auto IsSubSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { + return (C1->getValue() & C2->getValue()) == C1->getValue(); + }; + auto IsSuperSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { + return (C1->getValue() & C2->getValue()) == C2->getValue(); + }; + + // In the following, we consider only the cases where B is a superset of D, B + // is a subset of D, or B == D because otherwise there's at least one bit + // covered by B but not D, in which case we can't deduce much from it, so + // no folding (aside from the single must-be-one bit case right above.) + // For example, + // (icmp ne (A & 14), 0) & (icmp eq (A & 3), 1) -> no folding. + if (!IsSubSetOrEqual(BCst, DCst) && !IsSuperSetOrEqual(BCst, DCst)) + return nullptr; + + // At this point, either B is a superset of D, B is a subset of D or B == D. + + // If E is zero, if B is a subset of (or equal to) D, LHS and RHS contradict + // and the whole expression becomes false (or true if negated), otherwise, no + // folding. + // For example, + // (icmp ne (A & 3), 0) & (icmp eq (A & 7), 0) -> false. + // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding. + if (ECst->isZero()) { + if (IsSubSetOrEqual(BCst, DCst)) + return ConstantInt::get(LHS->getType(), !IsAnd); + return nullptr; + } + + // At this point, B, D, E aren't zero and (B & D) == B, (B & D) == D or B == + // D. If B is a superset of (or equal to) D, since E is not zero, LHS is + // subsumed by RHS (RHS implies LHS.) So the whole expression becomes + // RHS. For example, + // (icmp ne (A & 255), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + if (IsSuperSetOrEqual(BCst, DCst)) + return RHS; + // Otherwise, B is a subset of D. If B and E have a common bit set, + // ie. (B & E) != 0, then LHS is subsumed by RHS. For example. + // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); + if ((BCst->getValue() & ECst->getValue()) != 0) + return RHS; + // Otherwise, LHS and RHS contradict and the whole expression becomes false + // (or true if negated.) For example, + // (icmp ne (A & 7), 0) & (icmp eq (A & 15), 8) -> false. + // (icmp ne (A & 6), 0) & (icmp eq (A & 15), 8) -> false. + return ConstantInt::get(LHS->getType(), !IsAnd); +} + +/// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side +/// aren't of the common mask pattern type. +static Value *foldLogOpOfMaskedICmpsAsymmetric( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + unsigned LHSMask, unsigned RHSMask, + llvm::InstCombiner::BuilderTy &Builder) { + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + // Handle Mask_NotAllZeros-BMask_Mixed cases. + // (icmp ne/eq (A & B), C) &/| (icmp eq/ne (A & D), E), or + // (icmp eq/ne (A & B), C) &/| (icmp ne/eq (A & D), E) + // which gets swapped to + // (icmp ne/eq (A & D), E) &/| (icmp eq/ne (A & B), C). + if (!IsAnd) { + LHSMask = conjugateICmpMask(LHSMask); + RHSMask = conjugateICmpMask(RHSMask); + } + if ((LHSMask & Mask_NotAllZeros) && (RHSMask & BMask_Mixed)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + LHS, RHS, IsAnd, A, B, C, D, E, + PredL, PredR, Builder)) { + return V; + } + } else if ((LHSMask & BMask_Mixed) && (RHSMask & Mask_NotAllZeros)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + RHS, LHS, IsAnd, A, D, E, B, C, + PredR, PredL, Builder)) { + return V; + } + } + return nullptr; } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) @@ -439,13 +610,24 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - unsigned Mask = + Optional<std::pair<unsigned, unsigned>> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); - if (Mask == 0) + if (!MaskPair) return nullptr; - assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && "Expected equality predicates for masked type of icmps."); + unsigned LHSMask = MaskPair->first; + unsigned RHSMask = MaskPair->second; + unsigned Mask = LHSMask & RHSMask; + if (Mask == 0) { + // Even if the two sides don't share a common pattern, check if folding can + // still happen. + if (Value *V = foldLogOpOfMaskedICmpsAsymmetric( + LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask, + Builder)) + return V; + return nullptr; + } // In full generality: // (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -939,8 +1121,8 @@ Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) return nullptr; // FCmp canonicalization ensures that (fcmp ord/uno X, X) and - // (fcmp ord/uno X, C) will be transformed to (fcmp X, 0.0). - if (match(LHS1, m_Zero()) && LHS1 == RHS1) + // (fcmp ord/uno X, C) will be transformed to (fcmp X, +0.0). + if (match(LHS1, m_PosZeroFP()) && match(RHS1, m_PosZeroFP())) // Ignore the constants because they are obviously not NANs: // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) @@ -1106,8 +1288,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // Operand complexity canonicalization guarantees that the 'or' is Op0. // (A | B) & ~(A & B) --> A ^ B // (A | B) & ~(B & A) --> A ^ B - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) + if (match(&I, m_BinOp(m_Or(m_Value(A), m_Value(B)), + m_Not(m_c_And(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateXor(A, B); // (A | ~B) & (~A | B) --> ~(A ^ B) @@ -1115,8 +1297,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // (~B | A) & (~A | B) --> ~(A ^ B) // (~B | A) & (B | ~A) --> ~(A ^ B) if (Op0->hasOneUse() || Op1->hasOneUse()) - if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(&I, m_BinOp(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); return nullptr; @@ -1148,18 +1330,86 @@ static Instruction *foldOrToXor(BinaryOperator &I, return nullptr; } +/// Return true if a constant shift amount is always less than the specified +/// bit-width. If not, the shift could create poison in the narrower type. +static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { + if (auto *ScalarC = dyn_cast<ConstantInt>(C)) + return ScalarC->getZExtValue() < BitWidth; + + if (C->getType()->isVectorTy()) { + // Check each element of a constant vector. + unsigned NumElts = C->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || CI->getZExtValue() >= BitWidth) + return false; + } + return true; + } + + // The constant is a constant expression or unknown. + return false; +} + +/// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and +/// a common zext operand: and (binop (zext X), C), (zext X). +Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) { + // This transform could also apply to {or, and, xor}, but there are better + // folds for those cases, so we don't expect those patterns here. AShr is not + // handled because it should always be transformed to LShr in this sequence. + // The subtract transform is different because it has a constant on the left. + // Add/mul commute the constant to RHS; sub with constant RHS becomes add. + Value *Op0 = And.getOperand(0), *Op1 = And.getOperand(1); + Constant *C; + if (!match(Op0, m_OneUse(m_Add(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Mul(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_LShr(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Shl(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Sub(m_Constant(C), m_Specific(Op1))))) + return nullptr; + + Value *X; + if (!match(Op1, m_ZExt(m_Value(X))) || Op1->hasNUsesOrMore(3)) + return nullptr; + + Type *Ty = And.getType(); + if (!isa<VectorType>(Ty) && !shouldChangeType(Ty, X->getType())) + return nullptr; + + // If we're narrowing a shift, the shift amount must be safe (less than the + // width) in the narrower type. If the shift amount is greater, instsimplify + // usually handles that case, but we can't guarantee/assert it. + Instruction::BinaryOps Opc = cast<BinaryOperator>(Op0)->getOpcode(); + if (Opc == Instruction::LShr || Opc == Instruction::Shl) + if (!canNarrowShiftAmt(C, X->getType()->getScalarSizeInBits())) + return nullptr; + + // and (sub C, (zext X)), (zext X) --> zext (and (sub C', X), X) + // and (binop (zext X), C), (zext X) --> zext (and (binop X, C'), X) + Value *NewC = ConstantExpr::getTrunc(C, X->getType()); + Value *NewBO = Opc == Instruction::Sub ? Builder.CreateBinOp(Opc, NewC, X) + : Builder.CreateBinOp(Opc, X, NewC); + return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitAnd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyAndInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1177,6 +1427,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); const APInt *C; if (match(Op1, m_APInt(C))) { Value *X, *Y; @@ -1289,9 +1540,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *Z = narrowMaskedBinOp(I)) + return Z; + + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; @@ -1397,7 +1650,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); - return Changed ? &I : nullptr; + return nullptr; } /// Given an OR instruction, check to see if this is a bswap idiom. If so, @@ -1424,7 +1677,18 @@ Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && match(Op1, m_And(m_Value(), m_Value())); - if (!OrOfOrs && !OrOfShifts && !OrOfAnds) + // (A << B) | (C & D) -> bswap if possible. + // The bigger pattern here is ((A & C1) << C2) | ((B >> C2) & C1), which is a + // part of the bswap idiom for specific values of C1, C2 (e.g. C1 = 16711935, + // C2 = 8 for i32). + // This pattern can occur when the operands of the 'or' are not canonicalized + // for some reason (not having only one use, for example). + bool OrOfAndAndSh = (match(Op0, m_LogicalShift(m_Value(), m_Value())) && + match(Op1, m_And(m_Value(), m_Value()))) || + (match(Op0, m_And(m_Value(), m_Value())) && + match(Op1, m_LogicalShift(m_Value(), m_Value()))); + + if (!OrOfOrs && !OrOfShifts && !OrOfAnds && !OrOfAndAndSh) return nullptr; SmallVector<Instruction*, 4> Insts; @@ -1448,7 +1712,6 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { return false; // One element must be all ones, and the other must be all zeros. - // FIXME: Allow undef elements. if (!((match(EltC1, m_Zero()) && match(EltC2, m_AllOnes())) || (match(EltC2, m_Zero()) && match(EltC1, m_AllOnes())))) return false; @@ -1755,14 +2018,15 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitOr(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyOrInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1780,14 +2044,14 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; // Given an OR instruction, check to see if this is a bswap. if (Instruction *BSwap = MatchBSwap(I)) return BSwap; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); { Value *A; const APInt *C; @@ -2027,7 +2291,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - return Changed ? &I : nullptr; + return nullptr; } /// A ^ B can be specified using other logic ops in a variety of patterns. We @@ -2045,10 +2309,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A & B) ^ (B | A) -> A ^ B // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) || - (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Specific(B))))) { + if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), + m_c_Or(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2058,10 +2320,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B - if ((match(Op0, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2071,10 +2331,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2113,6 +2371,34 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // TODO: This can be generalized to compares of non-signbits using + // decomposeBitTestICmp(). It could be enhanced more by using (something like) + // foldLogOpOfMaskedICmps(). + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + if ((LHS->hasOneUse() || RHS->hasOneUse()) && + LHS0->getType() == RHS0->getType()) { + // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0 + // (X < 0) ^ (Y < 0) --> (X ^ Y) < 0 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) { + Value *Zero = ConstantInt::getNullValue(LHS0->getType()); + return Builder.CreateICmpSLT(Builder.CreateXor(LHS0, RHS0), Zero); + } + // (X > -1) ^ (Y < 0) --> (X ^ Y) > -1 + // (X < 0) ^ (Y > -1) --> (X ^ Y) > -1 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) { + Value *MinusOne = ConstantInt::getAllOnesValue(LHS0->getType()); + return Builder.CreateICmpSGT(Builder.CreateXor(LHS0, RHS0), MinusOne); + } + } + // Instead of trying to imitate the folds for and/or, decompose this 'xor' // into those logic ops. That is, try to turn this into an and-of-icmps // because we have many folds for that pattern. @@ -2140,18 +2426,63 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return nullptr; } +/// If we have a masked merge, in the canonical form of: +/// (assuming that A only has one use.) +/// | A | |B| +/// ((x ^ y) & M) ^ y +/// | D | +/// * If M is inverted: +/// | D | +/// ((x ^ y) & ~M) ^ y +/// We can canonicalize by swapping the final xor operand +/// to eliminate the 'not' of the mask. +/// ((x ^ y) & M) ^ x +/// * If M is a constant, and D has one use, we transform to 'and' / 'or' ops +/// because that shortens the dependency chain and improves analysis: +/// (x & M) | (y & ~M) +static Instruction *visitMaskedMerge(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *B, *X, *D; + Value *M; + if (!match(&I, m_c_Xor(m_Value(B), + m_OneUse(m_c_And( + m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)), + m_Value(D)), + m_Value(M)))))) + return nullptr; + + Value *NotM; + if (match(M, m_Not(m_Value(NotM)))) { + // De-invert the mask and swap the value in B part. + Value *NewA = Builder.CreateAnd(D, NotM); + return BinaryOperator::CreateXor(NewA, X); + } + + Constant *C; + if (D->hasOneUse() && match(M, m_Constant(C))) { + // Unfold. + Value *LHS = Builder.CreateAnd(X, C); + Value *NotC = Builder.CreateNot(C); + Value *RHS = Builder.CreateAnd(B, NotC); + return BinaryOperator::CreateOr(LHS, RHS); + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitXor(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *NewXor = foldXorToXor(I, Builder)) return NewXor; @@ -2168,6 +2499,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + // A^B --> A|B iff A and B have no bits set in common. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (haveNoCommonBitsSet(Op0, Op1, DL, &AC, &I, &DT)) + return BinaryOperator::CreateOr(Op0, Op1); + // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. Value *X, *Y; @@ -2186,6 +2522,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, NotY); } + if (Instruction *Xor = visitMaskedMerge(I, Builder)) + return Xor; + // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; if (match(&I, m_Not(m_BinOp(NotVal)))) { @@ -2206,6 +2545,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } + // ~(X - Y) --> ~X + Y + if (match(NotVal, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Y); + // ~(~X >>s Y) --> (X >>s Y) if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); @@ -2214,16 +2557,18 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // the 'not' by inverting the constant and using the opposite shift type. // Canonicalization rules ensure that only a negative constant uses 'ashr', // but we must check that in case that transform has not fired yet. - const APInt *C; - if (match(NotVal, m_AShr(m_APInt(C), m_Value(Y))) && C->isNegative()) { + Constant *C; + if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && + match(C, m_Negative())) { // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) - Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + Constant *NotC = ConstantExpr::getNot(C); return BinaryOperator::CreateLShr(NotC, Y); } - if (match(NotVal, m_LShr(m_APInt(C), m_Value(Y))) && C->isNonNegative()) { + if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && + match(C, m_NonNegative())) { // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) - Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + Constant *NotC = ConstantExpr::getNot(C); return BinaryOperator::CreateAShr(NotC, Y); } } @@ -2305,9 +2650,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; { Value *A, *B; @@ -2397,25 +2741,59 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; - // Canonicalize the shifty way to code absolute value to the common pattern. + // Canonicalize a shifty way to code absolute value to the common pattern. // There are 4 potential commuted variants. Move the 'ashr' candidate to Op1. // We're relying on the fact that we only do this transform when the shift has // exactly 2 uses and the add has exactly 1 use (otherwise, we might increase // instructions). - if (Op0->getNumUses() == 2) + if (Op0->hasNUses(2)) std::swap(Op0, Op1); const APInt *ShAmt; Type *Ty = I.getType(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && - Op1->getNumUses() == 2 && *ShAmt == Ty->getScalarSizeInBits() - 1 && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { // B = ashr i32 A, 31 ; smear the sign bit // xor (add A, B), B ; add -1 and flip bits if negative // --> (A < 0) ? -A : A Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); - return SelectInst::Create(Cmp, Builder.CreateNeg(A), A); + // Copy the nuw/nsw flags from the add to the negate. + auto *Add = cast<BinaryOperator>(Op0); + Value *Neg = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), + Add->hasNoSignedWrap()); + return SelectInst::Create(Cmp, Neg, A); + } + + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // + // %notx = xor i32 %x, -1 + // %cmp1 = icmp sgt i32 %notx, %y + // %smax = select i1 %cmp1, i32 %notx, i32 %y + // %res = xor i32 %smax, -1 + // => + // %noty = xor i32 %y, -1 + // %cmp2 = icmp slt %x, %noty + // %res = select i1 %cmp2, i32 %x, i32 %noty + // + // Same is applicable for smin/umax/umin. + { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(Op0, LHS, RHS).Flavor; + if (Op0->hasOneUse() && SelectPatternResult::isMinOrMax(SPF) && + match(Op1, m_AllOnes())) { + + Value *X; + if (match(RHS, m_Not(m_Value(X)))) + std::swap(RHS, LHS); + + if (match(LHS, m_Not(m_Value(X)))) { + Value *NotY = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); + } + } } - return Changed ? &I : nullptr; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 40e52ee755e5..cbfbd8a53993 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -57,7 +58,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> #include <cassert> @@ -73,11 +73,11 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); -static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements( - "unfold-element-atomic-memcpy-max-elements", - cl::init(16), - cl::desc("Maximum number of elements in atomic memcpy the optimizer is " - "allowed to unfold")); +static cl::opt<unsigned> GuardWideningWindow( + "instcombine-guard-widening-window", + cl::init(3), + cl::desc("How wide an instruction window to bypass looking for " + "another guard")); /// Return the specified type promoted as it would be to pass though a va_arg /// area. @@ -106,97 +106,24 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } -Instruction * -InstCombiner::SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI) { - // Try to unfold this intrinsic into sequence of explicit atomic loads and - // stores. - // First check that number of elements is compile time constant. - auto *LengthCI = dyn_cast<ConstantInt>(AMI->getLength()); - if (!LengthCI) - return nullptr; - - // Check that there are not too many elements. - uint64_t LengthInBytes = LengthCI->getZExtValue(); - uint32_t ElementSizeInBytes = AMI->getElementSizeInBytes(); - uint64_t NumElements = LengthInBytes / ElementSizeInBytes; - if (NumElements >= UnfoldElementAtomicMemcpyMaxElements) - return nullptr; - - // Only expand if there are elements to copy. - if (NumElements > 0) { - // Don't unfold into illegal integers - uint64_t ElementSizeInBits = ElementSizeInBytes * 8; - if (!getDataLayout().isLegalInteger(ElementSizeInBits)) - return nullptr; - - // Cast source and destination to the correct type. Intrinsic input - // arguments are usually represented as i8*. Often operands will be - // explicitly casted to i8* and we can just strip those casts instead of - // inserting new ones. However it's easier to rely on other InstCombine - // rules which will cover trivial cases anyway. - Value *Src = AMI->getRawSource(); - Value *Dst = AMI->getRawDest(); - Type *ElementPointerType = - Type::getIntNPtrTy(AMI->getContext(), ElementSizeInBits, - Src->getType()->getPointerAddressSpace()); - - Value *SrcCasted = Builder.CreatePointerCast(Src, ElementPointerType, - "memcpy_unfold.src_casted"); - Value *DstCasted = Builder.CreatePointerCast(Dst, ElementPointerType, - "memcpy_unfold.dst_casted"); - - for (uint64_t i = 0; i < NumElements; ++i) { - // Get current element addresses - ConstantInt *ElementIdxCI = - ConstantInt::get(AMI->getContext(), APInt(64, i)); - Value *SrcElementAddr = - Builder.CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr"); - Value *DstElementAddr = - Builder.CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr"); - - // Load from the source. Transfer alignment information and mark load as - // unordered atomic. - LoadInst *Load = Builder.CreateLoad(SrcElementAddr, "memcpy_unfold.val"); - Load->setOrdering(AtomicOrdering::Unordered); - // We know alignment of the first element. It is also guaranteed by the - // verifier that element size is less or equal than first element - // alignment and both of this values are powers of two. This means that - // all subsequent accesses are at least element size aligned. - // TODO: We can infer better alignment but there is no evidence that this - // will matter. - Load->setAlignment(i == 0 ? AMI->getParamAlignment(1) - : ElementSizeInBytes); - Load->setDebugLoc(AMI->getDebugLoc()); - - // Store loaded value via unordered atomic store. - StoreInst *Store = Builder.CreateStore(Load, DstElementAddr); - Store->setOrdering(AtomicOrdering::Unordered); - Store->setAlignment(i == 0 ? AMI->getParamAlignment(0) - : ElementSizeInBytes); - Store->setDebugLoc(AMI->getDebugLoc()); - } +Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { + unsigned DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); + unsigned CopyDstAlign = MI->getDestAlignment(); + if (CopyDstAlign < DstAlign){ + MI->setDestAlignment(DstAlign); + return MI; } - // Set the number of elements of the copy to 0, it will be deleted on the - // next iteration. - AMI->setLength(Constant::getNullValue(LengthCI->getType())); - return AMI; -} - -Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); - unsigned MinAlign = std::min(DstAlign, SrcAlign); - unsigned CopyAlign = MI->getAlignment(); - - if (CopyAlign < MinAlign) { - MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), MinAlign, false)); + unsigned SrcAlign = getKnownAlignment(MI->getRawSource(), DL, MI, &AC, &DT); + unsigned CopySrcAlign = MI->getSourceAlignment(); + if (CopySrcAlign < SrcAlign) { + MI->setSourceAlignment(SrcAlign); return MI; } // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with // load/store. - ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getArgOperand(2)); + ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getLength()); if (!MemOpLength) return nullptr; // Source and destination pointer types are always "i8*" for intrinsic. See @@ -222,7 +149,9 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { // If the memcpy has metadata describing the members, see if we can get the // TBAA tag describing our copy. MDNode *CopyMD = nullptr; - if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { + if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa)) { + CopyMD = M; + } else if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { if (M->getNumOperands() == 3 && M->getOperand(0) && mdconst::hasa<ConstantInt>(M->getOperand(0)) && mdconst::extract<ConstantInt>(M->getOperand(0))->isZero() && @@ -234,15 +163,11 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { CopyMD = cast<MDNode>(M->getOperand(2)); } - // If the memcpy/memmove provides better alignment info than we can - // infer, use it. - SrcAlign = std::max(SrcAlign, CopyAlign); - DstAlign = std::max(DstAlign, CopyAlign); - Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); - LoadInst *L = Builder.CreateLoad(Src, MI->isVolatile()); - L->setAlignment(SrcAlign); + LoadInst *L = Builder.CreateLoad(Src); + // Alignment from the mem intrinsic will be better, so use it. + L->setAlignment(CopySrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); MDNode *LoopMemParallelMD = @@ -250,23 +175,34 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { if (LoopMemParallelMD) L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); - StoreInst *S = Builder.CreateStore(L, Dest, MI->isVolatile()); - S->setAlignment(DstAlign); + StoreInst *S = Builder.CreateStore(L, Dest); + // Alignment from the mem intrinsic will be better, so use it. + S->setAlignment(CopyDstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + if (auto *MT = dyn_cast<MemTransferInst>(MI)) { + // non-atomics can be volatile + L->setVolatile(MT->isVolatile()); + S->setVolatile(MT->isVolatile()); + } + if (isa<AtomicMemTransferInst>(MI)) { + // atomics have to be unordered + L->setOrdering(AtomicOrdering::Unordered); + S->setOrdering(AtomicOrdering::Unordered); + } + // Set the size of the copy to 0, it will be deleted on the next iteration. - MI->setArgOperand(2, Constant::getNullValue(MemOpLength->getType())); + MI->setLength(Constant::getNullValue(MemOpLength->getType())); return MI; } -Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { +Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); - if (MI->getAlignment() < Alignment) { - MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), - Alignment, false)); + if (MI->getDestAlignment() < Alignment) { + MI->setDestAlignment(Alignment); return MI; } @@ -276,7 +212,7 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8)) return nullptr; uint64_t Len = LenC->getLimitedValue(); - Alignment = MI->getAlignment(); + Alignment = MI->getDestAlignment(); assert(Len && "0-sized memory setting should be removed already."); // memset(s,c,n) -> store s, c (for n=1,2,4,8) @@ -296,6 +232,8 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, MI->isVolatile()); S->setAlignment(Alignment); + if (isa<AtomicMemSetInst>(MI)) + S->setOrdering(AtomicOrdering::Unordered); // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(LenC->getType())); @@ -563,55 +501,6 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86muldq(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Value *Arg0 = II.getArgOperand(0); - Value *Arg1 = II.getArgOperand(1); - Type *ResTy = II.getType(); - assert(Arg0->getType()->getScalarSizeInBits() == 32 && - Arg1->getType()->getScalarSizeInBits() == 32 && - ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types"); - - // muldq/muludq(undef, undef) -> zero (matches generic mul behavior) - if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1)) - return ConstantAggregateZero::get(ResTy); - - // Constant folding. - // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)), - // vXi64 sext(shuffle<0,2,..>(Arg1)))) - // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)), - // vXi64 zext(shuffle<0,2,..>(Arg1)))) - if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) - return nullptr; - - unsigned NumElts = ResTy->getVectorNumElements(); - assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) && - Arg1->getType()->getVectorNumElements() == (2 * NumElts) && - "Unexpected muldq/muludq types"); - - unsigned IntrinsicID = II.getIntrinsicID(); - bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID || - Intrinsic::x86_avx2_pmul_dq == IntrinsicID || - Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID); - - SmallVector<unsigned, 16> ShuffleMask; - for (unsigned i = 0; i != NumElts; ++i) - ShuffleMask.push_back(i * 2); - - auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask); - auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask); - - if (IsSigned) { - LHS = Builder.CreateSExt(LHS, ResTy); - RHS = Builder.CreateSExt(RHS, ResTy); - } else { - LHS = Builder.CreateZExt(LHS, ResTy); - RHS = Builder.CreateZExt(RHS, ResTy); - } - - return Builder.CreateMul(LHS, RHS); -} - static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -687,6 +576,105 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { return ConstantVector::get(Vals); } +// Replace X86-specific intrinsics with generic floor-ceil where applicable. +static Value *simplifyX86round(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + ConstantInt *Arg = nullptr; + Intrinsic::ID IntrinsicID = II.getIntrinsicID(); + + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(2)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else + Arg = dyn_cast<ConstantInt>(II.getArgOperand(1)); + if (!Arg) + return nullptr; + unsigned RoundControl = Arg->getZExtValue(); + + Arg = nullptr; + unsigned SAE = 0; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(5)); + else + SAE = 4; + if (!SAE) { + if (!Arg) + return nullptr; + SAE = Arg->getZExtValue(); + } + + if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/)) + return nullptr; + + Value *Src, *Dst, *Mask; + bool IsScalar = false; + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + IsScalar = true; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Mask = II.getArgOperand(3); + Value *Zero = Constant::getNullValue(Mask->getType()); + Mask = Builder.CreateAnd(Mask, 1); + Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero); + Dst = II.getArgOperand(2); + } else + Dst = II.getArgOperand(0); + Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0); + } else { + Src = II.getArgOperand(0); + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) { + Dst = II.getArgOperand(2); + Mask = II.getArgOperand(3); + } else { + Dst = Src; + Mask = ConstantInt::getAllOnesValue( + Builder.getIntNTy(Src->getType()->getVectorNumElements())); + } + } + + Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor; + Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II); + if (!IsScalar) { + if (auto *C = dyn_cast<Constant>(Mask)) + if (C->isAllOnesValue()) + return Res; + auto *MaskTy = VectorType::get( + Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder.CreateBitCast(Mask, MaskTy); + unsigned Width = Src->getType()->getVectorNumElements(); + if (MaskTy->getVectorNumElements() > Width) { + uint32_t Indices[4]; + for (unsigned i = 0; i != Width; ++i) + Indices[i] = i; + Mask = Builder.CreateShuffleVector(Mask, Mask, + makeArrayRef(Indices, Width)); + } + return Builder.CreateSelect(Mask, Res, Dst); + } + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Dst = Builder.CreateExtractElement(Dst, (uint64_t)0); + Res = Builder.CreateSelect(Mask, Res, Dst); + Dst = II.getArgOperand(0); + } + return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); @@ -1145,36 +1133,6 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II, return nullptr; } -// Emit a select instruction and appropriate bitcasts to help simplify -// masked intrinsics. -static Value *emitX86MaskSelect(Value *Mask, Value *Op0, Value *Op1, - InstCombiner::BuilderTy &Builder) { - unsigned VWidth = Op0->getType()->getVectorNumElements(); - - // If the mask is all ones we don't need the select. But we need to check - // only the bit thats will be used in case VWidth is less than 8. - if (auto *C = dyn_cast<ConstantInt>(Mask)) - if (C->getValue().zextOrTrunc(VWidth).isAllOnesValue()) - return Op0; - - auto *MaskTy = VectorType::get(Builder.getInt1Ty(), - cast<IntegerType>(Mask->getType())->getBitWidth()); - Mask = Builder.CreateBitCast(Mask, MaskTy); - - // If we have less than 8 elements, then the starting mask was an i8 and - // we need to extract down to the right number of elements. - if (VWidth < 8) { - uint32_t Indices[4]; - for (unsigned i = 0; i != VWidth; ++i) - Indices[i] = i; - Mask = Builder.CreateShuffleVector(Mask, Mask, - makeArrayRef(Indices, VWidth), - "extract"); - } - - return Builder.CreateSelect(Mask, Op0, Op1); -} - static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -1308,6 +1266,40 @@ static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } +/// This function transforms launder.invariant.group and strip.invariant.group +/// like: +/// launder(launder(%x)) -> launder(%x) (the result is not the argument) +/// launder(strip(%x)) -> launder(%x) +/// strip(strip(%x)) -> strip(%x) (the result is not the argument) +/// strip(launder(%x)) -> strip(%x) +/// This is legal because it preserves the most recent information about +/// the presence or absence of invariant.group. +static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, + InstCombiner &IC) { + auto *Arg = II.getArgOperand(0); + auto *StrippedArg = Arg->stripPointerCasts(); + auto *StrippedInvariantGroupsArg = Arg->stripPointerCastsAndInvariantGroups(); + if (StrippedArg == StrippedInvariantGroupsArg) + return nullptr; // No launders/strips to remove. + + Value *Result = nullptr; + + if (II.getIntrinsicID() == Intrinsic::launder_invariant_group) + Result = IC.Builder.CreateLaunderInvariantGroup(StrippedInvariantGroupsArg); + else if (II.getIntrinsicID() == Intrinsic::strip_invariant_group) + Result = IC.Builder.CreateStripInvariantGroup(StrippedInvariantGroupsArg); + else + llvm_unreachable( + "simplifyInvariantGroupIntrinsic only handles launder and strip"); + if (Result->getType()->getPointerAddressSpace() != + II.getType()->getPointerAddressSpace()) + Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType()); + if (Result->getType() != II.getType()) + Result = IC.Builder.CreateBitCast(Result, II.getType()); + + return cast<Instruction>(Result); +} + static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { // If the mask is all zeros, a scatter does nothing. auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); @@ -1498,6 +1490,68 @@ static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, return maxnum(Src0, Src1); } +/// Convert a table lookup to shufflevector if the mask is constant. +/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in +/// which case we could lower the shufflevector with rev64 instructions +/// as it's actually a byte reverse. +static Value *simplifyNeonTbl1(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + // Bail out if the mask is not a constant. + auto *C = dyn_cast<Constant>(II.getArgOperand(1)); + if (!C) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + unsigned NumElts = VecTy->getNumElements(); + + // Only perform this transformation for <8 x i8> vector types. + if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8) + return nullptr; + + uint32_t Indexes[8]; + + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = C->getAggregateElement(I); + + if (!COp || !isa<ConstantInt>(COp)) + return nullptr; + + Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue(); + + // Make sure the mask indices are in range. + if (Indexes[I] >= NumElts) + return nullptr; + } + + auto *ShuffleMask = ConstantDataVector::get(II.getContext(), + makeArrayRef(Indexes)); + auto *V1 = II.getArgOperand(0); + auto *V2 = Constant::getNullValue(V1->getType()); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// Convert a vector load intrinsic into a simple llvm load instruction. +/// This is beneficial when the underlying object being addressed comes +/// from a constant, since we get constant-folding for free. +static Value *simplifyNeonVld1(const IntrinsicInst &II, + unsigned MemAlign, + InstCombiner::BuilderTy &Builder) { + auto *IntrAlign = dyn_cast<ConstantInt>(II.getArgOperand(1)); + + if (!IntrAlign) + return nullptr; + + unsigned Alignment = IntrAlign->getLimitedValue() < MemAlign ? + MemAlign : IntrAlign->getLimitedValue(); + + if (!isPowerOf2_32(Alignment)) + return nullptr; + + auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0), + PointerType::get(II.getType(), 0)); + return Builder.CreateAlignedLoad(BCastInst, Alignment); +} + // Returns true iff the 2 intrinsics have the same operands, limiting the // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1820,7 +1874,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Intrinsics cannot occur in an invoke, so handle them here instead of in // visitCallSite. - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) { + if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) { bool Changed = false; // memmove/cpy/set of zero bytes is a noop. @@ -1837,17 +1891,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } // No other transformations apply to volatile transfers. - if (MI->isVolatile()) - return nullptr; + if (auto *M = dyn_cast<MemIntrinsic>(MI)) + if (M->isVolatile()) + return nullptr; // If we have a memmove and the source operation is a constant global, // then the source and dest pointers can't alias, so we can change this // into a call to memcpy. - if (MemMoveInst *MMI = dyn_cast<MemMoveInst>(MI)) { + if (auto *MMI = dyn_cast<AnyMemMoveInst>(MI)) { if (GlobalVariable *GVSrc = dyn_cast<GlobalVariable>(MMI->getSource())) if (GVSrc->isConstant()) { Module *M = CI.getModule(); - Intrinsic::ID MemCpyID = Intrinsic::memcpy; + Intrinsic::ID MemCpyID = + isa<AtomicMemMoveInst>(MMI) + ? Intrinsic::memcpy_element_unordered_atomic + : Intrinsic::memcpy; Type *Tys[3] = { CI.getArgOperand(0)->getType(), CI.getArgOperand(1)->getType(), CI.getArgOperand(2)->getType() }; @@ -1856,7 +1914,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { + if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(MI)) { // memmove(x,x,size) -> noop. if (MTI->getSource() == MTI->getDest()) return eraseInstFromFunction(CI); @@ -1864,26 +1922,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // If we can determine a pointer alignment that is bigger than currently // set, update the alignment. - if (isa<MemTransferInst>(MI)) { - if (Instruction *I = SimplifyMemTransfer(MI)) + if (auto *MTI = dyn_cast<AnyMemTransferInst>(MI)) { + if (Instruction *I = SimplifyAnyMemTransfer(MTI)) return I; - } else if (MemSetInst *MSI = dyn_cast<MemSetInst>(MI)) { - if (Instruction *I = SimplifyMemSet(MSI)) + } else if (auto *MSI = dyn_cast<AnyMemSetInst>(MI)) { + if (Instruction *I = SimplifyAnyMemSet(MSI)) return I; } if (Changed) return II; } - if (auto *AMI = dyn_cast<AtomicMemCpyInst>(II)) { - if (Constant *C = dyn_cast<Constant>(AMI->getLength())) - if (C->isNullValue()) - return eraseInstFromFunction(*AMI); - - if (Instruction *I = SimplifyElementUnorderedAtomicMemCpy(AMI)) - return I; - } - if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) return I; @@ -1925,7 +1974,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return simplifyMaskedGather(*II, *this); case Intrinsic::masked_scatter: return simplifyMaskedScatter(*II, *this); - + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) + return replaceInstUsesWith(*II, SkippedBarrier); + break; case Intrinsic::powi: if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // 0 and 1 are handled in instsimplify @@ -1991,8 +2044,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->setArgOperand(1, Arg0); return II; } + + // FIXME: Simplifications should be in instsimplify. if (Value *V = simplifyMinnumMaxnum(*II)) return replaceInstUsesWith(*II, V); + + Value *X, *Y; + if (match(Arg0, m_FNeg(m_Value(X))) && match(Arg1, m_FNeg(m_Value(Y))) && + (Arg0->hasOneUse() || Arg1->hasOneUse())) { + // If both operands are negated, invert the call and negate the result: + // minnum(-X, -Y) --> -(maxnum(X, Y)) + // maxnum(-X, -Y) --> -(minnum(X, Y)) + Intrinsic::ID NewIID = II->getIntrinsicID() == Intrinsic::maxnum ? + Intrinsic::minnum : Intrinsic::maxnum; + Value *NewCall = Builder.CreateIntrinsic(NewIID, { X, Y }, II); + Instruction *FNeg = BinaryOperator::CreateFNeg(NewCall); + FNeg->copyIRFlags(II); + return FNeg; + } break; } case Intrinsic::fmuladd: { @@ -2013,37 +2082,34 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src0 = II->getArgOperand(0); Value *Src1 = II->getArgOperand(1); - // Canonicalize constants into the RHS. + // Canonicalize constant multiply operand to Src1. if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { II->setArgOperand(0, Src1); II->setArgOperand(1, Src0); std::swap(Src0, Src1); } - Value *LHS = nullptr; - Value *RHS = nullptr; - // fma fneg(x), fneg(y), z -> fma x, y, z - if (match(Src0, m_FNeg(m_Value(LHS))) && - match(Src1, m_FNeg(m_Value(RHS)))) { - II->setArgOperand(0, LHS); - II->setArgOperand(1, RHS); + Value *X, *Y; + if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { + II->setArgOperand(0, X); + II->setArgOperand(1, Y); return II; } // fma fabs(x), fabs(x), z -> fma x, x, z - if (match(Src0, m_Intrinsic<Intrinsic::fabs>(m_Value(LHS))) && - match(Src1, m_Intrinsic<Intrinsic::fabs>(m_Value(RHS))) && LHS == RHS) { - II->setArgOperand(0, LHS); - II->setArgOperand(1, RHS); + if (match(Src0, m_FAbs(m_Value(X))) && + match(Src1, m_FAbs(m_Specific(X)))) { + II->setArgOperand(0, X); + II->setArgOperand(1, X); return II; } // fma x, 1, z -> fadd x, z if (match(Src1, m_FPOne())) { - Instruction *RI = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); - RI->copyFastMathFlags(II); - return RI; + auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; } break; @@ -2067,17 +2133,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::rint: case Intrinsic::trunc: { Value *ExtSrc; - if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) && - II->getArgOperand(0)->hasOneUse()) { - // fabs (fpext x) -> fpext (fabs x) - Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(), - { ExtSrc->getType() }); - CallInst *NewFabs = Builder.CreateCall(F, ExtSrc); - NewFabs->copyFastMathFlags(II); - NewFabs->takeName(II); - return new FPExtInst(NewFabs, II->getType()); + if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc))))) { + // Narrow the call: intrinsic (fpext x) -> fpext (intrinsic x) + Value *NarrowII = Builder.CreateIntrinsic(II->getIntrinsicID(), + { ExtSrc }, II); + return new FPExtInst(NarrowII, II->getType()); } - break; } case Intrinsic::cos: @@ -2085,7 +2146,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *SrcSrc; Value *Src = II->getArgOperand(0); if (match(Src, m_FNeg(m_Value(SrcSrc))) || - match(Src, m_Intrinsic<Intrinsic::fabs>(m_Value(SrcSrc)))) { + match(Src, m_FAbs(m_Value(SrcSrc)))) { // cos(-x) -> cos(x) // cos(fabs(x)) -> cos(x) II->setArgOperand(0, SrcSrc); @@ -2298,6 +2359,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_sse41_round_ps: + case Intrinsic::x86_sse41_round_pd: + case Intrinsic::x86_avx_round_ps_256: + case Intrinsic::x86_avx_round_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_128: + case Intrinsic::x86_avx512_mask_rndscale_ps_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_512: + case Intrinsic::x86_avx512_mask_rndscale_pd_128: + case Intrinsic::x86_avx512_mask_rndscale_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_pd_512: + case Intrinsic::x86_avx512_mask_rndscale_ss: + case Intrinsic::x86_avx512_mask_rndscale_sd: + if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::x86_mmx_pmovmskb: case Intrinsic::x86_sse_movmsk_ps: case Intrinsic::x86_sse2_movmsk_pd: @@ -2355,16 +2432,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; } - case Intrinsic::x86_avx512_mask_cmp_pd_128: - case Intrinsic::x86_avx512_mask_cmp_pd_256: - case Intrinsic::x86_avx512_mask_cmp_pd_512: - case Intrinsic::x86_avx512_mask_cmp_ps_128: - case Intrinsic::x86_avx512_mask_cmp_ps_256: - case Intrinsic::x86_avx512_mask_cmp_ps_512: { + case Intrinsic::x86_avx512_cmp_pd_128: + case Intrinsic::x86_avx512_cmp_pd_256: + case Intrinsic::x86_avx512_cmp_pd_512: + case Intrinsic::x86_avx512_cmp_ps_128: + case Intrinsic::x86_avx512_cmp_ps_256: + case Intrinsic::x86_avx512_cmp_ps_512: { // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - bool Arg0IsZero = match(Arg0, m_Zero()); + bool Arg0IsZero = match(Arg0, m_PosZeroFP()); if (Arg0IsZero) std::swap(Arg0, Arg1); Value *A, *B; @@ -2376,7 +2453,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // The compare intrinsic uses the above assumptions and therefore // doesn't require additional flags. if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && - match(Arg1, m_Zero()) && isa<Instruction>(Arg0) && + match(Arg1, m_PosZeroFP()) && isa<Instruction>(Arg0) && cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { if (Arg0IsZero) std::swap(A, B); @@ -2387,17 +2464,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_avx512_mask_add_ps_512: - case Intrinsic::x86_avx512_mask_div_ps_512: - case Intrinsic::x86_avx512_mask_mul_ps_512: - case Intrinsic::x86_avx512_mask_sub_ps_512: - case Intrinsic::x86_avx512_mask_add_pd_512: - case Intrinsic::x86_avx512_mask_div_pd_512: - case Intrinsic::x86_avx512_mask_mul_pd_512: - case Intrinsic::x86_avx512_mask_sub_pd_512: + case Intrinsic::x86_avx512_add_ps_512: + case Intrinsic::x86_avx512_div_ps_512: + case Intrinsic::x86_avx512_mul_ps_512: + case Intrinsic::x86_avx512_sub_ps_512: + case Intrinsic::x86_avx512_add_pd_512: + case Intrinsic::x86_avx512_div_pd_512: + case Intrinsic::x86_avx512_mul_pd_512: + case Intrinsic::x86_avx512_sub_pd_512: // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular // IR operations. - if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(2))) { if (R->getValue() == 4) { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); @@ -2405,27 +2482,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *V; switch (II->getIntrinsicID()) { default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_avx512_mask_add_ps_512: - case Intrinsic::x86_avx512_mask_add_pd_512: + case Intrinsic::x86_avx512_add_ps_512: + case Intrinsic::x86_avx512_add_pd_512: V = Builder.CreateFAdd(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_sub_ps_512: - case Intrinsic::x86_avx512_mask_sub_pd_512: + case Intrinsic::x86_avx512_sub_ps_512: + case Intrinsic::x86_avx512_sub_pd_512: V = Builder.CreateFSub(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_mul_ps_512: - case Intrinsic::x86_avx512_mask_mul_pd_512: + case Intrinsic::x86_avx512_mul_ps_512: + case Intrinsic::x86_avx512_mul_pd_512: V = Builder.CreateFMul(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_div_ps_512: - case Intrinsic::x86_avx512_mask_div_pd_512: + case Intrinsic::x86_avx512_div_ps_512: + case Intrinsic::x86_avx512_div_pd_512: V = Builder.CreateFDiv(Arg0, Arg1); break; } - // Create a select for the masking. - V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - Builder); return replaceInstUsesWith(*II, V); } } @@ -2499,32 +2573,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_mask_min_ss_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: - case Intrinsic::x86_avx512_mask_vfmadd_ss: - case Intrinsic::x86_avx512_mask_vfmadd_sd: - case Intrinsic::x86_avx512_maskz_vfmadd_ss: - case Intrinsic::x86_avx512_maskz_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmadd_ss: - case Intrinsic::x86_avx512_mask3_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmsub_ss: - case Intrinsic::x86_avx512_mask3_vfmsub_sd: - case Intrinsic::x86_avx512_mask3_vfnmsub_ss: - case Intrinsic::x86_avx512_mask3_vfnmsub_sd: - case Intrinsic::x86_fma_vfmadd_ss: - case Intrinsic::x86_fma_vfmsub_ss: - case Intrinsic::x86_fma_vfnmadd_ss: - case Intrinsic::x86_fma_vfnmsub_ss: - case Intrinsic::x86_fma_vfmadd_sd: - case Intrinsic::x86_fma_vfmsub_sd: - case Intrinsic::x86_fma_vfnmadd_sd: - case Intrinsic::x86_fma_vfnmsub_sd: case Intrinsic::x86_sse_cmp_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: case Intrinsic::x86_sse2_cmp_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: { unsigned VWidth = II->getType()->getVectorNumElements(); @@ -2537,6 +2591,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } else if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + } // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). @@ -2647,26 +2714,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_sse2_pmulu_dq: - case Intrinsic::x86_sse41_pmuldq: - case Intrinsic::x86_avx2_pmul_dq: - case Intrinsic::x86_avx2_pmulu_dq: - case Intrinsic::x86_avx512_pmul_dq_512: - case Intrinsic::x86_avx512_pmulu_dq_512: { - if (Value *V = simplifyX86muldq(*II, Builder)) - return replaceInstUsesWith(*II, V); - - unsigned VWidth = II->getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt DemandedElts = APInt::getAllOnesValue(VWidth); - if (Value *V = SimplifyDemandedVectorElts(II, DemandedElts, UndefElts)) { - if (V != II) - return replaceInstUsesWith(*II, V); - return II; - } - break; - } - case Intrinsic::x86_sse2_packssdw_128: case Intrinsic::x86_sse2_packsswb_128: case Intrinsic::x86_avx2_packssdw: @@ -2687,7 +2734,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_pclmulqdq: { + case Intrinsic::x86_pclmulqdq: + case Intrinsic::x86_pclmulqdq_256: + case Intrinsic::x86_pclmulqdq_512: { if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { unsigned Imm = C->getZExtValue(); @@ -2695,27 +2744,28 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); unsigned VWidth = Arg0->getType()->getVectorNumElements(); - APInt DemandedElts(VWidth, 0); APInt UndefElts1(VWidth, 0); - DemandedElts = (Imm & 0x01) ? 2 : 1; - if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts, + APInt DemandedElts1 = APInt::getSplat(VWidth, + APInt(2, (Imm & 0x01) ? 2 : 1)); + if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts1, UndefElts1)) { II->setArgOperand(0, V); MadeChange = true; } APInt UndefElts2(VWidth, 0); - DemandedElts = (Imm & 0x10) ? 2 : 1; - if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts, + APInt DemandedElts2 = APInt::getSplat(VWidth, + APInt(2, (Imm & 0x10) ? 2 : 1)); + if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts2, UndefElts2)) { II->setArgOperand(1, V); MadeChange = true; } - // If both input elements are undef, the result is undef. - if (UndefElts1[(Imm & 0x01) ? 1 : 0] || - UndefElts2[(Imm & 0x10) ? 1 : 0]) + // If either input elements are undef, the result is zero. + if (DemandedElts1.isSubsetOf(UndefElts1) || + DemandedElts2.isSubsetOf(UndefElts2)) return replaceInstUsesWith(*II, ConstantAggregateZero::get(II->getType())); @@ -2916,32 +2966,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_permd: case Intrinsic::x86_avx2_permps: + case Intrinsic::x86_avx512_permvar_df_256: + case Intrinsic::x86_avx512_permvar_df_512: + case Intrinsic::x86_avx512_permvar_di_256: + case Intrinsic::x86_avx512_permvar_di_512: + case Intrinsic::x86_avx512_permvar_hi_128: + case Intrinsic::x86_avx512_permvar_hi_256: + case Intrinsic::x86_avx512_permvar_hi_512: + case Intrinsic::x86_avx512_permvar_qi_128: + case Intrinsic::x86_avx512_permvar_qi_256: + case Intrinsic::x86_avx512_permvar_qi_512: + case Intrinsic::x86_avx512_permvar_sf_512: + case Intrinsic::x86_avx512_permvar_si_512: if (Value *V = simplifyX86vpermv(*II, Builder)) return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_avx512_mask_permvar_df_256: - case Intrinsic::x86_avx512_mask_permvar_df_512: - case Intrinsic::x86_avx512_mask_permvar_di_256: - case Intrinsic::x86_avx512_mask_permvar_di_512: - case Intrinsic::x86_avx512_mask_permvar_hi_128: - case Intrinsic::x86_avx512_mask_permvar_hi_256: - case Intrinsic::x86_avx512_mask_permvar_hi_512: - case Intrinsic::x86_avx512_mask_permvar_qi_128: - case Intrinsic::x86_avx512_mask_permvar_qi_256: - case Intrinsic::x86_avx512_mask_permvar_qi_512: - case Intrinsic::x86_avx512_mask_permvar_sf_256: - case Intrinsic::x86_avx512_mask_permvar_sf_512: - case Intrinsic::x86_avx512_mask_permvar_si_256: - case Intrinsic::x86_avx512_mask_permvar_si_512: - if (Value *V = simplifyX86vpermv(*II, Builder)) { - // We simplified the permuting, now create a select for the masking. - V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - Builder); - return replaceInstUsesWith(*II, V); - } - break; - case Intrinsic::x86_avx_maskload_ps: case Intrinsic::x86_avx_maskload_pd: case Intrinsic::x86_avx_maskload_ps_256: @@ -3042,7 +3082,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::arm_neon_vld1: + case Intrinsic::arm_neon_vld1: { + unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), + DL, II, &AC, &DT); + if (Value *V = simplifyNeonVld1(*II, MemAlign, Builder)) + return replaceInstUsesWith(*II, V); + break; + } + case Intrinsic::arm_neon_vld2: case Intrinsic::arm_neon_vld3: case Intrinsic::arm_neon_vld4: @@ -3069,6 +3116,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::arm_neon_vtbl1: + case Intrinsic::aarch64_neon_tbl1: + if (Value *V = simplifyNeonTbl1(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::arm_neon_vmulls: case Intrinsic::arm_neon_vmullu: case Intrinsic::aarch64_neon_smull: @@ -3107,6 +3160,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::arm_neon_aesd: + case Intrinsic::arm_neon_aese: + case Intrinsic::aarch64_crypto_aesd: + case Intrinsic::aarch64_crypto_aese: { + Value *DataArg = II->getArgOperand(0); + Value *KeyArg = II->getArgOperand(1); + + // Try to use the builtin XOR in AESE and AESD to eliminate a prior XOR + Value *Data, *Key; + if (match(KeyArg, m_ZeroInt()) && + match(DataArg, m_Xor(m_Value(Data), m_Value(Key)))) { + II->setArgOperand(0, Data); + II->setArgOperand(1, Key); + return II; + } + break; + } case Intrinsic::amdgcn_rcp: { Value *Src = II->getArgOperand(0); @@ -3264,6 +3334,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_cvt_pknorm_i16: + case Intrinsic::amdgcn_cvt_pknorm_u16: + case Intrinsic::amdgcn_cvt_pk_i16: + case Intrinsic::amdgcn_cvt_pk_u16: { + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + + if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + break; + } case Intrinsic::amdgcn_ubfe: case Intrinsic::amdgcn_sbfe: { // Decompose simple cases into standard shifts. @@ -3370,6 +3452,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src1 = II->getArgOperand(1); Value *Src2 = II->getArgOperand(2); + // Checking for NaN before canonicalization provides better fidelity when + // mapping other operations onto fmed3 since the order of operands is + // unchanged. + CallInst *NewCall = nullptr; + if (match(Src0, m_NaN()) || isa<UndefValue>(Src0)) { + NewCall = Builder.CreateMinNum(Src1, Src2); + } else if (match(Src1, m_NaN()) || isa<UndefValue>(Src1)) { + NewCall = Builder.CreateMinNum(Src0, Src2); + } else if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { + NewCall = Builder.CreateMaxNum(Src0, Src1); + } + + if (NewCall) { + NewCall->copyFastMathFlags(II); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + bool Swap = false; // Canonicalize constants to RHS operands. // @@ -3396,13 +3496,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; } - if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { - CallInst *NewCall = Builder.CreateMinNum(Src0, Src1); - NewCall->copyFastMathFlags(II); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { @@ -3536,13 +3629,32 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // amdgcn.kill(i1 1) is a no-op return eraseInstFromFunction(CI); } + case Intrinsic::amdgcn_update_dpp: { + Value *Old = II->getArgOperand(0); + + auto BC = dyn_cast<ConstantInt>(II->getArgOperand(5)); + auto RM = dyn_cast<ConstantInt>(II->getArgOperand(3)); + auto BM = dyn_cast<ConstantInt>(II->getArgOperand(4)); + if (!BC || !RM || !BM || + BC->isZeroValue() || + RM->getZExtValue() != 0xF || + BM->getZExtValue() != 0xF || + isa<UndefValue>(Old)) + break; + + // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value. + II->setOperand(0, UndefValue::get(Old->getType())); + return II; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { if (SS->getIntrinsicID() == Intrinsic::stacksave) { - if (&*++SS->getIterator() == II) + // Skip over debug info. + if (SS->getNextNonDebugInstruction() == II) { return eraseInstFromFunction(CI); + } } } @@ -3597,9 +3709,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::assume: { Value *IIOperand = II->getArgOperand(0); - // Remove an assume if it is immediately followed by an identical assume. - if (match(II->getNextNode(), - m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + // Remove an assume if it is followed by an identical assume. + // TODO: Do we need this? Unless there are conflicting assumptions, the + // computeKnownBits(IIOperand) below here eliminates redundant assumes. + Instruction *Next = II->getNextNonDebugInstruction(); + if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) return eraseInstFromFunction(CI); // Canonicalize assume(a && b) -> assume(a); assume(b); @@ -3686,8 +3800,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::experimental_guard: { - // Is this guard followed by another guard? + // Is this guard followed by another guard? We scan forward over a small + // fixed window of instructions to handle common cases with conditions + // computed between guards. Instruction *NextInst = II->getNextNode(); + for (unsigned i = 0; i < GuardWideningWindow; i++) { + // Note: Using context-free form to avoid compile time blow up + if (!isSafeToSpeculativelyExecute(NextInst)) + break; + NextInst = NextInst->getNextNode(); + } Value *NextCond = nullptr; if (match(NextInst, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { @@ -3698,6 +3820,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return eraseInstFromFunction(*NextInst); // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). + Instruction* MoveI = II->getNextNode(); + while (MoveI != NextInst) { + auto *Temp = MoveI; + MoveI = MoveI->getNextNode(); + Temp->moveBefore(II); + } II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); return eraseInstFromFunction(*NextInst); } @@ -3710,7 +3838,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Fence instruction simplification Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { // Remove identical consecutive fences. - if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode())) + Instruction *Next = FI.getNextNonDebugInstruction(); + if (auto *NFI = dyn_cast<FenceInst>(Next)) if (FI.isIdenticalTo(NFI)) return eraseInstFromFunction(FI); return nullptr; @@ -3887,8 +4016,8 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { // Remove the convergent attr on calls when the callee is not convergent. if (CS.isConvergent() && !CalleeF->isConvergent() && !CalleeF->isIntrinsic()) { - DEBUG(dbgs() << "Removing convergent attr from instr " - << CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); CS.setNotConvergent(); return CS.getInstruction(); } @@ -3919,7 +4048,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { } } - if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { + if ((isa<ConstantPointerNull>(Callee) && + !NullPointerIsDefined(CS.getInstruction()->getFunction())) || + isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!CS.getInstruction()->getType()->isVoidTy()) @@ -3986,10 +4117,19 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!Callee) return false; - // The prototype of a thunk is a lie. Don't directly call such a function. + // If this is a call to a thunk function, don't remove the cast. Thunks are + // used to transparently forward all incoming parameters and outgoing return + // values, so it's important to leave the cast in place. if (Callee->hasFnAttribute("thunk")) return false; + // If this is a musttail call, the callee's prototype must match the caller's + // prototype with the exception of pointee types. The code below doesn't + // implement that, so we can't do this transform. + // TODO: Do the transform if it only requires adding pointer casts. + if (CS.isMustTailCall()) + return false; + Instruction *Caller = CS.getInstruction(); const AttributeList &CallerPAL = CS.getAttributes(); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 178c8eaf2502..e8ea7396a96a 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -16,6 +16,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" using namespace llvm; @@ -256,7 +257,7 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, return Instruction::CastOps(Res); } -/// @brief Implement the transforms common to all CastInst visitors. +/// Implement the transforms common to all CastInst visitors. Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); @@ -265,14 +266,27 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { // The first cast (CSrc) is eliminable so we need to fix up or replace // the second cast (CI). CSrc will then have a good chance of being dead. - return CastInst::Create(NewOpc, CSrc->getOperand(0), CI.getType()); + auto *Ty = CI.getType(); + auto *Res = CastInst::Create(NewOpc, CSrc->getOperand(0), Ty); + // Point debug users of the dying cast to the new one. + if (CSrc->hasOneUse()) + replaceAllDbgUsesWith(*CSrc, *Res, CI, DT); + return Res; } } - // If we are casting a select, then fold the cast into the select. - if (auto *SI = dyn_cast<SelectInst>(Src)) - if (Instruction *NV = FoldOpIntoSelect(CI, SI)) - return NV; + if (auto *Sel = dyn_cast<SelectInst>(Src)) { + // We are casting a select. Try to fold the cast into the select, but only + // if the select does not have a compare instruction with matching operand + // types. Creating a select with operands that are different sizes than its + // condition may inhibit other folds and lead to worse codegen. + auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition()); + if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType()) + if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { + replaceAllDbgUsesWith(*Sel, *NV, CI, DT); + return NV; + } + } // If we are casting a PHI, then fold the cast into the PHI. if (auto *PN = dyn_cast<PHINode>(Src)) { @@ -287,6 +301,33 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { return nullptr; } +/// Constants and extensions/truncates from the destination type are always +/// free to be evaluated in that type. This is a helper for canEvaluate*. +static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { + if (isa<Constant>(V)) + return true; + Value *X; + if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) && + X->getType() == Ty) + return true; + + return false; +} + +/// Filter out values that we can not evaluate in the destination type for free. +/// This is a helper for canEvaluate*. +static bool canNotEvaluateInType(Value *V, Type *Ty) { + assert(!isa<Constant>(V) && "Constant should already be handled."); + if (!isa<Instruction>(V)) + return true; + // We don't extend or shrink something that has multiple uses -- doing so + // would require duplicating the instruction which isn't profitable. + if (!V->hasOneUse()) + return true; + + return false; +} + /// Return true if we can evaluate the specified expression tree as type Ty /// instead of its larger type, and arrive with the same value. /// This is used by code that tries to eliminate truncates. @@ -300,27 +341,14 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { /// static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, Instruction *CxtI) { - // We can always evaluate constants in another type. - if (isa<Constant>(V)) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - + auto *I = cast<Instruction>(V); Type *OrigTy = V->getType(); - - // If this is an extension from the dest type, we can eliminate it, even if it - // has multiple uses. - if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) && - I->getOperand(0)->getType() == Ty) - return true; - - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - - unsigned Opc = I->getOpcode(); - switch (Opc) { + switch (I->getOpcode()) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: @@ -336,13 +364,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, // UDiv and URem can be truncated if all the truncated bits are zero. uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (BitWidth < OrigBitWidth) { - APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth); - if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && - IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); - } + assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!"); + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -365,9 +392,9 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && - Amt->getLimitedValue(BitWidth) < BitWidth) { + if (Amt->getLimitedValue(BitWidth) < BitWidth && + IC.MaskedValueIsZero(I->getOperand(0), + APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) { return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } @@ -644,20 +671,6 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; - // Test if the trunc is the user of a select which is part of a - // minimum or maximum operation. If so, don't do any more simplification. - // Even simplifying demanded bits can break the canonical form of a - // min/max. - Value *LHS, *RHS; - if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) - if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) - return nullptr; - - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *DestTy = CI.getType(), *SrcTy = Src->getType(); @@ -670,13 +683,29 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid cast: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid cast: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); return replaceInstUsesWith(CI, Res); } + // Test if the trunc is the user of a select which is part of a + // minimum or maximum operation. If so, don't do any more simplification. + // Even simplifying demanded bits can break the canonical form of a + // min/max. + Value *LHS, *RHS; + if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) + if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) + return nullptr; + + // See if we can simplify any instructions used by the input whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(CI)) + return &CI; + // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. if (DestTy->getScalarSizeInBits() == 1) { Constant *One = ConstantInt::get(SrcTy, 1); @@ -916,23 +945,14 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, InstCombiner &IC, Instruction *CxtI) { BitsToClear = 0; - if (isa<Constant>(V)) - return true; - - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - - // If the input is a truncate from the destination type, we can trivially - // eliminate it. - if (isa<TruncInst>(I) && I->getOperand(0)->getType() == Ty) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - - unsigned Opc = I->getOpcode(), Tmp; - switch (Opc) { + auto *I = cast<Instruction>(V); + unsigned Tmp; + switch (I->getOpcode()) { case Instruction::ZExt: // zext(zext(x)) -> zext(x). case Instruction::SExt: // zext(sext(x)) -> sext(x). case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x) @@ -961,7 +981,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, 0, CxtI)) { // If this is an And instruction and all of the BitsToClear are // known to be zero we can reset BitsToClear. - if (Opc == Instruction::And) + if (I->getOpcode() == Instruction::And) BitsToClear = 0; return true; } @@ -1052,11 +1072,18 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { "Can't clear more bits than in SrcTy"); // Okay, we can transform this! Insert the new expression now. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid zero extend: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid zero extend: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); + // Preserve debug values referring to Src if the zext is its last use. + if (auto *SrcOp = dyn_cast<Instruction>(Src)) + if (SrcOp->hasOneUse()) + replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT); + uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear; uint32_t DestBitSize = DestTy->getScalarSizeInBits(); @@ -1168,22 +1195,19 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { if (!Op1->getType()->isIntOrIntVectorTy()) return nullptr; - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { + if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) || + (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) { // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive - if ((Pred == ICmpInst::ICMP_SLT && Op1C->isNullValue()) || - (Pred == ICmpInst::ICMP_SGT && Op1C->isAllOnesValue())) { + Value *Sh = ConstantInt::get(Op0->getType(), + Op0->getType()->getScalarSizeInBits() - 1); + Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); + if (In->getType() != CI.getType()) + In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); - Value *Sh = ConstantInt::get(Op0->getType(), - Op0->getType()->getScalarSizeInBits()-1); - Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); - if (In->getType() != CI.getType()) - In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); - - if (Pred == ICmpInst::ICMP_SGT) - In = Builder.CreateNot(In, In->getName() + ".not"); - return replaceInstUsesWith(CI, In); - } + if (Pred == ICmpInst::ICMP_SGT) + In = Builder.CreateNot(In, In->getName() + ".not"); + return replaceInstUsesWith(CI, In); } if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { @@ -1254,21 +1278,12 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { static bool canEvaluateSExtd(Value *V, Type *Ty) { assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && "Can't sign extend type to a smaller type"); - // If this is a constant, it can be trivially promoted. - if (isa<Constant>(V)) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - - // If this is a truncate from the dest type, we can trivially eliminate it. - if (isa<TruncInst>(I) && I->getOperand(0)->getType() == Ty) - return true; - - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - + auto *I = cast<Instruction>(V); switch (I->getOpcode()) { case Instruction::SExt: // sext(sext(x)) -> sext(x) case Instruction::ZExt: // sext(zext(x)) -> zext(x) @@ -1335,8 +1350,10 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid sign extend: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid sign extend: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, true); assert(Res->getType() == DestTy); @@ -1401,45 +1418,83 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { bool losesInfo; APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); - if (!losesInfo) - return ConstantFP::get(CFP->getContext(), F); + return !losesInfo; +} + +static Type *shrinkFPConstant(ConstantFP *CFP) { + if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) + return nullptr; // No constant folding of this. + // See if the value can be truncated to half and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEhalf())) + return Type::getHalfTy(CFP->getContext()); + // See if the value can be truncated to float and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEsingle())) + return Type::getFloatTy(CFP->getContext()); + if (CFP->getType()->isDoubleTy()) + return nullptr; // Won't shrink. + if (fitsInFPType(CFP, APFloat::IEEEdouble())) + return Type::getDoubleTy(CFP->getContext()); + // Don't try to shrink to various long double types. return nullptr; } -/// Look through floating-point extensions until we get the source value. -static Value *lookThroughFPExtensions(Value *V) { - while (auto *FPExt = dyn_cast<FPExtInst>(V)) - V = FPExt->getOperand(0); +// Determine if this is a vector of ConstantFPs and if so, return the minimal +// type we can safely truncate all elements to. +// TODO: Make these support undef elements. +static Type *shrinkFPConstantVector(Value *V) { + auto *CV = dyn_cast<Constant>(V); + if (!CV || !CV->getType()->isVectorTy()) + return nullptr; + + Type *MinType = nullptr; + + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); + if (!CFP) + return nullptr; + + Type *T = shrinkFPConstant(CFP); + if (!T) + return nullptr; + + // If we haven't found a type yet or this type has a larger mantissa than + // our previous type, this is our new minimal type. + if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) + MinType = T; + } + + // Make a vector type from the minimal type. + return VectorType::get(MinType, NumElts); +} + +/// Find the minimum FP type we can safely truncate to. +static Type *getMinimumFPType(Value *V) { + if (auto *FPExt = dyn_cast<FPExtInst>(V)) + return FPExt->getOperand(0)->getType(); // If this value is a constant, return the constant in the smallest FP type // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. - if (auto *CFP = dyn_cast<ConstantFP>(V)) { - if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) - return V; // No constant folding of this. - // See if the value can be truncated to half and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf())) - return V; - // See if the value can be truncated to float and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle())) - return V; - if (CFP->getType()->isDoubleTy()) - return V; // Won't shrink. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble())) - return V; - // Don't try to shrink to various long double types. - } - - return V; + if (auto *CFP = dyn_cast<ConstantFP>(V)) + if (Type *T = shrinkFPConstant(CFP)) + return T; + + // Try to shrink a vector of FP constants. + if (Type *T = shrinkFPConstantVector(V)) + return T; + + return V->getType(); } -Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { - if (Instruction *I = commonCastTransforms(CI)) +Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { + if (Instruction *I = commonCastTransforms(FPT)) return I; + // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to // simplify this expression to avoid one or more of the trunc/extend // operations if we can do so without changing the numerical results. @@ -1447,15 +1502,16 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // The exact manner in which the widths of the operands interact to limit // what we can and cannot do safely varies from operation to operation, and // is explained below in the various case statements. - BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0)); + Type *Ty = FPT.getType(); + BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (OpI && OpI->hasOneUse()) { - Value *LHSOrig = lookThroughFPExtensions(OpI->getOperand(0)); - Value *RHSOrig = lookThroughFPExtensions(OpI->getOperand(1)); + Type *LHSMinType = getMinimumFPType(OpI->getOperand(0)); + Type *RHSMinType = getMinimumFPType(OpI->getOperand(1)); unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); - unsigned LHSWidth = LHSOrig->getType()->getFPMantissaWidth(); - unsigned RHSWidth = RHSOrig->getType()->getFPMantissaWidth(); + unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); + unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); unsigned SrcWidth = std::max(LHSWidth, RHSWidth); - unsigned DstWidth = CI.getType()->getFPMantissaWidth(); + unsigned DstWidth = Ty->getFPMantissaWidth(); switch (OpI->getOpcode()) { default: break; case Instruction::FAdd: @@ -1479,12 +1535,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // could be tightened for those cases, but they are rare (the main // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS); RI->copyFastMathFlags(OpI); return RI; } @@ -1496,14 +1549,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // rounding can possibly occur; we can safely perform the operation // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::CreateFMul(LHSOrig, RHSOrig); - RI->copyFastMathFlags(OpI); - return RI; + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI); } break; case Instruction::FDiv: @@ -1514,72 +1562,48 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // condition used here is a good conservative first pass. // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); - RI->copyFastMathFlags(OpI); - return RI; + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI); } break; - case Instruction::FRem: + case Instruction::FRem: { // Remainder is straightforward. Remainder is always exact, so the // type of OpI doesn't enter into things at all. We simply evaluate // in whichever source type is larger, then convert to the // destination type. if (SrcWidth == OpWidth) break; - if (LHSWidth < SrcWidth) - LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType()); - else if (RHSWidth <= SrcWidth) - RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType()); - if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { - Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig); - if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) - RI->copyFastMathFlags(OpI); - return CastInst::CreateFPCast(ExactResult, CI.getType()); + Value *LHS, *RHS; + if (LHSWidth == SrcWidth) { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType); + } else { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType); } + + Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI); + return CastInst::CreateFPCast(ExactResult, Ty); + } } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) if (BinaryOperator::isFNeg(OpI)) { - Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), - CI.getType()); - Instruction *RI = BinaryOperator::CreateFNeg(InnerTrunc); - RI->copyFastMathFlags(OpI); - return RI; + Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFNegFMF(InnerTrunc, OpI); } } - // (fptrunc (select cond, R1, Cst)) --> - // (select cond, (fptrunc R1), (fptrunc Cst)) - // - // - but only if this isn't part of a min/max operation, else we'll - // ruin min/max canonical form which is to have the select and - // compare's operands be of the same type with no casts to look through. - Value *LHS, *RHS; - SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0)); - if (SI && - (isa<ConstantFP>(SI->getOperand(1)) || - isa<ConstantFP>(SI->getOperand(2))) && - matchSelectPattern(SI, LHS, RHS).Flavor == SPF_UNKNOWN) { - Value *LHSTrunc = Builder.CreateFPTrunc(SI->getOperand(1), CI.getType()); - Value *RHSTrunc = Builder.CreateFPTrunc(SI->getOperand(2), CI.getType()); - return SelectInst::Create(SI->getOperand(0), LHSTrunc, RHSTrunc); - } - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0)); - if (II) { + if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) { switch (II->getIntrinsicID()) { default: break; - case Intrinsic::fabs: case Intrinsic::ceil: + case Intrinsic::fabs: case Intrinsic::floor: + case Intrinsic::nearbyint: case Intrinsic::rint: case Intrinsic::round: - case Intrinsic::nearbyint: case Intrinsic::trunc: { Value *Src = II->getArgOperand(0); if (!Src->hasOneUse()) @@ -1590,30 +1614,26 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // truncating. if (II->getIntrinsicID() != Intrinsic::fabs) { FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); - if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType()) + if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty) break; } // Do unary FP operation on smaller type. // (fptrunc (fabs x)) -> (fabs (fptrunc x)) - Value *InnerTrunc = Builder.CreateFPTrunc(Src, CI.getType()); - Type *IntrinsicType[] = { CI.getType() }; - Function *Overload = Intrinsic::getDeclaration( - CI.getModule(), II->getIntrinsicID(), IntrinsicType); - + Value *InnerTrunc = Builder.CreateFPTrunc(Src, Ty); + Function *Overload = Intrinsic::getDeclaration(FPT.getModule(), + II->getIntrinsicID(), Ty); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); - - Value *Args[] = { InnerTrunc }; - CallInst *NewCI = CallInst::Create(Overload, Args, - OpBundles, II->getName()); + CallInst *NewCI = CallInst::Create(Overload, { InnerTrunc }, OpBundles, + II->getName()); NewCI->copyFastMathFlags(II); return NewCI; } } } - if (Instruction *I = shrinkInsertElt(CI, Builder)) + if (Instruction *I = shrinkInsertElt(FPT, Builder)) return I; return nullptr; @@ -1718,7 +1738,7 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { return nullptr; } -/// @brief Implement the transforms for cast of pointer (bitcast/ptrtoint) +/// Implement the transforms for cast of pointer (bitcast/ptrtoint) Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); @@ -1751,7 +1771,7 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { Type *Ty = CI.getType(); unsigned AS = CI.getPointerAddressSpace(); - if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) + if (Ty->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) return commonPointerCastTransforms(CI); Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); @@ -2004,13 +2024,13 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || !BO->isBitwiseLogicOp()) return nullptr; - + // FIXME: This transform is restricted to vector types to avoid backend // problems caused by creating potentially illegal operations. If a fix-up is // added to handle that situation, we can remove this check. if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) return nullptr; - + Value *X; if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && !isa<Constant>(X)) { diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3bc7fae77cb1..6de92a4842ab 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -682,7 +682,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, // 4. Emit GEPs to get the original pointers. // 5. Remove the original instructions. Type *IndexType = IntegerType::get( - Base->getContext(), DL.getPointerTypeSizeInBits(Start->getType())); + Base->getContext(), DL.getIndexTypeSizeInBits(Start->getType())); DenseMap<Value *, Value *> NewInsts; NewInsts[Base] = ConstantInt::getNullValue(IndexType); @@ -723,7 +723,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, } auto *Op = NewInsts[GEP->getOperand(0)]; - if (isa<ConstantInt>(Op) && dyn_cast<ConstantInt>(Op)->isZero()) + if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) NewInsts[GEP] = Index; else NewInsts[GEP] = Builder.CreateNSWAdd( @@ -790,7 +790,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, static std::pair<Value *, Value *> getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { Type *IndexType = IntegerType::get(V->getContext(), - DL.getPointerTypeSizeInBits(V->getType())); + DL.getIndexTypeSizeInBits(V->getType())); Constant *Index = ConstantInt::getNullValue(IndexType); while (true) { @@ -1893,11 +1893,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { - // This is the same code as the SGT case, but assert the pre-condition - // that is needed for this to work with equality predicates. - assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C && - "Compare known true or false was not folded"); + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.ashr(*ShiftAmt).shl(*ShiftAmt) == C) { APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } @@ -1926,11 +1923,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { - // This is the same code as the UGT case, but assert the pre-condition - // that is needed for this to work with equality predicates. - assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C && - "Compare known true or false was not folded"); + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.lshr(*ShiftAmt).shl(*ShiftAmt) == C) { APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } @@ -2463,6 +2457,45 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, + BitCastInst *Bitcast, + const APInt &C) { + // Folding: icmp <pred> iN X, C + // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN + // and C is a splat of a K-bit pattern + // and SC is a constant vector = <C', C', C', ..., C'> + // Into: + // %E = extractelement <M x iK> %vec, i32 C' + // icmp <pred> iK %E, trunc(C) + if (!Bitcast->getType()->isIntegerTy() || + !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + return nullptr; + + Value *BCIOp = Bitcast->getOperand(0); + Value *Vec = nullptr; // 1st vector arg of the shufflevector + Constant *Mask = nullptr; // Mask arg of the shufflevector + if (match(BCIOp, + m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { + // Check whether every element of Mask is the same constant + if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { + auto *VecTy = cast<VectorType>(BCIOp->getType()); + auto *EltTy = cast<IntegerType>(VecTy->getElementType()); + auto Pred = Cmp.getPredicate(); + if (C.isSplat(EltTy->getBitWidth())) { + // Fold the icmp based on the value of C + // If C is M copies of an iK sized bit pattern, + // then: + // => %E = extractelement <N x iK> %vec, i32 Elem + // icmp <pred> iK %SplatVal, <pattern> + Value *Extract = Builder.CreateExtractElement(Vec, Elem); + Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth())); + return new ICmpInst(Pred, Extract, NewC); + } + } + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2537,6 +2570,11 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } + if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) { + if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C)) + return I; + } + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) return I; @@ -2828,6 +2866,160 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { return nullptr; } +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy truncation. +/// Folds: +/// x & (-1 >> y) SrcPred x to x DstPred (-1 >> y) +/// The Mask can be a constant, too. +/// For some predicates, the operands are commutative. +/// For others, x can only be on a specific side. +static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X, *M; + auto m_Mask = m_CombineOr(m_LShr(m_AllOnes(), m_Value()), m_LowBitMask()); + if (!match(&I, m_c_ICmp(SrcPred, + m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), + m_Deferred(X)))) + return nullptr; + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // x & (-1 >> y) == x -> x u<= (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_NE: + // x & (-1 >> y) != x -> x u> (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_UGT: + // x u> x & (-1 >> y) -> x u> (-1 >> y) + assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_UGE: + // x & (-1 >> y) u>= x -> x u<= (-1 >> y) + assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_ULT: + // x & (-1 >> y) u< x -> x u> (-1 >> y) + assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_ULE: + // x u<= x & (-1 >> y) -> x u<= (-1 >> y) + assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_SGT: + // x s> x & (-1 >> y) -> x s> (-1 >> y) + if (X != I.getOperand(0)) // X must be on LHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SGT; + break; + case ICmpInst::Predicate::ICMP_SGE: + // x & (-1 >> y) s>= x -> x s<= (-1 >> y) + if (X != I.getOperand(1)) // X must be on RHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SLE; + break; + case ICmpInst::Predicate::ICMP_SLT: + // x & (-1 >> y) s< x -> x s> (-1 >> y) + if (X != I.getOperand(1)) // X must be on RHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SGT; + break; + case ICmpInst::Predicate::ICMP_SLE: + // x s<= x & (-1 >> y) -> x s<= (-1 >> y) + if (X != I.getOperand(0)) // X must be on LHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SLE; + break; + default: + llvm_unreachable("All possible folds are handled."); + } + + return Builder.CreateICmp(DstPred, X, M); +} + +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy signed truncation. +/// Folds: (MaskedBits is a constant.) +/// ((%x << MaskedBits) a>> MaskedBits) SrcPred %x +/// Into: +/// (add %x, (1 << (KeptBits-1))) DstPred (1 << KeptBits) +/// Where KeptBits = bitwidth(%x) - MaskedBits +static Value * +foldICmpWithTruncSignExtendedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X; + const APInt *C0, *C1; // FIXME: non-splats, potentially with undef. + // We are ok with 'shl' having multiple uses, but 'ashr' must be one-use. + if (!match(&I, m_c_ICmp(SrcPred, + m_OneUse(m_AShr(m_Shl(m_Value(X), m_APInt(C0)), + m_APInt(C1))), + m_Deferred(X)))) + return nullptr; + + // Potential handling of non-splats: for each element: + // * if both are undef, replace with constant 0. + // Because (1<<0) is OK and is 1, and ((1<<0)>>1) is also OK and is 0. + // * if both are not undef, and are different, bailout. + // * else, only one is undef, then pick the non-undef one. + + // The shift amount must be equal. + if (*C0 != *C1) + return nullptr; + const APInt &MaskedBits = *C0; + assert(MaskedBits != 0 && "shift by zero should be folded away already."); + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // ((%x << MaskedBits) a>> MaskedBits) == %x + // => + // (add %x, (1 << (KeptBits-1))) u< (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_ULT; + break; + case ICmpInst::Predicate::ICMP_NE: + // ((%x << MaskedBits) a>> MaskedBits) != %x + // => + // (add %x, (1 << (KeptBits-1))) u>= (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_UGE; + break; + // FIXME: are more folds possible? + default: + return nullptr; + } + + auto *XType = X->getType(); + const unsigned XBitWidth = XType->getScalarSizeInBits(); + const APInt BitWidth = APInt(XBitWidth, XBitWidth); + assert(BitWidth.ugt(MaskedBits) && "shifts should leave some bits untouched"); + + // KeptBits = bitwidth(%x) - MaskedBits + const APInt KeptBits = BitWidth - MaskedBits; + assert(KeptBits.ugt(0) && KeptBits.ult(BitWidth) && "unreachable"); + // ICmpCst = (1 << KeptBits) + const APInt ICmpCst = APInt(XBitWidth, 1).shl(KeptBits); + assert(ICmpCst.isPowerOf2()); + // AddCst = (1 << (KeptBits-1)) + const APInt AddCst = ICmpCst.lshr(1); + assert(AddCst.ult(ICmpCst) && AddCst.isPowerOf2()); + + // T0 = add %x, AddCst + Value *T0 = Builder.CreateAdd(X, ConstantInt::get(XType, AddCst)); + // T1 = T0 DstPred ICmpCst + Value *T1 = Builder.CreateICmp(DstPred, T0, ConstantInt::get(XType, ICmpCst)); + + return T1; +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -3011,17 +3203,22 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. if (C == Op0 && NoOp1WrapProblem) return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + // (A - B) >u A --> A <u B + if (A == Op1 && Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, A, B); + // C <u (C - D) --> C <u D + if (C == Op0 && Pred == ICmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_ULT, C, D); + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && // Try not to increase register pressure. BO0->hasOneUse() && BO1->hasOneUse()) return new ICmpInst(Pred, A, C); - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && // Try not to increase register pressure. @@ -3032,8 +3229,8 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { Value *X; if (match(BO0, m_Neg(m_Value(X)))) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - if (!RHSC->isMinValue(/*isSigned=*/true)) + if (Constant *RHSC = dyn_cast<Constant>(Op1)) + if (RHSC->isNotMinSignedValue()) return new ICmpInst(I.getSwappedPredicate(), X, ConstantExpr::getNeg(RHSC)); } @@ -3160,6 +3357,12 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { } } + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) + return replaceInstUsesWith(I, V); + return nullptr; } @@ -3414,8 +3617,15 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. + const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; if (LHSCI->getOpcode() == Instruction::PtrToInt && - DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + CompatibleSizes(SrcTy, DestTy)) { Value *RHSOp = nullptr; if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *RHSCIOp = RHSC->getOperand(0); @@ -3618,7 +3828,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, return false; } -/// \brief Recognize and process idiom involving test for multiplication +/// Recognize and process idiom involving test for multiplication /// overflow. /// /// The caller has matched a pattern of the form: @@ -3799,7 +4009,8 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, // mul.with.overflow and adjust properly mask/size. if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); - for (User *U : MulVal->users()) { + for (auto UI = MulVal->user_begin(), UE = MulVal->user_end(); UI != UE;) { + User *U = *UI++; if (U == &I || U == OtherVal) continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { @@ -3890,48 +4101,33 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { } } -/// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst +/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst /// should be swapped. /// The decision is based on how many times these two operands are reused /// as subtract operands and their positions in those instructions. -/// The rational is that several architectures use the same instruction for -/// both subtract and cmp, thus it is better if the order of those operands +/// The rationale is that several architectures use the same instruction for +/// both subtract and cmp. Thus, it is better if the order of those operands /// match. /// \return true if Op0 and Op1 should be swapped. -static bool swapMayExposeCSEOpportunities(const Value * Op0, - const Value * Op1) { - // Filter out pointer value as those cannot appears directly in subtract. +static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { + // Filter out pointer values as those cannot appear directly in subtract. // FIXME: we may want to go through inttoptrs or bitcasts. if (Op0->getType()->isPointerTy()) return false; - // Count every uses of both Op0 and Op1 in a subtract. - // Each time Op0 is the first operand, count -1: swapping is bad, the - // subtract has already the same layout as the compare. - // Each time Op0 is the second operand, count +1: swapping is good, the - // subtract has a different layout as the compare. - // At the end, if the benefit is greater than 0, Op0 should come second to - // expose more CSE opportunities. - int GlobalSwapBenefits = 0; + // If a subtract already has the same operands as a compare, swapping would be + // bad. If a subtract has the same operands as a compare but in reverse order, + // then swapping is good. + int GoodToSwap = 0; for (const User *U : Op0->users()) { - const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(U); - if (!BinOp || BinOp->getOpcode() != Instruction::Sub) - continue; - // If Op0 is the first argument, this is not beneficial to swap the - // arguments. - int LocalSwapBenefits = -1; - unsigned Op1Idx = 1; - if (BinOp->getOperand(Op1Idx) == Op0) { - Op1Idx = 0; - LocalSwapBenefits = 1; - } - if (BinOp->getOperand(Op1Idx) != Op1) - continue; - GlobalSwapBenefits += LocalSwapBenefits; + if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0)))) + GoodToSwap++; + else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1)))) + GoodToSwap--; } - return GlobalSwapBenefits > 0; + return GoodToSwap > 0; } -/// \brief Check that one use is in the same block as the definition and all +/// Check that one use is in the same block as the definition and all /// other uses are in blocks dominated by a given block. /// /// \param DI Definition @@ -3976,7 +4172,7 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { return true; } -/// \brief True when a select result is replaced by one of its operands +/// True when a select result is replaced by one of its operands /// in select-icmp sequence. This will eventually result in the elimination /// of the select. /// @@ -4052,7 +4248,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Get scalar or pointer size. unsigned BitWidth = Ty->isIntOrIntVectorTy() ? Ty->getScalarSizeInBits() - : DL.getTypeSizeInBits(Ty->getScalarType()); + : DL.getIndexTypeSizeInBits(Ty->getScalarType()); if (!BitWidth) return nullptr; @@ -4082,13 +4278,13 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { computeUnsignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); } - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Constant fold this now, so that + // If Min and Max are known to be the same, then SimplifyDemandedBits figured + // out that the LHS or RHS is a constant. Constant fold this now, so that // code below can assume that Min != Max. if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(Pred, ConstantInt::get(Op0->getType(), Op0Min), Op1); + return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1); if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(Pred, Op0, ConstantInt::get(Op1->getType(), Op1Min)); + return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); // Based on the range information we know about the LHS, see if we can // simplify this comparison. For example, (x&4) < 8 is always true. @@ -4520,6 +4716,34 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return New; } + // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. + Value *X; + if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) { + // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 + // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 + // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && + match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + + // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 + if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) + return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType())); + } + + // Zero-equality checks are preserved through unsigned floating-point casts: + // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 + if (match(Op0, m_BitCast(m_UIToFP(m_Value(X))))) + if (I.isEquality() && match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + // Test to see if the operands of the icmp are casted versions of other // values. If the ptr->ptr cast can be stripped off both arguments, we do so // now. @@ -4642,6 +4866,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); } + return Changed ? &I : nullptr; } @@ -4928,11 +5153,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) { + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0)) { I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); return &I; } - if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) { + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1)) { I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); return &I; } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index f1f66d86cb73..58ef3d41415c 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -20,6 +20,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -40,7 +41,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -122,17 +122,17 @@ static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { return V; } -/// \brief Add one to a Constant +/// Add one to a Constant static inline Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } -/// \brief Subtract one from a Constant +/// Subtract one from a Constant static inline Constant *SubOne(Constant *C) { return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } -/// \brief Return true if the specified value is free to invert (apply ~ to). +/// Return true if the specified value is free to invert (apply ~ to). /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. @@ -178,7 +178,7 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } -/// \brief Specific patterns of overflow check idioms that we match. +/// Specific patterns of overflow check idioms that we match. enum OverflowCheckFlavor { OCF_UNSIGNED_ADD, OCF_SIGNED_ADD, @@ -190,7 +190,7 @@ enum OverflowCheckFlavor { OCF_INVALID }; -/// \brief Returns the OverflowCheckFlavor corresponding to a overflow_with_op +/// Returns the OverflowCheckFlavor corresponding to a overflow_with_op /// intrinsic. static inline OverflowCheckFlavor IntrinsicIDToOverflowCheckFlavor(unsigned ID) { @@ -212,7 +212,62 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { } } -/// \brief The core instruction combiner logic. +/// Some binary operators require special handling to avoid poison and undefined +/// behavior. If a constant vector has undef elements, replace those undefs with +/// identity constants if possible because those are always safe to execute. +/// If no identity constant exists, replace undef with some other safe constant. +static inline Constant *getSafeVectorConstantForBinop( + BinaryOperator::BinaryOps Opcode, Constant *In, bool IsRHSConstant) { + assert(In->getType()->isVectorTy() && "Not expecting scalars here"); + + Type *EltTy = In->getType()->getVectorElementType(); + auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant); + if (!SafeC) { + // TODO: Should this be available as a constant utility function? It is + // similar to getBinOpAbsorber(). + if (IsRHSConstant) { + switch (Opcode) { + case Instruction::SRem: // X % 1 = 0 + case Instruction::URem: // X %u 1 = 0 + SafeC = ConstantInt::get(EltTy, 1); + break; + case Instruction::FRem: // X % 1.0 (doesn't simplify, but it is safe) + SafeC = ConstantFP::get(EltTy, 1.0); + break; + default: + llvm_unreachable("Only rem opcodes have no identity constant for RHS"); + } + } else { + switch (Opcode) { + case Instruction::Shl: // 0 << X = 0 + case Instruction::LShr: // 0 >>u X = 0 + case Instruction::AShr: // 0 >> X = 0 + case Instruction::SDiv: // 0 / X = 0 + case Instruction::UDiv: // 0 /u X = 0 + case Instruction::SRem: // 0 % X = 0 + case Instruction::URem: // 0 %u X = 0 + case Instruction::Sub: // 0 - X (doesn't simplify, but it is safe) + case Instruction::FSub: // 0.0 - X (doesn't simplify, but it is safe) + case Instruction::FDiv: // 0.0 / X (doesn't simplify, but it is safe) + case Instruction::FRem: // 0.0 % X = 0 + SafeC = Constant::getNullValue(EltTy); + break; + default: + llvm_unreachable("Expected to find identity constant for opcode"); + } + } + } + assert(SafeC && "Must have safe constant for binop"); + unsigned NumElts = In->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> Out(NumElts); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *C = In->getAggregateElement(i); + Out[i] = isa<UndefValue>(C) ? SafeC : C; + } + return ConstantVector::get(Out); +} + +/// The core instruction combiner logic. /// /// This class provides both the logic to recursively visit instructions and /// combine them. @@ -220,10 +275,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner : public InstVisitor<InstCombiner, Instruction *> { // FIXME: These members shouldn't be public. public: - /// \brief A worklist of the instructions that need to be simplified. + /// A worklist of the instructions that need to be simplified. InstCombineWorklist &Worklist; - /// \brief An IRBuilder that automatically inserts new instructions into the + /// An IRBuilder that automatically inserts new instructions into the /// worklist. using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy &Builder; @@ -261,7 +316,7 @@ public: ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} - /// \brief Run the combiner over the entire worklist until it is empty. + /// Run the combiner over the entire worklist until it is empty. /// /// \returns true if the IR is changed. bool run(); @@ -289,8 +344,6 @@ public: Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); - Value *foldFMulConst(Instruction *FMulOrDiv, Constant *C, - Instruction *InsertBefore); Instruction *visitFMul(BinaryOperator &I); Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); @@ -378,7 +431,6 @@ private: bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; - Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const; Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, SmallVectorImpl<Value *> &NewIndices); @@ -393,7 +445,7 @@ private: /// if it cannot already be eliminated by some other transformation. bool shouldOptimizeCast(CastInst *CI); - /// \brief Try to optimize a sequence of instructions checking if an operation + /// Try to optimize a sequence of instructions checking if an operation /// on LHS and RHS overflows. /// /// If this overflow check is done via one of the overflow check intrinsics, @@ -445,11 +497,22 @@ private: } bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForSignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForUnsignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForSignedMul(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { @@ -462,6 +525,7 @@ private: Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); @@ -490,7 +554,7 @@ private: Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI); public: - /// \brief Inserts an instruction \p New before instruction \p Old + /// Inserts an instruction \p New before instruction \p Old /// /// Also adds the new instruction to the worklist and returns \p New so that /// it is suitable for use as the return from the visitation patterns. @@ -503,13 +567,13 @@ public: return New; } - /// \brief Same as InsertNewInstBefore, but also sets the debug loc. + /// Same as InsertNewInstBefore, but also sets the debug loc. Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) { New->setDebugLoc(Old.getDebugLoc()); return InsertNewInstBefore(New, Old); } - /// \brief A combiner-aware RAUW-like routine. + /// A combiner-aware RAUW-like routine. /// /// This method is to be used when an instruction is found to be dead, /// replaceable with another preexisting expression. Here we add all uses of @@ -527,8 +591,8 @@ public: if (&I == V) V = UndefValue::get(I.getType()); - DEBUG(dbgs() << "IC: Replacing " << I << "\n" - << " with " << *V << '\n'); + LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" + << " with " << *V << '\n'); I.replaceAllUsesWith(V); return &I; @@ -544,13 +608,13 @@ public: return InsertValueInst::Create(Struct, Result, 0); } - /// \brief Combiner aware instruction erasure. + /// Combiner aware instruction erasure. /// /// When dealing with an instruction that has side effects or produces a void /// value, we can't rely on DCE to delete the instruction. Instead, visit /// methods should return the value returned by this function. Instruction *eraseInstFromFunction(Instruction &I) { - DEBUG(dbgs() << "IC: ERASE " << I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); salvageDebugInfo(I); @@ -599,6 +663,12 @@ public: return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForSignedMul(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); + } + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { @@ -611,15 +681,26 @@ public: return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedSub(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + /// Maximum size of array considered when transforming. uint64_t MaxArraySizeForCombine; private: - /// \brief Performs a few simplifications for operators which are associative + /// Performs a few simplifications for operators which are associative /// or commutative. bool SimplifyAssociativeOrCommutative(BinaryOperator &I); - /// \brief Tries to simplify binary operations which some other binary + /// Tries to simplify binary operations which some other binary /// operation distributes over. /// /// It does this by either by factorizing out common terms (eg "(A*B)+(A*C)" @@ -628,6 +709,13 @@ private: /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + /// Tries to simplify add operations using the definition of remainder. + /// + /// The definition of remainder is X % C = X - (X / C ) * C. The add + /// expression X % C0 + (( X / C0 ) % C1) * C0 can be simplified to + /// X % (C0 * C1) + Value *SimplifyAddWithRemainder(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be // efficiently reorganized. Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, @@ -647,7 +735,7 @@ private: ConstantInt *&Less, ConstantInt *&Equal, ConstantInt *&Greater); - /// \brief Attempts to replace V with a simpler value based on the demanded + /// Attempts to replace V with a simpler value based on the demanded /// bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, unsigned Depth, Instruction *CxtI); @@ -669,15 +757,19 @@ private: Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known); - /// \brief Tries to simplify operands to an integer instruction based on its + /// Tries to simplify operands to an integer instruction based on its /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); + Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, + APInt DemandedElts, + int DmaskIdx = -1); + Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0); - Value *SimplifyVectorOp(BinaryOperator &Inst); - + /// Canonicalize the position of binops relative to shufflevector. + Instruction *foldShuffledBinop(BinaryOperator &Inst); /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is @@ -691,11 +783,11 @@ private: Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); /// This is a convenience wrapper function for the above two functions. - Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); + Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); Instruction *foldAddWithConstant(BinaryOperator &Add); - /// \brief Try to rotate an operation below a PHI node, using PHI nodes for + /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN); @@ -735,6 +827,8 @@ private: Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); + Instruction *foldICmpBitCastConstant(ICmpInst &Cmp, BitCastInst *Bitcast, + const APInt &C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, @@ -789,13 +883,12 @@ private: Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); - Instruction *SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI); - Instruction *SimplifyMemTransfer(MemIntrinsic *MI); - Instruction *SimplifyMemSet(MemSetInst *MI); + Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI); + Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI); Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); - /// \brief Returns a value X such that Val = X * Scale, or null if none. + /// Returns a value X such that Val = X * Scale, or null if none. /// /// If the multiplication is known not to overflow then NoSignedWrap is set. Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index d4f06e18b957..742caf649007 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" @@ -23,7 +24,6 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -270,7 +270,7 @@ void PointerReplacer::findLoadAndReplace(Instruction &I) { auto *Inst = dyn_cast<Instruction>(&*U); if (!Inst) return; - DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); + LLVM_DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); if (isa<LoadInst>(Inst)) { for (auto P : Path) replace(P); @@ -405,8 +405,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); if (AI.getAlignment() <= SourceAlign && isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { - DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); - DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); + LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); + LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); @@ -437,10 +437,10 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // Are we allowed to form a atomic load or store of this type? static bool isSupportedAtomicType(Type *Ty) { - return Ty->isIntegerTy() || Ty->isPointerTy() || Ty->isFloatingPointTy(); + return Ty->isIntOrPtrTy() || Ty->isFloatingPointTy(); } -/// \brief Helper to combine a load to a new type. +/// Helper to combine a load to a new type. /// /// This just does the work of combining a load to a new type. It handles /// metadata, etc., and returns the new instruction. The \c NewTy should be the @@ -453,15 +453,20 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT const Twine &Suffix = "") { assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); - + Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; LI.getAllMetadata(MD); + Value *NewPtr = nullptr; + if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && + NewPtr->getType()->getPointerElementType() == NewTy && + NewPtr->getType()->getPointerAddressSpace() == AS)) + NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); + LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( - IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)), - LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { @@ -507,7 +512,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT return NewLoad; } -/// \brief Combine a store to a new type. +/// Combine a store to a new type. /// /// Returns the newly created store instruction. static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) { @@ -584,7 +589,7 @@ static bool isMinMaxWithLoads(Value *V) { match(L2, m_Load(m_Specific(LHS)))); } -/// \brief Combine loads to match the type of their uses' value after looking +/// Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// /// The core idea here is that if the result of a load is used in an operation, @@ -959,23 +964,26 @@ static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, } static bool canSimplifyNullStoreOrGEP(StoreInst &SI) { - if (SI.getPointerAddressSpace() != 0) + if (NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) return false; auto *Ptr = SI.getPointerOperand(); if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) Ptr = GEPI->getOperand(0); - return isa<ConstantPointerNull>(Ptr); + return (isa<ConstantPointerNull>(Ptr) && + !NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())); } static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); - if (isa<ConstantPointerNull>(GEPI0) && GEPI->getPointerAddressSpace() == 0) + if (isa<ConstantPointerNull>(GEPI0) && + !NullPointerIsDefined(LI.getFunction(), GEPI->getPointerAddressSpace())) return true; } if (isa<UndefValue>(Op) || - (isa<ConstantPointerNull>(Op) && LI.getPointerAddressSpace() == 0)) + (isa<ConstantPointerNull>(Op) && + !NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))) return true; return false; } @@ -1071,14 +1079,16 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // load (select (cond, null, P)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(1)) && - LI.getPointerAddressSpace() == 0) { + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) { LI.setOperand(0, SI->getOperand(2)); return &LI; } // load (select (cond, P, null)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(2)) && - LI.getPointerAddressSpace() == 0) { + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) { LI.setOperand(0, SI->getOperand(1)); return &LI; } @@ -1087,7 +1097,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return nullptr; } -/// \brief Look for extractelement/insertvalue sequence that acts like a bitcast. +/// Look for extractelement/insertvalue sequence that acts like a bitcast. /// /// \returns underlying value that was "cast", or nullptr otherwise. /// @@ -1142,7 +1152,7 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { return U; } -/// \brief Combine stores to match the type of value being stored. +/// Combine stores to match the type of value being stored. /// /// The core idea here is that the memory does not have any intrinsic type and /// where we can we should match the type of a store to the type of value being diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 541dde6c47d2..63761d427235 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -33,6 +33,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> #include <cstddef> #include <cstdint> @@ -94,115 +95,52 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, return MadeChange ? V : nullptr; } -/// True if the multiply can not be expressed in an int this size. -static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, - bool IsSigned) { - bool Overflow; - if (IsSigned) - Product = C1.smul_ov(C2, Overflow); - else - Product = C1.umul_ov(C2, Overflow); - - return Overflow; -} - -/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1. -static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, - bool IsSigned) { - assert(C1.getBitWidth() == C2.getBitWidth() && - "Inconsistent width of constants!"); - - // Bail if we will divide by zero. - if (C2.isMinValue()) - return false; - - // Bail if we would divide INT_MIN by -1. - if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) - return false; - - APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); - if (IsSigned) - APInt::sdivrem(C1, C2, Quotient, Remainder); - else - APInt::udivrem(C1, C2, Quotient, Remainder); - - return Remainder.isMinValue(); -} - -/// \brief A helper routine of InstCombiner::visitMul(). +/// A helper routine of InstCombiner::visitMul(). /// -/// If C is a vector of known powers of 2, then this function returns -/// a new vector obtained from C replacing each element with its logBase2. +/// If C is a scalar/vector of known powers of 2, then this function returns +/// a new scalar/vector obtained from logBase2 of C. /// Return a null pointer otherwise. -static Constant *getLogBase2Vector(ConstantDataVector *CV) { +static Constant *getLogBase2(Type *Ty, Constant *C) { const APInt *IVal; - SmallVector<Constant *, 4> Elts; + if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) + return ConstantInt::get(Ty, IVal->logBase2()); + + if (!Ty->isVectorTy()) + return nullptr; - for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) { - Constant *Elt = CV->getElementAsConstant(I); + SmallVector<Constant *, 4> Elts; + for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { + Constant *Elt = C->getAggregateElement(I); + if (!Elt) + return nullptr; + if (isa<UndefValue>(Elt)) { + Elts.push_back(UndefValue::get(Ty->getScalarType())); + continue; + } if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2()) return nullptr; - Elts.push_back(ConstantInt::get(Elt->getType(), IVal->logBase2())); + Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2())); } return ConstantVector::get(Elts); } -/// \brief Return true if we can prove that: -/// (mul LHS, RHS) === (mul nsw LHS, RHS) -bool InstCombiner::willNotOverflowSignedMul(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // Multiplying n * m significant bits yields a result of n + m significant - // bits. If the total number of significant bits does not exceed the - // result bit width (minus 1), there is no overflow. - // This means if we have enough leading sign bits in the operands - // we can guarantee that the result does not overflow. - // Ref: "Hacker's Delight" by Henry Warren - unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); - - // Note that underestimating the number of sign bits gives a more - // conservative answer. - unsigned SignBits = - ComputeNumSignBits(LHS, 0, &CxtI) + ComputeNumSignBits(RHS, 0, &CxtI); - - // First handle the easy case: if we have enough sign bits there's - // definitely no overflow. - if (SignBits > BitWidth + 1) - return true; - - // There are two ambiguous cases where there can be no overflow: - // SignBits == BitWidth + 1 and - // SignBits == BitWidth - // The second case is difficult to check, therefore we only handle the - // first case. - if (SignBits == BitWidth + 1) { - // It overflows only when both arguments are negative and the true - // product is exactly the minimum negative number. - // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 - // For simplicity we just check if at least one side is not negative. - KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); - KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); - if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) - return true; - } - return false; -} - Instruction *InstCombiner::visitMul(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // X * -1 == 0 - X + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_AllOnes())) { BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); if (I.hasNoSignedWrap()) @@ -231,16 +169,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { - Constant *NewCst = nullptr; - if (match(C1, m_APInt(IVal)) && IVal->isPowerOf2()) - // Replace X*(2^C) with X << C, where C is either a scalar or a splat. - NewCst = ConstantInt::get(NewOp->getType(), IVal->logBase2()); - else if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(C1)) - // Replace X*(2^C) with X << C, where C is a vector of known - // constant powers of 2. - NewCst = getLogBase2Vector(CV); - - if (NewCst) { + // Replace X*(2^C) with X << C, where C is either a scalar or a vector. + if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) { unsigned Width = NewCst->getType()->getPrimitiveSizeInBits(); BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); @@ -282,34 +212,37 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; + // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) - return FoldedMul; - // Canonicalize (X+C1)*CI -> X*CI+C1*CI. - { - Value *X; - Constant *C1; - if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { - Value *Mul = Builder.CreateMul(C1, Op1); - // Only go forward with the transform if C1*CI simplifies to a tidier - // constant. - if (!match(Mul, m_Mul(m_Value(), m_Value()))) - return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); - } + Value *X; + Constant *C1; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { + Value *Mul = Builder.CreateMul(C1, Op1); + // Only go forward with the transform if C1*CI simplifies to a tidier + // constant. + if (!match(Mul, m_Mul(m_Value(), m_Value()))) + return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); } } - if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y - if (Value *Op1v = dyn_castNegVal(Op1)) { - BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v); - if (I.hasNoSignedWrap() && - match(Op0, m_NSWSub(m_Value(), m_Value())) && - match(Op1, m_NSWSub(m_Value(), m_Value()))) - BO->setHasNoSignedWrap(); - return BO; - } + // -X * C --> X * -C + Value *X, *Y; + Constant *Op1C; + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Constant(Op1C))) + return BinaryOperator::CreateMul(X, ConstantExpr::getNeg(Op1C)); + + // -X * -Y --> X * Y + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) { + auto *NewMul = BinaryOperator::CreateMul(X, Y); + if (I.hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) + NewMul->setHasNoSignedWrap(); + return NewMul; } // (X / Y) * Y = X - (X % Y) @@ -371,28 +304,24 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - // If one of the operands of the multiply is a cast from a boolean value, then - // we know the bool is either zero or one, so this is a 'masking' multiply. - // X * Y (where Y is 0 or 1) -> X & (0-Y) - if (!I.getType()->isVectorTy()) { - // -2 is "-1 << 1" so it is all bits set except the low one. - APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); - - Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2, 0, &I)) { - BoolCast = Op0; - OtherOp = Op1; - } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) { - BoolCast = Op1; - OtherOp = Op0; - } - - if (BoolCast) { - Value *V = Builder.CreateSub(Constant::getNullValue(I.getType()), - BoolCast); - return BinaryOperator::CreateAnd(V, OtherOp); - } - } + // (bool X) * Y --> X ? Y : 0 + // Y * (bool X) --> X ? Y : 0 + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op1, ConstantInt::get(I.getType(), 0)); + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op0, ConstantInt::get(I.getType(), 0)); + + // (lshr X, 31) * Y --> (ashr X, 31) & Y + // Y * (lshr X, 31) --> (ashr X, 31) & Y + // TODO: We are not checking one-use because the elimination of the multiply + // is better for analysis? + // TODO: Should we canonicalize to '(X < 0) ? Y : 0' instead? That would be + // more similar to what we're doing above. + const APInt *C; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) + return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op1); + if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) + return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0); // Check for (mul (sext x), y), see if we can merge this into an // integer mul followed by a sext. @@ -466,6 +395,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -479,286 +409,103 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -/// Detect pattern log2(Y * 0.5) with corresponding fast math flags. -static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { - if (!Op->hasOneUse()) - return; - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); - if (!II) - return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->isFast()) - return; - Log2 = II; - - Value *OpLog2Of = II->getArgOperand(0); - if (!OpLog2Of->hasOneUse()) - return; - - Instruction *I = dyn_cast<Instruction>(OpLog2Of); - if (!I) - return; - - if (I->getOpcode() != Instruction::FMul || !I->isFast()) - return; - - if (match(I->getOperand(0), m_SpecificFP(0.5))) - Y = I->getOperand(1); - else if (match(I->getOperand(1), m_SpecificFP(0.5))) - Y = I->getOperand(0); -} - -static bool isFiniteNonZeroFp(Constant *C) { - if (C->getType()->isVectorTy()) { - for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; - ++I) { - ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I)); - if (!CFP || !CFP->getValueAPF().isFiniteNonZero()) - return false; - } - return true; - } - - return isa<ConstantFP>(C) && - cast<ConstantFP>(C)->getValueAPF().isFiniteNonZero(); -} - -static bool isNormalFp(Constant *C) { - if (C->getType()->isVectorTy()) { - for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; - ++I) { - ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I)); - if (!CFP || !CFP->getValueAPF().isNormal()) - return false; - } - return true; - } - - return isa<ConstantFP>(C) && cast<ConstantFP>(C)->getValueAPF().isNormal(); -} - -/// Helper function of InstCombiner::visitFMul(BinaryOperator(). It returns -/// true iff the given value is FMul or FDiv with one and only one operand -/// being a normal constant (i.e. not Zero/NaN/Infinity). -static bool isFMulOrFDivWithConstant(Value *V) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I || (I->getOpcode() != Instruction::FMul && - I->getOpcode() != Instruction::FDiv)) - return false; - - Constant *C0 = dyn_cast<Constant>(I->getOperand(0)); - Constant *C1 = dyn_cast<Constant>(I->getOperand(1)); - - if (C0 && C1) - return false; - - return (C0 && isFiniteNonZeroFp(C0)) || (C1 && isFiniteNonZeroFp(C1)); -} +Instruction *InstCombiner::visitFMul(BinaryOperator &I) { + if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); -/// foldFMulConst() is a helper routine of InstCombiner::visitFMul(). -/// The input \p FMulOrDiv is a FMul/FDiv with one and only one operand -/// being a constant (i.e. isFMulOrFDivWithConstant(FMulOrDiv) == true). -/// This function is to simplify "FMulOrDiv * C" and returns the -/// resulting expression. Note that this function could return NULL in -/// case the constants cannot be folded into a normal floating-point. -Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, - Instruction *InsertBefore) { - assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid"); - - Value *Opnd0 = FMulOrDiv->getOperand(0); - Value *Opnd1 = FMulOrDiv->getOperand(1); - - Constant *C0 = dyn_cast<Constant>(Opnd0); - Constant *C1 = dyn_cast<Constant>(Opnd1); - - BinaryOperator *R = nullptr; - - // (X * C0) * C => X * (C0*C) - if (FMulOrDiv->getOpcode() == Instruction::FMul) { - Constant *F = ConstantExpr::getFMul(C1 ? C1 : C0, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFMul(C1 ? Opnd0 : Opnd1, F); - } else { - if (C0) { - // (C0 / X) * C => (C0 * C) / X - if (FMulOrDiv->hasOneUse()) { - // It would otherwise introduce another div. - Constant *F = ConstantExpr::getFMul(C0, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFDiv(F, Opnd1); - } - } else { - // (X / C1) * C => X * (C/C1) if C/C1 is not a denormal - Constant *F = ConstantExpr::getFDiv(C, C1); - if (isNormalFp(F)) { - R = BinaryOperator::CreateFMul(Opnd0, F); - } else { - // (X / C1) * C => X / (C1/C) - Constant *F = ConstantExpr::getFDiv(C1, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFDiv(Opnd0, F); - } - } - } + if (SimplifyAssociativeOrCommutative(I)) + return &I; - if (R) { - R->setFast(true); - InsertNewInstWith(R, *InsertBefore); - } + if (Instruction *X = foldShuffledBinop(I)) + return X; - return R; -} + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; -Instruction *InstCombiner::visitFMul(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (match(Op1, m_SpecificFP(-1.0))) + return BinaryOperator::CreateFNegFMF(Op0, &I); - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); + // -X * -Y --> X * Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateFMulFMF(X, Y, &I); - if (isa<Constant>(Op0)) - std::swap(Op0, Op1); + // -X * C --> X * -C + Constant *C; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + // Sink negation: -X * Y --> -(X * Y) + if (match(Op0, m_OneUse(m_FNeg(m_Value(X))))) + return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); - bool AllowReassociate = I.isFast(); + // Sink negation: Y * -X --> -(X * Y) + if (match(Op1, m_OneUse(m_FNeg(m_Value(X))))) + return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); - // Simplify mul instructions with a constant RHS. - if (isa<Constant>(Op1)) { - if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) - return FoldedMul; - - // (fmul X, -1.0) --> (fsub -0.0, X) - if (match(Op1, m_SpecificFP(-1.0))) { - Constant *NegZero = ConstantFP::getNegativeZero(Op1->getType()); - Instruction *RI = BinaryOperator::CreateFSub(NegZero, Op0); - RI->copyFastMathFlags(&I); - return RI; - } - - Constant *C = cast<Constant>(Op1); - if (AllowReassociate && isFiniteNonZeroFp(C)) { - // Let MDC denote an expression in one of these forms: - // X * C, C/X, X/C, where C is a constant. - // - // Try to simplify "MDC * Constant" - if (isFMulOrFDivWithConstant(Op0)) - if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I)) - return replaceInstUsesWith(I, V); - - // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) - Instruction *FAddSub = dyn_cast<Instruction>(Op0); - if (FAddSub && - (FAddSub->getOpcode() == Instruction::FAdd || - FAddSub->getOpcode() == Instruction::FSub)) { - Value *Opnd0 = FAddSub->getOperand(0); - Value *Opnd1 = FAddSub->getOperand(1); - Constant *C0 = dyn_cast<Constant>(Opnd0); - Constant *C1 = dyn_cast<Constant>(Opnd1); - bool Swap = false; - if (C0) { - std::swap(C0, C1); - std::swap(Opnd0, Opnd1); - Swap = true; - } + // fabs(X) * fabs(X) -> X * X + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) + return BinaryOperator::CreateFMulFMF(X, X, &I); - if (C1 && isFiniteNonZeroFp(C1) && isFMulOrFDivWithConstant(Opnd0)) { - Value *M1 = ConstantExpr::getFMul(C1, C); - Value *M0 = isNormalFp(cast<Constant>(M1)) ? - foldFMulConst(cast<Instruction>(Opnd0), C, &I) : - nullptr; - if (M0 && M1) { - if (Swap && FAddSub->getOpcode() == Instruction::FSub) - std::swap(M0, M1); - - Instruction *RI = (FAddSub->getOpcode() == Instruction::FAdd) - ? BinaryOperator::CreateFAdd(M0, M1) - : BinaryOperator::CreateFSub(M0, M1); - RI->copyFastMathFlags(&I); - return RI; - } - } - } - } - } + // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); - if (Op0 == Op1) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) - return replaceInstUsesWith(I, II->getOperand(0)); - - // fabs(X) * fabs(X) -> X * X - if (II->getIntrinsicID() == Intrinsic::fabs) { - Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), - II->getOperand(0), - I.getName()); - FMulVal->copyFastMathFlags(&I); - return FMulVal; + if (I.hasAllowReassoc()) { + // Reassociate constant RHS with another constant to form constant + // expression. + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + Constant *C1; + if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { + // (C1 / X) * C --> (C * C1) / X + Constant *CC1 = ConstantExpr::getFMul(C, C1); + if (CC1->isNormalFP()) + return BinaryOperator::CreateFDivFMF(CC1, X, &I); } - } - } - - // Under unsafe algebra do: - // X * log2(0.5*Y) = X*log2(Y) - X - if (AllowReassociate) { - Value *OpX = nullptr; - Value *OpY = nullptr; - IntrinsicInst *Log2; - detectLog2OfHalf(Op0, OpY, Log2); - if (OpY) { - OpX = Op1; - } else { - detectLog2OfHalf(Op1, OpY, Log2); - if (OpY) { - OpX = Op0; + if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // (X / C1) * C --> X * (C / C1) + Constant *CDivC1 = ConstantExpr::getFDiv(C, C1); + if (CDivC1->isNormalFP()) + return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + + // If the constant was a denormal, try reassociating differently. + // (X / C1) * C --> X / (C1 / C) + Constant *C1DivC = ConstantExpr::getFDiv(C1, C); + if (Op0->hasOneUse() && C1DivC->isNormalFP()) + return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); } - } - // if pattern detected emit alternate sequence - if (OpX && OpY) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Log2->getFastMathFlags()); - Log2->setArgOperand(0, OpY); - Value *FMulVal = Builder.CreateFMul(OpX, Log2); - Value *FSub = Builder.CreateFSub(FMulVal, OpX); - FSub->takeName(&I); - return replaceInstUsesWith(I, FSub); - } - } - // Handle symmetric situation in a 2-iteration loop - Value *Opnd0 = Op0; - Value *Opnd1 = Op1; - for (int i = 0; i < 2; i++) { - bool IgnoreZeroSign = I.hasNoSignedZeros(); - if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - - Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign); - Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign); - - // -X * -Y => X*Y - if (N1) { - Value *FMul = Builder.CreateFMul(N0, N1); - FMul->takeName(&I); - return replaceInstUsesWith(I, FMul); + // We do not need to match 'fadd C, X' and 'fsub X, C' because they are + // canonicalized to 'fadd X, C'. Distributing the multiply may allow + // further folds and (X * C) + C2 is 'fma'. + if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { + // (X + C1) * C --> (X * C) + (C * C1) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFAddFMF(XC, CC1, &I); } - - if (Opnd0->hasOneUse()) { - // -X * Y => -(X*Y) (Promote negation as high as possible) - Value *T = Builder.CreateFMul(N0, Opnd1); - Value *Neg = Builder.CreateFNeg(T); - Neg->takeName(&I); - return replaceInstUsesWith(I, Neg); + if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { + // (C1 - X) * C --> (C * C1) - (X * C) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFSubFMF(CC1, XC, &I); } } - // Handle specials cases for FMul with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) + // nnan disallows the possibility of returning a number if both operands are + // negative (in that case, we should return NaN). + if (I.hasNoNaNs() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) && + match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + Value *Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I); + return replaceInstUsesWith(I, Sqrt); + } // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: @@ -767,34 +514,40 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // latency of the instruction Y is amortized by the expression of X*X, // and therefore Y is in a "less critical" position compared to what it // was before the transformation. - if (AllowReassociate) { - Value *Opnd0_0, *Opnd0_1; - if (Opnd0->hasOneUse() && - match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) { - Value *Y = nullptr; - if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1) - Y = Opnd0_1; - else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1) - Y = Opnd0_0; - - if (Y) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - Value *T = Builder.CreateFMul(Opnd1, Opnd1); - Value *R = Builder.CreateFMul(T, Y); - R->takeName(&I); - return replaceInstUsesWith(I, R); - } - } + if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && + Op1 != Y) { + Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && + Op0 != Y) { + Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); } + } - if (!isa<Constant>(Op1)) - std::swap(Opnd0, Opnd1); - else - break; + // log2(X * 0.5) * Y = log2(X) * Y - Y + if (I.isFast()) { + IntrinsicInst *Log2 = nullptr; + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op0); + Y = Op1; + } + if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op1); + Y = Op0; + } + if (Log2) { + Log2->setArgOperand(0, X); + Log2->copyFastMathFlags(&I); + Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I); + return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I); + } } - return Changed ? &I : nullptr; + return nullptr; } /// Fold a divide or remainder with a select instruction divisor when one of the @@ -835,9 +588,9 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { Type *CondTy = SelectCond->getType(); while (BBI != BBFront) { --BBI; - // If we found a call to a function, we can't assume it will return, so + // If we found an instruction that we can't assume will return, so // information from below it cannot be propagated above it. - if (isa<CallInst>(BBI) && !isa<IntrinsicInst>(BBI)) + if (!isGuaranteedToTransferExecutionToSuccessor(&*BBI)) break; // Replace uses of the select or its condition with the known values. @@ -867,12 +620,44 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { return true; } +/// True if the multiply can not be expressed in an int this size. +static bool multiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + Product = IsSigned ? C1.smul_ov(C2, Overflow) : C1.umul_ov(C2, Overflow); + return Overflow; +} + +/// True if C1 is a multiple of C2. Quotient contains C1/C2. +static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Constant widths not equal"); + + // Bail if we will divide by zero. + if (C2.isNullValue()) + return false; + + // Bail if we would divide INT_MIN by -1. + if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) + return false; + + APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); + + return Remainder.isMinValue(); +} + /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. -/// @brief Common integer divide transforms +/// Common integer divide transforms Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + bool IsSigned = I.getOpcode() == Instruction::SDiv; + Type *Ty = I.getType(); // The RHS is known non-zero. if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { @@ -885,94 +670,87 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { - const APInt *C2; - if (match(Op1, m_APInt(C2))) { - Value *X; - const APInt *C1; - bool IsSigned = I.getOpcode() == Instruction::SDiv; - - // (X / C1) / C2 -> X / (C1*C2) - if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) || - (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) { - APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - if (!MultiplyOverflows(*C1, *C2, Product, IsSigned)) - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(I.getType(), Product)); - } - - if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) || - (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(Op0, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + if (!multiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Product)); + } - // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. - if (IsMultiple(*C2, *C1, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); - BO->setIsExact(I.isExact()); - return BO; - } + if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. - if (IsMultiple(*C1, *C2, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); - BO->setHasNoUnsignedWrap( - !IsSigned && - cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap( - cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); - return BO; - } + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (isMultiple(*C2, *C1, Quotient, IsSigned)) { + auto *NewDiv = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + NewDiv->setIsExact(I.isExact()); + return NewDiv; } - if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) && - *C1 != C1->getBitWidth() - 1) || - (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - APInt C1Shifted = APInt::getOneBitSet( - C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); - - // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1. - if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); - BO->setIsExact(I.isExact()); - return BO; - } + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (isMultiple(*C1, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } + } - // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2. - if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); - BO->setHasNoUnsignedWrap( - !IsSigned && - cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap( - cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); - return BO; - } + if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && + *C1 != C1->getBitWidth() - 1) || + (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1. + if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + auto *BO = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + BO->setIsExact(I.isExact()); + return BO; } - if (!C2->isNullValue()) // avoid X udiv 0 - if (Instruction *FoldedDiv = foldOpWithConstantIntoOperand(I)) - return FoldedDiv; + // (X << C1) / C2 -> X * ((1 << C1) / C2) if 1 << C1 is a multiple of C2. + if (isMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } } + + if (!C2->isNullValue()) // avoid X udiv 0 + if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) + return FoldedDiv; } if (match(Op0, m_One())) { - assert(!I.getType()->isIntOrIntVectorTy(1) && "i1 divide not removed?"); - if (I.getOpcode() == Instruction::SDiv) { + assert(!Ty->isIntOrIntVectorTy(1) && "i1 divide not removed?"); + if (IsSigned) { // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the // result is one, if Op1 is -1 then the result is minus one, otherwise // it's zero. Value *Inc = Builder.CreateAdd(Op1, Op0); - Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(I.getType(), 3)); - return SelectInst::Create(Cmp, Op1, ConstantInt::get(I.getType(), 0)); + Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3)); + return SelectInst::Create(Cmp, Op1, ConstantInt::get(Ty, 0)); } else { // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the // result is one, otherwise it's zero. - return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), I.getType()); + return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), Ty); } } @@ -981,12 +759,28 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { return &I; // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y - Value *X = nullptr, *Z = nullptr; - if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) { // (X - Z) / Y; Y = Op1 - bool isSigned = I.getOpcode() == Instruction::SDiv; - if ((isSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) || - (!isSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1))))) + Value *X, *Z; + if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) // (X - Z) / Y; Y = Op1 + if ((IsSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) || + (!IsSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1))))) return BinaryOperator::Create(I.getOpcode(), X, Op1); + + // (X << Y) / X -> 1 << Y + Value *Y; + if (IsSigned && match(Op0, m_NSWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNSWShl(ConstantInt::get(Ty, 1), Y); + if (!IsSigned && match(Op0, m_NUWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNUWShl(ConstantInt::get(Ty, 1), Y); + + // X / (X * Y) -> 1 / Y if the multiplication does not overflow. + if (match(Op1, m_c_Mul(m_Specific(Op0), m_Value(Y)))) { + bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap(); + bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(); + if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) { + I.setOperand(0, ConstantInt::get(Ty, 1)); + I.setOperand(1, Y); + return &I; + } } return nullptr; @@ -1000,7 +794,7 @@ using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC); -/// \brief Used to maintain state for visitUDivOperand(). +/// Used to maintain state for visitUDivOperand(). struct UDivFoldAction { /// Informs visitUDiv() how to fold this operand. This can be zero if this /// action joins two actions together. @@ -1028,23 +822,15 @@ struct UDivFoldAction { // X udiv 2^C -> X >> C static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - const APInt &C = cast<Constant>(Op1)->getUniqueInteger(); - BinaryOperator *LShr = BinaryOperator::CreateLShr( - Op0, ConstantInt::get(Op0->getType(), C.logBase2())); + Constant *C1 = getLogBase2(Op0->getType(), cast<Constant>(Op1)); + if (!C1) + llvm_unreachable("Failed to constant fold udiv -> logbase2"); + BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); if (I.isExact()) LShr->setIsExact(); return LShr; } -// X udiv C, where C >= signbit -static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, - const BinaryOperator &I, InstCombiner &IC) { - Value *ICI = IC.Builder.CreateICmpULT(Op0, cast<ConstantInt>(Op1)); - - return SelectInst::Create(ICI, Constant::getNullValue(I.getType()), - ConstantInt::get(I.getType(), 1)); -} - // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) // X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, @@ -1053,12 +839,14 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) ShiftLeft = Op1; - const APInt *CI; + Constant *CI; Value *N; - if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) + if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) llvm_unreachable("match should never fail here!"); - if (*CI != 1) - N = IC.Builder.CreateAdd(N, ConstantInt::get(N->getType(), CI->logBase2())); + Constant *Log2Base = getLogBase2(N->getType(), CI); + if (!Log2Base) + llvm_unreachable("getLogBase2 should never fail here!"); + N = IC.Builder.CreateAdd(N, Log2Base); if (Op1 != ShiftLeft) N = IC.Builder.CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); @@ -1067,7 +855,7 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, return LShr; } -// \brief Recursively visits the possible right hand operands of a udiv +// Recursively visits the possible right hand operands of a udiv // instruction, seeing through select instructions, to determine if we can // replace the udiv with something simpler. If we find that an operand is not // able to simplify the udiv, we abort the entire transformation. @@ -1081,13 +869,6 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return Actions.size(); } - if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) - // X udiv C, where C >= signbit - if (C->getValue().isNegative()) { - Actions.push_back(UDivFoldAction(foldUDivNegCst, C)); - return Actions.size(); - } - // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) if (match(Op1, m_Shl(m_Power2(), m_Value())) || match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { @@ -1148,40 +929,65 @@ static Instruction *narrowUDivURem(BinaryOperator &I, } Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; - // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - { - Value *X; - const APInt *C1, *C2; - if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && - match(Op1, m_APInt(C2))) { - bool Overflow; - APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); - if (!Overflow) { - bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); - BinaryOperator *BO = BinaryOperator::CreateUDiv( - X, ConstantInt::get(X->getType(), C2ShlC1)); - if (IsExact) - BO->setIsExact(); - return BO; - } + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X; + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && match(Op1, m_APInt(C2))) { + // (X lshr C1) udiv C2 --> X udiv (C2 << C1) + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; } } + // Op0 / C where C is large (negative) --> zext (Op0 >= C) + // TODO: Could use isKnownNegative() to handle non-constant values. + Type *Ty = I.getType(); + if (match(Op1, m_Negative())) { + Value *Cmp = Builder.CreateICmpUGE(Op0, Op1); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + // Op0 / (sext i1 X) --> zext (Op0 == -1) (if X is 0, the div is undefined) + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) return NarrowDiv; + // If the udiv operands are non-overflowing multiplies with a common operand, + // then eliminate the common factor: + // (A * B) / (A * X) --> B / X (and commuted variants) + // TODO: The code would be reduced if we had m_c_NUWMul pattern matching. + // TODO: If -reassociation handled this generally, we could remove this. + Value *A, *B; + if (match(Op0, m_NUWMul(m_Value(A), m_Value(B)))) { + if (match(Op1, m_NUWMul(m_Specific(A), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(A)))) + return BinaryOperator::CreateUDiv(B, X); + if (match(Op1, m_NUWMul(m_Specific(B), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateUDiv(A, X); + } + // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; if (visitUDivOperand(Op0, Op1, I, UDivActions)) @@ -1217,24 +1023,27 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { } Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X; + // sdiv Op0, -1 --> -Op0 + // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) + if (match(Op1, m_AllOnes()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + return BinaryOperator::CreateNeg(Op0); + const APInt *Op1C; if (match(Op1, m_APInt(Op1C))) { - // sdiv X, -1 == -X - if (Op1C->isAllOnesValue()) - return BinaryOperator::CreateNeg(Op0); - // sdiv exact X, C --> ashr exact X, log2(C) if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); @@ -1298,166 +1107,148 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return nullptr; } -/// CvtFDivConstToReciprocal tries to convert X/C into X*1/C if C not a special -/// FP value and: -/// 1) 1/C is exact, or -/// 2) reciprocal is allowed. -/// If the conversion was successful, the simplified expression "X * 1/C" is -/// returned; otherwise, nullptr is returned. -static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, - bool AllowReciprocal) { - if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. +/// Remove negation and try to convert division into multiplication. +static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(1), m_Constant(C))) return nullptr; - const APFloat &FpVal = cast<ConstantFP>(Divisor)->getValueAPF(); - APFloat Reciprocal(FpVal.getSemantics()); - bool Cvt = FpVal.getExactInverse(&Reciprocal); + // -X / C --> X / -C + Value *X; + if (match(I.getOperand(0), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); - if (!Cvt && AllowReciprocal && FpVal.isFiniteNonZero()) { - Reciprocal = APFloat(FpVal.getSemantics(), 1.0f); - (void)Reciprocal.divide(FpVal, APFloat::rmNearestTiesToEven); - Cvt = !Reciprocal.isDenormal(); - } + // If the constant divisor has an exact inverse, this is always safe. If not, + // then we can still create a reciprocal if fast-math-flags allow it and the + // constant is a regular number (not zero, infinite, or denormal). + if (!(C->hasExactInverseFP() || (I.hasAllowReciprocal() && C->isNormalFP()))) + return nullptr; - if (!Cvt) + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + auto *RecipC = ConstantExpr::getFDiv(ConstantFP::get(I.getType(), 1.0), C); + if (!RecipC->isNormalFP()) return nullptr; - ConstantFP *R; - R = ConstantFP::get(Dividend->getType()->getContext(), Reciprocal); - return BinaryOperator::CreateFMul(Dividend, R); + // X / C --> X * (1 / C) + return BinaryOperator::CreateFMulFMF(I.getOperand(0), RecipC, &I); } -Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +/// Remove negation and try to reassociate constant math. +static Instruction *foldFDivConstantDividend(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(0), m_Constant(C))) + return nullptr; - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); + // C / -X --> -C / X + Value *X; + if (match(I.getOperand(1), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + + if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) + return nullptr; + + // Try to reassociate C / X expressions where X includes another constant. + Constant *C2, *NewC = nullptr; + if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) { + // C / (X * C2) --> (C / C2) / X + NewC = ConstantExpr::getFDiv(C, C2); + } else if (match(I.getOperand(1), m_FDiv(m_Value(X), m_Constant(C2)))) { + // C / (X / C2) --> (C * C2) / X + NewC = ConstantExpr::getFMul(C, C2); + } + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + if (!NewC || !NewC->isNormalFP()) + return nullptr; + + return BinaryOperator::CreateFDivFMF(NewC, X, &I); +} - if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), +Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { + if (Value *V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + + if (Instruction *R = foldFDivConstantDivisor(I)) + return R; + + if (Instruction *R = foldFDivConstantDividend(I)) + return R; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - bool AllowReassociate = I.isFast(); - bool AllowReciprocal = I.hasAllowReciprocal(); - - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { + if (isa<Constant>(Op1)) if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - if (AllowReassociate) { - Constant *C1 = nullptr; - Constant *C2 = Op1C; - Value *X; - Instruction *Res = nullptr; - - if (match(Op0, m_FMul(m_Value(X), m_Constant(C1)))) { - // (X*C1)/C2 => X * (C1/C2) - // - Constant *C = ConstantExpr::getFDiv(C1, C2); - if (isNormalFp(C)) - Res = BinaryOperator::CreateFMul(X, C); - } else if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { - // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed] - Constant *C = ConstantExpr::getFMul(C1, C2); - if (isNormalFp(C)) { - Res = CvtFDivConstToReciprocal(X, C, AllowReciprocal); - if (!Res) - Res = BinaryOperator::CreateFDiv(X, C); - } - } - - if (Res) { - Res->setFastMathFlags(I.getFastMathFlags()); - return Res; - } + if (I.hasAllowReassoc() && I.hasAllowReciprocal()) { + Value *X, *Y; + if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op1))) { + // (X / Y) / Z => X / (Y * Z) + Value *YZ = Builder.CreateFMulFMF(Y, Op1, &I); + return BinaryOperator::CreateFDivFMF(X, YZ, &I); } - - // X / C => X * 1/C - if (Instruction *T = CvtFDivConstToReciprocal(Op0, Op1C, AllowReciprocal)) { - T->copyFastMathFlags(&I); - return T; + if (match(Op1, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op0))) { + // Z / (X / Y) => (Y * Z) / X + Value *YZ = Builder.CreateFMulFMF(Y, Op0, &I); + return BinaryOperator::CreateFDivFMF(YZ, X, &I); } - - return nullptr; } - if (AllowReassociate && isa<Constant>(Op0)) { - Constant *C1 = cast<Constant>(Op0), *C2; - Constant *Fold = nullptr; + if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) { + // sin(X) / cos(X) -> tan(X) + // cos(X) / sin(X) -> 1/tan(X) (cotangent) Value *X; - bool CreateDiv = true; - - // C1 / (X*C2) => (C1/C2) / X - if (match(Op1, m_FMul(m_Value(X), m_Constant(C2)))) - Fold = ConstantExpr::getFDiv(C1, C2); - else if (match(Op1, m_FDiv(m_Value(X), m_Constant(C2)))) { - // C1 / (X/C2) => (C1*C2) / X - Fold = ConstantExpr::getFMul(C1, C2); - } else if (match(Op1, m_FDiv(m_Constant(C2), m_Value(X)))) { - // C1 / (C2/X) => (C1/C2) * X - Fold = ConstantExpr::getFDiv(C1, C2); - CreateDiv = false; - } - - if (Fold && isNormalFp(Fold)) { - Instruction *R = CreateDiv ? BinaryOperator::CreateFDiv(Fold, X) - : BinaryOperator::CreateFMul(X, Fold); - R->setFastMathFlags(I.getFastMathFlags()); - return R; + bool IsTan = match(Op0, m_Intrinsic<Intrinsic::sin>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::cos>(m_Specific(X))); + bool IsCot = + !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); + + if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan, + LibFunc_tanf, LibFunc_tanl)) { + IRBuilder<> B(&I); + IRBuilder<>::FastMathFlagGuard FMFGuard(B); + B.setFastMathFlags(I.getFastMathFlags()); + AttributeList Attrs = CallSite(Op0).getCalledFunction()->getAttributes(); + Value *Res = emitUnaryFloatFnCall(X, TLI.getName(LibFunc_tan), B, Attrs); + if (IsCot) + Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res); + return replaceInstUsesWith(I, Res); } - return nullptr; } - if (AllowReassociate) { - Value *X, *Y; - Value *NewInst = nullptr; - Instruction *SimpR = nullptr; - - if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) { - // (X/Y) / Z => X / (Y*Z) - if (!isa<Constant>(Y) || !isa<Constant>(Op1)) { - NewInst = Builder.CreateFMul(Y, Op1); - if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= cast<Instruction>(Op0)->getFastMathFlags(); - RI->setFastMathFlags(Flags); - } - SimpR = BinaryOperator::CreateFDiv(X, NewInst); - } - } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) { - // Z / (X/Y) => Z*Y / X - if (!isa<Constant>(Y) || !isa<Constant>(Op0)) { - NewInst = Builder.CreateFMul(Op0, Y); - if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= cast<Instruction>(Op1)->getFastMathFlags(); - RI->setFastMathFlags(Flags); - } - SimpR = BinaryOperator::CreateFDiv(NewInst, X); - } - } - - if (NewInst) { - if (Instruction *T = dyn_cast<Instruction>(NewInst)) - T->setDebugLoc(I.getDebugLoc()); - SimpR->setFastMathFlags(I.getFastMathFlags()); - return SimpR; - } + // -X / -Y -> X / Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) { + I.setOperand(0, X); + I.setOperand(1, Y); + return &I; } - Value *LHS; - Value *RHS; - - // -x / -y -> x / y - if (match(Op0, m_FNeg(m_Value(LHS))) && match(Op1, m_FNeg(m_Value(RHS)))) { - I.setOperand(0, LHS); - I.setOperand(1, RHS); + // X / (X * Y) --> 1.0 / Y + // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed. + // We can ignore the possibility that X is infinity because INF/INF is NaN. + if (I.hasNoNaNs() && I.hasAllowReassoc() && + match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) { + I.setOperand(0, ConstantFP::get(I.getType(), 1.0)); + I.setOperand(1, Y); return &I; } @@ -1467,7 +1258,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. -/// @brief Common integer remainder transforms +/// Common integer remainder transforms Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1509,13 +1300,12 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { } Instruction *InstCombiner::visitURem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyURemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *common = commonIRemTransforms(I)) return common; @@ -1524,47 +1314,55 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { - Constant *N1 = Constant::getAllOnesValue(I.getType()); + Constant *N1 = Constant::getAllOnesValue(Ty); Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); } // 1 urem X -> zext(X != 1) - if (match(Op0, m_One())) { - Value *Cmp = Builder.CreateICmpNE(Op1, Op0); - Value *Ext = Builder.CreateZExt(Cmp, I.getType()); - return replaceInstUsesWith(I, Ext); - } + if (match(Op0, m_One())) + return CastInst::CreateZExtOrBitCast(Builder.CreateICmpNE(Op1, Op0), Ty); // X urem C -> X < C ? X : X - C, where C >= signbit. - const APInt *DivisorC; - if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { + if (match(Op1, m_Negative())) { Value *Cmp = Builder.CreateICmpULT(Op0, Op1); Value *Sub = Builder.CreateSub(Op0, Op1); return SelectInst::Create(Cmp, Op0, Sub); } + // If the divisor is a sext of a boolean, then the divisor must be max + // unsigned value (-1). Therefore, the remainder is Op0 unless Op0 is also + // max unsigned value. In that case, the remainder is 0: + // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + } + return nullptr; } Instruction *InstCombiner::visitSRem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySRemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) return Common; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); { const APInt *Y; // X % -Y -> X % Y - if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) { + if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) { Worklist.AddValue(I.getOperand(1)); I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); return &I; @@ -1622,14 +1420,13 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { } Instruction *InstCombiner::visitFRem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), + if (Value *V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 7ee018dbc49b..e54a1dd05a24 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -15,14 +15,18 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +static cl::opt<unsigned> +MaxNumPhis("instcombine-max-num-phis", cl::init(512), + cl::desc("Maximum number phis to handle in intptr/ptrint folding")); + /// The PHI arguments will be folded into a single operation with a PHI node /// as input. The debug location of the single operation will be the merged /// locations of the original PHI node arguments. @@ -176,8 +180,12 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { assert(AvailablePtrVals.size() == PN.getNumIncomingValues() && "Not enough available ptr typed incoming values"); PHINode *MatchingPtrPHI = nullptr; + unsigned NumPhis = 0; for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI()); - II != EI; II++) { + II != EI; II++, NumPhis++) { + // FIXME: consider handling this in AggressiveInstCombine + if (NumPhis > MaxNumPhis) + return nullptr; PHINode *PtrPHI = dyn_cast<PHINode>(II); if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) continue; @@ -1008,10 +1016,9 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // extracted out of it. First, sort the users by their offset and size. array_pod_sort(PHIUsers.begin(), PHIUsers.end()); - DEBUG(dbgs() << "SLICING UP PHI: " << FirstPhi << '\n'; - for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) - dbgs() << "AND USER PHI #" << i << ": " << *PHIsToSlice[i] << '\n'; - ); + LLVM_DEBUG(dbgs() << "SLICING UP PHI: " << FirstPhi << '\n'; + for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) dbgs() + << "AND USER PHI #" << i << ": " << *PHIsToSlice[i] << '\n';); // PredValues - This is a temporary used when rewriting PHI nodes. It is // hoisted out here to avoid construction/destruction thrashing. @@ -1092,8 +1099,8 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { } PredValues.clear(); - DEBUG(dbgs() << " Made element PHI for offset " << Offset << ": " - << *EltPHI << '\n'); + LLVM_DEBUG(dbgs() << " Made element PHI for offset " << Offset << ": " + << *EltPHI << '\n'); ExtractedVals[LoweredPHIRecord(PN, Offset, Ty)] = EltPHI; } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 6f26f7f5cd19..4867808478a3 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -47,93 +47,51 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -static SelectPatternFlavor -getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return SPF_SMAX; - case SPF_UMIN: - return SPF_UMAX; - case SPF_SMAX: - return SPF_SMIN; - case SPF_UMAX: - return SPF_UMIN; - } -} - -static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, - bool Ordered=false) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return ICmpInst::ICMP_SLT; - case SPF_UMIN: - return ICmpInst::ICMP_ULT; - case SPF_SMAX: - return ICmpInst::ICMP_SGT; - case SPF_UMAX: - return ICmpInst::ICMP_UGT; - case SPF_FMINNUM: - return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; - case SPF_FMAXNUM: - return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; - } -} - -static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, - SelectPatternFlavor SPF, Value *A, - Value *B) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); - assert(CmpInst::isIntPredicate(Pred)); +static Value *createMinMax(InstCombiner::BuilderTy &Builder, + SelectPatternFlavor SPF, Value *A, Value *B) { + CmpInst::Predicate Pred = getMinMaxPred(SPF); + assert(CmpInst::isIntPredicate(Pred) && "Expected integer predicate"); return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. /// This folds: -/// select (icmp eq (and X, C1)), C2, C3 -/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// select (icmp eq (and X, C1)), TC, FC +/// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. /// To something like: -/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC /// Or: -/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 -/// With some variations depending if C3 is larger than C2, or the shift +/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC +/// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, - APInt TrueVal, APInt FalseVal, +static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, InstCombiner::BuilderTy &Builder) { - assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + const APInt *SelTC, *SelFC; + if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || + !match(Sel.getFalseValue(), m_APInt(SelFC))) + return nullptr; // If this is a vector select, we need a vector compare. - if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) return nullptr; Value *V; APInt AndMask; bool CreateAnd = false; - ICmpInst::Predicate Pred = IC->getPredicate(); + ICmpInst::Predicate Pred = Cmp->getPredicate(); if (ICmpInst::isEquality(Pred)) { - if (!match(IC->getOperand(1), m_Zero())) + if (!match(Cmp->getOperand(1), m_Zero())) return nullptr; - V = IC->getOperand(0); - + V = Cmp->getOperand(0); const APInt *AndRHS; if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) return nullptr; AndMask = *AndRHS; - } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), Pred, V, AndMask)) { assert(ICmpInst::isEquality(Pred) && "Not equality test?"); - if (!AndMask.isPowerOf2()) return nullptr; @@ -142,39 +100,58 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, return nullptr; } - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else + // In general, when both constants are non-zero, we would need an offset to + // replace the select. This would require more instructions than we started + // with. But there's one special-case that we handle here because it can + // simplify/reduce the instructions. + APInt TC = *SelTC; + APInt FC = *SelFC; + if (!TC.isNullValue() && !FC.isNullValue()) { + // If the select constants differ by exactly one bit and that's the same + // bit that is masked and checked by the select condition, the select can + // be replaced by bitwise logic to set/clear one bit of the constant result. + if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask) return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; + if (CreateAnd) { + // If we have to create an 'and', then we must kill the cmp to not + // increase the instruction count. + if (!Cmp->hasOneUse()) + return nullptr; + V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask)); + } + bool ExtraBitInTC = TC.ugt(FC); + if (Pred == ICmpInst::ICMP_EQ) { + // If the masked bit in V is clear, clear or set the bit in the result: + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC + Constant *C = ConstantInt::get(SelType, TC); + return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C); + } + if (Pred == ICmpInst::ICMP_NE) { + // If the masked bit in V is set, set or clear the bit in the result: + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC + Constant *C = ConstantInt::get(SelType, FC); + return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C); + } + llvm_unreachable("Only expecting equality predicates"); } - // Make sure one of the select arms is a power of 2. - if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + // Make sure one of the select arms is a power-of-2. + if (!TC.isPowerOf2() && !FC.isPowerOf2()) return nullptr; // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + const APInt &ValC = !TC.isNullValue() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); - if (CreateAnd) { - // Insert the AND instruction on the input to the truncate. + // Insert the 'and' instruction on the input to the truncate. + if (CreateAnd) V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); - } - // If types don't match we can still convert the select by introducing a zext + // If types don't match, we can still convert the select by introducing a zext // or a trunc of the 'and'. if (ValZeros > AndZeros) { V = Builder.CreateZExtOrTrunc(V, SelType); @@ -182,19 +159,17 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, } else if (ValZeros < AndZeros) { V = Builder.CreateLShr(V, AndZeros - ValZeros); V = Builder.CreateZExtOrTrunc(V, SelType); - } else + } else { V = Builder.CreateZExtOrTrunc(V, SelType); + } // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); + bool ShouldNotVal = !TC.isNullValue(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); return V; } @@ -300,12 +275,13 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators with one-use here. As with the cast case - // above, it may be possible to relax the one-use constraint, but that needs - // be examined carefully since it may not reduce the total number of - // instructions. - BinaryOperator *BO = dyn_cast<BinaryOperator>(TI); - if (!BO || !TI->hasOneUse() || !FI->hasOneUse()) + // Only handle binary operators (including two-operand getelementptr) with + // one-use here. As with the cast case above, it may be possible to relax the + // one-use constraint, but that needs be examined carefully since it may not + // reduce the total number of instructions. + if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || + (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || + !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -342,7 +318,18 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; - return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + if (auto *BO = dyn_cast<BinaryOperator>(TI)) { + return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + } + if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { + auto *FGEP = cast<GetElementPtrInst>(FI); + Type *ElementType = TGEP->getResultElementType(); + return TGEP->isInBounds() && FGEP->isInBounds() + ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) + : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + } + llvm_unreachable("Expected BinaryOperator or GEP"); + return nullptr; } static bool isSelect01(const APInt &C1I, const APInt &C2I) { @@ -424,6 +411,47 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } /// We want to turn: +/// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1) +/// into: +/// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0) +/// Note: +/// Z may be 0 if lshr is missing. +/// Worst-case scenario is that we will replace 5 instructions with 5 different +/// instructions, but we got rid of select. +static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, + Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() && + Cmp->getPredicate() == ICmpInst::ICMP_EQ && + match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One()))) + return nullptr; + + // The TrueVal has general form of: and %B, 1 + Value *B; + if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One())))) + return nullptr; + + // Where %B may be optionally shifted: lshr %X, %Z. + Value *X, *Z; + const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); + if (!HasShift) + X = B; + + Value *Y; + if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y)))) + return nullptr; + + // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0 + // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0 + Constant *One = ConstantInt::get(SelType, 1); + Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One; + Value *FullMask = Builder.CreateOr(Y, MaskB); + Value *MaskedX = Builder.CreateAnd(X, FullMask); + Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX); + return new ZExtInst(ICmpNeZero, SelType); +} + +/// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: /// (or (shl (and X, C1), C3), Y) @@ -526,6 +554,59 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, return Builder.CreateOr(V, Y); } +/// Transform patterns such as: (a > b) ? a - b : 0 +/// into: ((a > b) ? a : b) - b) +/// This produces a canonical max pattern that is more easily recognized by the +/// backend and converted into saturated subtraction instructions if those +/// exist. +/// There are 8 commuted/swapped variants of this pattern. +/// TODO: Also support a - UMIN(a,b) patterns. +static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, + const Value *TrueVal, + const Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + + // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + if (match(TrueVal, m_Zero())) { + Pred = ICmpInst::getInversePredicate(Pred); + std::swap(TrueVal, FalseVal); + } + if (!match(FalseVal, m_Zero())) + return nullptr; + + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { + // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 + std::swap(A, B); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected isUnsigned predicate!"); + + // Account for swapped form of subtraction: ((a > b) ? b - a : 0). + bool IsNegative = false; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) + IsNegative = true; + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) + return nullptr; + + // If sub is used anywhere else, we wouldn't be able to eliminate it + // afterwards. + if (!TrueVal->hasOneUse()) + return nullptr; + + // All checks passed, convert to canonical unsigned saturated subtraction + // form: sub(max()). + // (a > b) ? a - b : 0 -> ((a > b) ? a : b) - b) + Value *Max = Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); + return IsNegative ? Builder.CreateSub(B, Max) : Builder.CreateSub(Max, B); +} + /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single /// call to cttz/ctlz with flag 'is_zero_undef' cleared. /// @@ -687,23 +768,18 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, // Canonicalize the compare predicate based on whether we have min or max. Value *LHS, *RHS; - ICmpInst::Predicate NewPred; SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); - switch (SPR.Flavor) { - case SPF_SMIN: NewPred = ICmpInst::ICMP_SLT; break; - case SPF_UMIN: NewPred = ICmpInst::ICMP_ULT; break; - case SPF_SMAX: NewPred = ICmpInst::ICMP_SGT; break; - case SPF_UMAX: NewPred = ICmpInst::ICMP_UGT; break; - default: return nullptr; - } + if (!SelectPatternResult::isMinOrMax(SPR.Flavor)) + return nullptr; // Is this already canonical? + ICmpInst::Predicate CanonicalPred = getMinMaxPred(SPR.Flavor); if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && - Cmp.getPredicate() == NewPred) + Cmp.getPredicate() == CanonicalPred) return nullptr; // Create the canonical compare and plug it into the select. - Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); + Sel.setCondition(Builder.CreateICmp(CanonicalPred, LHS, RHS)); // If the select operands did not change, we're done. if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) @@ -718,6 +794,89 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } +/// There are many select variants for each of ABS/NABS. +/// In matchSelectPattern(), there are different compare constants, compare +/// predicates/operands and select operands. +/// In isKnownNegation(), there are different formats of negated operands. +/// Canonicalize all these variants to 1 pattern. +/// This makes CSE more likely. +static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) + return nullptr; + + // Choose a sign-bit check for the compare (likely simpler for codegen). + // ABS: (X <s 0) ? -X : X + // NABS: (X <s 0) ? X : -X + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; + if (SPF != SelectPatternFlavor::SPF_ABS && + SPF != SelectPatternFlavor::SPF_NABS) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + assert(isKnownNegation(TVal, FVal) && + "Unexpected result from matchSelectPattern"); + + // The compare may use the negated abs()/nabs() operand, or it may use + // negation in non-canonical form such as: sub A, B. + bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || + match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); + + bool CmpCanonicalized = !CmpUsesNegatedOp && + match(Cmp.getOperand(1), m_ZeroInt()) && + Cmp.getPredicate() == ICmpInst::ICMP_SLT; + bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS))); + + // Is this already canonical? + if (CmpCanonicalized && RHSCanonicalized) + return nullptr; + + // If RHS is used by other instructions except compare and select, don't + // canonicalize it to not increase the instruction count. + if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) + return nullptr; + + // Create the canonical compare: icmp slt LHS 0. + if (!CmpCanonicalized) { + Cmp.setPredicate(ICmpInst::ICMP_SLT); + Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); + if (CmpUsesNegatedOp) + Cmp.setOperand(0, LHS); + } + + // Create the canonical RHS: RHS = sub (0, LHS). + if (!RHSCanonicalized) { + assert(RHS->hasOneUse() && "RHS use number is not right"); + RHS = Builder.CreateNeg(LHS); + if (TVal == LHS) { + Sel.setFalseValue(RHS); + FVal = RHS; + } else { + Sel.setTrueValue(RHS); + TVal = RHS; + } + } + + // If the select operands do not change, we're done. + if (SPF == SelectPatternFlavor::SPF_NABS) { + if (TVal == LHS) + return &Sel; + assert(FVal == LHS && "Unexpected results from matchSelectPattern"); + } else { + if (FVal == LHS) + return &Sel; + assert(TVal == LHS && "Unexpected results from matchSelectPattern"); + } + + // We are swapping the select operands, so swap the metadata too. + Sel.setTrueValue(FVal); + Sel.setFalseValue(TVal); + Sel.swapProfMetadata(); + return &Sel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -727,59 +886,18 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; + if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) + return NewAbs; + bool Changed = adjustMinMax(SI, *ICI); + if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) + return replaceInstUsesWith(SI, V); + + // NOTE: if we wanted to, this is where to detect integer MIN/MAX ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - - // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 - // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 - // FIXME: Type and constness constraints could be lifted, but we have to - // watch code size carefully. We should consider xor instead of - // sub/add when we decide to do that. - // TODO: Merge this with foldSelectICmpAnd somehow. - if (CmpLHS->getType()->isIntOrIntVectorTy() && - CmpLHS->getType() == TrueVal->getType()) { - const APInt *C1, *C2; - if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *X; - APInt Mask; - if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { - if (Mask.isSignMask()) { - assert(X == CmpLHS && "Expected to use the compare input directly"); - assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); - - if (Pred == ICmpInst::ICMP_NE) - std::swap(C1, C2); - - // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); - - // Check if we can express the operation with a single or. - if (C2->isAllOnesValue()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - - Value *And = Builder.CreateAnd(AShr, *C2 - *C1); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, - ConstantInt::get(And->getType(), *C1))); - } - } - } - } - - { - const APInt *TrueValC, *FalseValC; - if (match(TrueVal, m_APInt(TrueValC)) && - match(FalseVal, m_APInt(FalseValC))) - if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, - *FalseValC, Builder)) - return replaceInstUsesWith(SI, V); - } - - // NOTE: if we wanted to, this is where to detect integer MIN/MAX - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y @@ -842,16 +960,22 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } + if (Instruction *V = + foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } - /// SI is a select whose condition is a PHI node (but the two may be in /// different blocks). See if the true/false values (V) are live in all of the /// predecessor blocks of the PHI. For example, cases like this can't be mapped: @@ -900,7 +1024,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) - if (SPF1 == SPF2) + if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a @@ -992,10 +1116,10 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (!NotC) NotC = Builder.CreateNot(C); - Value *NewInner = generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); - Value *NewOuter = Builder.CreateNot(generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); + Value *NewInner = createMinMax(Builder, getInverseMinMaxFlavor(SPF1), NotA, + NotB); + Value *NewOuter = Builder.CreateNot( + createMinMax(Builder, getInverseMinMaxFlavor(SPF2), NewInner, NotC)); return replaceInstUsesWith(Outer, NewOuter); } @@ -1075,6 +1199,11 @@ static Instruction *foldAddSubSelect(SelectInst &SI, } Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + Instruction *ExtInst; if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && !match(Sel.getFalseValue(), m_Instruction(ExtInst))) @@ -1084,20 +1213,18 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) return nullptr; - // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. + // If we are extending from a boolean type or if we can create a select that + // has the same size operands as its condition, try to narrow the select. Value *X = ExtInst->getOperand(0); Type *SmallType = X->getType(); - if (!SmallType->isIntOrIntVectorTy(1)) - return nullptr; - - Constant *C; - if (!match(Sel.getTrueValue(), m_Constant(C)) && - !match(Sel.getFalseValue(), m_Constant(C))) + Value *Cond = Sel.getCondition(); + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!SmallType->isIntOrIntVectorTy(1) && + (!Cmp || Cmp->getOperand(0)->getType() != SmallType)) return nullptr; // If the constant is the same after truncation to the smaller type and // extension to the original type, we can narrow the select. - Value *Cond = Sel.getCondition(); Type *SelType = Sel.getType(); Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); @@ -1289,6 +1416,63 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } +/// Reduce a sequence of min/max with a common operand. +static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, + Value *RHS, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); + // TODO: Allow FP min/max with nnan/nsz. + if (!LHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + Value *A, *B, *C, *D; + SelectPatternResult L = matchSelectPattern(LHS, A, B); + SelectPatternResult R = matchSelectPattern(RHS, C, D); + if (SPF != L.Flavor || L.Flavor != R.Flavor) + return nullptr; + + // Look for a common operand. The use checks are different than usual because + // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by + // the select. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (!LHS->hasNUsesOrMore(3) && RHS->hasNUsesOrMore(3)) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else if (!RHS->hasNUsesOrMore(3)) { + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + if (!MinMaxOp || !ThirdOp) + return nullptr; + + CmpInst::Predicate P = getMinMaxPred(SPF); + Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); + return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1489,7 +1673,37 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } - // NOTE: if we wanted to, this is where to detect ABS + + // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We + // also require nnan because we do not want to unintentionally change the + // sign of a NaN value. + Value *X = FCI->getOperand(0); + FCmpInst::Predicate Pred = FCI->getPredicate(); + if (match(FCI->getOperand(1), m_AnyZeroFP()) && FCI->hasNoNaNs()) { + // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) + // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) + if ((X == FalseVal && Pred == FCmpInst::FCMP_OLE && + match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || + (X == TrueVal && Pred == FCmpInst::FCMP_OGT && + match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + // With nsz: + // (X < +/-0.0) ? -X : X --> fabs(X) + // (X <= +/-0.0) ? -X : X --> fabs(X) + // (X > +/-0.0) ? X : -X --> fabs(X) + // (X >= +/-0.0) ? X : -X --> fabs(X) + if (FCI->hasNoSignedZeros() && + ((X == FalseVal && match(TrueVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || + (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + } } // See if we are selecting two values based on a comparison of the two values. @@ -1532,7 +1746,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (LHS->getType()->isFPOrFPVectorTy() && ((CmpLHS != LHS && CmpLHS != RHS) || (CmpRHS != LHS && CmpRHS != RHS)))) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); + CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; if (CmpInst::isIntPredicate(Pred)) { @@ -1551,6 +1765,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); return replaceInstUsesWith(SI, NewCast); } + + // MAX(~a, ~b) -> ~MIN(a, b) + // MIN(~a, ~b) -> ~MAX(a, b) + Value *A, *B; + if (match(LHS, m_Not(m_Value(A))) && match(RHS, m_Not(m_Value(B))) && + (LHS->getNumUses() <= 2 || RHS->getNumUses() <= 2)) { + CmpInst::Predicate InvertedPred = getInverseMinMaxPred(SPF); + Value *InvertedCmp = Builder.CreateICmp(InvertedPred, A, B); + Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); + return BinaryOperator::CreateNot(NewSel); + } + + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } if (SPF) { @@ -1570,28 +1798,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return R; } - // MAX(~a, ~b) -> ~MIN(a, b) - if ((SPF == SPF_SMAX || SPF == SPF_UMAX) && - IsFreeToInvert(LHS, LHS->hasNUses(2)) && - IsFreeToInvert(RHS, RHS->hasNUses(2))) { - // For this transform to be profitable, we need to eliminate at least two - // 'not' instructions if we're going to add one 'not' instruction. - int NumberOfNots = - (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) + - (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) + - (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - - if (NumberOfNots >= 2) { - Value *NewLHS = Builder.CreateNot(LHS); - Value *NewRHS = Builder.CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX ? Builder.CreateICmpSLT(NewLHS, NewRHS) - : Builder.CreateICmpULT(NewLHS, NewRHS); - Value *NewSI = - Builder.CreateNot(Builder.CreateSelect(NewCmp, NewLHS, NewRHS)); - return replaceInstUsesWith(SI, NewSI); - } - } - // TODO. // ABS(-X) -> ABS(X) } @@ -1643,11 +1849,25 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { + // The select might be preventing a division by 0. + switch (BO->getOpcode()) { + default: + return true; + case Instruction::SRem: + case Instruction::URem: + case Instruction::SDiv: + case Instruction::UDiv: + return false; + } + }; + // Try to simplify a binop sandwiched between 2 selects with the same // condition. // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && + canMergeSelectThroughBinop(TrueBO)) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { TrueBO->setOperand(0, TrueBOSI->getTrueValue()); @@ -1666,7 +1886,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && + canMergeSelectThroughBinop(FalseBO)) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { FalseBO->setOperand(0, FalseBOSI->getFalseValue()); diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 44bbb84686ab..34f8037e519f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -87,8 +87,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Equal shift amounts in opposite directions become bitwise 'and': // lshr (shl X, C), C --> and X, C' // shl (lshr X, C), C --> and X, C' - unsigned InnerShAmt = InnerShiftConst->getZExtValue(); - if (InnerShAmt == OuterShAmt) + if (*InnerShiftConst == OuterShAmt) return true; // If the 2nd shift is bigger than the 1st, we can fold: @@ -98,7 +97,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Also, check that the inner shift is valid (less than the type width) or // we'll crash trying to produce the bit mask for the 'and'. unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); - if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { + if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); unsigned MaskShift = IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; @@ -135,7 +135,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, ConstantInt *CI = nullptr; if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { - if (CI->getZExtValue() == NumBits) { + if (CI->getValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 // If this is a truncate of a logical shr, we can truncate it to a smaller @@ -356,8 +356,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { - DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" - " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); + LLVM_DEBUG( + dbgs() << "ICE: GetShiftedValue propagating shift through expression" + " to eliminate shift:\n IN: " + << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); @@ -370,7 +372,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); - if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) + if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) @@ -586,23 +588,23 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } Instruction *InstCombiner::visitShl(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *V = commonShiftTransforms(I)) return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. @@ -620,11 +622,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - // Be careful about hiding shl instructions behind bit masks. They are used - // to represent multiplies by a constant, and it is important that simple - // arithmetic expressions are still recognizable by scalar evolution. - // The inexact versions are deferred to DAGCombine, so we don't hide shl - // behind a bit mask. + // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) + // needs a few fixes for the rotate pattern recognition first. const APInt *ShOp1; if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { unsigned ShrAmt = ShOp1->getZExtValue(); @@ -668,6 +667,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } } + // Transform (x >> y) << y to x & (-1 << y) + // Valid for any type of right-shift. + Value *X; + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateShl(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + Constant *C1; if (match(Op1, m_Constant(C1))) { Constant *C2; @@ -685,17 +693,17 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } Instruction *InstCombiner::visitLShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { @@ -800,25 +808,34 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return &I; } } + + // Transform (x << y) >> y to x & (-1 >> y) + Value *X; + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateLShr(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + return nullptr; } Instruction *InstCombiner::visitAShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { + if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); // If the shift amount equals the difference in width of the destination @@ -832,7 +849,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. const APInt *ShOp1; - if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned ShlAmt = ShOp1->getZExtValue(); if (ShlAmt < ShAmt) { // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) @@ -850,7 +868,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } - if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); // Oversized arithmetic shifts replicate the sign bit. AmtSum = std::min(AmtSum, BitWidth - 1); diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a2e757cb4273..425f5ce384be 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -23,6 +23,17 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +namespace { + +struct AMDGPUImageDMaskIntrinsic { + unsigned Intr; +}; + +#define GET_AMDGPUImageDMaskIntrinsicTable_IMPL +#include "InstCombineTables.inc" + +} // end anonymous namespace + /// Check to see if the specified operand of the specified instruction is a /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. @@ -333,7 +344,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits InputKnown(SrcBitWidth); if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) return I; - Known = Known.zextOrTrunc(BitWidth); + Known = InputKnown.zextOrTrunc(BitWidth); // Any top bits are known to be zero. if (BitWidth > SrcBitWidth) Known.Zero.setBitsFrom(SrcBitWidth); @@ -545,6 +556,27 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } break; } + case Instruction::UDiv: { + // UDiv doesn't demand low bits that are zero in the divisor. + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + // If the shift is exact, then it does demand the low bits. + if (cast<UDivOperator>(I)->isExact()) + break; + + // FIXME: Take the demanded mask of the result into account. + unsigned RHSTrailingZeros = SA->countTrailingZeros(); + APInt DemandedMaskIn = + APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) + return I; + + // Propagate zero bits from the input. + Known.Zero.setHighBits(std::min( + BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + } + break; + } case Instruction::SRem: if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { // X % -1 demands all the bits because we don't want to introduce @@ -888,6 +920,110 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, return nullptr; } +/// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics. +Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, + APInt DemandedElts, + int DMaskIdx) { + unsigned VWidth = II->getType()->getVectorNumElements(); + if (VWidth == 1) + return nullptr; + + ConstantInt *NewDMask = nullptr; + + if (DMaskIdx < 0) { + // Pretend that a prefix of elements is demanded to simplify the code + // below. + DemandedElts = (1 << DemandedElts.getActiveBits()) - 1; + } else { + ConstantInt *DMask = dyn_cast<ConstantInt>(II->getArgOperand(DMaskIdx)); + if (!DMask) + return nullptr; // non-constant dmask is not supported by codegen + + unsigned DMaskVal = DMask->getZExtValue() & 0xf; + + // Mask off values that are undefined because the dmask doesn't cover them + DemandedElts &= (1 << countPopulation(DMaskVal)) - 1; + + unsigned NewDMaskVal = 0; + unsigned OrigLoadIdx = 0; + for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) { + const unsigned Bit = 1 << SrcIdx; + if (!!(DMaskVal & Bit)) { + if (!!DemandedElts[OrigLoadIdx]) + NewDMaskVal |= Bit; + OrigLoadIdx++; + } + } + + if (DMaskVal != NewDMaskVal) + NewDMask = ConstantInt::get(DMask->getType(), NewDMaskVal); + } + + // TODO: Handle 3 vectors when supported in code gen. + unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countPopulation()); + if (!NewNumElts) + return UndefValue::get(II->getType()); + + if (NewNumElts >= VWidth && DemandedElts.isMask()) { + if (NewDMask) + II->setArgOperand(DMaskIdx, NewDMask); + return nullptr; + } + + // Determine the overload types of the original intrinsic. + auto IID = II->getIntrinsicID(); + SmallVector<Intrinsic::IITDescriptor, 16> Table; + getIntrinsicInfoTableEntries(IID, Table); + ArrayRef<Intrinsic::IITDescriptor> TableRef = Table; + + FunctionType *FTy = II->getCalledFunction()->getFunctionType(); + SmallVector<Type *, 6> OverloadTys; + Intrinsic::matchIntrinsicType(FTy->getReturnType(), TableRef, OverloadTys); + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) + Intrinsic::matchIntrinsicType(FTy->getParamType(i), TableRef, OverloadTys); + + // Get the new return type overload of the intrinsic. + Module *M = II->getParent()->getParent()->getParent(); + Type *EltTy = II->getType()->getVectorElementType(); + Type *NewTy = (NewNumElts == 1) ? EltTy : VectorType::get(EltTy, NewNumElts); + + OverloadTys[0] = NewTy; + Function *NewIntrin = Intrinsic::getDeclaration(M, IID, OverloadTys); + + SmallVector<Value *, 16> Args; + for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) + Args.push_back(II->getArgOperand(I)); + + if (NewDMask) + Args[DMaskIdx] = NewDMask; + + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(II); + + CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); + NewCall->takeName(II); + NewCall->copyMetadata(*II); + + if (NewNumElts == 1) { + return Builder.CreateInsertElement(UndefValue::get(II->getType()), NewCall, + DemandedElts.countTrailingZeros()); + } + + SmallVector<uint32_t, 8> EltMask; + unsigned NewLoadIdx = 0; + for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) { + if (!!DemandedElts[OrigLoadIdx]) + EltMask.push_back(NewLoadIdx++); + else + EltMask.push_back(NewNumElts); + } + + Value *Shuffle = + Builder.CreateShuffleVector(NewCall, UndefValue::get(NewTy), EltMask); + + return Shuffle; +} + /// The specified value produces a vector with any number of elements. /// DemandedElts contains the set of elements that are actually used by the /// caller. This method analyzes which elements of the operand are undef and @@ -1187,7 +1323,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } - // div/rem demand all inputs, because they don't want divide by zero. TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts, UndefElts2, Depth + 1); if (TmpV) { @@ -1247,8 +1382,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); if (!II) break; switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: // The instructions for these intrinsics are speced to zero upper bits not @@ -1273,8 +1406,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Unary scalar-as-vector operations that work column-wise. case Intrinsic::x86_sse_rcp_ss: case Intrinsic::x86_sse_rsqrt_ss: - case Intrinsic::x86_sse_sqrt_ss: - case Intrinsic::x86_sse2_sqrt_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1366,18 +1497,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::x86_avx512_mask_sub_sd_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: - case Intrinsic::x86_fma_vfmadd_ss: - case Intrinsic::x86_fma_vfmsub_ss: - case Intrinsic::x86_fma_vfnmadd_ss: - case Intrinsic::x86_fma_vfnmsub_ss: - case Intrinsic::x86_fma_vfmadd_sd: - case Intrinsic::x86_fma_vfmsub_sd: - case Intrinsic::x86_fma_vfnmadd_sd: - case Intrinsic::x86_fma_vfnmsub_sd: - case Intrinsic::x86_avx512_mask_vfmadd_ss: - case Intrinsic::x86_avx512_mask_vfmadd_sd: - case Intrinsic::x86_avx512_maskz_vfmadd_ss: - case Intrinsic::x86_avx512_maskz_vfmadd_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1404,68 +1523,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; - case Intrinsic::x86_avx512_mask3_vfmadd_ss: - case Intrinsic::x86_avx512_mask3_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmsub_ss: - case Intrinsic::x86_avx512_mask3_vfmsub_sd: - case Intrinsic::x86_avx512_mask3_vfnmsub_ss: - case Intrinsic::x86_avx512_mask3_vfnmsub_sd: - // These intrinsics get the passthru bits from operand 2. - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } - - // If lowest element of a scalar op isn't used then use Arg2. - if (!DemandedElts[0]) { - Worklist.Add(II); - return II->getArgOperand(2); - } - - // Only lower element is used for operand 0 and 1. - DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, - UndefElts3, Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - - // Lower element is undefined if all three lower elements are undefined. - // Consider things like undef&0. The result is known zero, not undef. - if (!UndefElts2[0] || !UndefElts3[0]) - UndefElts.clearBit(0); - - break; - - case Intrinsic::x86_sse2_pmulu_dq: - case Intrinsic::x86_sse41_pmuldq: - case Intrinsic::x86_avx2_pmul_dq: - case Intrinsic::x86_avx2_pmulu_dq: - case Intrinsic::x86_avx512_pmul_dq_512: - case Intrinsic::x86_avx512_pmulu_dq_512: { - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - unsigned InnerVWidth = Op0->getType()->getVectorNumElements(); - assert((VWidth * 2) == InnerVWidth && "Unexpected input size"); - - APInt InnerDemandedElts(InnerVWidth, 0); - for (unsigned i = 0; i != VWidth; ++i) - if (DemandedElts[i]) - InnerDemandedElts.setBit(i * 2); - - UndefElts2 = APInt(InnerVWidth, 0); - TmpV = SimplifyDemandedVectorElts(Op0, InnerDemandedElts, UndefElts2, - Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } - - UndefElts3 = APInt(InnerVWidth, 0); - TmpV = SimplifyDemandedVectorElts(Op1, InnerDemandedElts, UndefElts3, - Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - - break; - } - case Intrinsic::x86_sse2_packssdw_128: case Intrinsic::x86_sse2_packsswb_128: case Intrinsic::x86_sse2_packuswb_128: @@ -1554,124 +1611,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; case Intrinsic::amdgcn_buffer_load: case Intrinsic::amdgcn_buffer_load_format: - case Intrinsic::amdgcn_image_sample: - case Intrinsic::amdgcn_image_sample_cl: - case Intrinsic::amdgcn_image_sample_d: - case Intrinsic::amdgcn_image_sample_d_cl: - case Intrinsic::amdgcn_image_sample_l: - case Intrinsic::amdgcn_image_sample_b: - case Intrinsic::amdgcn_image_sample_b_cl: - case Intrinsic::amdgcn_image_sample_lz: - case Intrinsic::amdgcn_image_sample_cd: - case Intrinsic::amdgcn_image_sample_cd_cl: - - case Intrinsic::amdgcn_image_sample_c: - case Intrinsic::amdgcn_image_sample_c_cl: - case Intrinsic::amdgcn_image_sample_c_d: - case Intrinsic::amdgcn_image_sample_c_d_cl: - case Intrinsic::amdgcn_image_sample_c_l: - case Intrinsic::amdgcn_image_sample_c_b: - case Intrinsic::amdgcn_image_sample_c_b_cl: - case Intrinsic::amdgcn_image_sample_c_lz: - case Intrinsic::amdgcn_image_sample_c_cd: - case Intrinsic::amdgcn_image_sample_c_cd_cl: - - case Intrinsic::amdgcn_image_sample_o: - case Intrinsic::amdgcn_image_sample_cl_o: - case Intrinsic::amdgcn_image_sample_d_o: - case Intrinsic::amdgcn_image_sample_d_cl_o: - case Intrinsic::amdgcn_image_sample_l_o: - case Intrinsic::amdgcn_image_sample_b_o: - case Intrinsic::amdgcn_image_sample_b_cl_o: - case Intrinsic::amdgcn_image_sample_lz_o: - case Intrinsic::amdgcn_image_sample_cd_o: - case Intrinsic::amdgcn_image_sample_cd_cl_o: - - case Intrinsic::amdgcn_image_sample_c_o: - case Intrinsic::amdgcn_image_sample_c_cl_o: - case Intrinsic::amdgcn_image_sample_c_d_o: - case Intrinsic::amdgcn_image_sample_c_d_cl_o: - case Intrinsic::amdgcn_image_sample_c_l_o: - case Intrinsic::amdgcn_image_sample_c_b_o: - case Intrinsic::amdgcn_image_sample_c_b_cl_o: - case Intrinsic::amdgcn_image_sample_c_lz_o: - case Intrinsic::amdgcn_image_sample_c_cd_o: - case Intrinsic::amdgcn_image_sample_c_cd_cl_o: - - case Intrinsic::amdgcn_image_getlod: { - if (VWidth == 1 || !DemandedElts.isMask()) - return nullptr; - - // TODO: Handle 3 vectors when supported in code gen. - unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes()); - if (NewNumElts == VWidth) - return nullptr; - - Module *M = II->getParent()->getParent()->getParent(); - Type *EltTy = V->getType()->getVectorElementType(); - - Type *NewTy = (NewNumElts == 1) ? EltTy : - VectorType::get(EltTy, NewNumElts); - - auto IID = II->getIntrinsicID(); - - bool IsBuffer = IID == Intrinsic::amdgcn_buffer_load || - IID == Intrinsic::amdgcn_buffer_load_format; - - Function *NewIntrin = IsBuffer ? - Intrinsic::getDeclaration(M, IID, NewTy) : - // Samplers have 3 mangled types. - Intrinsic::getDeclaration(M, IID, - { NewTy, II->getArgOperand(0)->getType(), - II->getArgOperand(1)->getType()}); - - SmallVector<Value *, 5> Args; - for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) - Args.push_back(II->getArgOperand(I)); - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(II); - - CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); - NewCall->takeName(II); - NewCall->copyMetadata(*II); - - if (!IsBuffer) { - ConstantInt *DMask = dyn_cast<ConstantInt>(NewCall->getArgOperand(3)); - if (DMask) { - unsigned DMaskVal = DMask->getZExtValue() & 0xf; - - unsigned PopCnt = 0; - unsigned NewDMask = 0; - for (unsigned I = 0; I < 4; ++I) { - const unsigned Bit = 1 << I; - if (!!(DMaskVal & Bit)) { - if (++PopCnt > NewNumElts) - break; + return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); + default: { + if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) + return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0); - NewDMask |= Bit; - } - } - - NewCall->setArgOperand(3, ConstantInt::get(DMask->getType(), NewDMask)); - } - } - - - if (NewNumElts == 1) { - return Builder.CreateInsertElement(UndefValue::get(V->getType()), - NewCall, static_cast<uint64_t>(0)); - } - - SmallVector<uint32_t, 8> EltMask; - for (unsigned I = 0; I < VWidth; ++I) - EltMask.push_back(I); - - Value *Shuffle = Builder.CreateShuffleVector( - NewCall, UndefValue::get(NewTy), EltMask); - - MadeChange = true; - return Shuffle; + break; } } break; diff --git a/lib/Transforms/InstCombine/InstCombineTables.td b/lib/Transforms/InstCombine/InstCombineTables.td new file mode 100644 index 000000000000..98b2adc442fa --- /dev/null +++ b/lib/Transforms/InstCombine/InstCombineTables.td @@ -0,0 +1,11 @@ +include "llvm/TableGen/SearchableTable.td" +include "llvm/IR/Intrinsics.td" + +def AMDGPUImageDMaskIntrinsicTable : GenericTable { + let FilterClass = "AMDGPUImageDMaskIntrinsic"; + let Fields = ["Intr"]; + + let PrimaryKey = ["Intr"]; + let PrimaryKeyName = "getAMDGPUImageDMaskIntrinsic"; + let PrimaryKeyEarlyOut = 1; +} diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index aeac8910af6b..2560feb37d66 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1140,6 +1140,216 @@ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, return true; } +/// These are the ingredients in an alternate form binary operator as described +/// below. +struct BinopElts { + BinaryOperator::BinaryOps Opcode; + Value *Op0; + Value *Op1; + BinopElts(BinaryOperator::BinaryOps Opc = (BinaryOperator::BinaryOps)0, + Value *V0 = nullptr, Value *V1 = nullptr) : + Opcode(Opc), Op0(V0), Op1(V1) {} + operator bool() const { return Opcode != 0; } +}; + +/// Binops may be transformed into binops with different opcodes and operands. +/// Reverse the usual canonicalization to enable folds with the non-canonical +/// form of the binop. If a transform is possible, return the elements of the +/// new binop. If not, return invalid elements. +static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { + Value *BO0 = BO->getOperand(0), *BO1 = BO->getOperand(1); + Type *Ty = BO->getType(); + switch (BO->getOpcode()) { + case Instruction::Shl: { + // shl X, C --> mul X, (1 << C) + Constant *C; + if (match(BO1, m_Constant(C))) { + Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); + return { Instruction::Mul, BO0, ShlOne }; + } + break; + } + case Instruction::Or: { + // or X, C --> add X, C (when X and C have no common bits set) + const APInt *C; + if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + return { Instruction::Add, BO0, BO1 }; + break; + } + default: + break; + } + return {}; +} + +static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + // Are we shuffling together some value and that same value after it has been + // modified by a binop with a constant? + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + Constant *C; + bool Op0IsBinop; + if (match(Op0, m_BinOp(m_Specific(Op1), m_Constant(C)))) + Op0IsBinop = true; + else if (match(Op1, m_BinOp(m_Specific(Op0), m_Constant(C)))) + Op0IsBinop = false; + else + return nullptr; + + // The identity constant for a binop leaves a variable operand unchanged. For + // a vector, this is a splat of something like 0, -1, or 1. + // If there's no identity constant for this binop, we're done. + auto *BO = cast<BinaryOperator>(Op0IsBinop ? Op0 : Op1); + BinaryOperator::BinaryOps BOpcode = BO->getOpcode(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BOpcode, Shuf.getType(), true); + if (!IdC) + return nullptr; + + // Shuffle identity constants into the lanes that return the original value. + // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} + // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} + // The existing binop constant vector remains in the same operand position. + Constant *Mask = Shuf.getMask(); + Constant *NewC = Op0IsBinop ? ConstantExpr::getShuffleVector(C, IdC, Mask) : + ConstantExpr::getShuffleVector(IdC, C, Mask); + + bool MightCreatePoisonOrUB = + Mask->containsUndefElement() && + (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); + if (MightCreatePoisonOrUB) + NewC = getSafeVectorConstantForBinop(BOpcode, NewC, true); + + // shuf (bop X, C), X, M --> bop X, C' + // shuf X, (bop X, C), M --> bop X, C' + Value *X = Op0IsBinop ? Op1 : Op0; + Instruction *NewBO = BinaryOperator::Create(BOpcode, X, NewC); + NewBO->copyIRFlags(BO); + + // An undef shuffle mask element may propagate as an undef constant element in + // the new binop. That would produce poison where the original code might not. + // If we already made a safe constant, then there's no danger. + if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + NewBO->dropPoisonGeneratingFlags(); + return NewBO; +} + +/// Try to fold shuffles that are the equivalent of a vector select. +static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder, + const DataLayout &DL) { + if (!Shuf.isSelect()) + return nullptr; + + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) + return I; + + BinaryOperator *B0, *B1; + if (!match(Shuf.getOperand(0), m_BinOp(B0)) || + !match(Shuf.getOperand(1), m_BinOp(B1))) + return nullptr; + + Value *X, *Y; + Constant *C0, *C1; + bool ConstantsAreOp1; + if (match(B0, m_BinOp(m_Value(X), m_Constant(C0))) && + match(B1, m_BinOp(m_Value(Y), m_Constant(C1)))) + ConstantsAreOp1 = true; + else if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) && + match(B1, m_BinOp(m_Constant(C1), m_Value(Y)))) + ConstantsAreOp1 = false; + else + return nullptr; + + // We need matching binops to fold the lanes together. + BinaryOperator::BinaryOps Opc0 = B0->getOpcode(); + BinaryOperator::BinaryOps Opc1 = B1->getOpcode(); + bool DropNSW = false; + if (ConstantsAreOp1 && Opc0 != Opc1) { + // TODO: We drop "nsw" if shift is converted into multiply because it may + // not be correct when the shift amount is BitWidth - 1. We could examine + // each vector element to determine if it is safe to keep that flag. + if (Opc0 == Instruction::Shl || Opc1 == Instruction::Shl) + DropNSW = true; + if (BinopElts AltB0 = getAlternateBinop(B0, DL)) { + assert(isa<Constant>(AltB0.Op1) && "Expecting constant with alt binop"); + Opc0 = AltB0.Opcode; + C0 = cast<Constant>(AltB0.Op1); + } else if (BinopElts AltB1 = getAlternateBinop(B1, DL)) { + assert(isa<Constant>(AltB1.Op1) && "Expecting constant with alt binop"); + Opc1 = AltB1.Opcode; + C1 = cast<Constant>(AltB1.Op1); + } + } + + if (Opc0 != Opc1) + return nullptr; + + // The opcodes must be the same. Use a new name to make that clear. + BinaryOperator::BinaryOps BOpc = Opc0; + + // Select the constant elements needed for the single binop. + Constant *Mask = Shuf.getMask(); + Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Mask); + + // We are moving a binop after a shuffle. When a shuffle has an undefined + // mask element, the result is undefined, but it is not poison or undefined + // behavior. That is not necessarily true for div/rem/shift. + bool MightCreatePoisonOrUB = + Mask->containsUndefElement() && + (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); + if (MightCreatePoisonOrUB) + NewC = getSafeVectorConstantForBinop(BOpc, NewC, ConstantsAreOp1); + + Value *V; + if (X == Y) { + // Remove a binop and the shuffle by rearranging the constant: + // shuffle (op V, C0), (op V, C1), M --> op V, C' + // shuffle (op C0, V), (op C1, V), M --> op C', V + V = X; + } else { + // If there are 2 different variable operands, we must create a new shuffle + // (select) first, so check uses to ensure that we don't end up with more + // instructions than we started with. + if (!B0->hasOneUse() && !B1->hasOneUse()) + return nullptr; + + // If we use the original shuffle mask and op1 is *variable*, we would be + // putting an undef into operand 1 of div/rem/shift. This is either UB or + // poison. We do not have to guard against UB when *constants* are op1 + // because safe constants guarantee that we do not overflow sdiv/srem (and + // there's no danger for other opcodes). + // TODO: To allow this case, create a new shuffle mask with no undefs. + if (MightCreatePoisonOrUB && !ConstantsAreOp1) + return nullptr; + + // Note: In general, we do not create new shuffles in InstCombine because we + // do not know if a target can lower an arbitrary shuffle optimally. In this + // case, the shuffle uses the existing mask, so there is no additional risk. + + // Select the variable vectors first, then perform the binop: + // shuffle (op X, C0), (op Y, C1), M --> op (shuffle X, Y, M), C' + // shuffle (op C0, X), (op C1, Y), M --> op C', (shuffle X, Y, M) + V = Builder.CreateShuffleVector(X, Y, Mask); + } + + Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(BOpc, V, NewC) : + BinaryOperator::Create(BOpc, NewC, V); + + // Flags are intersected from the 2 source binops. But there are 2 exceptions: + // 1. If we changed an opcode, poison conditions might have changed. + // 2. If the shuffle had undef mask elements, the new binop might have undefs + // where the original code did not. But if we already made a safe constant, + // then there's no danger. + NewBO->copyIRFlags(B0); + NewBO->andIRFlags(B1); + if (DropNSW) + NewBO->setHasNoSignedWrap(false); + if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + NewBO->dropPoisonGeneratingFlags(); + return NewBO; +} + Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -1150,6 +1360,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); + if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) + return I; + bool MadeChange = false; unsigned VWidth = SVI.getType()->getVectorNumElements(); diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index b332e75c7feb..12fcc8752ea9 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -34,6 +34,8 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/InstCombine.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -55,6 +57,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -72,6 +75,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -93,8 +97,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -144,12 +146,20 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { /// We don't want to convert from a legal to an illegal type or from a smaller /// to a larger illegal type. A width of '1' is always treated as a legal type /// because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. +/// optimizations for i1 types. Widths of 8, 16 or 32 are equally treated as +/// legal to convert to, in order to open up more combining opportunities. +/// NOTE: this treats i8, i16 and i32 specially, due to them being so common +/// from frontend languages. bool InstCombiner::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); + // Convert to widths of 8, 16 or 32 even if they are not legal types. Only + // shrink types, to prevent infinite loops. + if (ToWidth < FromWidth && (ToWidth == 8 || ToWidth == 16 || ToWidth == 32)) + return true; + // If this is a legal integer from type, and the result would be an illegal // type, don't do the transformation. if (FromLegal && !ToLegal) @@ -396,28 +406,23 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)" // if C1 and C2 are constants. + Value *A, *B; + Constant *C1, *C2; if (Op0 && Op1 && Op0->getOpcode() == Opcode && Op1->getOpcode() == Opcode && - isa<Constant>(Op0->getOperand(1)) && - isa<Constant>(Op1->getOperand(1)) && - Op0->hasOneUse() && Op1->hasOneUse()) { - Value *A = Op0->getOperand(0); - Constant *C1 = cast<Constant>(Op0->getOperand(1)); - Value *B = Op1->getOperand(0); - Constant *C2 = cast<Constant>(Op1->getOperand(1)); - - Constant *Folded = ConstantExpr::get(Opcode, C1, C2); - BinaryOperator *New = BinaryOperator::Create(Opcode, A, B); - if (isa<FPMathOperator>(New)) { + match(Op0, m_OneUse(m_BinOp(m_Value(A), m_Constant(C1)))) && + match(Op1, m_OneUse(m_BinOp(m_Value(B), m_Constant(C2))))) { + BinaryOperator *NewBO = BinaryOperator::Create(Opcode, A, B); + if (isa<FPMathOperator>(NewBO)) { FastMathFlags Flags = I.getFastMathFlags(); Flags &= Op0->getFastMathFlags(); Flags &= Op1->getFastMathFlags(); - New->setFastMathFlags(Flags); + NewBO->setFastMathFlags(Flags); } - InsertNewInstWith(New, I); - New->takeName(Op1); - I.setOperand(0, New); - I.setOperand(1, Folded); + InsertNewInstWith(NewBO, I); + NewBO->takeName(Op1); + I.setOperand(0, NewBO); + I.setOperand(1, ConstantExpr::get(Opcode, C1, C2)); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -434,72 +439,38 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { /// Return whether "X LOp (Y ROp Z)" is always equal to /// "(X LOp Y) ROp (X LOp Z)". -static bool LeftDistributesOverRight(Instruction::BinaryOps LOp, +static bool leftDistributesOverRight(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { - switch (LOp) { - default: - return false; + // X & (Y | Z) <--> (X & Y) | (X & Z) + // X & (Y ^ Z) <--> (X & Y) ^ (X & Z) + if (LOp == Instruction::And) + return ROp == Instruction::Or || ROp == Instruction::Xor; - case Instruction::And: - // And distributes over Or and Xor. - switch (ROp) { - default: - return false; - case Instruction::Or: - case Instruction::Xor: - return true; - } + // X | (Y & Z) <--> (X | Y) & (X | Z) + if (LOp == Instruction::Or) + return ROp == Instruction::And; - case Instruction::Mul: - // Multiplication distributes over addition and subtraction. - switch (ROp) { - default: - return false; - case Instruction::Add: - case Instruction::Sub: - return true; - } + // X * (Y + Z) <--> (X * Y) + (X * Z) + // X * (Y - Z) <--> (X * Y) - (X * Z) + if (LOp == Instruction::Mul) + return ROp == Instruction::Add || ROp == Instruction::Sub; - case Instruction::Or: - // Or distributes over And. - switch (ROp) { - default: - return false; - case Instruction::And: - return true; - } - } + return false; } /// Return whether "(X LOp Y) ROp Z" is always equal to /// "(X ROp Z) LOp (Y ROp Z)". -static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, +static bool rightDistributesOverLeft(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { if (Instruction::isCommutative(ROp)) - return LeftDistributesOverRight(ROp, LOp); + return leftDistributesOverRight(ROp, LOp); + + // (X {&|^} Y) >> Z <--> (X >> Z) {&|^} (Y >> Z) for all shifts. + return Instruction::isBitwiseLogicOp(LOp) && Instruction::isShift(ROp); - switch (LOp) { - default: - return false; - // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. - // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. - // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - switch (ROp) { - default: - return false; - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - return true; - } - } // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", // but this requires knowing that the addition does not overflow and other // such subtleties. - return false; } /// This function returns identity value for given opcode, which can be used to @@ -511,37 +482,27 @@ static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) { return ConstantExpr::getBinOpIdentity(Opcode, V->getType()); } -/// This function factors binary ops which can be combined using distributive -/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable -/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called -/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms -/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and -/// RHS to 4. +/// This function predicates factorization using distributive laws. By default, +/// it just returns the 'Op' inputs. But for special-cases like +/// 'add(shl(X, 5), ...)', this function will have TopOpcode == Instruction::Add +/// and Op = shl(X, 5). The 'shl' is treated as the more general 'mul X, 32' to +/// allow more factorization opportunities. static Instruction::BinaryOps -getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, - BinaryOperator *Op, Value *&LHS, Value *&RHS) { +getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, + Value *&LHS, Value *&RHS) { assert(Op && "Expected a binary operator"); - LHS = Op->getOperand(0); RHS = Op->getOperand(1); - - switch (TopLevelOpcode) { - default: - return Op->getOpcode(); - - case Instruction::Add: - case Instruction::Sub: - if (Op->getOpcode() == Instruction::Shl) { - if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { - // The multiplier is really 1 << CST. - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); - return Instruction::Mul; - } + if (TopOpcode == Instruction::Add || TopOpcode == Instruction::Sub) { + Constant *C; + if (match(Op, m_Shl(m_Value(), m_Constant(C)))) { + // X << C --> X * (1 << C) + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), C); + return Instruction::Mul; } - return Op->getOpcode(); + // TODO: We can add other conversions e.g. shr => div etc. } - - // TODO: We can add other conversions e.g. shr => div etc. + return Op->getOpcode(); } /// This tries to simplify binary operations by factorizing out common terms @@ -560,7 +521,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, bool InnerCommutative = Instruction::isCommutative(InnerOpcode); // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? - if (LeftDistributesOverRight(InnerOpcode, TopLevelOpcode)) + if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) // Does the instruction have the form "(A op' B) op (A op' D)" or, in the // commutative case, "(A op' B) op (C op' A)"? if (A == C || (InnerCommutative && A == D)) { @@ -579,7 +540,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, } // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? - if (!SimplifiedInst && RightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) + if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) // Does the instruction have the form "(A op' B) op (C op' B)" or, in the // commutative case, "(A op' B) op (B op' D)"? if (B == D || (InnerCommutative && B == C)) { @@ -664,21 +625,19 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // term. if (Op0) if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) - if (Value *V = - tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) + if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) return V; // The instruction has the form "(B) op (C op' D)". Try to factorize common // term. if (Op1) if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) - if (Value *V = - tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) + if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) return V; } // Expansion. - if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { + if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { // The instruction has the form "(A op' B) op C". See if expanding it out // to "(A op C) op' (B op C)" results in simplifications. Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; @@ -715,7 +674,7 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { } } - if (Op1 && LeftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { + if (Op1 && leftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { // The instruction has the form "A op (B op' C)". See if expanding it out // to "(A op B) op' (A op C)" results in simplifications. Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); @@ -817,23 +776,6 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const { return nullptr; } -/// Given a 'fsub' instruction, return the RHS of the instruction if the LHS is -/// a constant negative zero (which is the 'negate' form). -Value *InstCombiner::dyn_castFNegVal(Value *V, bool IgnoreZeroSign) const { - if (BinaryOperator::isFNeg(V, IgnoreZeroSign)) - return BinaryOperator::getFNegArgument(V); - - // Constants can be considered to be negated values if they can be folded. - if (ConstantFP *C = dyn_cast<ConstantFP>(V)) - return ConstantExpr::getFNeg(C); - - if (ConstantDataVector *C = dyn_cast<ConstantDataVector>(V)) - if (C->getType()->getElementType()->isFloatingPointTy()) - return ConstantExpr::getFNeg(C); - - return nullptr; -} - static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, InstCombiner::BuilderTy &Builder) { if (auto *Cast = dyn_cast<CastInst>(&I)) @@ -1081,8 +1023,9 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { return replaceInstUsesWith(I, NewPN); } -Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) { - assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type"); +Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { + if (!isa<Constant>(I.getOperand(1))) + return nullptr; if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) @@ -1107,7 +1050,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, // Start with the index over the outer type. Note that the type size // might be zero (even if the offset isn't zero) if the indexed type // is something like [0 x {int, int}] - Type *IntPtrTy = DL.getIntPtrType(PtrTy); + Type *IndexTy = DL.getIndexType(PtrTy); int64_t FirstIdx = 0; if (int64_t TySize = DL.getTypeAllocSize(Ty)) { FirstIdx = Offset/TySize; @@ -1122,7 +1065,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, assert((uint64_t)Offset < (uint64_t)TySize && "Out of range offset"); } - NewIndices.push_back(ConstantInt::get(IntPtrTy, FirstIdx)); + NewIndices.push_back(ConstantInt::get(IndexTy, FirstIdx)); // Index into the types. If we fail, set OrigBase to null. while (Offset) { @@ -1144,7 +1087,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) { uint64_t EltSize = DL.getTypeAllocSize(AT->getElementType()); assert(EltSize && "Cannot index into a zero-sized array"); - NewIndices.push_back(ConstantInt::get(IntPtrTy,Offset/EltSize)); + NewIndices.push_back(ConstantInt::get(IndexTy,Offset/EltSize)); Offset %= EltSize; Ty = AT->getElementType(); } else { @@ -1408,22 +1351,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } while (true); } -/// \brief Creates node of binary operation with the same attributes as the -/// specified one but with other operands. -static Value *CreateBinOpAsGiven(BinaryOperator &Inst, Value *LHS, Value *RHS, - InstCombiner::BuilderTy &B) { - Value *BO = B.CreateBinOp(Inst.getOpcode(), LHS, RHS); - // If LHS and RHS are constant, BO won't be a binary operator. - if (BinaryOperator *NewBO = dyn_cast<BinaryOperator>(BO)) - NewBO->copyIRFlags(&Inst); - return BO; -} - -/// \brief Makes transformation of binary operation specific for vector types. -/// \param Inst Binary operator to transform. -/// \return Pointer to node that must replace the original binary operator, or -/// null pointer if no transformation was made. -Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { +Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { if (!Inst.getType()->isVectorTy()) return nullptr; // It may not be safe to reorder shuffles and things like div, urem, etc. @@ -1437,58 +1365,71 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth); assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); + auto createBinOpShuffle = [&](Value *X, Value *Y, Constant *M) { + Value *XY = Builder.CreateBinOp(Inst.getOpcode(), X, Y); + if (auto *BO = dyn_cast<BinaryOperator>(XY)) + BO->copyIRFlags(&Inst); + return new ShuffleVectorInst(XY, UndefValue::get(XY->getType()), M); + }; + // If both arguments of the binary operation are shuffles that use the same - // mask and shuffle within a single vector, move the shuffle after the binop: - // Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m) - auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS); - auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS); - if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() && - isa<UndefValue>(LShuf->getOperand(1)) && - isa<UndefValue>(RShuf->getOperand(1)) && - LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) { - Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), - RShuf->getOperand(0), Builder); - return Builder.CreateShuffleVector( - NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask()); + // mask and shuffle within a single vector, move the shuffle after the binop. + Value *V1, *V2; + Constant *Mask; + if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))) && + match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(Mask))) && + V1->getType() == V2->getType() && + (LHS->hasOneUse() || RHS->hasOneUse() || LHS == RHS)) { + // Op(shuffle(V1, Mask), shuffle(V2, Mask)) -> shuffle(Op(V1, V2), Mask) + return createBinOpShuffle(V1, V2, Mask); } - // If one argument is a shuffle within one vector, the other is a constant, - // try moving the shuffle after the binary operation. - ShuffleVectorInst *Shuffle = nullptr; - Constant *C1 = nullptr; - if (isa<ShuffleVectorInst>(LHS)) Shuffle = cast<ShuffleVectorInst>(LHS); - if (isa<ShuffleVectorInst>(RHS)) Shuffle = cast<ShuffleVectorInst>(RHS); - if (isa<Constant>(LHS)) C1 = cast<Constant>(LHS); - if (isa<Constant>(RHS)) C1 = cast<Constant>(RHS); - if (Shuffle && C1 && - (isa<ConstantVector>(C1) || isa<ConstantDataVector>(C1)) && - isa<UndefValue>(Shuffle->getOperand(1)) && - Shuffle->getType() == Shuffle->getOperand(0)->getType()) { - SmallVector<int, 16> ShMask = Shuffle->getShuffleMask(); - // Find constant C2 that has property: - // shuffle(C2, ShMask) = C1 - // If such constant does not exist (example: ShMask=<0,0> and C1=<1,2>) - // reorder is not possible. - SmallVector<Constant*, 16> C2M(VWidth, - UndefValue::get(C1->getType()->getScalarType())); + // If one argument is a shuffle within one vector and the other is a constant, + // try moving the shuffle after the binary operation. This canonicalization + // intends to move shuffles closer to other shuffles and binops closer to + // other binops, so they can be folded. It may also enable demanded elements + // transforms. + Constant *C; + if (match(&Inst, m_c_BinOp( + m_OneUse(m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))), + m_Constant(C))) && + V1->getType() == Inst.getType()) { + // Find constant NewC that has property: + // shuffle(NewC, ShMask) = C + // If such constant does not exist (example: ShMask=<0,0> and C=<1,2>) + // reorder is not possible. A 1-to-1 mapping is not required. Example: + // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef> + SmallVector<int, 16> ShMask; + ShuffleVectorInst::getShuffleMask(Mask, ShMask); + SmallVector<Constant *, 16> + NewVecC(VWidth, UndefValue::get(C->getType()->getScalarType())); bool MayChange = true; for (unsigned I = 0; I < VWidth; ++I) { if (ShMask[I] >= 0) { assert(ShMask[I] < (int)VWidth); - if (!isa<UndefValue>(C2M[ShMask[I]])) { + Constant *CElt = C->getAggregateElement(I); + Constant *NewCElt = NewVecC[ShMask[I]]; + if (!CElt || (!isa<UndefValue>(NewCElt) && NewCElt != CElt)) { MayChange = false; break; } - C2M[ShMask[I]] = C1->getAggregateElement(I); + NewVecC[ShMask[I]] = CElt; } } if (MayChange) { - Constant *C2 = ConstantVector::get(C2M); - Value *NewLHS = isa<Constant>(LHS) ? C2 : Shuffle->getOperand(0); - Value *NewRHS = isa<Constant>(LHS) ? Shuffle->getOperand(0) : C2; - Value *NewBO = CreateBinOpAsGiven(Inst, NewLHS, NewRHS, Builder); - return Builder.CreateShuffleVector(NewBO, - UndefValue::get(Inst.getType()), Shuffle->getMask()); + Constant *NewC = ConstantVector::get(NewVecC); + // It may not be safe to execute a binop on a vector with undef elements + // because the entire instruction can be folded to undef or create poison + // that did not exist in the original code. + bool ConstOp1 = isa<Constant>(Inst.getOperand(1)); + if (Inst.isIntDivRem() || (Inst.isShift() && ConstOp1)) + NewC = getSafeVectorConstantForBinop(Inst.getOpcode(), NewC, ConstOp1); + + // Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask) + // Op(C, shuffle(V1, Mask)) -> shuffle(Op(NewC, V1), Mask) + Value *NewLHS = isa<Constant>(LHS) ? NewC : V1; + Value *NewRHS = isa<Constant>(LHS) ? V1 : NewC; + return createBinOpShuffle(NewLHS, NewRHS, Mask); } } @@ -1497,9 +1438,9 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - - if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, - SQ.getWithInstruction(&GEP))) + Type *GEPType = GEP.getType(); + Type *GEPEltType = GEP.getSourceElementType(); + if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1507,8 +1448,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Eliminate unneeded casts for indices, and replace indices which displace // by multiples of a zero size type with zero. bool MadeChange = false; - Type *IntPtrTy = - DL.getIntPtrType(GEP.getPointerOperandType()->getScalarType()); + + // Index width may not be the same width as pointer width. + // Data layout chooses the right type based on supported integer types. + Type *NewScalarIndexTy = + DL.getIndexType(GEP.getPointerOperandType()->getScalarType()); gep_type_iterator GTI = gep_type_begin(GEP); for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; @@ -1517,10 +1461,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GTI.isStruct()) continue; - // Index type should have the same width as IntPtr Type *IndexTy = (*I)->getType(); - Type *NewIndexType = IndexTy->isVectorTy() ? - VectorType::get(IntPtrTy, IndexTy->getVectorNumElements()) : IntPtrTy; + Type *NewIndexType = + IndexTy->isVectorTy() + ? VectorType::get(NewScalarIndexTy, IndexTy->getVectorNumElements()) + : NewScalarIndexTy; // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. @@ -1543,8 +1488,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return &GEP; // Check to see if the inputs to the PHI node are getelementptr instructions. - if (PHINode *PN = dyn_cast<PHINode>(PtrOp)) { - GetElementPtrInst *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); + if (auto *PN = dyn_cast<PHINode>(PtrOp)) { + auto *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); if (!Op1) return nullptr; @@ -1560,7 +1505,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { int DI = -1; for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { - GetElementPtrInst *Op2 = dyn_cast<GetElementPtrInst>(*I); + auto *Op2 = dyn_cast<GetElementPtrInst>(*I); if (!Op2 || Op1->getNumOperands() != Op2->getNumOperands()) return nullptr; @@ -1602,7 +1547,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (J > 0) { if (J == 1) { CurTy = Op1->getSourceElementType(); - } else if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { + } else if (auto *CT = dyn_cast<CompositeType>(CurTy)) { CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { CurTy = nullptr; @@ -1617,7 +1562,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (DI != -1 && !PN->hasOneUse()) return nullptr; - GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(Op1->clone()); + auto *NewGEP = cast<GetElementPtrInst>(Op1->clone()); if (DI == -1) { // All the GEPs feeding the PHI are identical. Clone one down into our // BB so that it can be merged with the current GEP. @@ -1652,15 +1597,64 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Combine Indices - If the source pointer to this getelementptr instruction // is a getelementptr instruction, combine the indices of the two // getelementptr instructions into a single instruction. - if (GEPOperator *Src = dyn_cast<GEPOperator>(PtrOp)) { + if (auto *Src = dyn_cast<GEPOperator>(PtrOp)) { if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; + // Try to reassociate loop invariant GEP chains to enable LICM. + if (LI && Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && + Src->hasOneUse()) { + if (Loop *L = LI->getLoopFor(GEP.getParent())) { + Value *GO1 = GEP.getOperand(1); + Value *SO1 = Src->getOperand(1); + // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is + // invariant: this breaks the dependence between GEPs and allows LICM + // to hoist the invariant part out of the loop. + if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { + // We have to be careful here. + // We have something like: + // %src = getelementptr <ty>, <ty>* %base, <ty> %idx + // %gep = getelementptr <ty>, <ty>* %src, <ty> %idx2 + // If we just swap idx & idx2 then we could inadvertantly + // change %src from a vector to a scalar, or vice versa. + // Cases: + // 1) %base a scalar & idx a scalar & idx2 a vector + // => Swapping idx & idx2 turns %src into a vector type. + // 2) %base a scalar & idx a vector & idx2 a scalar + // => Swapping idx & idx2 turns %src in a scalar type + // 3) %base, %idx, and %idx2 are scalars + // => %src & %gep are scalars + // => swapping idx & idx2 is safe + // 4) %base a vector + // => %src is a vector + // => swapping idx & idx2 is safe. + auto *SO0 = Src->getOperand(0); + auto *SO0Ty = SO0->getType(); + if (!isa<VectorType>(GEPType) || // case 3 + isa<VectorType>(SO0Ty)) { // case 4 + Src->setOperand(1, GO1); + GEP.setOperand(1, SO1); + return &GEP; + } else { + // Case 1 or 2 + // -- have to recreate %src & %gep + // put NewSrc at same location as %src + Builder.SetInsertPoint(cast<Instruction>(PtrOp)); + auto *NewSrc = cast<GetElementPtrInst>( + Builder.CreateGEP(SO0, GO1, Src->getName())); + NewSrc->setIsInBounds(Src->isInBounds()); + auto *NewGEP = GetElementPtrInst::Create(nullptr, NewSrc, {SO1}); + NewGEP->setIsInBounds(GEP.isInBounds()); + return NewGEP; + } + } + } + } + // Note that if our source is a gep chain itself then we wait for that // chain to be resolved before we perform this transformation. This // avoids us creating a TON of code in some cases. - if (GEPOperator *SrcGEP = - dyn_cast<GEPOperator>(Src->getOperand(0))) + if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0))) if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) return nullptr; // Wait until our source is folded to completion. @@ -1723,9 +1717,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == - DL.getPointerSizeInBits(AS)) { - Type *Ty = GEP.getSourceElementType(); - uint64_t TyAllocSize = DL.getTypeAllocSize(Ty); + DL.getIndexSizeInBits(AS)) { + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType); bool Matched = false; uint64_t C; @@ -1752,22 +1745,20 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Operator *Index = cast<Operator>(V); Value *PtrToInt = Builder.CreatePtrToInt(PtrOp, Index->getType()); Value *NewSub = Builder.CreateSub(PtrToInt, Index->getOperand(1)); - return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + return CastInst::Create(Instruction::IntToPtr, NewSub, GEPType); } // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) // to (bitcast Y) Value *Y; if (match(V, m_Sub(m_PtrToInt(m_Value(Y)), - m_PtrToInt(m_Specific(GEP.getOperand(0)))))) { - return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, - GEP.getType()); - } + m_PtrToInt(m_Specific(GEP.getOperand(0)))))) + return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType); } } } // We do not handle pointer-vector geps here. - if (GEP.getType()->isVectorTy()) + if (GEPType->isVectorTy()) return nullptr; // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). @@ -1776,7 +1767,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (StrippedPtr != PtrOp) { bool HasZeroPointerIndex = false; - if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) + if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) HasZeroPointerIndex = C->isZero(); // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... @@ -1787,8 +1778,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // // This occurs when the program declares an array extern like "int X[];" if (HasZeroPointerIndex) { - if (ArrayType *CATy = - dyn_cast<ArrayType>(GEP.getSourceElementType())) { + if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == StrippedPtrTy->getElementType()) { // -> GEP i8* X, ... @@ -1804,11 +1794,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -> // %0 = GEP i8 addrspace(1)* X, ... // addrspacecast i8 addrspace(1)* %0 to i8* - return new AddrSpaceCastInst(Builder.Insert(Res), GEP.getType()); + return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); } - if (ArrayType *XATy = - dyn_cast<ArrayType>(StrippedPtrTy->getElementType())){ + if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrTy->getElementType())) { // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == XATy->getElementType()) { // -> GEP [10 x i8]* X, i32 0, ... @@ -1836,7 +1825,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { nullptr, StrippedPtr, Idx, GEP.getName()) : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); - return new AddrSpaceCastInst(NewGEP, GEP.getType()); + return new AddrSpaceCastInst(NewGEP, GEPType); } } } @@ -1844,12 +1833,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Transform things like: // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast - Type *SrcElTy = StrippedPtrTy->getElementType(); - Type *ResElTy = GEP.getSourceElementType(); - if (SrcElTy->isArrayTy() && - DL.getTypeAllocSize(SrcElTy->getArrayElementType()) == - DL.getTypeAllocSize(ResElTy)) { - Type *IdxType = DL.getIntPtrType(GEP.getType()); + Type *SrcEltTy = StrippedPtrTy->getElementType(); + if (SrcEltTy->isArrayTy() && + DL.getTypeAllocSize(SrcEltTy->getArrayElementType()) == + DL.getTypeAllocSize(GEPEltType)) { + Type *IdxType = DL.getIndexType(GEPType); Value *Idx[2] = { Constant::getNullValue(IdxType), GEP.getOperand(1) }; Value *NewGEP = GEP.isInBounds() @@ -1858,28 +1846,28 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); // V and GEP are both pointer types --> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); } // Transform things like: // %V = mul i64 %N, 4 // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast - if (ResElTy->isSized() && SrcElTy->isSized()) { + if (GEPEltType->isSized() && SrcEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(ResElTy); - uint64_t SrcSize = DL.getTypeAllocSize(SrcElTy); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); + uint64_t SrcSize = DL.getTypeAllocSize(SrcEltTy); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); uint64_t Scale = SrcSize / ResSize; - // Earlier transforms ensure that the index has type IntPtrType, which - // considerably simplifies the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIntPtrType(GEP.getType()) && - "Index not cast to pointer width?"); + // Earlier transforms ensure that the index has the right type + // according to Data Layout, which considerably simplifies the + // logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); bool NSW; if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { @@ -1895,7 +1883,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + GEPType); } } } @@ -1904,39 +1892,40 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp // (where tmp = 8*tmp2) into: // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - if (ResElTy->isSized() && SrcElTy->isSized() && SrcElTy->isArrayTy()) { + if (GEPEltType->isSized() && SrcEltTy->isSized() && + SrcEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(ResElTy); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); uint64_t ArrayEltSize = - DL.getTypeAllocSize(SrcElTy->getArrayElementType()); + DL.getTypeAllocSize(SrcEltTy->getArrayElementType()); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); uint64_t Scale = ArrayEltSize / ResSize; - // Earlier transforms ensure that the index has type IntPtrType, which - // considerably simplifies the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIntPtrType(GEP.getType()) && - "Index not cast to pointer width?"); + // Earlier transforms ensure that the index has the right type + // according to the Data Layout, which considerably simplifies + // the logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); bool NSW; if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. // If the multiplication NewIdx * Scale may overflow then the new // GEP may not be "inbounds". - Value *Off[2] = { - Constant::getNullValue(DL.getIntPtrType(GEP.getType())), - NewIdx}; + Type *IndTy = DL.getIndexType(GEPType); + Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; Value *NewGEP = GEP.isInBounds() && NSW ? Builder.CreateInBoundsGEP( - SrcElTy, StrippedPtr, Off, GEP.getName()) - : Builder.CreateGEP(SrcElTy, StrippedPtr, Off, + SrcEltTy, StrippedPtr, Off, GEP.getName()) + : Builder.CreateGEP(SrcEltTy, StrippedPtr, Off, GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + GEPType); } } } @@ -1946,34 +1935,53 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // addrspacecast between types is canonicalized as a bitcast, then an // addrspacecast. To take advantage of the below bitcast + struct GEP, look // through the addrspacecast. - if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { + Value *ASCStrippedPtrOp = PtrOp; + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { // X = bitcast A addrspace(1)* to B addrspace(1)* // Y = addrspacecast A addrspace(1)* to B addrspace(2)* // Z = gep Y, <...constant indices...> // Into an addrspacecasted GEP of the struct. - if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) - PtrOp = BC; + if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) + ASCStrippedPtrOp = BC; } - /// See if we can simplify: - /// X = bitcast A* to B* - /// Y = gep X, <...constant indices...> - /// into a gep of the original struct. This is important for SROA and alias - /// analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. - if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) { - Value *Operand = BCI->getOperand(0); - PointerType *OpType = cast<PointerType>(Operand->getType()); - unsigned OffsetBits = DL.getPointerTypeSizeInBits(GEP.getType()); - APInt Offset(OffsetBits, 0); - if (!isa<BitCastInst>(Operand) && - GEP.accumulateConstantOffset(DL, Offset)) { + if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp)) { + Value *SrcOp = BCI->getOperand(0); + PointerType *SrcType = cast<PointerType>(BCI->getSrcTy()); + Type *SrcEltType = SrcType->getElementType(); + + // GEP directly using the source operand if this GEP is accessing an element + // of a bitcasted pointer to vector or array of the same dimensions: + // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z + // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z + auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy) { + return ArrTy->getArrayElementType() == VecTy->getVectorElementType() && + ArrTy->getArrayNumElements() == VecTy->getVectorNumElements(); + }; + if (GEP.getNumOperands() == 3 && + ((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() && + areMatchingArrayAndVecTypes(GEPEltType, SrcEltType)) || + (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() && + areMatchingArrayAndVecTypes(SrcEltType, GEPEltType)))) { + GEP.setOperand(0, SrcOp); + GEP.setSourceElementType(SrcEltType); + return &GEP; + } + // See if we can simplify: + // X = bitcast A* to B* + // Y = gep X, <...constant indices...> + // into a gep of the original struct. This is important for SROA and alias + // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. + unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEPType); + APInt Offset(OffsetBits, 0); + if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset)) { // If this GEP instruction doesn't move the pointer, just replace the GEP // with a bitcast of the real input to the dest type. if (!Offset) { // If the bitcast is of an allocation, and the allocation will be // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, &TLI)) { + if (isa<AllocaInst>(SrcOp) || isAllocationFn(SrcOp, &TLI)) { // See if the bitcast simplifies, if so, don't nuke this GEP yet. if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { @@ -1985,43 +1993,43 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } - if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(Operand, GEP.getType()); - return new BitCastInst(Operand, GEP.getType()); + if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(SrcOp, GEPType); + return new BitCastInst(SrcOp, GEPType); } // Otherwise, if the offset is non-zero, we need to find out if there is a // field at Offset in 'A's type. If so, we can pull the cast through the // GEP. SmallVector<Value*, 8> NewIndices; - if (FindElementAtOffset(OpType, Offset.getSExtValue(), NewIndices)) { + if (FindElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices)) { Value *NGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(nullptr, Operand, NewIndices) - : Builder.CreateGEP(nullptr, Operand, NewIndices); + ? Builder.CreateInBoundsGEP(nullptr, SrcOp, NewIndices) + : Builder.CreateGEP(nullptr, SrcOp, NewIndices); - if (NGEP->getType() == GEP.getType()) + if (NGEP->getType() == GEPType) return replaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - return new BitCastInst(NGEP, GEP.getType()); + return new AddrSpaceCastInst(NGEP, GEPType); + return new BitCastInst(NGEP, GEPType); } } } if (!GEP.isInBounds()) { - unsigned PtrWidth = - DL.getPointerSizeInBits(PtrOp->getType()->getPointerAddressSpace()); - APInt BasePtrOffset(PtrWidth, 0); + unsigned IdxWidth = + DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); + APInt BasePtrOffset(IdxWidth, 0); Value *UnderlyingPtrOp = PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, BasePtrOffset); if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && BasePtrOffset.isNonNegative()) { - APInt AllocSize(PtrWidth, DL.getTypeAllocSize(AI->getAllocatedType())); + APInt AllocSize(IdxWidth, DL.getTypeAllocSize(AI->getAllocatedType())); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( PtrOp, makeArrayRef(Ops).slice(1), GEP.getName()); @@ -2198,7 +2206,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { return nullptr; } -/// \brief Move the call to free before a NULL test. +/// Move the call to free before a NULL test. /// /// Check if this free is accessed after its argument has been test /// against NULL (property 0). @@ -2562,6 +2570,7 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { case EHPersonality::MSVC_Win64SEH: case EHPersonality::MSVC_CXX: case EHPersonality::CoreCLR: + case EHPersonality::Wasm_CXX: return TypeInfo->isNullValue(); } llvm_unreachable("invalid enum"); @@ -2889,6 +2898,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { assert(I->hasOneUse() && "Invariants didn't hold!"); + BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. if (isa<PHINode>(I) || I->isEHPad() || I->mayHaveSideEffects() || @@ -2918,10 +2928,20 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { if (Scan->mayWriteToMemory()) return false; } - BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); I->moveBefore(&*InsertPos); ++NumSunkInst; + + // Also sink all related debug uses from the source basic block. Otherwise we + // get debug use before the def. + SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, I); + for (auto *DII : DbgUsers) { + if (DII->getParent() == SrcBlock) { + DII->moveBefore(&*InsertPos); + LLVM_DEBUG(dbgs() << "SINK: " << *DII << '\n'); + } + } return true; } @@ -2932,7 +2952,7 @@ bool InstCombiner::run() { // Check to see if we can DCE the instruction. if (isInstructionTriviallyDead(I, &TLI)) { - DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); eraseInstFromFunction(*I); ++NumDeadInst; MadeIRChange = true; @@ -2946,7 +2966,8 @@ bool InstCombiner::run() { if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { - DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I + << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); @@ -2965,8 +2986,8 @@ bool InstCombiner::run() { KnownBits Known = computeKnownBits(I, /*Depth*/0, I); if (Known.isConstant()) { Constant *C = ConstantInt::get(Ty, Known.getConstant()); - DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C << - " from: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C + << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); @@ -3005,7 +3026,7 @@ bool InstCombiner::run() { if (UserIsSuccessor && UserParent->getUniquePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { - DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since sinking // can expose opportunities for it's *operands* add them to the @@ -3025,15 +3046,15 @@ bool InstCombiner::run() { #ifndef NDEBUG std::string OrigI; #endif - DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); - DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); + LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); + LLVM_DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); if (Instruction *Result = visit(*I)) { ++NumCombined; // Should we replace the old instruction with a new one? if (Result != I) { - DEBUG(dbgs() << "IC: Old = " << *I << '\n' - << " New = " << *Result << '\n'); + LLVM_DEBUG(dbgs() << "IC: Old = " << *I << '\n' + << " New = " << *Result << '\n'); if (I->getDebugLoc()) Result->setDebugLoc(I->getDebugLoc()); @@ -3060,8 +3081,8 @@ bool InstCombiner::run() { eraseInstFromFunction(*I); } else { - DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' - << " New = " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' + << " New = " << *I << '\n'); // If the instruction was modified, it's possible that it is now dead. // if so, remove it. @@ -3112,7 +3133,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, // DCE instruction if trivially dead. if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; - DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); salvageDebugInfo(*Inst); Inst->eraseFromParent(); MadeIRChange = true; @@ -3123,8 +3144,8 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (!Inst->use_empty() && (Inst->getNumOperands() == 0 || isa<Constant>(Inst->getOperand(0)))) if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI)) { - DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " - << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *Inst + << '\n'); Inst->replaceAllUsesWith(C); ++NumConstProp; if (isInstructionTriviallyDead(Inst, TLI)) @@ -3146,9 +3167,9 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, FoldRes = C; if (FoldRes != C) { - DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst - << "\n Old = " << *C - << "\n New = " << *FoldRes << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + << "\n Old = " << *C + << "\n New = " << *FoldRes << '\n'); U = FoldRes; MadeIRChange = true; } @@ -3191,7 +3212,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, return MadeIRChange; } -/// \brief Populate the IC worklist from a function, and prune any dead basic +/// Populate the IC worklist from a function, and prune any dead basic /// blocks discovered in the process. /// /// This also does basic constant propagation and other forward fixing to make @@ -3251,8 +3272,8 @@ static bool combineInstructionsOverFunction( int Iteration = 0; while (true) { ++Iteration; - DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " - << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " + << F.getName() << "\n"); MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); @@ -3348,3 +3369,7 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { return new InstructionCombiningPass(ExpensiveCombines); } + +void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createInstructionCombiningPass()); +} |