diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/Scalarizer.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/Scalarizer.cpp | 103 |
1 files changed, 68 insertions, 35 deletions
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 3606c8a4b073..08f4b2173da2 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -39,8 +39,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -52,7 +50,7 @@ using namespace llvm; #define DEBUG_TYPE "scalarizer" -static cl::opt<bool> ScalarizeVariableInsertExtract( +static cl::opt<bool> ClScalarizeVariableInsertExtract( "scalarize-variable-insert-extract", cl::init(true), cl::Hidden, cl::desc("Allow the scalarizer pass to scalarize " "insertelement/extractelement with variable index")); @@ -60,9 +58,9 @@ static cl::opt<bool> ScalarizeVariableInsertExtract( // This is disabled by default because having separate loads and stores // makes it more likely that the -combiner-alias-analysis limits will be // reached. -static cl::opt<bool> - ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden, - cl::desc("Allow the scalarizer pass to scalarize loads and store")); +static cl::opt<bool> ClScalarizeLoadStore( + "scalarize-load-store", cl::init(false), cl::Hidden, + cl::desc("Allow the scalarizer pass to scalarize loads and store")); namespace { @@ -96,7 +94,7 @@ public: // Scatter V into Size components. If new instructions are needed, // insert them before BBI in BB. If Cache is nonnull, use it to cache // the results. - Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, + Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy, ValueVector *cachePtr = nullptr); // Return component I, creating a new Value for it if necessary. @@ -109,8 +107,8 @@ private: BasicBlock *BB; BasicBlock::iterator BBI; Value *V; + Type *PtrElemTy; ValueVector *CachePtr; - PointerType *PtrTy; ValueVector Tmp; unsigned Size; }; @@ -188,10 +186,23 @@ struct VectorLayout { uint64_t ElemSize = 0; }; +template <typename T> +T getWithDefaultOverride(const cl::opt<T> &ClOption, + const llvm::Optional<T> &DefaultOverride) { + return ClOption.getNumOccurrences() ? ClOption + : DefaultOverride.value_or(ClOption); +} + class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> { public: - ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT) - : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT) { + ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT, + ScalarizerPassOptions Options) + : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT), + ScalarizeVariableInsertExtract( + getWithDefaultOverride(ClScalarizeVariableInsertExtract, + Options.ScalarizeVariableInsertExtract)), + ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore, + Options.ScalarizeLoadStore)) { } bool visit(Function &F); @@ -216,8 +227,9 @@ public: bool visitCallInst(CallInst &ICI); private: - Scatterer scatter(Instruction *Point, Value *V); + Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr); void gather(Instruction *Op, const ValueVector &CV); + void replaceUses(Instruction *Op, Value *CV); bool canTransferMetadata(unsigned Kind); void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, @@ -231,12 +243,16 @@ private: ScatterMap Scattered; GatherList Gathered; + bool Scalarized; SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; unsigned ParallelLoopAccessMDKind; DominatorTree *DT; + + const bool ScalarizeVariableInsertExtract; + const bool ScalarizeLoadStore; }; class ScalarizerLegacyPass : public FunctionPass { @@ -265,12 +281,14 @@ INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", "Scalarize vector operations", false, false) Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, - ValueVector *cachePtr) - : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) { + Type *PtrElemTy, ValueVector *cachePtr) + : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) { Type *Ty = V->getType(); - PtrTy = dyn_cast<PointerType>(Ty); - if (PtrTy) - Ty = PtrTy->getPointerElementType(); + if (Ty->isPointerTy()) { + assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) && + "Pointer element type mismatch"); + Ty = PtrElemTy; + } Size = cast<FixedVectorType>(Ty)->getNumElements(); if (!CachePtr) Tmp.resize(Size, nullptr); @@ -287,15 +305,15 @@ Value *Scatterer::operator[](unsigned I) { if (CV[I]) return CV[I]; IRBuilder<> Builder(BB, BBI); - if (PtrTy) { - Type *ElTy = - cast<VectorType>(PtrTy->getPointerElementType())->getElementType(); + if (PtrElemTy) { + Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType(); if (!CV[0]) { - Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace()); + Type *NewPtrTy = PointerType::get( + VectorElemTy, V->getType()->getPointerAddressSpace()); CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); } if (I != 0) - CV[I] = Builder.CreateConstGEP1_32(ElTy, CV[0], I, + CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I, V->getName() + ".i" + Twine(I)); } else { // Search through a chain of InsertElementInsts looking for element I. @@ -334,7 +352,7 @@ bool ScalarizerLegacyPass::runOnFunction(Function &F) { unsigned ParallelLoopAccessMDKind = M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT); + ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions()); return Impl.visit(F); } @@ -345,6 +363,8 @@ FunctionPass *llvm::createScalarizerPass() { bool ScalarizerVisitor::visit(Function &F) { assert(Gathered.empty() && Scattered.empty()); + Scalarized = false; + // To ensure we replace gathered components correctly we need to do an ordered // traversal of the basic blocks in the function. ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); @@ -362,13 +382,14 @@ bool ScalarizerVisitor::visit(Function &F) { // Return a scattered form of V that can be accessed by Point. V must be a // vector or a pointer to a vector. -Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { +Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, + Type *PtrElemTy) { if (Argument *VArg = dyn_cast<Argument>(V)) { // Put the scattered form of arguments in the entry block, // so that it can be used everywhere. Function *F = VArg->getParent(); BasicBlock *BB = &F->getEntryBlock(); - return Scatterer(BB, BB->begin(), V, &Scattered[V]); + return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[V]); } if (Instruction *VOp = dyn_cast<Instruction>(V)) { // When scalarizing PHI nodes we might try to examine/rewrite InsertElement @@ -379,17 +400,17 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { // need to analyse them further. if (!DT->isReachableFromEntry(VOp->getParent())) return Scatterer(Point->getParent(), Point->getIterator(), - UndefValue::get(V->getType())); + PoisonValue::get(V->getType()), PtrElemTy); // Put the scattered form of an instruction directly after the // instruction, skipping over PHI nodes and debug intrinsics. BasicBlock *BB = VOp->getParent(); return Scatterer( BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, - &Scattered[V]); + PtrElemTy, &Scattered[V]); } // In the fallback case, just put the scattered before Point and // keep the result local to Point. - return Scatterer(Point->getParent(), Point->getIterator(), V); + return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy); } // Replace Op with the gathered form of the components in CV. Defer the @@ -419,6 +440,15 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { Gathered.push_back(GatherList::value_type(Op, &SV)); } +// Replace Op with CV and collect Op has a potentially dead instruction. +void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) { + if (CV != Op) { + Op->replaceAllUsesWith(CV); + PotentiallyDeadInstrs.emplace_back(Op); + Scalarized = true; + } +} + // Return true if it is safe to transfer the given metadata tag from // vector to scalar instructions. bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { @@ -558,9 +588,11 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (OpI->getType()->isVectorTy()) { Scattered[I] = scatter(&CI, OpI); assert(Scattered[I].size() == NumElems && "mismatched call operands"); + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) + Tys.push_back(OpI->getType()->getScalarType()); } else { ScalarOperands[I] = OpI; - if (hasVectorInstrinsicOverloadedScalarOpd(ID, I)) + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) Tys.push_back(OpI->getType()); } } @@ -576,7 +608,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { ScalarCallOps.clear(); for (unsigned J = 0; J != NumArgs; ++J) { - if (hasVectorInstrinsicScalarOpd(ID, J)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) ScalarCallOps.push_back(ScalarOperands[J]); else ScalarCallOps.push_back(Scattered[J][Elem]); @@ -809,7 +841,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { Value *Res = Op0[CI->getValue().getZExtValue()]; - gather(&EEI, {Res}); + replaceUses(&EEI, Res); return true; } @@ -825,7 +857,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { Res = Builder.CreateSelect(ShouldExtract, Elt, Res, EEI.getName() + ".upto" + Twine(I)); } - gather(&EEI, {Res}); + replaceUses(&EEI, Res); return true; } @@ -891,7 +923,7 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&LI); - Scatterer Ptr = scatter(&LI, LI.getPointerOperand()); + Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType()); ValueVector Res; Res.resize(NumElems); @@ -917,7 +949,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&SI); - Scatterer VPtr = scatter(&SI, SI.getPointerOperand()); + Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType()); Scatterer VVal = scatter(&SI, FullValue); ValueVector Stores; @@ -940,7 +972,7 @@ bool ScalarizerVisitor::visitCallInst(CallInst &CI) { bool ScalarizerVisitor::finish() { // The presence of data in Gathered or Scattered indicates changes // made to the Function. - if (Gathered.empty() && Scattered.empty()) + if (Gathered.empty() && Scattered.empty() && !Scalarized) return false; for (const auto &GMI : Gathered) { Instruction *Op = GMI.first; @@ -971,6 +1003,7 @@ bool ScalarizerVisitor::finish() { } Gathered.clear(); Scattered.clear(); + Scalarized = false; RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); @@ -982,7 +1015,7 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) unsigned ParallelLoopAccessMDKind = M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT); + ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options); bool Changed = Impl.visit(F); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); |