diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineVectorOps.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 171 |
1 files changed, 153 insertions, 18 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index dc9abdd7f47a..9c890748e5ab 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -253,6 +253,69 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, return nullptr; } +/// Find elements of V demanded by UserInstr. +static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + // Conservatively assume that all elements are needed. + APInt UsedElts(APInt::getAllOnesValue(VWidth)); + + switch (UserInstr->getOpcode()) { + case Instruction::ExtractElement: { + ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr); + assert(EEI->getVectorOperand() == V); + ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand()); + if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { + UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + } + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); + unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements(); + + UsedElts = APInt(VWidth, 0); + for (unsigned i = 0; i < MaskNumElts; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u || MaskVal >= 2 * VWidth) + continue; + if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) + UsedElts.setBit(MaskVal); + if (Shuffle->getOperand(1) == V && + ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) + UsedElts.setBit(MaskVal - VWidth); + } + break; + } + default: + break; + } + return UsedElts; +} + +/// Find union of elements of V demanded by all its users. +/// If it is known by querying findDemandedEltsBySingleUser that +/// no user demands an element of V, then the corresponding bit +/// remains unset in the returned value. +static APInt findDemandedEltsByAllUsers(Value *V) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + APInt UnionUsedElts(VWidth, 0); + for (const Use &U : V->uses()) { + if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + } else { + UnionUsedElts = APInt::getAllOnesValue(VWidth); + break; + } + + if (UnionUsedElts.isAllOnesValue()) + break; + } + + return UnionUsedElts; +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); @@ -271,19 +334,35 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { return nullptr; // This instruction only demands the single element from the input vector. - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse() && NumElts != 1) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, - UndefElts)) { - EI.setOperand(0, V); - return &EI; + if (NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnesValue()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) return I; @@ -766,6 +845,55 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); } +/// Try to fold an extract+insert element into an existing identity shuffle by +/// changing the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is an identity shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !isa<UndefValue>(Shuf->getOperand(1)) || + !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if this insert's scalar op is extracted from the identity shuffle's + // input vector. + Value *Scalar = InsElt.getOperand(1); + Value *X = Shuf->getOperand(0); + if (!match(Scalar, m_ExtractElement(m_Specific(X), m_SpecificInt(IdxC)))) + return nullptr; + + // Replace the shuffle mask element at the index of this extract+insert with + // that same index value. + // For example: + // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' + unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); + Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); + Constant *NewMaskEltC = ConstantInt::get(I32Ty, IdxC); + Constant *OldMask = Shuf->getMask(); + for (unsigned i = 0; i != NumMaskElts; ++i) { + if (i != IdxC) { + // All mask elements besides the inserted element remain the same. + NewMaskVec[i] = OldMask->getAggregateElement(i); + } else if (OldMask->getAggregateElement(i) == NewMaskEltC) { + // If the mask element was already set, there's nothing to do + // (demanded elements analysis may unset it later). + return nullptr; + } else { + assert(isa<UndefValue>(OldMask->getAggregateElement(i)) && + "Unexpected shuffle mask element for identity shuffle"); + NewMaskVec[i] = NewMaskEltC; + } + } + + Constant *NewMask = ConstantVector::get(NewMaskVec); + return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask); +} + /// If we have an insertelement instruction feeding into another insertelement /// and the 2nd is inserting a constant into the vector, canonicalize that /// constant insertion before the insertion of a variable: @@ -987,6 +1115,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Splat = foldInsEltIntoSplat(IE)) return Splat; + if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) + return IdentityShuf; + return nullptr; } @@ -1009,17 +1140,23 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, if (Depth == 0) return false; switch (I->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // Propagating an undefined shuffle mask element to integer div/rem is not + // allowed because those opcodes can create immediate undefined behavior + // from an undefined element in an operand. + if (llvm::any_of(Mask, [](int M){ return M == -1; })) + return false; + LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: case Instruction::Mul: case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: case Instruction::FRem: case Instruction::Shl: case Instruction::LShr: @@ -1040,9 +1177,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, case Instruction::FPExt: case Instruction::GetElementPtr: { // Bail out if we would create longer vector ops. We could allow creating - // longer vector ops, but that may result in more expensive codegen. We - // would also need to limit the transform to avoid undefined behavior for - // integer div/rem. + // longer vector ops, but that may result in more expensive codegen. Type *ITy = I->getType(); if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements()) return false; |