diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
commit | cfca06d7963fa0909f90483b42a6d7d194d01e08 (patch) | |
tree | 209fb2a2d68f8f277793fc8df46c753d31bc853b /llvm/lib/Transforms/Scalar | |
parent | 706b4fc47bbc608932d3b491ae19a3b9cde9497b (diff) |
Notes
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
61 files changed, 6579 insertions, 4020 deletions
diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp index cc3d3bf7cdbf..c3709b9afffb 100644 --- a/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -182,7 +182,7 @@ class AggressiveDeadCodeElimination { /// Identify connected sections of the control flow graph which have /// dead terminators and rewrite the control flow graph to remove them. - void updateDeadRegions(); + bool updateDeadRegions(); /// Set the BlockInfo::PostOrder field based on a post-order /// numbering of the reverse control flow graph. @@ -505,7 +505,7 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { //===----------------------------------------------------------------------===// bool AggressiveDeadCodeElimination::removeDeadInstructions() { // Updates control and dataflow around dead blocks - updateDeadRegions(); + bool RegionsUpdated = updateDeadRegions(); LLVM_DEBUG({ for (Instruction &I : instructions(F)) { @@ -556,11 +556,11 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { I->eraseFromParent(); } - return !Worklist.empty(); + return !Worklist.empty() || RegionsUpdated; } // A dead region is the set of dead blocks with a common live post-dominator. -void AggressiveDeadCodeElimination::updateDeadRegions() { +bool AggressiveDeadCodeElimination::updateDeadRegions() { LLVM_DEBUG({ dbgs() << "final dead terminator blocks: " << '\n'; for (auto *BB : BlocksWithDeadTerminators) @@ -570,6 +570,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { // Don't compute the post ordering unless we needed it. bool HavePostOrder = false; + bool Changed = false; for (auto *BB : BlocksWithDeadTerminators) { auto &Info = BlockInfo[BB]; @@ -624,7 +625,10 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { .applyUpdates(DeletedEdges); NumBranchesRemoved += 1; + Changed = true; } + + return Changed; } // reverse top-sort order @@ -685,10 +689,14 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + // TODO: We could track if we have actually done CFG changes. + if (!RemoveControlFlowFlag) + PA.preserveSet<CFGAnalyses>(); + else { + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); + } PA.preserve<GlobalsAA>(); - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<PostDominatorTreeAnalysis>(); return PA; } diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 06deaf3c4f9a..bccf94fc217f 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -15,6 +15,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" #define AA_NAME "alignment-from-assumptions" #define DEBUG_TYPE AA_NAME @@ -30,6 +31,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" @@ -90,9 +92,9 @@ FunctionPass *llvm::createAlignmentFromAssumptionsPass() { // to a constant. Using SCEV to compute alignment handles the case where // DiffSCEV is a recurrence with constant start such that the aligned offset // is constant. e.g. {16,+,32} % 32 -> 16. -static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, - const SCEV *AlignSCEV, - ScalarEvolution *SE) { +static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV, + const SCEV *AlignSCEV, + ScalarEvolution *SE) { // DiffUnits = Diff % int64_t(Alignment) const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV); @@ -107,26 +109,30 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, // displaced pointer has the same alignment as the aligned pointer, so // return the alignment value. if (!DiffUnits) - return (unsigned) - cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue(); + return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue(); // If the displacement is not an exact multiple, but the remainder is a // constant, then return this remainder (but only if it is a power of 2). uint64_t DiffUnitsAbs = std::abs(DiffUnits); if (isPowerOf2_64(DiffUnitsAbs)) - return (unsigned) DiffUnitsAbs; + return Align(DiffUnitsAbs); } - return 0; + return None; } // There is an address given by an offset OffSCEV from AASCEV which has an // alignment AlignSCEV. Use that information, if possible, to compute a new // alignment for Ptr. -static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, - const SCEV *OffSCEV, Value *Ptr, - ScalarEvolution *SE) { +static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, + const SCEV *OffSCEV, Value *Ptr, + ScalarEvolution *SE) { const SCEV *PtrSCEV = SE->getSCEV(Ptr); + // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes + // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV + // may disagree. Trunc/extend so they agree. + PtrSCEV = SE->getTruncateOrZeroExtend( + PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType())); const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV); // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always @@ -141,13 +147,12 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, << *AlignSCEV << " and offset " << *OffSCEV << " using diff " << *DiffSCEV << "\n"); - unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); - LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); + if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) { + LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n"); + return *NewAlignment; + } - if (NewAlignment) { - return NewAlignment; - } else if (const SCEVAddRecExpr *DiffARSCEV = - dyn_cast<SCEVAddRecExpr>(DiffSCEV)) { + if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) { // The relative offset to the alignment assumption did not yield a constant, // but we should try harder: if we assume that a is 32-byte aligned, then in // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are @@ -165,134 +170,67 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, // first iteration, and also the alignment using the per-iteration delta. // If these are the same, then use that answer. Otherwise, use the smaller // one, but only if it divides the larger one. - NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); - unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); - - LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); - LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); - - if (!NewAlignment || !NewIncAlignment) { - return 0; - } else if (NewAlignment > NewIncAlignment) { - if (NewAlignment % NewIncAlignment == 0) { - LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment - << "\n"); - return NewIncAlignment; - } - } else if (NewIncAlignment > NewAlignment) { - if (NewIncAlignment % NewAlignment == 0) { - LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment - << "\n"); - return NewAlignment; - } - } else if (NewIncAlignment == NewAlignment) { - LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment + MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); + MaybeAlign NewIncAlignment = + getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); + + LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment) + << "\n"); + LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment) + << "\n"); + + if (!NewAlignment || !NewIncAlignment) + return Align(1); + + const Align NewAlign = *NewAlignment; + const Align NewIncAlign = *NewIncAlignment; + if (NewAlign > NewIncAlign) { + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " + << DebugStr(NewIncAlign) << "\n"); + return NewIncAlign; + } + if (NewIncAlign > NewAlign) { + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) << "\n"); - return NewAlignment; + return NewAlign; } + assert(NewIncAlign == NewAlign); + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) + << "\n"); + return NewAlign; } - return 0; + return Align(1); } bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, + unsigned Idx, Value *&AAPtr, const SCEV *&AlignSCEV, const SCEV *&OffSCEV) { - // An alignment assume must be a statement about the least-significant - // bits of the pointer being zero, possibly with some offset. - ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); - if (!ICI) - return false; - - // This must be an expression of the form: x & m == 0. - if (ICI->getPredicate() != ICmpInst::ICMP_EQ) - return false; - - // Swap things around so that the RHS is 0. - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); - const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); - if (CmpLHSSCEV->isZero()) - std::swap(CmpLHS, CmpRHS); - else if (!CmpRHSSCEV->isZero()) + Type *Int64Ty = Type::getInt64Ty(I->getContext()); + OperandBundleUse AlignOB = I->getOperandBundleAt(Idx); + if (AlignOB.getTagName() != "align") return false; - - BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS); - if (!CmpBO || CmpBO->getOpcode() != Instruction::And) - return false; - - // Swap things around so that the right operand of the and is a constant - // (the mask); we cannot deal with variable masks. - Value *AndLHS = CmpBO->getOperand(0); - Value *AndRHS = CmpBO->getOperand(1); - const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); - const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); - if (isa<SCEVConstant>(AndLHSSCEV)) { - std::swap(AndLHS, AndRHS); - std::swap(AndLHSSCEV, AndRHSSCEV); - } - - const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV); - if (!MaskSCEV) - return false; - - // The mask must have some trailing ones (otherwise the condition is - // trivial and tells us nothing about the alignment of the left operand). - unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes(); - if (!TrailingOnes) - return false; - - // Cap the alignment at the maximum with which LLVM can deal (and make sure - // we don't overflow the shift). - uint64_t Alignment; - TrailingOnes = std::min(TrailingOnes, - unsigned(sizeof(unsigned) * CHAR_BIT - 1)); - Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); - - Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); - AlignSCEV = SE->getConstant(Int64Ty, Alignment); - - // The LHS might be a ptrtoint instruction, or it might be the pointer - // with an offset. - AAPtr = nullptr; - OffSCEV = nullptr; - if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) { - AAPtr = PToI->getPointerOperand(); + assert(AlignOB.Inputs.size() >= 2); + AAPtr = AlignOB.Inputs[0].get(); + // TODO: Consider accumulating the offset to the base. + AAPtr = AAPtr->stripPointerCastsSameRepresentation(); + AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get()); + AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty); + if (AlignOB.Inputs.size() == 3) + OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get()); + else OffSCEV = SE->getZero(Int64Ty); - } else if (const SCEVAddExpr* AndLHSAddSCEV = - dyn_cast<SCEVAddExpr>(AndLHSSCEV)) { - // Try to find the ptrtoint; subtract it and the rest is the offset. - for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), - JE = AndLHSAddSCEV->op_end(); J != JE; ++J) - if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J)) - if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) { - AAPtr = PToI->getPointerOperand(); - OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); - break; - } - } - - if (!AAPtr) - return false; - - // Sign extend the offset to 64 bits (so that it is like all of the other - // expressions). - unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); - if (OffSCEVBits < 64) - OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); - else if (OffSCEVBits > 64) - return false; - - AAPtr = AAPtr->stripPointerCasts(); + OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty); return true; } -bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, + unsigned Idx) { Value *AAPtr; const SCEV *AlignSCEV, *OffSCEV; - if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) + if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV)) return false; // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't @@ -310,35 +248,38 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { continue; if (Instruction *K = dyn_cast<Instruction>(J)) - if (isValidAssumeForContext(ACall, K, DT)) WorkList.push_back(K); } while (!WorkList.empty()) { Instruction *J = WorkList.pop_back_val(); - if (LoadInst *LI = dyn_cast<LoadInst>(J)) { - unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, - LI->getPointerOperand(), SE); - - if (NewAlignment > LI->getAlignment()) { - LI->setAlignment(MaybeAlign(NewAlignment)); + if (!isValidAssumeForContext(ACall, J, DT)) + continue; + Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + LI->getPointerOperand(), SE); + if (NewAlignment > LI->getAlign()) { + LI->setAlignment(NewAlignment); ++NumLoadAlignChanged; } } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { - unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, - SI->getPointerOperand(), SE); - - if (NewAlignment > SI->getAlignment()) { - SI->setAlignment(MaybeAlign(NewAlignment)); + if (!isValidAssumeForContext(ACall, J, DT)) + continue; + Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + SI->getPointerOperand(), SE); + if (NewAlignment > SI->getAlign()) { + SI->setAlignment(NewAlignment); ++NumStoreAlignChanged; } } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { - unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, - MI->getDest(), SE); - - LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";); - if (NewDestAlignment > MI->getDestAlignment()) { + if (!isValidAssumeForContext(ACall, J, DT)) + continue; + Align NewDestAlignment = + getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE); + + LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment) + << "\n";); + if (NewDestAlignment > *MI->getDestAlign()) { MI->setDestAlignment(NewDestAlignment); ++NumMemIntAlignChanged; } @@ -346,12 +287,13 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { // For memory transfers, there is also a source alignment that // can be set. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { - unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, - MTI->getSource(), SE); + Align NewSrcAlignment = + getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE); - LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";); + LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment) + << "\n";); - if (NewSrcAlignment > MTI->getSourceAlignment()) { + if (NewSrcAlignment > *MTI->getSourceAlign()) { MTI->setSourceAlignment(NewSrcAlignment); ++NumMemIntAlignChanged; } @@ -363,7 +305,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { Visited.insert(J); for (User *UJ : J->users()) { Instruction *K = cast<Instruction>(UJ); - if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT)) + if (!Visited.count(K)) WorkList.push_back(K); } } @@ -390,8 +332,11 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, bool Changed = false; for (auto &AssumeVH : AC.assumptions()) - if (AssumeVH) - Changed |= processAssumption(cast<CallInst>(AssumeVH)); + if (AssumeVH) { + CallInst *Call = cast<CallInst>(AssumeVH); + for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) + Changed |= processAssumption(Call, Idx); + } return Changed; } diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp index 0fa38fa80b17..767c7656dcfa 100644 --- a/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -9,7 +9,8 @@ // This file implements the Bit-Tracking Dead Code Elimination pass. Some // instructions (shifts, some ands, ors, etc.) kill some of their input bits. // We track these dead bits and remove instructions that compute only these -// dead bits. +// dead bits. We also simplify sext that generates unused extension bits, +// converting it to a zext. // //===----------------------------------------------------------------------===// @@ -19,6 +20,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" @@ -33,6 +35,8 @@ using namespace llvm; STATISTIC(NumRemoved, "Number of instructions removed (unused)"); STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); +STATISTIC(NumSExt2ZExt, + "Number of sign extension instructions converted to zero extension"); /// If an instruction is trivialized (dead), then the chain of users of that /// instruction may need to be cleared of assumptions that can no longer be @@ -102,13 +106,31 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { (I.getType()->isIntOrIntVectorTy() && DB.getDemandedBits(&I).isNullValue() && wouldInstructionBeTriviallyDead(&I))) { - salvageDebugInfoOrMarkUndef(I); + salvageDebugInfo(I); Worklist.push_back(&I); I.dropAllReferences(); Changed = true; continue; } + // Convert SExt into ZExt if none of the extension bits is required + if (SExtInst *SE = dyn_cast<SExtInst>(&I)) { + APInt Demanded = DB.getDemandedBits(SE); + const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits(); + auto *const DstTy = SE->getDestTy(); + const uint32_t DestBitSize = DstTy->getScalarSizeInBits(); + if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) { + clearAssumptionsOfUsers(SE, DB); + IRBuilder<> Builder(SE); + I.replaceAllUsesWith( + Builder.CreateZExt(SE->getOperand(0), DstTy, SE->getName())); + Worklist.push_back(SE); + Changed = true; + NumSExt2ZExt++; + continue; + } + } + for (Use &U : I.operands()) { // DemandedBits only detects dead integer uses. if (!U->getType()->isIntOrIntVectorTy()) diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index e34c011b1c87..b26bd1114bd4 100644 --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -85,37 +85,36 @@ static cl::opt<unsigned> "their cost is below DuplicationThreshold"), cl::init(5)); -static void addNonNullAttribute(CallSite CS, Value *Op) { +static void addNonNullAttribute(CallBase &CB, Value *Op) { unsigned ArgNo = 0; - for (auto &I : CS.args()) { + for (auto &I : CB.args()) { if (&*I == Op) - CS.addParamAttr(ArgNo, Attribute::NonNull); + CB.addParamAttr(ArgNo, Attribute::NonNull); ++ArgNo; } } -static void setConstantInArgument(CallSite CS, Value *Op, +static void setConstantInArgument(CallBase &CB, Value *Op, Constant *ConstValue) { unsigned ArgNo = 0; - for (auto &I : CS.args()) { + for (auto &I : CB.args()) { if (&*I == Op) { // It is possible we have already added the non-null attribute to the // parameter by using an earlier constraining condition. - CS.removeParamAttr(ArgNo, Attribute::NonNull); - CS.setArgument(ArgNo, ConstValue); + CB.removeParamAttr(ArgNo, Attribute::NonNull); + CB.setArgOperand(ArgNo, ConstValue); } ++ArgNo; } } -static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { +static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallBase &CB) { assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand."); Value *Op0 = Cmp->getOperand(0); unsigned ArgNo = 0; - for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; - ++I, ++ArgNo) { + for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I, ++ArgNo) { // Don't consider constant or arguments that are already known non-null. - if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) + if (isa<Constant>(*I) || CB.paramHasAttr(ArgNo, Attribute::NonNull)) continue; if (*I == Op0) @@ -128,8 +127,8 @@ typedef std::pair<ICmpInst *, unsigned> ConditionTy; typedef SmallVector<ConditionTy, 2> ConditionsTy; /// If From has a conditional jump to To, add the condition to Conditions, -/// if it is relevant to any argument at CS. -static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To, +/// if it is relevant to any argument at CB. +static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To, ConditionsTy &Conditions) { auto *BI = dyn_cast<BranchInst>(From->getTerminator()); if (!BI || !BI->isConditional()) @@ -142,38 +141,38 @@ static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To, ICmpInst *Cmp = cast<ICmpInst>(Cond); if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) - if (isCondRelevantToAnyCallArgument(Cmp, CS)) + if (isCondRelevantToAnyCallArgument(Cmp, CB)) Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To ? Pred : Cmp->getInversePredicate()}); } -/// Record ICmp conditions relevant to any argument in CS following Pred's +/// Record ICmp conditions relevant to any argument in CB following Pred's /// single predecessors. If there are conflicting conditions along a path, like /// x == 1 and x == 0, the first condition will be used. We stop once we reach /// an edge to StopAt. -static void recordConditions(CallSite CS, BasicBlock *Pred, +static void recordConditions(CallBase &CB, BasicBlock *Pred, ConditionsTy &Conditions, BasicBlock *StopAt) { BasicBlock *From = Pred; BasicBlock *To = Pred; SmallPtrSet<BasicBlock *, 4> Visited; while (To != StopAt && !Visited.count(From->getSinglePredecessor()) && (From = From->getSinglePredecessor())) { - recordCondition(CS, From, To, Conditions); + recordCondition(CB, From, To, Conditions); Visited.insert(From); To = From; } } -static void addConditions(CallSite CS, const ConditionsTy &Conditions) { +static void addConditions(CallBase &CB, const ConditionsTy &Conditions) { for (auto &Cond : Conditions) { Value *Arg = Cond.first->getOperand(0); Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); if (Cond.second == ICmpInst::ICMP_EQ) - setConstantInArgument(CS, Arg, ConstVal); + setConstantInArgument(CB, Arg, ConstVal); else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { assert(Cond.second == ICmpInst::ICMP_NE); - addNonNullAttribute(CS, Arg); + addNonNullAttribute(CB, Arg); } } } @@ -184,17 +183,16 @@ static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { return Preds; } -static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { - if (CS.isConvergent() || CS.cannotDuplicate()) +static bool canSplitCallSite(CallBase &CB, TargetTransformInfo &TTI) { + if (CB.isConvergent() || CB.cannotDuplicate()) return false; // FIXME: As of now we handle only CallInst. InvokeInst could be handled // without too much effort. - Instruction *Instr = CS.getInstruction(); - if (!isa<CallInst>(Instr)) + if (!isa<CallInst>(CB)) return false; - BasicBlock *CallSiteBB = Instr->getParent(); + BasicBlock *CallSiteBB = CB.getParent(); // Need 2 predecessors and cannot split an edge from an IndirectBrInst. SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB)); if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || @@ -212,7 +210,7 @@ static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { // corresponding uses will be updated. unsigned Cost = 0; for (auto &InstBeforeCall : - llvm::make_range(CallSiteBB->begin(), Instr->getIterator())) { + llvm::make_range(CallSiteBB->begin(), CB.getIterator())) { Cost += TTI.getInstructionCost(&InstBeforeCall, TargetTransformInfo::TCK_CodeSize); if (Cost >= DuplicationThreshold) @@ -304,24 +302,23 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, /// predecessors, new call-sites with more constrained arguments will be /// created in createCallSitesOnPredicatedArgument(). static void splitCallSite( - CallSite CS, + CallBase &CB, const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds, DomTreeUpdater &DTU) { - Instruction *Instr = CS.getInstruction(); - BasicBlock *TailBB = Instr->getParent(); - bool IsMustTailCall = CS.isMustTailCall(); + BasicBlock *TailBB = CB.getParent(); + bool IsMustTailCall = CB.isMustTailCall(); PHINode *CallPN = nullptr; // `musttail` calls must be followed by optional `bitcast`, and `ret`. The // split blocks will be terminated right after that so there're no users for // this phi in a `TailBB`. - if (!IsMustTailCall && !Instr->use_empty()) { - CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); - CallPN->setDebugLoc(Instr->getDebugLoc()); + if (!IsMustTailCall && !CB.use_empty()) { + CallPN = PHINode::Create(CB.getType(), Preds.size(), "phi.call"); + CallPN->setDebugLoc(CB.getDebugLoc()); } - LLVM_DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); + LLVM_DEBUG(dbgs() << "split call-site : " << CB << " into \n"); assert(Preds.size() == 2 && "The ValueToValueMaps array has size 2."); // ValueToValueMapTy is neither copy nor moveable, so we use a simple array @@ -330,21 +327,20 @@ static void splitCallSite( for (unsigned i = 0; i < Preds.size(); i++) { BasicBlock *PredBB = Preds[i].first; BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween( - TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i], + TailBB, PredBB, &*std::next(CB.getIterator()), ValueToValueMaps[i], DTU); assert(SplitBlock && "Unexpected new basic block split."); - Instruction *NewCI = - &*std::prev(SplitBlock->getTerminator()->getIterator()); - CallSite NewCS(NewCI); - addConditions(NewCS, Preds[i].second); + auto *NewCI = + cast<CallBase>(&*std::prev(SplitBlock->getTerminator()->getIterator())); + addConditions(*NewCI, Preds[i].second); // Handle PHIs used as arguments in the call-site. for (PHINode &PN : TailBB->phis()) { unsigned ArgNo = 0; - for (auto &CI : CS.args()) { + for (auto &CI : CB.args()) { if (&*CI == &PN) { - NewCS.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock)); + NewCI->setArgOperand(ArgNo, PN.getIncomingValueForBlock(SplitBlock)); } ++ArgNo; } @@ -356,7 +352,7 @@ static void splitCallSite( // Clone and place bitcast and return instructions before `TI` if (IsMustTailCall) - copyMustTailReturn(SplitBlock, Instr, NewCI); + copyMustTailReturn(SplitBlock, &CB, NewCI); } NumCallSiteSplit++; @@ -383,7 +379,7 @@ static void splitCallSite( // Replace users of the original call with a PHI mering call-sites split. if (CallPN) { CallPN->insertBefore(OriginalBegin); - Instr->replaceAllUsesWith(CallPN); + CB.replaceAllUsesWith(CallPN); } // Remove instructions moved to split blocks from TailBB, from the duplicated @@ -393,7 +389,7 @@ static void splitCallSite( // instruction, so we do not end up deleting them. By using reverse-order, we // do not introduce unnecessary PHI nodes for def-use chains from the call // instruction to the beginning of the block. - auto I = Instr->getReverseIterator(); + auto I = CB.getReverseIterator(); while (I != TailBB->rend()) { Instruction *CurrentI = &*I++; if (!CurrentI->use_empty()) { @@ -418,28 +414,25 @@ static void splitCallSite( // Return true if the call-site has an argument which is a PHI with only // constant incoming values. -static bool isPredicatedOnPHI(CallSite CS) { - Instruction *Instr = CS.getInstruction(); - BasicBlock *Parent = Instr->getParent(); - if (Instr != Parent->getFirstNonPHIOrDbg()) +static bool isPredicatedOnPHI(CallBase &CB) { + BasicBlock *Parent = CB.getParent(); + if (&CB != Parent->getFirstNonPHIOrDbg()) return false; - for (auto &BI : *Parent) { - if (PHINode *PN = dyn_cast<PHINode>(&BI)) { - for (auto &I : CS.args()) - if (&*I == PN) { - assert(PN->getNumIncomingValues() == 2 && - "Unexpected number of incoming values"); - if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) - return false; - if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) - continue; - if (isa<Constant>(PN->getIncomingValue(0)) && - isa<Constant>(PN->getIncomingValue(1))) - return true; - } + for (auto &PN : Parent->phis()) { + for (auto &Arg : CB.args()) { + if (&*Arg != &PN) + continue; + assert(PN.getNumIncomingValues() == 2 && + "Unexpected number of incoming values"); + if (PN.getIncomingBlock(0) == PN.getIncomingBlock(1)) + return false; + if (PN.getIncomingValue(0) == PN.getIncomingValue(1)) + continue; + if (isa<Constant>(PN.getIncomingValue(0)) && + isa<Constant>(PN.getIncomingValue(1))) + return true; } - break; } return false; } @@ -448,20 +441,20 @@ using PredsWithCondsTy = SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2>; // Check if any of the arguments in CS are predicated on a PHI node and return // the set of predecessors we should use for splitting. -static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallSite CS) { - if (!isPredicatedOnPHI(CS)) +static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallBase &CB) { + if (!isPredicatedOnPHI(CB)) return {}; - auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + auto Preds = getTwoPredecessors(CB.getParent()); return {{Preds[0], {}}, {Preds[1], {}}}; } // Checks if any of the arguments in CS are predicated in a predecessor and // returns a list of predecessors with the conditions that hold on their edges // to CS. -static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS, +static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallBase &CB, DomTreeUpdater &DTU) { - auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + auto Preds = getTwoPredecessors(CB.getParent()); if (Preds[0] == Preds[1]) return {}; @@ -470,16 +463,16 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS, // that node will be the same for all paths to the call site and splitting // is not beneficial. assert(DTU.hasDomTree() && "We need a DTU with a valid DT!"); - auto *CSDTNode = DTU.getDomTree().getNode(CS.getInstruction()->getParent()); + auto *CSDTNode = DTU.getDomTree().getNode(CB.getParent()); BasicBlock *StopAt = CSDTNode ? CSDTNode->getIDom()->getBlock() : nullptr; SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS; for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) { ConditionsTy Conditions; // Record condition on edge BB(CS) <- Pred - recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + recordCondition(CB, Pred, CB.getParent(), Conditions); // Record conditions following Pred's single predecessors. - recordConditions(CS, Pred, Conditions, StopAt); + recordConditions(CB, Pred, Conditions, StopAt); PredsCS.push_back({Pred, Conditions}); } @@ -491,19 +484,19 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS, return PredsCS; } -static bool tryToSplitCallSite(CallSite CS, TargetTransformInfo &TTI, +static bool tryToSplitCallSite(CallBase &CB, TargetTransformInfo &TTI, DomTreeUpdater &DTU) { // Check if we can split the call site. - if (!CS.arg_size() || !canSplitCallSite(CS, TTI)) + if (!CB.arg_size() || !canSplitCallSite(CB, TTI)) return false; - auto PredsWithConds = shouldSplitOnPredicatedArgument(CS, DTU); + auto PredsWithConds = shouldSplitOnPredicatedArgument(CB, DTU); if (PredsWithConds.empty()) - PredsWithConds = shouldSplitOnPHIPredicatedArgument(CS); + PredsWithConds = shouldSplitOnPHIPredicatedArgument(CB); if (PredsWithConds.empty()) return false; - splitCallSite(CS, PredsWithConds, DTU); + splitCallSite(CB, PredsWithConds, DTU); return true; } @@ -521,20 +514,19 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, // case, IE will be invalidated and we also have to check the current // terminator. while (II != IE && &*II != BB.getTerminator()) { - Instruction *I = &*II++; - CallSite CS(cast<Value>(I)); - if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) + CallBase *CB = dyn_cast<CallBase>(&*II++); + if (!CB || isa<IntrinsicInst>(CB) || isInstructionTriviallyDead(CB, &TLI)) continue; - Function *Callee = CS.getCalledFunction(); + Function *Callee = CB->getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; // Successful musttail call-site splits result in erased CI and erased BB. // Check if such path is possible before attempting the splitting. - bool IsMustTail = CS.isMustTailCall(); + bool IsMustTail = CB->isMustTailCall(); - Changed |= tryToSplitCallSite(CS, TTI, DTU); + Changed |= tryToSplitCallSite(*CB, TTI, DTU); // There're no interesting instructions after this. The call site // itself might have been erased on splitting. diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 5bfece010bec..7c14b69d658d 100644 --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -250,7 +250,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, Orders.push_back(Entry); while (Idx != Orders.size()) { BasicBlock *Node = Orders[Idx++]; - for (auto ChildDomNode : DT.getNode(Node)->getChildren()) { + for (auto ChildDomNode : DT.getNode(Node)->children()) { if (Candidates.count(ChildDomNode->getBlock())) Orders.push_back(ChildDomNode->getBlock()); } @@ -363,10 +363,12 @@ void ConstantHoistingPass::collectConstantCandidates( // instruction and operand index. if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst)) Cost = TTI->getIntImmCostIntrin(IntrInst->getIntrinsicID(), Idx, - ConstInt->getValue(), ConstInt->getType()); + ConstInt->getValue(), ConstInt->getType(), + TargetTransformInfo::TCK_SizeAndLatency); else Cost = TTI->getIntImmCostInst(Inst->getOpcode(), Idx, ConstInt->getValue(), - ConstInt->getType()); + ConstInt->getType(), + TargetTransformInfo::TCK_SizeAndLatency); // Ignore cheap integer constants. if (Cost > TargetTransformInfo::TCC_Basic) { @@ -416,7 +418,8 @@ void ConstantHoistingPass::collectConstantCandidates( // usually lowered to a load from constant pool. Such operation is unlikely // to be cheaper than compute it by <Base + Offset>, which can be lowered to // an ADD instruction or folded into Load/Store instruction. - int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy); + int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy, + TargetTransformInfo::TCK_SizeAndLatency); ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV]; ConstCandMapType::iterator Itr; bool Inserted; @@ -491,7 +494,7 @@ void ConstantHoistingPass::collectConstantCandidates( // take constant variables is lower than `TargetTransformInfo::TCC_Basic`. // So it's safe for us to collect constant candidates from all // IntrinsicInsts. - if (canReplaceOperandWithVariable(Inst, Idx) || isa<IntrinsicInst>(Inst)) { + if (canReplaceOperandWithVariable(Inst, Idx)) { collectConstantCandidates(ConstCandMap, Inst, Idx); } } // end of for all operands @@ -582,7 +585,8 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, for (auto User : ConstCand->Uses) { unsigned Opcode = User.Inst->getOpcode(); unsigned OpndIdx = User.OpndIdx; - Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty); + Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty, + TargetTransformInfo::TCK_SizeAndLatency); LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n"); for (auto C2 = S; C2 != E; ++C2) { @@ -975,8 +979,8 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F, auto BFI = ConstHoistWithBlockFrequency ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; - auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); - auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock(), PSI)) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 3435bc7f5eaa..cd2f4ca36f3b 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -22,7 +22,6 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -125,7 +124,7 @@ Pass *llvm::createCorrelatedValuePropagationPass() { static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { if (S->getType()->isVectorTy()) return false; - if (isa<Constant>(S->getOperand(0))) return false; + if (isa<Constant>(S->getCondition())) return false; Constant *C = LVI->getConstant(S->getCondition(), S->getParent(), S); if (!C) return false; @@ -133,11 +132,7 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { ConstantInt *CI = dyn_cast<ConstantInt>(C); if (!CI) return false; - Value *ReplaceWith = S->getTrueValue(); - Value *Other = S->getFalseValue(); - if (!CI->isOne()) std::swap(ReplaceWith, Other); - if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType()); - + Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue(); S->replaceAllUsesWith(ReplaceWith); S->eraseFromParent(); @@ -310,9 +305,10 @@ static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { // the comparison is testing local values. While LVI can sometimes reason // about such cases, it's not its primary purpose. We do make sure to do // the block local query for uses from terminator instructions, but that's - // handled in the code for each terminator. + // handled in the code for each terminator. As an exception, we allow phi + // nodes, for which LVI can thread the condition into predecessors. auto *I = dyn_cast<Instruction>(Op0); - if (I && I->getParent() == Cmp->getParent()) + if (I && I->getParent() == Cmp->getParent() && !isa<PHINode>(I)) return false; LazyValueInfo::Tristate Result = @@ -535,18 +531,18 @@ static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) { } /// Infer nonnull attributes for the arguments at the specified callsite. -static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { +static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) { SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; - if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) { + if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) { if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) { processOverflowIntrinsic(WO, LVI); return true; } } - if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) { + if (auto *SI = dyn_cast<SaturatingInst>(&CB)) { if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) { processSaturatingInst(SI, LVI); return true; @@ -559,8 +555,8 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { // desireable since it may allow further optimization of that value (e.g. via // single use rules in instcombine). Since deopt uses tend to, // idiomatically, appear along rare conditional paths, it's reasonable likely - // we may have a conditional fact with which LVI can fold. - if (auto DeoptBundle = CS.getOperandBundle(LLVMContext::OB_deopt)) { + // we may have a conditional fact with which LVI can fold. + if (auto DeoptBundle = CB.getOperandBundle(LLVMContext::OB_deopt)) { bool Progress = false; for (const Use &ConstU : DeoptBundle->Inputs) { Use &U = const_cast<Use&>(ConstU); @@ -568,7 +564,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { if (V->getType()->isVectorTy()) continue; if (isa<Constant>(V)) continue; - Constant *C = LVI->getConstant(V, CS.getParent(), CS.getInstruction()); + Constant *C = LVI->getConstant(V, CB.getParent(), &CB); if (!C) continue; U.set(C); Progress = true; @@ -577,30 +573,30 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { return true; } - for (Value *V : CS.args()) { + for (Value *V : CB.args()) { PointerType *Type = dyn_cast<PointerType>(V->getType()); // Try to mark pointer typed parameters as non-null. We skip the // relatively expensive analysis for constants which are obviously either // null or non-null to start with. - if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && + if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) && !isa<Constant>(V) && LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, ConstantPointerNull::get(Type), - CS.getInstruction()) == LazyValueInfo::False) + &CB) == LazyValueInfo::False) ArgNos.push_back(ArgNo); ArgNo++; } - assert(ArgNo == CS.arg_size() && "sanity check"); + assert(ArgNo == CB.arg_size() && "sanity check"); if (ArgNos.empty()) return false; - AttributeList AS = CS.getAttributes(); - LLVMContext &Ctx = CS.getInstruction()->getContext(); + AttributeList AS = CB.getAttributes(); + LLVMContext &Ctx = CB.getContext(); AS = AS.addParamAttribute(Ctx, ArgNos, Attribute::get(Ctx, Attribute::NonNull)); - CS.setAttributes(AS); + CB.setAttributes(AS); return true; } @@ -793,7 +789,10 @@ static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { if (!RHS || !RHS->getValue().isMask()) return false; - ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp); + // We can only replace the AND with LHS based on range info if the range does + // not include undef. + ConstantRange LRange = + LVI->getConstantRange(LHS, BB, BinOp, /*UndefAllowed=*/false); if (!LRange.getUnsignedMax().ule(RHS->getValue())) return false; @@ -856,7 +855,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, break; case Instruction::Call: case Instruction::Invoke: - BBChanged |= processCallSite(CallSite(II), LVI); + BBChanged |= processCallSite(cast<CallBase>(*II), LVI); break; case Instruction::SRem: BBChanged |= processSRem(cast<BinaryOperator>(II), LVI); diff --git a/llvm/lib/Transforms/Scalar/DCE.cpp b/llvm/lib/Transforms/Scalar/DCE.cpp index a4b0c8df98f6..28947482e303 100644 --- a/llvm/lib/Transforms/Scalar/DCE.cpp +++ b/llvm/lib/Transforms/Scalar/DCE.cpp @@ -25,6 +25,7 @@ #include "llvm/Pass.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -127,6 +128,7 @@ static bool DCEInstruction(Instruction *I, return false; salvageDebugInfo(*I); + salvageKnowledge(I); // Null out all of the instruction's operands to see if any operand becomes // dead as we go. diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 1ba4aab999e1..e58db03225ee 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -29,17 +30,19 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -48,16 +51,19 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -68,14 +74,23 @@ #include <utility> using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "dse" +STATISTIC(NumRemainingStores, "Number of stores remaining after DSE"); STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); STATISTIC(NumFastStores, "Number of stores deleted"); STATISTIC(NumFastOther, "Number of other instrs removed"); STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); STATISTIC(NumModifiedStores, "Number of stores modified"); +STATISTIC(NumNoopStores, "Number of noop stores deleted"); +STATISTIC(NumCFGChecks, "Number of stores modified"); +STATISTIC(NumCFGTries, "Number of stores modified"); +STATISTIC(NumCFGSuccess, "Number of stores modified"); + +DEBUG_COUNTER(MemorySSACounter, "dse-memoryssa", + "Controls which MemoryDefs are eliminated."); static cl::opt<bool> EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", @@ -87,6 +102,25 @@ EnablePartialStoreMerging("enable-dse-partial-store-merging", cl::init(true), cl::Hidden, cl::desc("Enable partial store merging in DSE")); +static cl::opt<bool> + EnableMemorySSA("enable-dse-memoryssa", cl::init(false), cl::Hidden, + cl::desc("Use the new MemorySSA-backed DSE.")); + +static cl::opt<unsigned> + MemorySSAScanLimit("dse-memoryssa-scanlimit", cl::init(100), cl::Hidden, + cl::desc("The number of memory instructions to scan for " + "dead store elimination (default = 100)")); + +static cl::opt<unsigned> MemorySSADefsPerBlockLimit( + "dse-memoryssa-defs-per-block-limit", cl::init(5000), cl::Hidden, + cl::desc("The number of MemoryDefs we consider as candidates to eliminated " + "other stores per basic block (default = 5000)")); + +static cl::opt<unsigned> MemorySSAPathCheckLimit( + "dse-memoryssa-path-check-limit", cl::init(50), cl::Hidden, + cl::desc("The maximum number of blocks to check when trying to prove that " + "all paths to an exit go through a killing block (default = 50)")); + //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// @@ -100,7 +134,7 @@ using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>; static void deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, - InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, + InstOverlapIntervalsTy &IOL, MapVector<Instruction *, bool> &ThrowableInst, SmallSetVector<const Value *, 16> *ValueSet = nullptr) { SmallVector<Instruction*, 32> NowDeadInsts; @@ -123,6 +157,7 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, // Try to preserve debug information attached to the dead instruction. salvageDebugInfo(*DeadInst); + salvageKnowledge(DeadInst); // This instruction is dead, zap it, in stages. Start by removing it from // MemDep, which needs to know the operands and needs it to be in the @@ -143,7 +178,6 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, if (ValueSet) ValueSet->remove(DeadInst); IOL.erase(DeadInst); - OBB.eraseInstruction(DeadInst); if (NewIter == DeadInst->getIterator()) NewIter = DeadInst->eraseFromParent(); @@ -177,19 +211,17 @@ static bool hasAnalyzableMemoryWrite(Instruction *I, return true; } } - if (auto CS = CallSite(I)) { - if (Function *F = CS.getCalledFunction()) { - LibFunc LF; - if (TLI.getLibFunc(*F, LF) && TLI.has(LF)) { - switch (LF) { - case LibFunc_strcpy: - case LibFunc_strncpy: - case LibFunc_strcat: - case LibFunc_strncat: - return true; - default: - return false; - } + if (auto *CB = dyn_cast<CallBase>(I)) { + LibFunc LF; + if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) { + switch (LF) { + case LibFunc_strcpy: + case LibFunc_strncpy: + case LibFunc_strcat: + case LibFunc_strncat: + return true; + default: + return false; } } } @@ -222,10 +254,10 @@ static MemoryLocation getLocForWrite(Instruction *Inst) { } } } - if (auto CS = CallSite(Inst)) + if (auto *CB = dyn_cast<CallBase>(Inst)) // All the supported TLI functions so far happen to have dest as their // first argument. - return MemoryLocation(CS.getArgument(0)); + return MemoryLocation(CB->getArgOperand(0)); return MemoryLocation(); } @@ -272,8 +304,8 @@ static bool isRemovable(Instruction *I) { } // note: only get here for calls with analyzable writes - i.e. libcalls - if (auto CS = CallSite(I)) - return CS.getInstruction()->use_empty(); + if (auto *CB = dyn_cast<CallBase>(I)) + return CB->use_empty(); return false; } @@ -597,51 +629,82 @@ static bool isPossibleSelfRead(Instruction *Inst, /// instruction. static bool memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, - AliasAnalysis *AA) { - SmallVector<BasicBlock *, 16> WorkList; - SmallPtrSet<BasicBlock *, 8> Visited; + AliasAnalysis *AA, + const DataLayout &DL, + DominatorTree *DT) { + // Do a backwards scan through the CFG from SecondI to FirstI. Look for + // instructions which can modify the memory location accessed by SecondI. + // + // While doing the walk keep track of the address to check. It might be + // different in different basic blocks due to PHI translation. + using BlockAddressPair = std::pair<BasicBlock *, PHITransAddr>; + SmallVector<BlockAddressPair, 16> WorkList; + // Keep track of the address we visited each block with. Bail out if we + // visit a block with different addresses. + DenseMap<BasicBlock *, Value *> Visited; + BasicBlock::iterator FirstBBI(FirstI); ++FirstBBI; BasicBlock::iterator SecondBBI(SecondI); BasicBlock *FirstBB = FirstI->getParent(); BasicBlock *SecondBB = SecondI->getParent(); MemoryLocation MemLoc = MemoryLocation::get(SecondI); + auto *MemLocPtr = const_cast<Value *>(MemLoc.Ptr); - // Start checking the store-block. - WorkList.push_back(SecondBB); + // Start checking the SecondBB. + WorkList.push_back( + std::make_pair(SecondBB, PHITransAddr(MemLocPtr, DL, nullptr))); bool isFirstBlock = true; - // Check all blocks going backward until we reach the load-block. + // Check all blocks going backward until we reach the FirstBB. while (!WorkList.empty()) { - BasicBlock *B = WorkList.pop_back_val(); + BlockAddressPair Current = WorkList.pop_back_val(); + BasicBlock *B = Current.first; + PHITransAddr &Addr = Current.second; + Value *Ptr = Addr.getAddr(); - // Ignore instructions before LI if this is the FirstBB. + // Ignore instructions before FirstI if this is the FirstBB. BasicBlock::iterator BI = (B == FirstBB ? FirstBBI : B->begin()); BasicBlock::iterator EI; if (isFirstBlock) { - // Ignore instructions after SI if this is the first visit of SecondBB. + // Ignore instructions after SecondI if this is the first visit of SecondBB. assert(B == SecondBB && "first block is not the store block"); EI = SecondBBI; isFirstBlock = false; } else { // It's not SecondBB or (in case of a loop) the second visit of SecondBB. - // In this case we also have to look at instructions after SI. + // In this case we also have to look at instructions after SecondI. EI = B->end(); } for (; BI != EI; ++BI) { Instruction *I = &*BI; if (I->mayWriteToMemory() && I != SecondI) - if (isModSet(AA->getModRefInfo(I, MemLoc))) + if (isModSet(AA->getModRefInfo(I, MemLoc.getWithNewPtr(Ptr)))) return false; } if (B != FirstBB) { assert(B != &FirstBB->getParent()->getEntryBlock() && "Should not hit the entry block because SI must be dominated by LI"); for (auto PredI = pred_begin(B), PE = pred_end(B); PredI != PE; ++PredI) { - if (!Visited.insert(*PredI).second) + PHITransAddr PredAddr = Addr; + if (PredAddr.NeedsPHITranslationFromBlock(B)) { + if (!PredAddr.IsPotentiallyPHITranslatable()) + return false; + if (PredAddr.PHITranslateValue(B, *PredI, DT, false)) + return false; + } + Value *TranslatedPtr = PredAddr.getAddr(); + auto Inserted = Visited.insert(std::make_pair(*PredI, TranslatedPtr)); + if (!Inserted.second) { + // We already visited this block before. If it was with a different + // address - bail out! + if (TranslatedPtr != Inserted.first->second) + return false; + // ... otherwise just skip it. continue; - WorkList.push_back(*PredI); + } + WorkList.push_back(std::make_pair(*PredI, PredAddr)); } } } @@ -669,7 +732,7 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, static bool handleFree(CallInst *F, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, const TargetLibraryInfo *TLI, - InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, + InstOverlapIntervalsTy &IOL, MapVector<Instruction *, bool> &ThrowableInst) { bool MadeChange = false; @@ -704,7 +767,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, // DCE instructions only used to calculate that store. BasicBlock::iterator BBI(Dependency); - deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, OBB, + deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, ThrowableInst); ++NumFastStores; MadeChange = true; @@ -762,7 +825,7 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc, static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, MemoryDependenceResults *MD, const TargetLibraryInfo *TLI, - InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, + InstOverlapIntervalsTy &IOL, MapVector<Instruction *, bool> &ThrowableInst) { bool MadeChange = false; @@ -785,7 +848,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // Treat byval or inalloca arguments the same, stores to them are dead at the // end of the function. for (Argument &AI : BB.getParent()->args()) - if (AI.hasByValOrInAllocaAttr()) + if (AI.hasPassPointeeByValueAttr()) DeadStackObjects.insert(&AI); const DataLayout &DL = BB.getModule()->getDataLayout(); @@ -824,7 +887,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, << '\n'); // DCE instructions only used to calculate that store. - deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst, + deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, ThrowableInst, &DeadStackObjects); ++NumFastStores; MadeChange = true; @@ -836,7 +899,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, if (isInstructionTriviallyDead(&*BBI, TLI)) { LLVM_DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " << *&*BBI << '\n'); - deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst, + deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, ThrowableInst, &DeadStackObjects); ++NumFastOther; MadeChange = true; @@ -1043,8 +1106,8 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, const DataLayout &DL, const TargetLibraryInfo *TLI, InstOverlapIntervalsTy &IOL, - OrderedBasicBlock &OBB, - MapVector<Instruction *, bool> &ThrowableInst) { + MapVector<Instruction *, bool> &ThrowableInst, + DominatorTree *DT) { // Must be a store instruction. StoreInst *SI = dyn_cast<StoreInst>(Inst); if (!SI) @@ -1054,13 +1117,14 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, // then the store can be removed. if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) { if (SI->getPointerOperand() == DepLoad->getPointerOperand() && - isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) { + isRemovable(SI) && + memoryIsNotModifiedBetween(DepLoad, SI, AA, DL, DT)) { LLVM_DEBUG( dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " << *DepLoad << "\n STORE: " << *SI << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst); ++NumRedundantStores; return true; } @@ -1073,12 +1137,12 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, dyn_cast<Instruction>(GetUnderlyingObject(SI->getPointerOperand(), DL)); if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && - memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) { + memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA, DL, DT)) { LLVM_DEBUG( dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst); ++NumRedundantStores; return true; } @@ -1086,13 +1150,58 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, return false; } +static Constant * +tryToMergePartialOverlappingStores(StoreInst *Earlier, StoreInst *Later, + int64_t InstWriteOffset, + int64_t DepWriteOffset, const DataLayout &DL, + AliasAnalysis *AA, DominatorTree *DT) { + + if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && + DL.typeSizeEqualsStoreSize(Earlier->getValueOperand()->getType()) && + Later && isa<ConstantInt>(Later->getValueOperand()) && + DL.typeSizeEqualsStoreSize(Later->getValueOperand()->getType()) && + memoryIsNotModifiedBetween(Earlier, Later, AA, DL, DT)) { + // If the store we find is: + // a) partially overwritten by the store to 'Loc' + // b) the later store is fully contained in the earlier one and + // c) they both have a constant value + // d) none of the two stores need padding + // Merge the two stores, replacing the earlier store's value with a + // merge of both values. + // TODO: Deal with other constant types (vectors, etc), and probably + // some mem intrinsics (if needed) + + APInt EarlierValue = + cast<ConstantInt>(Earlier->getValueOperand())->getValue(); + APInt LaterValue = cast<ConstantInt>(Later->getValueOperand())->getValue(); + unsigned LaterBits = LaterValue.getBitWidth(); + assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); + LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + + // Offset of the smaller store inside the larger store + unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; + unsigned LShiftAmount = DL.isBigEndian() ? EarlierValue.getBitWidth() - + BitOffsetDiff - LaterBits + : BitOffsetDiff; + APInt Mask = APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, + LShiftAmount + LaterBits); + // Clear the bits we'll be replacing, then OR with the smaller + // store, shifted appropriately. + APInt Merged = (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); + LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *Earlier + << "\n Later: " << *Later + << "\n Merged Value: " << Merged << '\n'); + return ConstantInt::get(Earlier->getValueOperand()->getType(), Merged); + } + return nullptr; +} + static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, const TargetLibraryInfo *TLI) { const DataLayout &DL = BB.getModule()->getDataLayout(); bool MadeChange = false; - OrderedBasicBlock OBB(&BB); MapVector<Instruction *, bool> ThrowableInst; // A map of interval maps representing partially-overwritten value parts. @@ -1102,7 +1211,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { // Handle 'free' calls specially. if (CallInst *F = isFreeCall(&*BBI, TLI)) { - MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, OBB, ThrowableInst); + MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, ThrowableInst); // Increment BBI after handleFree has potentially deleted instructions. // This ensures we maintain a valid iterator. ++BBI; @@ -1121,14 +1230,14 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, continue; // eliminateNoopStore will update in iterator, if necessary. - if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, OBB, - ThrowableInst)) { + if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, + ThrowableInst, DT)) { MadeChange = true; continue; } // If we find something that writes memory, get its memory dependence. - MemDepResult InstDep = MD->getDependency(Inst, &OBB); + MemDepResult InstDep = MD->getDependency(Inst); // Ignore any store where we can't find a local dependence. // FIXME: cross-block DSE would be fun. :) @@ -1179,7 +1288,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // If the underlying object is a non-escaping memory allocation, any store // to it is dead along the unwind edge. Otherwise, we need to preserve // the store. - if (LastThrowing && OBB.dominates(DepWrite, LastThrowing)) { + if (LastThrowing && DepWrite->comesBefore(LastThrowing)) { const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL); bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying); if (!IsStoreDeadOnUnwind) { @@ -1210,13 +1319,13 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, << "\n KILLER: " << *Inst << '\n'); // Delete the store and now-dead instructions that feed it. - deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB, + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, ThrowableInst); ++NumFastStores; MadeChange = true; // We erased DepWrite; start over. - InstDep = MD->getDependency(Inst, &OBB); + InstDep = MD->getDependency(Inst); continue; } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) || ((OR == OW_Begin && @@ -1234,53 +1343,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, OR == OW_PartialEarlierWithFullLater) { auto *Earlier = dyn_cast<StoreInst>(DepWrite); auto *Later = dyn_cast<StoreInst>(Inst); - if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && - DL.typeSizeEqualsStoreSize( - Earlier->getValueOperand()->getType()) && - Later && isa<ConstantInt>(Later->getValueOperand()) && - DL.typeSizeEqualsStoreSize( - Later->getValueOperand()->getType()) && - memoryIsNotModifiedBetween(Earlier, Later, AA)) { - // If the store we find is: - // a) partially overwritten by the store to 'Loc' - // b) the later store is fully contained in the earlier one and - // c) they both have a constant value - // d) none of the two stores need padding - // Merge the two stores, replacing the earlier store's value with a - // merge of both values. - // TODO: Deal with other constant types (vectors, etc), and probably - // some mem intrinsics (if needed) - - APInt EarlierValue = - cast<ConstantInt>(Earlier->getValueOperand())->getValue(); - APInt LaterValue = - cast<ConstantInt>(Later->getValueOperand())->getValue(); - unsigned LaterBits = LaterValue.getBitWidth(); - assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); - LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); - - // Offset of the smaller store inside the larger store - unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; - unsigned LShiftAmount = - DL.isBigEndian() - ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits - : BitOffsetDiff; - APInt Mask = - APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, - LShiftAmount + LaterBits); - // Clear the bits we'll be replacing, then OR with the smaller - // store, shifted appropriately. - APInt Merged = - (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); - LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite - << "\n Later: " << *Inst - << "\n Merged Value: " << Merged << '\n'); - + if (Constant *C = tryToMergePartialOverlappingStores( + Earlier, Later, InstWriteOffset, DepWriteOffset, DL, AA, + DT)) { auto *SI = new StoreInst( - ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), - Earlier->getPointerOperand(), false, - MaybeAlign(Earlier->getAlignment()), Earlier->getOrdering(), - Earlier->getSyncScopeID(), DepWrite); + C, Earlier->getPointerOperand(), false, Earlier->getAlign(), + Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite); unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, @@ -1289,13 +1357,10 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, SI->copyMetadata(*DepWrite, MDToKeep); ++NumModifiedStores; - // Remove earlier, wider, store - OBB.replaceInstruction(DepWrite, SI); - // Delete the old stores and now-dead instructions that feed them. - deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, OBB, + deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, ThrowableInst); - deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB, + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, ThrowableInst); MadeChange = true; @@ -1331,7 +1396,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // If this block ends in a return, unwind, or unreachable, all allocas are // dead at its end, which means stores to them are also dead. if (BB.getTerminator()->getNumSuccessors() == 0) - MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, OBB, ThrowableInst); + MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, ThrowableInst); return MadeChange; } @@ -1349,22 +1414,913 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis *AA, return MadeChange; } +namespace { +//============================================================================= +// MemorySSA backed dead store elimination. +// +// The code below implements dead store elimination using MemorySSA. It uses +// the following general approach: given a MemoryDef, walk upwards to find +// clobbering MemoryDefs that may be killed by the starting def. Then check +// that there are no uses that may read the location of the original MemoryDef +// in between both MemoryDefs. A bit more concretely: +// +// For all MemoryDefs StartDef: +// 1. Get the next dominating clobbering MemoryDef (DomAccess) by walking +// upwards. +// 2. Check that there are no reads between DomAccess and the StartDef by +// checking all uses starting at DomAccess and walking until we see StartDef. +// 3. For each found DomDef, check that: +// 1. There are no barrier instructions between DomDef and StartDef (like +// throws or stores with ordering constraints). +// 2. StartDef is executed whenever DomDef is executed. +// 3. StartDef completely overwrites DomDef. +// 4. Erase DomDef from the function and MemorySSA. + +// Returns true if \p M is an intrisnic that does not read or write memory. +bool isNoopIntrinsic(MemoryUseOrDef *M) { + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(M->getMemoryInst())) { + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_end: + case Intrinsic::launder_invariant_group: + case Intrinsic::assume: + return true; + case Intrinsic::dbg_addr: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_label: + case Intrinsic::dbg_value: + llvm_unreachable("Intrinsic should not be modeled in MemorySSA"); + default: + return false; + } + } + return false; +} + +// Check if we can ignore \p D for DSE. +bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { + Instruction *DI = D->getMemoryInst(); + // Calls that only access inaccessible memory cannot read or write any memory + // locations we consider for elimination. + if (auto *CB = dyn_cast<CallBase>(DI)) + if (CB->onlyAccessesInaccessibleMemory()) + return true; + + // We can eliminate stores to locations not visible to the caller across + // throwing instructions. + if (DI->mayThrow() && !DefVisibleToCaller) + return true; + + // We can remove the dead stores, irrespective of the fence and its ordering + // (release/acquire/seq_cst). Fences only constraints the ordering of + // already visible stores, it does not make a store visible to other + // threads. So, skipping over a fence does not change a store from being + // dead. + if (isa<FenceInst>(DI)) + return true; + + // Skip intrinsics that do not really read or modify memory. + if (isNoopIntrinsic(D)) + return true; + + return false; +} + +struct DSEState { + Function &F; + AliasAnalysis &AA; + MemorySSA &MSSA; + DominatorTree &DT; + PostDominatorTree &PDT; + const TargetLibraryInfo &TLI; + + // All MemoryDefs that potentially could kill other MemDefs. + SmallVector<MemoryDef *, 64> MemDefs; + // Any that should be skipped as they are already deleted + SmallPtrSet<MemoryAccess *, 4> SkipStores; + // Keep track of all of the objects that are invisible to the caller before + // the function returns. + SmallPtrSet<const Value *, 16> InvisibleToCallerBeforeRet; + // Keep track of all of the objects that are invisible to the caller after + // the function returns. + SmallPtrSet<const Value *, 16> InvisibleToCallerAfterRet; + // Keep track of blocks with throwing instructions not modeled in MemorySSA. + SmallPtrSet<BasicBlock *, 16> ThrowingBlocks; + // Post-order numbers for each basic block. Used to figure out if memory + // accesses are executed before another access. + DenseMap<BasicBlock *, unsigned> PostOrderNumbers; + + /// Keep track of instructions (partly) overlapping with killing MemoryDefs per + /// basic block. + DenseMap<BasicBlock *, InstOverlapIntervalsTy> IOLs; + + DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, + PostDominatorTree &PDT, const TargetLibraryInfo &TLI) + : F(F), AA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI) {} + + static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, + DominatorTree &DT, PostDominatorTree &PDT, + const TargetLibraryInfo &TLI) { + DSEState State(F, AA, MSSA, DT, PDT, TLI); + // Collect blocks with throwing instructions not modeled in MemorySSA and + // alloc-like objects. + unsigned PO = 0; + for (BasicBlock *BB : post_order(&F)) { + State.PostOrderNumbers[BB] = PO++; + for (Instruction &I : *BB) { + MemoryAccess *MA = MSSA.getMemoryAccess(&I); + if (I.mayThrow() && !MA) + State.ThrowingBlocks.insert(I.getParent()); + + auto *MD = dyn_cast_or_null<MemoryDef>(MA); + if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit && + (State.getLocForWriteEx(&I) || State.isMemTerminatorInst(&I))) + State.MemDefs.push_back(MD); + + // Track whether alloca and alloca-like objects are visible in the + // caller before and after the function returns. Alloca objects are + // invalid in the caller, so they are neither visible before or after + // the function returns. + if (isa<AllocaInst>(&I)) { + State.InvisibleToCallerBeforeRet.insert(&I); + State.InvisibleToCallerAfterRet.insert(&I); + } + + // For alloca-like objects we need to check if they are captured before + // the function returns and if the return might capture the object. + if (isAllocLikeFn(&I, &TLI)) { + bool CapturesBeforeRet = PointerMayBeCaptured(&I, false, true); + if (!CapturesBeforeRet) { + State.InvisibleToCallerBeforeRet.insert(&I); + if (!PointerMayBeCaptured(&I, true, false)) + State.InvisibleToCallerAfterRet.insert(&I); + } + } + } + } + + // Treat byval or inalloca arguments the same as Allocas, stores to them are + // dead at the end of the function. + for (Argument &AI : F.args()) + if (AI.hasPassPointeeByValueAttr()) { + // For byval, the caller doesn't know the address of the allocation. + if (AI.hasByValAttr()) + State.InvisibleToCallerBeforeRet.insert(&AI); + State.InvisibleToCallerAfterRet.insert(&AI); + } + + return State; + } + + Optional<MemoryLocation> getLocForWriteEx(Instruction *I) const { + if (!I->mayWriteToMemory()) + return None; + + if (auto *MTI = dyn_cast<AnyMemIntrinsic>(I)) + return {MemoryLocation::getForDest(MTI)}; + + if (auto *CB = dyn_cast<CallBase>(I)) { + LibFunc LF; + if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) { + switch (LF) { + case LibFunc_strcpy: + case LibFunc_strncpy: + case LibFunc_strcat: + case LibFunc_strncat: + return {MemoryLocation(CB->getArgOperand(0))}; + default: + break; + } + } + return None; + } + + return MemoryLocation::getOrNone(I); + } + + /// Returns true if \p Use completely overwrites \p DefLoc. + bool isCompleteOverwrite(MemoryLocation DefLoc, Instruction *UseInst) const { + // UseInst has a MemoryDef associated in MemorySSA. It's possible for a + // MemoryDef to not write to memory, e.g. a volatile load is modeled as a + // MemoryDef. + if (!UseInst->mayWriteToMemory()) + return false; + + if (auto *CB = dyn_cast<CallBase>(UseInst)) + if (CB->onlyAccessesInaccessibleMemory()) + return false; + + int64_t InstWriteOffset, DepWriteOffset; + auto CC = getLocForWriteEx(UseInst); + InstOverlapIntervalsTy IOL; + + const DataLayout &DL = F.getParent()->getDataLayout(); + + return CC && + isOverwrite(*CC, DefLoc, DL, TLI, DepWriteOffset, InstWriteOffset, + UseInst, IOL, AA, &F) == OW_Complete; + } + + /// Returns true if \p Def is not read before returning from the function. + bool isWriteAtEndOfFunction(MemoryDef *Def) { + LLVM_DEBUG(dbgs() << " Check if def " << *Def << " (" + << *Def->getMemoryInst() + << ") is at the end the function \n"); + + auto MaybeLoc = getLocForWriteEx(Def->getMemoryInst()); + if (!MaybeLoc) { + LLVM_DEBUG(dbgs() << " ... could not get location for write.\n"); + return false; + } + + SmallVector<MemoryAccess *, 4> WorkList; + SmallPtrSet<MemoryAccess *, 8> Visited; + auto PushMemUses = [&WorkList, &Visited](MemoryAccess *Acc) { + if (!Visited.insert(Acc).second) + return; + for (Use &U : Acc->uses()) + WorkList.push_back(cast<MemoryAccess>(U.getUser())); + }; + PushMemUses(Def); + for (unsigned I = 0; I < WorkList.size(); I++) { + if (WorkList.size() >= MemorySSAScanLimit) { + LLVM_DEBUG(dbgs() << " ... hit exploration limit.\n"); + return false; + } + + MemoryAccess *UseAccess = WorkList[I]; + if (isa<MemoryPhi>(UseAccess)) { + PushMemUses(UseAccess); + continue; + } + + // TODO: Checking for aliasing is expensive. Consider reducing the amount + // of times this is called and/or caching it. + Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); + if (isReadClobber(*MaybeLoc, UseInst)) { + LLVM_DEBUG(dbgs() << " ... hit read clobber " << *UseInst << ".\n"); + return false; + } + + if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) + PushMemUses(UseDef); + } + return true; + } + + /// If \p I is a memory terminator like llvm.lifetime.end or free, return a + /// pair with the MemoryLocation terminated by \p I and a boolean flag + /// indicating whether \p I is a free-like call. + Optional<std::pair<MemoryLocation, bool>> + getLocForTerminator(Instruction *I) const { + uint64_t Len; + Value *Ptr; + if (match(I, m_Intrinsic<Intrinsic::lifetime_end>(m_ConstantInt(Len), + m_Value(Ptr)))) + return {std::make_pair(MemoryLocation(Ptr, Len), false)}; + + if (auto *CB = dyn_cast<CallBase>(I)) { + if (isFreeCall(I, &TLI)) + return {std::make_pair(MemoryLocation(CB->getArgOperand(0)), true)}; + } + + return None; + } + + /// Returns true if \p I is a memory terminator instruction like + /// llvm.lifetime.end or free. + bool isMemTerminatorInst(Instruction *I) const { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); + return (II && II->getIntrinsicID() == Intrinsic::lifetime_end) || + isFreeCall(I, &TLI); + } + + /// Returns true if \p MaybeTerm is a memory terminator for the same + /// underlying object as \p DefLoc. + bool isMemTerminator(MemoryLocation DefLoc, Instruction *MaybeTerm) const { + Optional<std::pair<MemoryLocation, bool>> MaybeTermLoc = + getLocForTerminator(MaybeTerm); + + if (!MaybeTermLoc) + return false; + + // If the terminator is a free-like call, all accesses to the underlying + // object can be considered terminated. + if (MaybeTermLoc->second) { + DataLayout DL = MaybeTerm->getParent()->getModule()->getDataLayout(); + DefLoc = MemoryLocation(GetUnderlyingObject(DefLoc.Ptr, DL)); + } + return AA.isMustAlias(MaybeTermLoc->first, DefLoc); + } + + // Returns true if \p Use may read from \p DefLoc. + bool isReadClobber(MemoryLocation DefLoc, Instruction *UseInst) const { + if (!UseInst->mayReadFromMemory()) + return false; + + if (auto *CB = dyn_cast<CallBase>(UseInst)) + if (CB->onlyAccessesInaccessibleMemory()) + return false; + + ModRefInfo MR = AA.getModRefInfo(UseInst, DefLoc); + // If necessary, perform additional analysis. + if (isRefSet(MR)) + MR = AA.callCapturesBefore(UseInst, DefLoc, &DT); + return isRefSet(MR); + } + + // Find a MemoryDef writing to \p DefLoc and dominating \p Current, with no + // read access between them or on any other path to a function exit block if + // \p DefLoc is not accessible after the function returns. If there is no such + // MemoryDef, return None. The returned value may not (completely) overwrite + // \p DefLoc. Currently we bail out when we encounter an aliasing MemoryUse + // (read). + Optional<MemoryAccess *> + getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *Current, + MemoryLocation DefLoc, bool DefVisibleToCallerBeforeRet, + bool DefVisibleToCallerAfterRet, int &ScanLimit) const { + MemoryAccess *DomAccess; + bool StepAgain; + LLVM_DEBUG(dbgs() << " trying to get dominating access for " << *Current + << "\n"); + // Find the next clobbering Mod access for DefLoc, starting at Current. + do { + StepAgain = false; + // Reached TOP. + if (MSSA.isLiveOnEntryDef(Current)) + return None; + + if (isa<MemoryPhi>(Current)) { + DomAccess = Current; + break; + } + MemoryUseOrDef *CurrentUD = cast<MemoryUseOrDef>(Current); + // Look for access that clobber DefLoc. + DomAccess = MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(CurrentUD, + DefLoc); + if (MSSA.isLiveOnEntryDef(DomAccess)) + return None; + + if (isa<MemoryPhi>(DomAccess)) + break; + + // Check if we can skip DomDef for DSE. + MemoryDef *DomDef = dyn_cast<MemoryDef>(DomAccess); + if (DomDef && canSkipDef(DomDef, DefVisibleToCallerBeforeRet)) { + StepAgain = true; + Current = DomDef->getDefiningAccess(); + } + + } while (StepAgain); + + // Accesses to objects accessible after the function returns can only be + // eliminated if the access is killed along all paths to the exit. Collect + // the blocks with killing (=completely overwriting MemoryDefs) and check if + // they cover all paths from DomAccess to any function exit. + SmallPtrSet<BasicBlock *, 16> KillingBlocks = {KillingDef->getBlock()}; + LLVM_DEBUG({ + dbgs() << " Checking for reads of " << *DomAccess; + if (isa<MemoryDef>(DomAccess)) + dbgs() << " (" << *cast<MemoryDef>(DomAccess)->getMemoryInst() << ")\n"; + else + dbgs() << ")\n"; + }); + + SmallSetVector<MemoryAccess *, 32> WorkList; + auto PushMemUses = [&WorkList](MemoryAccess *Acc) { + for (Use &U : Acc->uses()) + WorkList.insert(cast<MemoryAccess>(U.getUser())); + }; + PushMemUses(DomAccess); + + // Check if DomDef may be read. + for (unsigned I = 0; I < WorkList.size(); I++) { + MemoryAccess *UseAccess = WorkList[I]; + + LLVM_DEBUG(dbgs() << " " << *UseAccess); + if (--ScanLimit == 0) { + LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n"); + return None; + } + + if (isa<MemoryPhi>(UseAccess)) { + LLVM_DEBUG(dbgs() << "\n ... adding PHI uses\n"); + PushMemUses(UseAccess); + continue; + } + + Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); + LLVM_DEBUG(dbgs() << " (" << *UseInst << ")\n"); + + if (isNoopIntrinsic(cast<MemoryUseOrDef>(UseAccess))) { + LLVM_DEBUG(dbgs() << " ... adding uses of intrinsic\n"); + PushMemUses(UseAccess); + continue; + } + + // A memory terminator kills all preceeding MemoryDefs and all succeeding + // MemoryAccesses. We do not have to check it's users. + if (isMemTerminator(DefLoc, UseInst)) + continue; + + // Uses which may read the original MemoryDef mean we cannot eliminate the + // original MD. Stop walk. + if (isReadClobber(DefLoc, UseInst)) { + LLVM_DEBUG(dbgs() << " ... found read clobber\n"); + return None; + } + + // For the KillingDef and DomAccess we only have to check if it reads the + // memory location. + // TODO: It would probably be better to check for self-reads before + // calling the function. + if (KillingDef == UseAccess || DomAccess == UseAccess) { + LLVM_DEBUG(dbgs() << " ... skipping killing def/dom access\n"); + continue; + } + + // Check all uses for MemoryDefs, except for defs completely overwriting + // the original location. Otherwise we have to check uses of *all* + // MemoryDefs we discover, including non-aliasing ones. Otherwise we might + // miss cases like the following + // 1 = Def(LoE) ; <----- DomDef stores [0,1] + // 2 = Def(1) ; (2, 1) = NoAlias, stores [2,3] + // Use(2) ; MayAlias 2 *and* 1, loads [0, 3]. + // (The Use points to the *first* Def it may alias) + // 3 = Def(1) ; <---- Current (3, 2) = NoAlias, (3,1) = MayAlias, + // stores [0,1] + if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) { + if (isCompleteOverwrite(DefLoc, UseInst)) { + if (DefVisibleToCallerAfterRet && UseAccess != DomAccess) { + BasicBlock *MaybeKillingBlock = UseInst->getParent(); + if (PostOrderNumbers.find(MaybeKillingBlock)->second < + PostOrderNumbers.find(DomAccess->getBlock())->second) { + + LLVM_DEBUG(dbgs() << " ... found killing block " + << MaybeKillingBlock->getName() << "\n"); + KillingBlocks.insert(MaybeKillingBlock); + } + } + } else + PushMemUses(UseDef); + } + } + + // For accesses to locations visible after the function returns, make sure + // that the location is killed (=overwritten) along all paths from DomAccess + // to the exit. + if (DefVisibleToCallerAfterRet) { + assert(!KillingBlocks.empty() && + "Expected at least a single killing block"); + // Find the common post-dominator of all killing blocks. + BasicBlock *CommonPred = *KillingBlocks.begin(); + for (auto I = std::next(KillingBlocks.begin()), E = KillingBlocks.end(); + I != E; I++) { + if (!CommonPred) + break; + CommonPred = PDT.findNearestCommonDominator(CommonPred, *I); + } + + // If CommonPred is in the set of killing blocks, just check if it + // post-dominates DomAccess. + if (KillingBlocks.count(CommonPred)) { + if (PDT.dominates(CommonPred, DomAccess->getBlock())) + return {DomAccess}; + return None; + } + + // If the common post-dominator does not post-dominate DomAccess, there + // is a path from DomAccess to an exit not going through a killing block. + if (PDT.dominates(CommonPred, DomAccess->getBlock())) { + SetVector<BasicBlock *> WorkList; + + // DomAccess's post-order number provides an upper bound of the blocks + // on a path starting at DomAccess. + unsigned UpperBound = + PostOrderNumbers.find(DomAccess->getBlock())->second; + + // If CommonPred is null, there are multiple exits from the function. + // They all have to be added to the worklist. + if (CommonPred) + WorkList.insert(CommonPred); + else + for (BasicBlock *R : PDT.roots()) + WorkList.insert(R); + + NumCFGTries++; + // Check if all paths starting from an exit node go through one of the + // killing blocks before reaching DomAccess. + for (unsigned I = 0; I < WorkList.size(); I++) { + NumCFGChecks++; + BasicBlock *Current = WorkList[I]; + if (KillingBlocks.count(Current)) + continue; + if (Current == DomAccess->getBlock()) + return None; + + // DomAccess is reachable from the entry, so we don't have to explore + // unreachable blocks further. + if (!DT.isReachableFromEntry(Current)) + continue; + + unsigned CPO = PostOrderNumbers.find(Current)->second; + // Current block is not on a path starting at DomAccess. + if (CPO > UpperBound) + continue; + for (BasicBlock *Pred : predecessors(Current)) + WorkList.insert(Pred); + + if (WorkList.size() >= MemorySSAPathCheckLimit) + return None; + } + NumCFGSuccess++; + return {DomAccess}; + } + return None; + } + + // No aliasing MemoryUses of DomAccess found, DomAccess is potentially dead. + return {DomAccess}; + } + + // Delete dead memory defs + void deleteDeadInstruction(Instruction *SI) { + MemorySSAUpdater Updater(&MSSA); + SmallVector<Instruction *, 32> NowDeadInsts; + NowDeadInsts.push_back(SI); + --NumFastOther; + + while (!NowDeadInsts.empty()) { + Instruction *DeadInst = NowDeadInsts.pop_back_val(); + ++NumFastOther; + + // Try to preserve debug information attached to the dead instruction. + salvageDebugInfo(*DeadInst); + salvageKnowledge(DeadInst); + + // Remove the Instruction from MSSA. + if (MemoryAccess *MA = MSSA.getMemoryAccess(DeadInst)) { + if (MemoryDef *MD = dyn_cast<MemoryDef>(MA)) { + SkipStores.insert(MD); + } + Updater.removeMemoryAccess(MA); + } + + auto I = IOLs.find(DeadInst->getParent()); + if (I != IOLs.end()) + I->second.erase(DeadInst); + // Remove its operands + for (Use &O : DeadInst->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(O)) { + O = nullptr; + if (isInstructionTriviallyDead(OpI, &TLI)) + NowDeadInsts.push_back(OpI); + } + + DeadInst->eraseFromParent(); + } + } + + // Check for any extra throws between SI and NI that block DSE. This only + // checks extra maythrows (those that aren't MemoryDef's). MemoryDef that may + // throw are handled during the walk from one def to the next. + bool mayThrowBetween(Instruction *SI, Instruction *NI, + const Value *SILocUnd) const { + // First see if we can ignore it by using the fact that SI is an + // alloca/alloca like object that is not visible to the caller during + // execution of the function. + if (SILocUnd && InvisibleToCallerBeforeRet.count(SILocUnd)) + return false; + + if (SI->getParent() == NI->getParent()) + return ThrowingBlocks.count(SI->getParent()); + return !ThrowingBlocks.empty(); + } + + // Check if \p NI acts as a DSE barrier for \p SI. The following instructions + // act as barriers: + // * A memory instruction that may throw and \p SI accesses a non-stack + // object. + // * Atomic stores stronger that monotonic. + bool isDSEBarrier(const Value *SILocUnd, Instruction *NI) const { + // If NI may throw it acts as a barrier, unless we are to an alloca/alloca + // like object that does not escape. + if (NI->mayThrow() && !InvisibleToCallerBeforeRet.count(SILocUnd)) + return true; + + // If NI is an atomic load/store stronger than monotonic, do not try to + // eliminate/reorder it. + if (NI->isAtomic()) { + if (auto *LI = dyn_cast<LoadInst>(NI)) + return isStrongerThanMonotonic(LI->getOrdering()); + if (auto *SI = dyn_cast<StoreInst>(NI)) + return isStrongerThanMonotonic(SI->getOrdering()); + llvm_unreachable("other instructions should be skipped in MemorySSA"); + } + return false; + } + + /// Eliminate writes to objects that are not visible in the caller and are not + /// accessed before returning from the function. + bool eliminateDeadWritesAtEndOfFunction() { + const DataLayout &DL = F.getParent()->getDataLayout(); + bool MadeChange = false; + LLVM_DEBUG( + dbgs() + << "Trying to eliminate MemoryDefs at the end of the function\n"); + for (int I = MemDefs.size() - 1; I >= 0; I--) { + MemoryDef *Def = MemDefs[I]; + if (SkipStores.find(Def) != SkipStores.end() || + !isRemovable(Def->getMemoryInst())) + continue; + + // TODO: Consider doing the underlying object check first, if it is + // beneficial compile-time wise. + if (isWriteAtEndOfFunction(Def)) { + Instruction *DefI = Def->getMemoryInst(); + // See through pointer-to-pointer bitcasts + SmallVector<const Value *, 4> Pointers; + GetUnderlyingObjects(getLocForWriteEx(DefI)->Ptr, Pointers, DL); + + LLVM_DEBUG(dbgs() << " ... MemoryDef is not accessed until the end " + "of the function\n"); + bool CanKill = true; + for (const Value *Pointer : Pointers) { + if (!InvisibleToCallerAfterRet.count(Pointer)) { + CanKill = false; + break; + } + } + + if (CanKill) { + deleteDeadInstruction(DefI); + ++NumFastStores; + MadeChange = true; + } + } + } + return MadeChange; + } + + /// \returns true if \p Def is a no-op store, either because it + /// directly stores back a loaded value or stores zero to a calloced object. + bool storeIsNoop(MemoryDef *Def, MemoryLocation DefLoc, const Value *DefUO) { + StoreInst *Store = dyn_cast<StoreInst>(Def->getMemoryInst()); + if (!Store) + return false; + + if (auto *LoadI = dyn_cast<LoadInst>(Store->getOperand(0))) { + if (LoadI->getPointerOperand() == Store->getOperand(1)) { + auto *LoadAccess = MSSA.getMemoryAccess(LoadI)->getDefiningAccess(); + // If both accesses share the same defining access, no instructions + // between them can modify the memory location. + return LoadAccess == Def->getDefiningAccess(); + } + } + + Constant *StoredConstant = dyn_cast<Constant>(Store->getOperand(0)); + if (StoredConstant && StoredConstant->isNullValue()) { + auto *DefUOInst = dyn_cast<Instruction>(DefUO); + if (DefUOInst && isCallocLikeFn(DefUOInst, &TLI)) { + auto *UnderlyingDef = cast<MemoryDef>(MSSA.getMemoryAccess(DefUOInst)); + // If UnderlyingDef is the clobbering access of Def, no instructions + // between them can modify the memory location. + auto *ClobberDef = + MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def); + return UnderlyingDef == ClobberDef; + } + } + return false; + } +}; + +bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA, + MemorySSA &MSSA, DominatorTree &DT, + PostDominatorTree &PDT, + const TargetLibraryInfo &TLI) { + const DataLayout &DL = F.getParent()->getDataLayout(); + bool MadeChange = false; + + DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI); + // For each store: + for (unsigned I = 0; I < State.MemDefs.size(); I++) { + MemoryDef *KillingDef = State.MemDefs[I]; + if (State.SkipStores.count(KillingDef)) + continue; + Instruction *SI = KillingDef->getMemoryInst(); + + auto MaybeSILoc = State.getLocForWriteEx(SI); + if (State.isMemTerminatorInst(SI)) + MaybeSILoc = State.getLocForTerminator(SI).map( + [](const std::pair<MemoryLocation, bool> &P) { return P.first; }); + else + MaybeSILoc = State.getLocForWriteEx(SI); + + if (!MaybeSILoc) { + LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " + << *SI << "\n"); + continue; + } + MemoryLocation SILoc = *MaybeSILoc; + assert(SILoc.Ptr && "SILoc should not be null"); + const Value *SILocUnd = GetUnderlyingObject(SILoc.Ptr, DL); + + // Check if the store is a no-op. + if (isRemovable(SI) && State.storeIsNoop(KillingDef, SILoc, SILocUnd)) { + LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *SI << '\n'); + State.deleteDeadInstruction(SI); + NumNoopStores++; + MadeChange = true; + continue; + } + + Instruction *DefObj = + const_cast<Instruction *>(dyn_cast<Instruction>(SILocUnd)); + bool DefVisibleToCallerBeforeRet = + !State.InvisibleToCallerBeforeRet.count(SILocUnd); + bool DefVisibleToCallerAfterRet = + !State.InvisibleToCallerAfterRet.count(SILocUnd); + if (DefObj && isAllocLikeFn(DefObj, &TLI)) { + if (DefVisibleToCallerBeforeRet) + DefVisibleToCallerBeforeRet = + PointerMayBeCapturedBefore(DefObj, false, true, SI, &DT); + } + + MemoryAccess *Current = KillingDef; + LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " + << *KillingDef << " (" << *SI << ")\n"); + + int ScanLimit = MemorySSAScanLimit; + // Worklist of MemoryAccesses that may be killed by KillingDef. + SetVector<MemoryAccess *> ToCheck; + ToCheck.insert(KillingDef->getDefiningAccess()); + + // Check if MemoryAccesses in the worklist are killed by KillingDef. + for (unsigned I = 0; I < ToCheck.size(); I++) { + Current = ToCheck[I]; + if (State.SkipStores.count(Current)) + continue; + + Optional<MemoryAccess *> Next = State.getDomMemoryDef( + KillingDef, Current, SILoc, DefVisibleToCallerBeforeRet, + DefVisibleToCallerAfterRet, ScanLimit); + + if (!Next) { + LLVM_DEBUG(dbgs() << " finished walk\n"); + continue; + } + + MemoryAccess *DomAccess = *Next; + LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DomAccess); + if (isa<MemoryPhi>(DomAccess)) { + LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); + for (Value *V : cast<MemoryPhi>(DomAccess)->incoming_values()) { + MemoryAccess *IncomingAccess = cast<MemoryAccess>(V); + BasicBlock *IncomingBlock = IncomingAccess->getBlock(); + BasicBlock *PhiBlock = DomAccess->getBlock(); + + // We only consider incoming MemoryAccesses that come before the + // MemoryPhi. Otherwise we could discover candidates that do not + // strictly dominate our starting def. + if (State.PostOrderNumbers[IncomingBlock] > + State.PostOrderNumbers[PhiBlock]) + ToCheck.insert(IncomingAccess); + } + continue; + } + MemoryDef *NextDef = dyn_cast<MemoryDef>(DomAccess); + Instruction *NI = NextDef->getMemoryInst(); + LLVM_DEBUG(dbgs() << " (" << *NI << ")\n"); + + // Before we try to remove anything, check for any extra throwing + // instructions that block us from DSEing + if (State.mayThrowBetween(SI, NI, SILocUnd)) { + LLVM_DEBUG(dbgs() << " ... skip, may throw!\n"); + break; + } + + // Check for anything that looks like it will be a barrier to further + // removal + if (State.isDSEBarrier(SILocUnd, NI)) { + LLVM_DEBUG(dbgs() << " ... skip, barrier\n"); + continue; + } + + ToCheck.insert(NextDef->getDefiningAccess()); + + if (!hasAnalyzableMemoryWrite(NI, TLI)) { + LLVM_DEBUG(dbgs() << " ... skip, cannot analyze def\n"); + continue; + } + + if (!isRemovable(NI)) { + LLVM_DEBUG(dbgs() << " ... skip, cannot remove def\n"); + continue; + } + + if (!DebugCounter::shouldExecute(MemorySSACounter)) + continue; + + MemoryLocation NILoc = *State.getLocForWriteEx(NI); + + if (State.isMemTerminatorInst(SI)) { + const Value *NIUnd = GetUnderlyingObject(NILoc.Ptr, DL); + if (!SILocUnd || SILocUnd != NIUnd) + continue; + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI + << "\n KILLER: " << *SI << '\n'); + State.deleteDeadInstruction(NI); + ++NumFastStores; + MadeChange = true; + } else { + // Check if NI overwrites SI. + int64_t InstWriteOffset, DepWriteOffset; + auto Iter = State.IOLs.insert( + std::make_pair<BasicBlock *, InstOverlapIntervalsTy>( + NI->getParent(), InstOverlapIntervalsTy())); + auto &IOL = Iter.first->second; + OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset, + InstWriteOffset, NI, IOL, AA, &F); + + if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { + auto *Earlier = dyn_cast<StoreInst>(NI); + auto *Later = dyn_cast<StoreInst>(SI); + if (Constant *Merged = tryToMergePartialOverlappingStores( + Earlier, Later, InstWriteOffset, DepWriteOffset, DL, &AA, + &DT)) { + + // Update stored value of earlier store to merged constant. + Earlier->setOperand(0, Merged); + ++NumModifiedStores; + MadeChange = true; + + // Remove later store and remove any outstanding overlap intervals + // for the updated store. + State.deleteDeadInstruction(Later); + auto I = State.IOLs.find(Earlier->getParent()); + if (I != State.IOLs.end()) + I->second.erase(Earlier); + break; + } + } + + if (OR == OW_Complete) { + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI + << "\n KILLER: " << *SI << '\n'); + State.deleteDeadInstruction(NI); + ++NumFastStores; + MadeChange = true; + } + } + } + } + + if (EnablePartialOverwriteTracking) + for (auto &KV : State.IOLs) + MadeChange |= removePartiallyOverlappedStores(&AA, DL, KV.second); + + MadeChange |= State.eliminateDeadWritesAtEndOfFunction(); + return MadeChange; +} +} // end anonymous namespace + //===----------------------------------------------------------------------===// // DSE Pass //===----------------------------------------------------------------------===// PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { - AliasAnalysis *AA = &AM.getResult<AAManager>(F); - DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); - MemoryDependenceResults *MD = &AM.getResult<MemoryDependenceAnalysis>(F); - const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + AliasAnalysis &AA = AM.getResult<AAManager>(F); + const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + + bool Changed = false; + if (EnableMemorySSA) { + MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - if (!eliminateDeadStores(F, AA, MD, DT, TLI)) + Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI); + } else { + MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); + + Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI); + } + +#ifdef LLVM_ENABLE_STATS + if (AreStatisticsEnabled()) + for (auto &I : instructions(F)) + NumRemainingStores += isa<StoreInst>(&I); +#endif + + if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); - PA.preserve<MemoryDependenceAnalysis>(); + if (EnableMemorySSA) + PA.preserve<MemorySSAAnalysis>(); + else + PA.preserve<MemoryDependenceAnalysis>(); return PA; } @@ -1383,25 +2339,51 @@ public: if (skipFunction(F)) return false; - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - MemoryDependenceResults *MD = - &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + + bool Changed = false; + if (EnableMemorySSA) { + MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + PostDominatorTree &PDT = + getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + + Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI); + } else { + MemoryDependenceResults &MD = + getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + + Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI); + } - return eliminateDeadStores(F, AA, MD, DT, TLI); +#ifdef LLVM_ENABLE_STATS + if (AreStatisticsEnabled()) + for (auto &I : instructions(F)) + NumRemainingStores += isa<StoreInst>(&I); +#endif + + return Changed; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + + if (EnableMemorySSA) { + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<PostDominatorTreeWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } else { + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } } }; @@ -1412,8 +2394,10 @@ char DSELegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp index 132dfc8f6da1..d44a5979a8b2 100644 --- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/PatternMatch.h" @@ -71,6 +72,7 @@ static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) { return M; } +namespace { /// A thin wrapper to store two values that we matched as div-rem pair. /// We want this extra indirection to avoid dealing with RAUW'ing the map keys. struct DivRemPairWorklistEntry { @@ -111,6 +113,7 @@ struct DivRemPairWorklistEntry { } } }; +} // namespace using DivRemWorklistTy = SmallVector<DivRemPairWorklistEntry, 4>; /// Find matching pairs of integer div/rem ops (they have the same numerator, @@ -218,6 +221,7 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, NumRecomposed++; // Note that we have left ((X / Y) * Y) around. // If it had other uses we could rewrite it as X - X % Y + Changed = true; } assert((!E.isRemExpanded() || !HasDivRemOp) && @@ -301,6 +305,29 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, Mul->insertAfter(RemInst); Sub->insertAfter(Mul); + // If X can be undef, X should be frozen first. + // For example, let's assume that Y = 1 & X = undef: + // %div = sdiv undef, 1 // %div = undef + // %rem = srem undef, 1 // %rem = 0 + // => + // %div = sdiv undef, 1 // %div = undef + // %mul = mul %div, 1 // %mul = undef + // %rem = sub %x, %mul // %rem = undef - undef = undef + // If X is not frozen, %rem becomes undef after transformation. + // TODO: We need a undef-specific checking function in ValueTracking + if (!isGuaranteedNotToBeUndefOrPoison(X, DivInst, &DT)) { + auto *FrX = new FreezeInst(X, X->getName() + ".frozen", DivInst); + DivInst->setOperand(0, FrX); + Sub->setOperand(0, FrX); + } + // Same for Y. If X = 1 and Y = (undef | 1), %rem in src is either 1 or 0, + // but %rem in tgt can be one of many integer values. + if (!isGuaranteedNotToBeUndefOrPoison(Y, DivInst, &DT)) { + auto *FrY = new FreezeInst(Y, Y->getName() + ".frozen", DivInst); + DivInst->setOperand(1, FrY); + Mul->setOperand(1, FrY); + } + // Now kill the explicit remainder. We have replaced it with: // (sub X, (mul (div X, Y), Y) Sub->setName(RemInst->getName() + ".decomposed"); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 40c1ba88354f..ddfc8555b0a0 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" @@ -54,6 +55,7 @@ #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/GuardUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> @@ -114,7 +116,7 @@ struct SimpleValue { isa<CmpInst>(Inst) || isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) || - isa<InsertValueInst>(Inst); + isa<InsertValueInst>(Inst) || isa<FreezeInst>(Inst); } }; @@ -152,13 +154,50 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A, std::swap(A, B); } - // Set flavor if we find a match, or set it to unknown otherwise; in - // either case, return true to indicate that this is a select we can - // process. - if (auto *CmpI = dyn_cast<ICmpInst>(Cond)) - Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor; - else - Flavor = SPF_UNKNOWN; + // Match canonical forms of abs/nabs/min/max. We are not using ValueTracking's + // more powerful matchSelectPattern() because it may rely on instruction flags + // such as "nsw". That would be incompatible with the current hashing + // mechanism that may remove flags to increase the likelihood of CSE. + + // These are the canonical forms of abs(X) and nabs(X) created by instcombine: + // %N = sub i32 0, %X + // %C = icmp slt i32 %X, 0 + // %ABS = select i1 %C, i32 %N, i32 %X + // + // %N = sub i32 0, %X + // %C = icmp slt i32 %X, 0 + // %NABS = select i1 %C, i32 %X, i32 %N + Flavor = SPF_UNKNOWN; + CmpInst::Predicate Pred; + if (match(Cond, m_ICmp(Pred, m_Specific(B), m_ZeroInt())) && + Pred == ICmpInst::ICMP_SLT && match(A, m_Neg(m_Specific(B)))) { + // ABS: B < 0 ? -B : B + Flavor = SPF_ABS; + return true; + } + if (match(Cond, m_ICmp(Pred, m_Specific(A), m_ZeroInt())) && + Pred == ICmpInst::ICMP_SLT && match(B, m_Neg(m_Specific(A)))) { + // NABS: A < 0 ? A : -A + Flavor = SPF_NABS; + return true; + } + + if (!match(Cond, m_ICmp(Pred, m_Specific(A), m_Specific(B)))) { + // Check for commuted variants of min/max by swapping predicate. + // If we do not match the standard or commuted patterns, this is not a + // recognized form of min/max, but it is still a select, so return true. + if (!match(Cond, m_ICmp(Pred, m_Specific(B), m_Specific(A)))) + return true; + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + switch (Pred) { + case CmpInst::ICMP_UGT: Flavor = SPF_UMAX; break; + case CmpInst::ICMP_ULT: Flavor = SPF_UMIN; break; + case CmpInst::ICMP_SGT: Flavor = SPF_SMAX; break; + case CmpInst::ICMP_SLT: Flavor = SPF_SMIN; break; + default: break; + } return true; } @@ -231,6 +270,9 @@ static unsigned getHashValueImpl(SimpleValue Val) { if (CastInst *CI = dyn_cast<CastInst>(Inst)) return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0)); + if (FreezeInst *FI = dyn_cast<FreezeInst>(Inst)) + return hash_combine(FI->getOpcode(), FI->getOperand(0)); + if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(Inst)) return hash_combine(EVI->getOpcode(), EVI->getOperand(0), hash_combine_range(EVI->idx_begin(), EVI->idx_end())); @@ -242,7 +284,8 @@ static unsigned getHashValueImpl(SimpleValue Val) { assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) || isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || - isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst)) && + isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst) || + isa<FreezeInst>(Inst)) && "Invalid/unknown instruction"); // Mix in the opcode. @@ -414,6 +457,14 @@ template <> struct DenseMapInfo<CallValue> { unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) { Instruction *Inst = Val.Inst; + + // gc.relocate is 'special' call: its second and third operands are + // not real values, but indices into statepoint's argument list. + // Get values they point to. + if (const GCRelocateInst *GCR = dyn_cast<GCRelocateInst>(Inst)) + return hash_combine(GCR->getOpcode(), GCR->getOperand(0), + GCR->getBasePtr(), GCR->getDerivedPtr()); + // Hash all of the operands as pointers and mix in the opcode. return hash_combine( Inst->getOpcode(), @@ -424,6 +475,14 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; if (LHS.isSentinel() || RHS.isSentinel()) return LHSI == RHSI; + + // See comment above in `getHashValue()`. + if (const GCRelocateInst *GCR1 = dyn_cast<GCRelocateInst>(LHSI)) + if (const GCRelocateInst *GCR2 = dyn_cast<GCRelocateInst>(RHSI)) + return GCR1->getOperand(0) == GCR2->getOperand(0) && + GCR1->getBasePtr() == GCR2->getBasePtr() && + GCR1->getDerivedPtr() == GCR2->getDerivedPtr(); + return LHSI->isIdenticalTo(RHSI); } @@ -561,8 +620,8 @@ private: public: StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, - unsigned cg, DomTreeNode *n, DomTreeNode::iterator child, - DomTreeNode::iterator end) + unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child, + DomTreeNode::const_iterator end) : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableInvariants, @@ -576,7 +635,7 @@ private: unsigned childGeneration() { return ChildGeneration; } void childGeneration(unsigned generation) { ChildGeneration = generation; } DomTreeNode *node() { return Node; } - DomTreeNode::iterator childIter() { return ChildIter; } + DomTreeNode::const_iterator childIter() { return ChildIter; } DomTreeNode *nextChild() { DomTreeNode *child = *ChildIter; @@ -584,7 +643,7 @@ private: return child; } - DomTreeNode::iterator end() { return EndIter; } + DomTreeNode::const_iterator end() { return EndIter; } bool isProcessed() { return Processed; } void process() { Processed = true; } @@ -592,8 +651,8 @@ private: unsigned CurrentGeneration; unsigned ChildGeneration; DomTreeNode *Node; - DomTreeNode::iterator ChildIter; - DomTreeNode::iterator EndIter; + DomTreeNode::const_iterator ChildIter; + DomTreeNode::const_iterator EndIter; NodeScope Scopes; bool Processed = false; }; @@ -716,7 +775,7 @@ private: bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); - void removeMSSA(Instruction *Inst) { + void removeMSSA(Instruction &Inst) { if (!MSSA) return; if (VerifyMemorySSA) @@ -727,7 +786,7 @@ private: // is handled by MemorySSA when passing OptimizePhis = true to // removeMemoryAccess. The non-optimized MemoryUse case is lazily updated // by MemorySSA's getClobberingMemoryAccess. - MSSAUpdater->removeMemoryAccess(Inst, true); + MSSAUpdater->removeMemoryAccess(&Inst, true); } }; @@ -897,20 +956,19 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // See if any instructions in the block can be eliminated. If so, do it. If // not, add them to AvailableValues. - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { - Instruction *Inst = &*I++; - + for (Instruction &Inst : make_early_inc_range(BB->getInstList())) { // Dead instructions should just be removed. - if (isInstructionTriviallyDead(Inst, &TLI)) { - LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); + if (isInstructionTriviallyDead(&Inst, &TLI)) { + LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << Inst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - salvageDebugInfoOrMarkUndef(*Inst); + salvageKnowledge(&Inst, &AC); + salvageDebugInfo(Inst); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; ++NumSimplify; continue; @@ -920,21 +978,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // they're marked as such to ensure preservation of control dependencies), // and this pass will not bother with its removal. However, we should mark // its condition as true for all dominated blocks. - if (match(Inst, m_Intrinsic<Intrinsic::assume>())) { + if (match(&Inst, m_Intrinsic<Intrinsic::assume>())) { auto *CondI = - dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0)); + dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0)); if (CondI && SimpleValue::canHandle(CondI)) { - LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst + LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << Inst << '\n'); AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); } else - LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << Inst << '\n'); continue; } // Skip sideeffect intrinsics, for the same reason as assume intrinsics. - if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) { - LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); + if (match(&Inst, m_Intrinsic<Intrinsic::sideeffect>())) { + LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << Inst << '\n'); continue; } @@ -951,21 +1009,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // store 40, i8* p // We can DSE the store to 30, since the store 40 to invariant location p // causes undefined behaviour. - if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) { + if (match(&Inst, m_Intrinsic<Intrinsic::invariant_start>())) { // If there are any uses, the scope might end. - if (!Inst->use_empty()) + if (!Inst.use_empty()) continue; - auto *CI = cast<CallInst>(Inst); - MemoryLocation MemLoc = MemoryLocation::getForArgument(CI, 1, TLI); + MemoryLocation MemLoc = + MemoryLocation::getForArgument(&cast<CallInst>(Inst), 1, TLI); // Don't start a scope if we already have a better one pushed if (!AvailableInvariants.count(MemLoc)) AvailableInvariants.insert(MemLoc, CurrentGeneration); continue; } - if (isGuard(Inst)) { + if (isGuard(&Inst)) { if (auto *CondI = - dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { + dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0))) { if (SimpleValue::canHandle(CondI)) { // Do we already know the actual value of this condition? if (auto *KnownCond = AvailableValues.lookup(CondI)) { @@ -973,14 +1031,15 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (isa<ConstantInt>(KnownCond) && cast<ConstantInt>(KnownCond)->isOne()) { LLVM_DEBUG(dbgs() - << "EarlyCSE removing guard: " << *Inst << '\n'); + << "EarlyCSE removing guard: " << Inst << '\n'); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; continue; } else // Use the known value if it wasn't true. - cast<CallInst>(Inst)->setArgOperand(0, KnownCond); + cast<CallInst>(Inst).setArgOperand(0, KnownCond); } // The condition we're on guarding here is true for all dominated // locations. @@ -997,20 +1056,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. - if (Value *V = SimplifyInstruction(Inst, SQ)) { - LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V + if (Value *V = SimplifyInstruction(&Inst, SQ)) { + LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << Inst << " to: " << *V << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); } else { bool Killed = false; - if (!Inst->use_empty()) { - Inst->replaceAllUsesWith(V); + if (!Inst.use_empty()) { + Inst.replaceAllUsesWith(V); Changed = true; } - if (isInstructionTriviallyDead(Inst, &TLI)) { + if (isInstructionTriviallyDead(&Inst, &TLI)) { + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; Killed = true; } @@ -1022,31 +1082,32 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { } // If this is a simple instruction that we can value number, process it. - if (SimpleValue::canHandle(Inst)) { + if (SimpleValue::canHandle(&Inst)) { // See if the instruction has an available value. If so, use it. - if (Value *V = AvailableValues.lookup(Inst)) { - LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V + if (Value *V = AvailableValues.lookup(&Inst)) { + LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << Inst << " to: " << *V << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } if (auto *I = dyn_cast<Instruction>(V)) - I->andIRFlags(Inst); - Inst->replaceAllUsesWith(V); + I->andIRFlags(&Inst); + Inst.replaceAllUsesWith(V); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; ++NumCSE; continue; } // Otherwise, just remember that this value is available. - AvailableValues.insert(Inst, Inst); + AvailableValues.insert(&Inst, &Inst); continue; } - ParseMemoryInst MemInst(Inst, TTI); + ParseMemoryInst MemInst(&Inst, TTI); // If this is a non-volatile load, process it. if (MemInst.isValid() && MemInst.isLoad()) { // (conservatively) we can't peak past the ordering implied by this @@ -1062,7 +1123,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // We conservatively treat the invariant_load as that moment. If we // pass a invariant load after already establishing a scope, don't // restart it since we want to preserve the earliest point seen. - auto MemLoc = MemoryLocation::get(Inst); + auto MemLoc = MemoryLocation::get(&Inst); if (!AvailableInvariants.count(MemLoc)) AvailableInvariants.insert(MemLoc, CurrentGeneration); } @@ -1081,21 +1142,22 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { !MemInst.isVolatile() && MemInst.isUnordered() && // We can't replace an atomic load with one which isn't also atomic. InVal.IsAtomic >= MemInst.isAtomic() && - (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || + (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) || isSameMemGeneration(InVal.Generation, CurrentGeneration, - InVal.DefInst, Inst))) { - Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); + InVal.DefInst, &Inst))) { + Value *Op = getOrCreateResult(InVal.DefInst, Inst.getType()); if (Op != nullptr) { - LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst + LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst << " to: " << *InVal.DefInst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - if (!Inst->use_empty()) - Inst->replaceAllUsesWith(Op); + if (!Inst.use_empty()) + Inst.replaceAllUsesWith(Op); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; ++NumCSELoad; continue; @@ -1103,10 +1165,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { } // Otherwise, remember that we have this instruction. - AvailableLoads.insert( - MemInst.getPointerOperand(), - LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + AvailableLoads.insert(MemInst.getPointerOperand(), + LoadValue(&Inst, CurrentGeneration, + MemInst.getMatchingId(), + MemInst.isAtomic())); LastStore = nullptr; continue; } @@ -1117,36 +1179,36 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // may override this (e.g. so that a store intrinsic does not read from // memory, and thus will be treated the same as a regular store for // commoning purposes). - if ((Inst->mayReadFromMemory() || Inst->mayThrow()) && + if ((Inst.mayReadFromMemory() || Inst.mayThrow()) && !(MemInst.isValid() && !MemInst.mayReadFromMemory())) LastStore = nullptr; // If this is a read-only call, process it. - if (CallValue::canHandle(Inst)) { + if (CallValue::canHandle(&Inst)) { // If we have an available version of this call, and if it is the right // generation, replace this instruction. - std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst); + std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(&Inst); if (InVal.first != nullptr && isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, - Inst)) { - LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst + &Inst)) { + LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << Inst << " to: " << *InVal.first << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - if (!Inst->use_empty()) - Inst->replaceAllUsesWith(InVal.first); + if (!Inst.use_empty()) + Inst.replaceAllUsesWith(InVal.first); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; ++NumCSECall; continue; } // Otherwise, remember that we have this instruction. - AvailableCalls.insert( - Inst, std::pair<Instruction *, unsigned>(Inst, CurrentGeneration)); + AvailableCalls.insert(&Inst, std::make_pair(&Inst, CurrentGeneration)); continue; } @@ -1155,9 +1217,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // result, we don't need to consider it as writing to memory and don't need // to advance the generation. We do need to prevent DSE across the fence, // but that's handled above. - if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + if (auto *FI = dyn_cast<FenceInst>(&Inst)) if (FI->getOrdering() == AtomicOrdering::Release) { - assert(Inst->mayReadFromMemory() && "relied on to prevent DSE above"); + assert(Inst.mayReadFromMemory() && "relied on to prevent DSE above"); continue; } @@ -1169,13 +1231,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (MemInst.isValid() && MemInst.isStore()) { LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); if (InVal.DefInst && - InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) && + InVal.DefInst == getOrCreateResult(&Inst, InVal.DefInst->getType()) && InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing stores with ordering of any kind. !MemInst.isVolatile() && MemInst.isUnordered() && - (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || + (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) || isSameMemGeneration(InVal.Generation, CurrentGeneration, - InVal.DefInst, Inst))) { + InVal.DefInst, &Inst))) { // It is okay to have a LastStore to a different pointer here if MemorySSA // tells us that the load and store are from the same memory generation. // In that case, LastStore should keep its present value since we're @@ -1185,13 +1247,14 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { MemInst.getPointerOperand() || MSSA) && "can't have an intervening store if not using MemorySSA!"); - LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << Inst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Changed = true; ++NumDSE; // We can avoid incrementing the generation count since we were able @@ -1203,7 +1266,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Okay, this isn't something we can CSE at all. Check to see if it is // something that could modify memory. If so, our available memory values // cannot be used so bump the generation count. - if (Inst->mayWriteToMemory()) { + if (Inst.mayWriteToMemory()) { ++CurrentGeneration; if (MemInst.isValid() && MemInst.isStore()) { @@ -1221,11 +1284,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { "Violated invariant"); if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore - << " due to: " << *Inst << '\n'); + << " due to: " << Inst << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); } else { - removeMSSA(LastStore); + salvageKnowledge(&Inst, &AC); + removeMSSA(*LastStore); LastStore->eraseFromParent(); Changed = true; ++NumDSE; @@ -1240,10 +1304,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // version of the pointer. It is safe to forward from volatile stores // to non-volatile loads, so we don't have to check for volatility of // the store. - AvailableLoads.insert( - MemInst.getPointerOperand(), - LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + AvailableLoads.insert(MemInst.getPointerOperand(), + LoadValue(&Inst, CurrentGeneration, + MemInst.getMatchingId(), + MemInst.isAtomic())); // Remember that this was the last unordered store we saw for DSE. We // don't yet handle DSE on ordered or volatile stores since we don't @@ -1253,7 +1317,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // it's not clear this is a profitable transform. Another option would // be to merge the ordering with that of the post dominating store. if (MemInst.isUnordered() && !MemInst.isVolatile()) - LastStore = Inst; + LastStore = &Inst; else LastStore = nullptr; } diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index af223cc837f2..83f4c402ed4d 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -120,8 +120,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) { // Find the roots - instructions that convert from the FP domain to // integer domain. -void Float2IntPass::findRoots(Function &F, const DominatorTree &DT, - SmallPtrSet<Instruction*,8> &Roots) { +void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) { for (BasicBlock &BB : F) { // Unreachable code can take on strange forms that we are not prepared to // handle. For example, an instruction may have itself as an operand. @@ -184,7 +183,7 @@ ConstantRange Float2IntPass::validateRange(ConstantRange R) { // Breadth-first walk of the use-def graph; determine the set of nodes // we care about and eagerly determine if some of them are poisonous. -void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { +void Float2IntPass::walkBackwards() { std::deque<Instruction*> Worklist(Roots.begin(), Roots.end()); while (!Worklist.empty()) { Instruction *I = Worklist.back(); @@ -327,7 +326,7 @@ void Float2IntPass::walkForwards() { APFloat NewF = F; auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); - if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) { + if (Res != APFloat::opOK || NewF != F) { seen(I, badRange()); Abort = true; break; @@ -525,9 +524,9 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) { Ctx = &F.getParent()->getContext(); - findRoots(F, DT, Roots); + findRoots(F, DT); - walkBackwards(Roots); + walkBackwards(); walkForwards(); bool Modified = validateAndTransform(); diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 1e6aab14e7b4..b16f8591b5a4 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" @@ -42,7 +43,6 @@ #include "llvm/Config/llvm-config.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -72,6 +72,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -97,10 +98,11 @@ STATISTIC(NumGVNSimpl, "Number of instructions simplified"); STATISTIC(NumGVNEqProp, "Number of equalities propagated"); STATISTIC(NumPRELoad, "Number of loads PRE'd"); -static cl::opt<bool> EnablePRE("enable-pre", - cl::init(true), cl::Hidden); -static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true)); -static cl::opt<bool> EnableMemDep("enable-gvn-memdep", cl::init(true)); +static cl::opt<bool> GVNEnablePRE("enable-pre", cl::init(true), cl::Hidden); +static cl::opt<bool> GVNEnableLoadPRE("enable-load-pre", cl::init(true)); +static cl::opt<bool> GVNEnableLoadInLoopPRE("enable-load-in-loop-pre", + cl::init(true)); +static cl::opt<bool> GVNEnableMemDep("enable-gvn-memdep", cl::init(true)); // Maximum allowed recursion depth. static cl::opt<uint32_t> @@ -113,8 +115,8 @@ static cl::opt<uint32_t> MaxNumDeps( struct llvm::GVN::Expression { uint32_t opcode; - Type *type = nullptr; bool commutative = false; + Type *type = nullptr; SmallVector<uint32_t, 4> varargs; Expression(uint32_t o = ~2U) : opcode(o) {} @@ -288,7 +290,7 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { e.commutative = true; } - if (CmpInst *C = dyn_cast<CmpInst>(I)) { + if (auto *C = dyn_cast<CmpInst>(I)) { // Sort the operand value numbers so x<y and y>x get the same value number. CmpInst::Predicate Predicate = C->getPredicate(); if (e.varargs[0] > e.varargs[1]) { @@ -297,10 +299,11 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { } e.opcode = (C->getOpcode() << 8) | Predicate; e.commutative = true; - } else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) { - for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end(); - II != IE; ++II) - e.varargs.push_back(*II); + } else if (auto *E = dyn_cast<InsertValueInst>(I)) { + e.varargs.append(E->idx_begin(), E->idx_end()); + } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { + ArrayRef<int> ShuffleMask = SVI->getShuffleMask(); + e.varargs.append(ShuffleMask.begin(), ShuffleMask.end()); } return e; @@ -530,6 +533,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::Select: + case Instruction::Freeze: case Instruction::ExtractElement: case Instruction::InsertElement: case Instruction::ShuffleVector: @@ -610,6 +614,22 @@ void GVN::ValueTable::verifyRemoved(const Value *V) const { // GVN Pass //===----------------------------------------------------------------------===// +bool GVN::isPREEnabled() const { + return Options.AllowPRE.getValueOr(GVNEnablePRE); +} + +bool GVN::isLoadPREEnabled() const { + return Options.AllowLoadPRE.getValueOr(GVNEnableLoadPRE); +} + +bool GVN::isLoadInLoopPREEnabled() const { + return Options.AllowLoadInLoopPRE.getValueOr(GVNEnableLoadInLoopPRE); +} + +bool GVN::isMemDepEnabled() const { + return Options.AllowMemDep.getValueOr(GVNEnableMemDep); +} + PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { // FIXME: The order of evaluation of these 'getResult' calls is very // significant! Re-ordering these variables will cause GVN when run alone to @@ -619,10 +639,11 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F); + auto *MemDep = + isMemDepEnabled() ? &AM.getResult<MemoryDependenceAnalysis>(F) : nullptr; auto *LI = AM.getCachedResult<LoopAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep, LI, &ORE); + bool Changed = runImpl(F, AC, DT, TLI, AA, MemDep, LI, &ORE); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; @@ -927,6 +948,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // Loading the allocation -> undef. if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || + isAlignedAllocLikeFn(DepInst, TLI) || // Loading immediately after lifetime begin -> undef. isLifetimeStart(DepInst)) { Res = AvailableValue::get(UndefValue::get(LI->getType())); @@ -1245,7 +1267,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, auto *NewLoad = new LoadInst( LI->getType(), LoadPtr, LI->getName() + ".pre", LI->isVolatile(), - MaybeAlign(LI->getAlignment()), LI->getOrdering(), LI->getSyncScopeID(), + LI->getAlign(), LI->getOrdering(), LI->getSyncScopeID(), UnavailablePred->getTerminator()); NewLoad->setDebugLoc(LI->getDebugLoc()); @@ -1383,7 +1405,10 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { } // Step 4: Eliminate partial redundancy. - if (!EnablePRE || !EnableLoadPRE) + if (!isPREEnabled() || !isLoadPREEnabled()) + return false; + if (!isLoadInLoopPREEnabled() && this->LI && + this->LI->getLoopFor(LI->getParent())) return false; return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); @@ -1428,7 +1453,7 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) { Value *LHS = Cmp->getOperand(0); Value *RHS = Cmp->getOperand(1); // If we can prove either side non-zero, then equality must imply - // equivalence. + // equivalence. // FIXME: We should do this optimization if 'no signed zeros' is // applicable via an instruction-level fast-math-flag or some other // indicator that relaxed FP semantics are being used. @@ -1465,7 +1490,8 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { Constant::getNullValue(Int8Ty->getPointerTo()), IntrinsicI); } - markInstructionForDeletion(IntrinsicI); + if (isAssumeWithEmptyBundle(*IntrinsicI)) + markInstructionForDeletion(IntrinsicI); return false; } else if (isa<Constant>(V)) { // If it's not false, and constant, it must evaluate to true. This means our @@ -1493,10 +1519,10 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { // If we find an equality fact, canonicalize all dominated uses in this block // to one of the two values. We heuristically choice the "oldest" of the // two where age is determined by value number. (Note that propagateEquality - // above handles the cross block case.) - // + // above handles the cross block case.) + // // Key case to cover are: - // 1) + // 1) // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen // call void @llvm.assume(i1 %cmp) // ret float %0 ; will change it to ret float 3.000000e+00 @@ -1537,7 +1563,7 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { << *CmpLHS << " with " << *CmpRHS << " in block " << IntrinsicI->getParent()->getName() << "\n"); - + // Setup the replacement map - this handles uses within the same block if (hasUsersIn(CmpLHS, IntrinsicI->getParent())) @@ -1710,7 +1736,8 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, // instead of value numbers. Those index numbers should not be // translated. if ((i > 1 && Exp.opcode == Instruction::InsertValue) || - (i > 0 && Exp.opcode == Instruction::ExtractValue)) + (i > 0 && Exp.opcode == Instruction::ExtractValue) || + (i > 1 && Exp.opcode == Instruction::ShuffleVector)) continue; Exp.varargs[i] = phiTranslate(Pred, PhiBlock, Exp.varargs[i], Gvn); } @@ -1802,7 +1829,7 @@ void GVN::assignBlockRPONumber(Function &F) { bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { bool Changed = false; for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { - Value *Operand = Instr->getOperand(OpNum); + Value *Operand = Instr->getOperand(OpNum); auto it = ReplaceOperandsWithMap.find(Operand); if (it != ReplaceOperandsWithMap.end()) { LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " @@ -1922,7 +1949,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, // If "A == B" is known true, or "A != B" is known false, then replace // A with B everywhere in the scope. For floating point operations, we - // have to be careful since equality does not always imply equivalance. + // have to be careful since equality does not always imply equivalance. if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) || (isKnownFalse && impliesEquivalanceIfFalse(Cmp))) Worklist.push_back(std::make_pair(Op0, Op1)); @@ -2117,7 +2144,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, TLI = &RunTLI; VN.setAliasAnalysis(&RunAA); MD = RunMD; - ImplicitControlFlowTracking ImplicitCFT(DT); + ImplicitControlFlowTracking ImplicitCFT; ICF = &ImplicitCFT; this->LI = LI; VN.setMemDep(MD); @@ -2148,7 +2175,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, ++Iteration; } - if (EnablePRE) { + if (isPREEnabled()) { // Fabricate val-num for dead-code in order to suppress assertion in // performPRE(). assignValNumForDeadCode(); @@ -2206,6 +2233,7 @@ bool GVN::processBlock(BasicBlock *BB) { for (auto *I : InstrsToErase) { assert(I->getParent() == BB && "Removing instruction from wrong block?"); LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n'); + salvageKnowledge(I, AC); salvageDebugInfo(*I); if (MD) MD->removeInstruction(I); LLVM_DEBUG(verifyRemoved(I)); @@ -2478,8 +2506,11 @@ bool GVN::performPRE(Function &F) { /// Split the critical edge connecting the given two blocks, and return /// the block inserted to the critical edge. BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { - BasicBlock *BB = - SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT, LI)); + // GVN does not require loop-simplify, do not try to preserve it if it is not + // possible. + BasicBlock *BB = SplitCriticalEdge( + Pred, Succ, + CriticalEdgeSplittingOptions(DT, LI).unsetPreserveLoopSimplify()); if (MD) MD->invalidateCachedPredecessors(); InvalidBlockRPONumbers = true; @@ -2682,8 +2713,8 @@ class llvm::gvn::GVNLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid - explicit GVNLegacyPass(bool NoMemDepAnalysis = !EnableMemDep) - : FunctionPass(ID), NoMemDepAnalysis(NoMemDepAnalysis) { + explicit GVNLegacyPass(bool NoMemDepAnalysis = !GVNEnableMemDep) + : FunctionPass(ID), Impl(GVNOptions().setMemDep(!NoMemDepAnalysis)) { initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2698,9 +2729,9 @@ public: getAnalysis<DominatorTreeWrapperPass>().getDomTree(), getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), getAnalysis<AAResultsWrapperPass>().getAAResults(), - NoMemDepAnalysis - ? nullptr - : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), + Impl.isMemDepEnabled() + ? &getAnalysis<MemoryDependenceWrapperPass>().getMemDep() + : nullptr, LIWP ? &LIWP->getLoopInfo() : nullptr, &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); } @@ -2710,7 +2741,7 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); - if (!NoMemDepAnalysis) + if (Impl.isMemDepEnabled()) AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); @@ -2718,12 +2749,10 @@ public: AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<TargetLibraryInfoWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreservedID(LoopSimplifyID); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } private: - bool NoMemDepAnalysis; GVN Impl; }; diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp index e1796f6bf05a..9c4cdf2feb56 100644 --- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -890,18 +890,16 @@ private: void updateAlignment(Instruction *I, Instruction *Repl) { if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { - ReplacementLoad->setAlignment(MaybeAlign(std::min( - ReplacementLoad->getAlignment(), cast<LoadInst>(I)->getAlignment()))); + ReplacementLoad->setAlignment( + std::min(ReplacementLoad->getAlign(), cast<LoadInst>(I)->getAlign())); ++NumLoadsRemoved; } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { - ReplacementStore->setAlignment( - MaybeAlign(std::min(ReplacementStore->getAlignment(), - cast<StoreInst>(I)->getAlignment()))); + ReplacementStore->setAlignment(std::min(ReplacementStore->getAlign(), + cast<StoreInst>(I)->getAlign())); ++NumStoresRemoved; } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { - ReplacementAlloca->setAlignment( - MaybeAlign(std::max(ReplacementAlloca->getAlignment(), - cast<AllocaInst>(I)->getAlignment()))); + ReplacementAlloca->setAlignment(std::max( + ReplacementAlloca->getAlign(), cast<AllocaInst>(I)->getAlign())); } else if (isa<CallInst>(Repl)) { ++NumCallsRemoved; } diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index 6d0a4975e266..dfb4b7e038ba 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -350,6 +350,7 @@ using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>; class InstructionUseExpr : public GVNExpression::BasicExpression { unsigned MemoryUseOrder = -1; bool Volatile = false; + ArrayRef<int> ShuffleMask; public: InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R, @@ -359,6 +360,9 @@ public: setOpcode(I->getOpcode()); setType(I->getType()); + if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I)) + ShuffleMask = SVI->getShuffleMask().copy(A); + for (auto &U : I->uses()) op_push_back(U.getUser()); llvm::sort(op_begin(), op_end()); @@ -369,12 +373,12 @@ public: hash_code getHashValue() const override { return hash_combine(GVNExpression::BasicExpression::getHashValue(), - MemoryUseOrder, Volatile); + MemoryUseOrder, Volatile, ShuffleMask); } template <typename Function> hash_code getHashValue(Function MapFn) { - hash_code H = - hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile); + hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile, + ShuffleMask); for (auto *V : operands()) H = hash_combine(H, MapFn(V)); return H; @@ -475,6 +479,7 @@ public: case Instruction::PtrToInt: case Instruction::IntToPtr: case Instruction::BitCast: + case Instruction::AddrSpaceCast: case Instruction::Select: case Instruction::ExtractElement: case Instruction::InsertElement: @@ -576,7 +581,7 @@ public: private: ValueTable VN; - bool isInstructionBlacklisted(Instruction *I) { + bool shouldAvoidSinkingInstruction(Instruction *I) { // These instructions may change or break semantics if moved. if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || I->getType()->isTokenTy()) @@ -668,7 +673,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( NewInsts.push_back(I); } for (auto *I : NewInsts) - if (isInstructionBlacklisted(I)) + if (shouldAvoidSinkingInstruction(I)) return None; // If we've restricted the incoming blocks, restrict all needed PHIs also diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index d8d7acae5c9f..0f36c3f772e6 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -38,8 +38,9 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -81,6 +82,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" #include <cassert> #include <cstdint> @@ -100,10 +102,10 @@ STATISTIC(NumElimIV , "Number of congruent IVs eliminated"); // implement a strong expression equivalence checker in SCEV. Until then, we // use the verify-indvars flag, which may assert in some cases. static cl::opt<bool> VerifyIndvars( - "verify-indvars", cl::Hidden, - cl::desc("Verify the ScalarEvolution result after running indvars")); - -enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, NoHardUse, AlwaysRepl }; + "verify-indvars", cl::Hidden, + cl::desc("Verify the ScalarEvolution result after running indvars. Has no " + "effect in release builds. (Note: this adds additional SCEV " + "queries potentially changing the analysis result)")); static cl::opt<ReplaceExitVal> ReplaceExitValue( "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), @@ -140,11 +142,10 @@ class IndVarSimplify { const DataLayout &DL; TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; + std::unique_ptr<MemorySSAUpdater> MSSAU; SmallVector<WeakTrackingVH, 16> DeadInsts; - bool isValidRewrite(Value *FromVal, Value *ToVal); - bool handleFloatingPointIV(Loop *L, PHINode *PH); bool rewriteNonIntegerIVs(Loop *L); @@ -155,10 +156,7 @@ class IndVarSimplify { /// iterations of the loop run when that is unobservable. bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); - bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); - bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); bool rewriteFirstIterationLoopExitValues(Loop *L); - bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const; bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, const SCEV *ExitCount, @@ -169,66 +167,17 @@ class IndVarSimplify { public: IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, const DataLayout &DL, TargetLibraryInfo *TLI, - TargetTransformInfo *TTI) - : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {} + TargetTransformInfo *TTI, MemorySSA *MSSA) + : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) { + if (MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); + } bool run(Loop *L); }; } // end anonymous namespace -/// Return true if the SCEV expansion generated by the rewriter can replace the -/// original value. SCEV guarantees that it produces the same value, but the way -/// it is produced may be illegal IR. Ideally, this function will only be -/// called for verification. -bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { - // If an SCEV expression subsumed multiple pointers, its expansion could - // reassociate the GEP changing the base pointer. This is illegal because the - // final address produced by a GEP chain must be inbounds relative to its - // underlying object. Otherwise basic alias analysis, among other things, - // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid - // producing an expression involving multiple pointers. Until then, we must - // bail out here. - // - // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject - // because it understands lcssa phis while SCEV does not. - Value *FromPtr = FromVal; - Value *ToPtr = ToVal; - if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) { - FromPtr = GEP->getPointerOperand(); - } - if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) { - ToPtr = GEP->getPointerOperand(); - } - if (FromPtr != FromVal || ToPtr != ToVal) { - // Quickly check the common case - if (FromPtr == ToPtr) - return true; - - // SCEV may have rewritten an expression that produces the GEP's pointer - // operand. That's ok as long as the pointer operand has the same base - // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the - // base of a recurrence. This handles the case in which SCEV expansion - // converts a pointer type recurrence into a nonrecurrent pointer base - // indexed by an integer recurrence. - - // If the GEP base pointer is a vector of pointers, abort. - if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) - return false; - - const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); - const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); - if (FromBase == ToBase) - return true; - - LLVM_DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " << *FromBase - << " != " << *ToBase << "\n"); - - return false; - } - return true; -} - /// Determine the insertion point for this user. By default, insert immediately /// before the user. SCEVExpander or LICM will hoist loop invariants out of the /// loop. For PHI nodes, there may be multiple uses, so compute the nearest @@ -477,11 +426,11 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // new comparison. NewCompare->takeName(Compare); Compare->replaceAllUsesWith(NewCompare); - RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI); + RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get()); // Delete the old floating point increment. Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); - RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI); + RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI, MSSAU.get()); // If the FP induction variable still has uses, this is because something else // in the loop uses its value. In order to canonicalize the induction @@ -494,7 +443,7 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv", &*PN->getParent()->getFirstInsertionPt()); PN->replaceAllUsesWith(Conv); - RecursivelyDeleteTriviallyDeadInstructions(PN, TLI); + RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get()); } return true; } @@ -522,222 +471,6 @@ bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { return Changed; } -namespace { - -// Collect information about PHI nodes which can be transformed in -// rewriteLoopExitValues. -struct RewritePhi { - PHINode *PN; - - // Ith incoming value. - unsigned Ith; - - // Exit value after expansion. - Value *Val; - - // High Cost when expansion. - bool HighCost; - - RewritePhi(PHINode *P, unsigned I, Value *V, bool H) - : PN(P), Ith(I), Val(V), HighCost(H) {} -}; - -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -// rewriteLoopExitValues - Optimize IV users outside the loop. -// As a side effect, reduces the amount of IV processing within the loop. -//===----------------------------------------------------------------------===// - -bool IndVarSimplify::hasHardUserWithinLoop(const Loop *L, const Instruction *I) const { - SmallPtrSet<const Instruction *, 8> Visited; - SmallVector<const Instruction *, 8> WorkList; - Visited.insert(I); - WorkList.push_back(I); - while (!WorkList.empty()) { - const Instruction *Curr = WorkList.pop_back_val(); - // This use is outside the loop, nothing to do. - if (!L->contains(Curr)) - continue; - // Do we assume it is a "hard" use which will not be eliminated easily? - if (Curr->mayHaveSideEffects()) - return true; - // Otherwise, add all its users to worklist. - for (auto U : Curr->users()) { - auto *UI = cast<Instruction>(U); - if (Visited.insert(UI).second) - WorkList.push_back(UI); - } - } - return false; -} - -/// Check to see if this loop has a computable loop-invariant execution count. -/// If so, this means that we can compute the final value of any expressions -/// that are recurrent in the loop, and substitute the exit values from the loop -/// into any instructions outside of the loop that use the final values of the -/// current expressions. -/// -/// This is mostly redundant with the regular IndVarSimplify activities that -/// happen later, except that it's more powerful in some cases, because it's -/// able to brute-force evaluate arbitrary instructions as long as they have -/// constant operands at the beginning of the loop. -bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { - // Check a pre-condition. - assert(L->isRecursivelyLCSSAForm(*DT, *LI) && - "Indvars did not preserve LCSSA!"); - - SmallVector<BasicBlock*, 8> ExitBlocks; - L->getUniqueExitBlocks(ExitBlocks); - - SmallVector<RewritePhi, 8> RewritePhiSet; - // Find all values that are computed inside the loop, but used outside of it. - // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan - // the exit blocks of the loop to find them. - for (BasicBlock *ExitBB : ExitBlocks) { - // If there are no PHI nodes in this exit block, then no values defined - // inside the loop are used on this path, skip it. - PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); - if (!PN) continue; - - unsigned NumPreds = PN->getNumIncomingValues(); - - // Iterate over all of the PHI nodes. - BasicBlock::iterator BBI = ExitBB->begin(); - while ((PN = dyn_cast<PHINode>(BBI++))) { - if (PN->use_empty()) - continue; // dead use, don't replace it - - if (!SE->isSCEVable(PN->getType())) - continue; - - // It's necessary to tell ScalarEvolution about this explicitly so that - // it can walk the def-use list and forget all SCEVs, as it may not be - // watching the PHI itself. Once the new exit value is in place, there - // may not be a def-use connection between the loop and every instruction - // which got a SCEVAddRecExpr for that loop. - SE->forgetValue(PN); - - // Iterate over all of the values in all the PHI nodes. - for (unsigned i = 0; i != NumPreds; ++i) { - // If the value being merged in is not integer or is not defined - // in the loop, skip it. - Value *InVal = PN->getIncomingValue(i); - if (!isa<Instruction>(InVal)) - continue; - - // If this pred is for a subloop, not L itself, skip it. - if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) - continue; // The Block is in a subloop, skip it. - - // Check that InVal is defined in the loop. - Instruction *Inst = cast<Instruction>(InVal); - if (!L->contains(Inst)) - continue; - - // Okay, this instruction has a user outside of the current loop - // and varies predictably *inside* the loop. Evaluate the value it - // contains when the loop exits, if possible. We prefer to start with - // expressions which are true for all exits (so as to maximize - // expression reuse by the SCEVExpander), but resort to per-exit - // evaluation if that fails. - const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); - if (isa<SCEVCouldNotCompute>(ExitValue) || - !SE->isLoopInvariant(ExitValue, L) || - !isSafeToExpand(ExitValue, *SE)) { - // TODO: This should probably be sunk into SCEV in some way; maybe a - // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for - // most SCEV expressions and other recurrence types (e.g. shift - // recurrences). Is there existing code we can reuse? - const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i)); - if (isa<SCEVCouldNotCompute>(ExitCount)) - continue; - if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst))) - if (AddRec->getLoop() == L) - ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE); - if (isa<SCEVCouldNotCompute>(ExitValue) || - !SE->isLoopInvariant(ExitValue, L) || - !isSafeToExpand(ExitValue, *SE)) - continue; - } - - // Computing the value outside of the loop brings no benefit if it is - // definitely used inside the loop in a way which can not be optimized - // away. Avoid doing so unless we know we have a value which computes - // the ExitValue already. TODO: This should be merged into SCEV - // expander to leverage its knowledge of existing expressions. - if (ReplaceExitValue != AlwaysRepl && - !isa<SCEVConstant>(ExitValue) && !isa<SCEVUnknown>(ExitValue) && - hasHardUserWithinLoop(L, Inst)) - continue; - - bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); - Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst); - - LLVM_DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal - << '\n' - << " LoopVal = " << *Inst << "\n"); - - if (!isValidRewrite(Inst, ExitVal)) { - DeadInsts.push_back(ExitVal); - continue; - } - -#ifndef NDEBUG - // If we reuse an instruction from a loop which is neither L nor one of - // its containing loops, we end up breaking LCSSA form for this loop by - // creating a new use of its instruction. - if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal)) - if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) - if (EVL != L) - assert(EVL->contains(L) && "LCSSA breach detected!"); -#endif - - // Collect all the candidate PHINodes to be rewritten. - RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); - } - } - } - - bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); - - bool Changed = false; - // Transformation. - for (const RewritePhi &Phi : RewritePhiSet) { - PHINode *PN = Phi.PN; - Value *ExitVal = Phi.Val; - - // Only do the rewrite when the ExitValue can be expanded cheaply. - // If LoopCanBeDel is true, rewrite exit value aggressively. - if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { - DeadInsts.push_back(ExitVal); - continue; - } - - Changed = true; - ++NumReplaced; - Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); - PN->setIncomingValue(Phi.Ith, ExitVal); - - // If this instruction is dead now, delete it. Don't do it now to avoid - // invalidating iterators. - if (isInstructionTriviallyDead(Inst, TLI)) - DeadInsts.push_back(Inst); - - // Replace PN with ExitVal if that is legal and does not break LCSSA. - if (PN->getNumIncomingValues() == 1 && - LI->replacementPreservesLCSSAForm(PN, ExitVal)) { - PN->replaceAllUsesWith(ExitVal); - PN->eraseFromParent(); - } - } - - // The insertion point instruction may have been deleted; clear it out - // so that the rewriter doesn't trip over it later. - Rewriter.clearInsertPoint(); - return Changed; -} - //===---------------------------------------------------------------------===// // rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know // they will exit at the first iteration. @@ -813,61 +546,6 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { return MadeAnyChanges; } -/// Check whether it is possible to delete the loop after rewriting exit -/// value. If it is possible, ignore ReplaceExitValue and do rewriting -/// aggressively. -bool IndVarSimplify::canLoopBeDeleted( - Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) { - BasicBlock *Preheader = L->getLoopPreheader(); - // If there is no preheader, the loop will not be deleted. - if (!Preheader) - return false; - - // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1. - // We obviate multiple ExitingBlocks case for simplicity. - // TODO: If we see testcase with multiple ExitingBlocks can be deleted - // after exit value rewriting, we can enhance the logic here. - SmallVector<BasicBlock *, 4> ExitingBlocks; - L->getExitingBlocks(ExitingBlocks); - SmallVector<BasicBlock *, 8> ExitBlocks; - L->getUniqueExitBlocks(ExitBlocks); - if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1) - return false; - - BasicBlock *ExitBlock = ExitBlocks[0]; - BasicBlock::iterator BI = ExitBlock->begin(); - while (PHINode *P = dyn_cast<PHINode>(BI)) { - Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); - - // If the Incoming value of P is found in RewritePhiSet, we know it - // could be rewritten to use a loop invariant value in transformation - // phase later. Skip it in the loop invariant check below. - bool found = false; - for (const RewritePhi &Phi : RewritePhiSet) { - unsigned i = Phi.Ith; - if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { - found = true; - break; - } - } - - Instruction *I; - if (!found && (I = dyn_cast<Instruction>(Incoming))) - if (!L->hasLoopInvariantOperands(I)) - return false; - - ++BI; - } - - for (auto *BB : L->blocks()) - if (llvm::any_of(*BB, [](Instruction &I) { - return I.mayHaveSideEffects(); - })) - return false; - - return true; -} - //===----------------------------------------------------------------------===// // IV Widening - Extend the width of an IV to cover its widest uses. //===----------------------------------------------------------------------===// @@ -1060,8 +738,8 @@ protected: Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); bool widenLoopCompare(NarrowIVDefUse DU); - bool widenWithVariantLoadUse(NarrowIVDefUse DU); - void widenWithVariantLoadUseCodegen(NarrowIVDefUse DU); + bool widenWithVariantUse(NarrowIVDefUse DU); + void widenWithVariantUseCodegen(NarrowIVDefUse DU); void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); }; @@ -1399,20 +1077,27 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { return true; } -/// If the narrow use is an instruction whose two operands are the defining -/// instruction of DU and a load instruction, then we have the following: -/// if the load is hoisted outside the loop, then we do not reach this function -/// as scalar evolution analysis works fine in widenIVUse with variables -/// hoisted outside the loop and efficient code is subsequently generated by -/// not emitting truncate instructions. But when the load is not hoisted -/// (whether due to limitation in alias analysis or due to a true legality), -/// then scalar evolution can not proceed with loop variant values and -/// inefficient code is generated. This function handles the non-hoisted load -/// special case by making the optimization generate the same type of code for -/// hoisted and non-hoisted load (widen use and eliminate sign extend -/// instruction). This special case is important especially when the induction -/// variables are affecting addressing mode in code generation. -bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) { +// The widenIVUse avoids generating trunc by evaluating the use as AddRec, this +// will not work when: +// 1) SCEV traces back to an instruction inside the loop that SCEV can not +// expand, eg. add %indvar, (load %addr) +// 2) SCEV finds a loop variant, eg. add %indvar, %loopvariant +// While SCEV fails to avoid trunc, we can still try to use instruction +// combining approach to prove trunc is not required. This can be further +// extended with other instruction combining checks, but for now we handle the +// following case (sub can be "add" and "mul", "nsw + sext" can be "nus + zext") +// +// Src: +// %c = sub nsw %b, %indvar +// %d = sext %c to i64 +// Dst: +// %indvar.ext1 = sext %indvar to i64 +// %m = sext %b to i64 +// %d = sub nsw i64 %m, %indvar.ext1 +// Therefore, as long as the result of add/sub/mul is extended to wide type, no +// trunc is required regardless of how %b is generated. This pattern is common +// when calculating address in 64 bit architecture +bool WidenIV::widenWithVariantUse(NarrowIVDefUse DU) { Instruction *NarrowUse = DU.NarrowUse; Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; @@ -1443,12 +1128,6 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) { else return false; - // We are interested in the other operand being a load instruction. - // But, we should look into relaxing this restriction later on. - auto *I = dyn_cast<Instruction>(NarrowUse->getOperand(ExtendOperIdx)); - if (I && I->getOpcode() != Instruction::Load) - return false; - // Verifying that Defining operand is an AddRec const SCEV *Op1 = SE->getSCEV(WideDef); const SCEVAddRecExpr *AddRecOp1 = dyn_cast<SCEVAddRecExpr>(Op1); @@ -1480,9 +1159,9 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) { return true; } -/// Special Case for widening with variant Loads (see -/// WidenIV::widenWithVariantLoadUse). This is the code generation part. -void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) { +/// Special Case for widening with loop variant (see +/// WidenIV::widenWithVariant). This is the code generation part. +void WidenIV::widenWithVariantUseCodegen(NarrowIVDefUse DU) { Instruction *NarrowUse = DU.NarrowUse; Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; @@ -1508,33 +1187,22 @@ void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) { Builder.Insert(WideBO); WideBO->copyIRFlags(NarrowBO); - if (ExtKind == SignExtended) - ExtendKindMap[NarrowUse] = SignExtended; - else - ExtendKindMap[NarrowUse] = ZeroExtended; + assert(ExtKind != Unknown && "Unknown ExtKind not handled"); - // Update the Use. - if (ExtKind == SignExtended) { - for (Use &U : NarrowUse->uses()) { - SExtInst *User = dyn_cast<SExtInst>(U.getUser()); - if (User && User->getType() == WideType) { - LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " - << *WideBO << "\n"); - ++NumElimExt; - User->replaceAllUsesWith(WideBO); - DeadInsts.emplace_back(User); - } - } - } else { // ExtKind == ZeroExtended - for (Use &U : NarrowUse->uses()) { - ZExtInst *User = dyn_cast<ZExtInst>(U.getUser()); - if (User && User->getType() == WideType) { - LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " - << *WideBO << "\n"); - ++NumElimExt; - User->replaceAllUsesWith(WideBO); - DeadInsts.emplace_back(User); - } + ExtendKindMap[NarrowUse] = ExtKind; + + for (Use &U : NarrowUse->uses()) { + Instruction *User = nullptr; + if (ExtKind == SignExtended) + User = dyn_cast<SExtInst>(U.getUser()); + else + User = dyn_cast<ZExtInst>(U.getUser()); + if (User && User->getType() == WideType) { + LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " + << *WideBO << "\n"); + ++NumElimExt; + User->replaceAllUsesWith(WideBO); + DeadInsts.emplace_back(User); } } } @@ -1641,8 +1309,8 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // in WideAddRec.first does not indicate a polynomial induction expression. // In that case, look at the operands of the use instruction to determine // if we can still widen the use instead of truncating its operand. - if (widenWithVariantLoadUse(DU)) { - widenWithVariantLoadUseCodegen(DU); + if (widenWithVariantUse(DU)) { + widenWithVariantUseCodegen(DU); return nullptr; } @@ -1992,8 +1660,8 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L, // Information about sign/zero extensions of CurrIV. IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); - Changed |= - simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, Rewriter, &Visitor); + Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, Rewriter, + &Visitor); if (Visitor.WI.WidestNativeType) { WideIVs.push_back(Visitor.WI); @@ -2017,7 +1685,7 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L, /// Given an Value which is hoped to be part of an add recurance in the given /// loop, return the associated Phi node if so. Otherwise, return null. Note -/// that this is less general than SCEVs AddRec checking. +/// that this is less general than SCEVs AddRec checking. static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) { Instruction *IncI = dyn_cast<Instruction>(IncV); if (!IncI) @@ -2079,7 +1747,7 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) { BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); if (L->isLoopInvariant(BI->getCondition())) return false; - + // Do LFTR to simplify the exit condition to an ICMP. ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition()); if (!Cond) @@ -2122,9 +1790,9 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) { /// actually poison. This can be used to assess whether a new use of Root can /// be added at a location which is control equivalent with OnPathTo (such as /// immediately before it) without introducing UB which didn't previously -/// exist. Note that a false result conveys no information. +/// exist. Note that a false result conveys no information. static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, - Instruction *OnPathTo, + Instruction *OnPathTo, DominatorTree *DT) { // Basic approach is to assume Root is poison, propagate poison forward // through all users we can easily track, and then check whether any of those @@ -2142,10 +1810,10 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, // If we know this must trigger UB on a path leading our target. if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo)) return true; - + // If we can't analyze propagation through this instruction, just skip it // and transitive users. Safe as false is a conservative result. - if (!propagatesFullPoison(I) && I != Root) + if (!propagatesPoison(I) && I != Root) continue; if (KnownPoison.insert(I).second) @@ -2154,7 +1822,7 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, } // Might be non-UB, or might have a path we couldn't prove must execute on - // way to exiting bb. + // way to exiting bb. return false; } @@ -2221,7 +1889,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L, ScalarEvolution *SE) { assert(Phi->getParent() == L->getHeader()); assert(L->getLoopLatch()); - + if (!SE->isSCEVable(Phi->getType())) return false; @@ -2282,7 +1950,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, if (!hasConcreteDef(Phi)) { // We explicitly allow unknown phis as long as they are already used by // the loop exit test. This is legal since performing LFTR could not - // increase the number of undef users. + // increase the number of undef users. Value *IncPhi = Phi->getIncomingValueForBlock(LatchBlock); if (!isLoopExitTestBasedOn(Phi, ExitingBB) && !isLoopExitTestBasedOn(IncPhi, ExitingBB)) @@ -2300,7 +1968,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, if (!Phi->getType()->isIntegerTy() && !mustExecuteUBIfPoisonOnPathTo(Phi, ExitingBB->getTerminator(), DT)) continue; - + const SCEV *Init = AR->getStart(); if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { @@ -2506,14 +2174,14 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, // reasoning as from SimplifyIndvar::eliminateTrunc to see if we can extend // the other side of the comparison instead. We still evaluate the limit // in the narrower bitwidth, we just prefer a zext/sext outside the loop to - // a truncate within in. + // a truncate within in. bool Extended = false; const SCEV *IV = SE->getSCEV(CmpIndVar); const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar), ExitCnt->getType()); const SCEV *ZExtTrunc = SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType()); - + if (ZExtTrunc == IV) { Extended = true; ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), @@ -2531,7 +2199,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, if (Extended) { bool Discard; L->makeLoopInvariant(ExitCnt, Discard); - } else + } else CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), "lftr.wideiv"); } @@ -2551,7 +2219,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, // update the branch to use the new comparison; in the common case this // will make old comparison dead. BI->setCondition(Cond); - DeadInsts.push_back(OrigCond); + DeadInsts.emplace_back(OrigCond); ++NumLFTR; return true; @@ -2685,11 +2353,10 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { L->getExitingBlocks(ExitingBlocks); // Remove all exits which aren't both rewriteable and analyzeable. - auto NewEnd = llvm::remove_if(ExitingBlocks, - [&](BasicBlock *ExitingBB) { + auto NewEnd = llvm::remove_if(ExitingBlocks, [&](BasicBlock *ExitingBB) { // If our exitting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost - // loop runs before it exits. + // loop runs before it exits. if (LI->getLoopFor(ExitingBB) != L) return true; @@ -2701,18 +2368,18 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // If already constant, nothing to do. if (isa<Constant>(BI->getCondition())) return true; - + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); if (isa<SCEVCouldNotCompute>(ExitCount)) return true; return false; - }); + }); ExitingBlocks.erase(NewEnd, ExitingBlocks.end()); if (ExitingBlocks.empty()) return false; - - // Get a symbolic upper bound on the loop backedge taken count. + + // Get a symbolic upper bound on the loop backedge taken count. const SCEV *MaxExitCount = getMaxBackedgeTakenCount(*SE, *DT, L); if (isa<SCEVCouldNotCompute>(MaxExitCount)) return false; @@ -2720,11 +2387,12 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // Visit our exit blocks in order of dominance. We know from the fact that // all exits (left) are analyzeable that the must be a total dominance order // between them as each must dominate the latch. The visit order only - // matters for the provably equal case. + // matters for the provably equal case. llvm::sort(ExitingBlocks, [&](BasicBlock *A, BasicBlock *B) { // std::sort sorts in ascending order, so we want the inverse of // the normal dominance relation. + if (A == B) return false; if (DT->properlyDominates(A, B)) return true; if (DT->properlyDominates(B, A)) return false; llvm_unreachable("expected total dominance order!"); @@ -2734,7 +2402,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i])); } #endif - + auto FoldExit = [&](BasicBlock *ExitingBB, bool IsTaken) { BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); @@ -2743,7 +2411,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { IsTaken ? ExitIfTrue : !ExitIfTrue); BI->setCondition(NewCond); if (OldCond->use_empty()) - DeadInsts.push_back(OldCond); + DeadInsts.emplace_back(OldCond); }; bool Changed = false; @@ -2751,7 +2419,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { for (BasicBlock *ExitingBB : ExitingBlocks) { const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); assert(!isa<SCEVCouldNotCompute>(ExitCount) && "checked above"); - + // If we know we'd exit on the first iteration, rewrite the exit to // reflect this. This does not imply the loop must exit through this // exit; there may be an earlier one taken on the first iteration. @@ -2769,13 +2437,13 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { if (!ExitCount->getType()->isIntegerTy() || !MaxExitCount->getType()->isIntegerTy()) continue; - + Type *WiderType = SE->getWiderType(MaxExitCount->getType(), ExitCount->getType()); ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType); MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType); assert(MaxExitCount->getType() == ExitCount->getType()); - + // Can we prove that some other exit must be taken strictly before this // one? if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, @@ -2788,7 +2456,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // As we run, keep track of which exit counts we've encountered. If we // find a duplicate, we've found an exit which would have exited on the // exiting iteration, but (from the visit order) strictly follows another - // which does the same and is thus dead. + // which does the same and is thus dead. if (!DominatingExitCounts.insert(ExitCount).second) { FoldExit(ExitingBB, false); Changed = true; @@ -2809,22 +2477,20 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { SmallVector<BasicBlock*, 16> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - bool Changed = false; - // Finally, see if we can rewrite our exit conditions into a loop invariant - // form. If we have a read-only loop, and we can tell that we must exit down + // form. If we have a read-only loop, and we can tell that we must exit down // a path which does not need any of the values computed within the loop, we // can rewrite the loop to exit on the first iteration. Note that this // doesn't either a) tell us the loop exits on the first iteration (unless // *all* exits are predicateable) or b) tell us *which* exit might be taken. // This transformation looks a lot like a restricted form of dead loop // elimination, but restricted to read-only loops and without neccesssarily - // needing to kill the loop entirely. + // needing to kill the loop entirely. if (!LoopPredication) - return Changed; + return false; if (!SE->hasLoopInvariantBackedgeTakenCount(L)) - return Changed; + return false; // Note: ExactBTC is the exact backedge taken count *iff* the loop exits // through *explicit* control flow. We have to eliminate the possibility of @@ -2833,16 +2499,16 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (isa<SCEVCouldNotCompute>(ExactBTC) || !SE->isLoopInvariant(ExactBTC, L) || !isSafeToExpand(ExactBTC, *SE)) - return Changed; + return false; // If we end up with a pointer exit count, bail. It may be unsized. if (!ExactBTC->getType()->isIntegerTy()) - return Changed; + return false; auto BadExit = [&](BasicBlock *ExitingBB) { // If our exiting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost - // loop runs before it exits. + // loop runs before it exits. if (LI->getLoopFor(ExitingBB) != L) return true; @@ -2897,18 +2563,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // is complicated and we choose not to for now. for (unsigned i = 1; i < ExitingBlocks.size(); i++) if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i])) - return Changed; + return false; // Given our sorted total order, we know that exit[j] must be evaluated // after all exit[i] such j > i. for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++) if (BadExit(ExitingBlocks[i])) { - ExitingBlocks.resize(i); + ExitingBlocks.resize(i); break; } if (ExitingBlocks.empty()) - return Changed; + return false; // We rely on not being able to reach an exiting block on a later iteration // then it's statically compute exit count. The implementaton of @@ -2930,8 +2596,9 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { for (auto &I : *BB) // TODO:isGuaranteedToTransfer if (I.mayHaveSideEffects() || I.mayThrow()) - return Changed; + return false; + bool Changed = false; // Finally, do the actual predication for all predicatable blocks. A couple // of notes here: // 1) We don't bother to constant fold dominated exits with identical exit @@ -2970,7 +2637,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { Value *OldCond = BI->getCondition(); BI->setCondition(NewCond); if (OldCond->use_empty()) - DeadInsts.push_back(OldCond); + DeadInsts.emplace_back(OldCond); Changed = true; } @@ -2985,7 +2652,6 @@ bool IndVarSimplify::run(Loop *L) { // We need (and expect!) the incoming loop to be in LCSSA. assert(L->isRecursivelyLCSSAForm(*DT, *LI) && "LCSSA required to run indvars!"); - bool Changed = false; // If LoopSimplify form is not available, stay out of trouble. Some notes: // - LSR currently only supports LoopSimplify-form loops. Indvars' @@ -3001,9 +2667,15 @@ bool IndVarSimplify::run(Loop *L) { #ifndef NDEBUG // Used below for a consistency check only - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + // Note: Since the result returned by ScalarEvolution may depend on the order + // in which previous results are added to its cache, the call to + // getBackedgeTakenCount() may change following SCEV queries. + const SCEV *BackedgeTakenCount; + if (VerifyIndvars) + BackedgeTakenCount = SE->getBackedgeTakenCount(L); #endif + bool Changed = false; // If there are any floating-point recurrences, attempt to // transform them to use integer recurrences. Changed |= rewriteNonIntegerIVs(L); @@ -3027,8 +2699,13 @@ bool IndVarSimplify::run(Loop *L) { // that are recurrent in the loop, and substitute the exit values from the // loop into any instructions outside of the loop that use the final values // of the current expressions. - if (ReplaceExitValue != NeverRepl) - Changed |= rewriteLoopExitValues(L, Rewriter); + if (ReplaceExitValue != NeverRepl) { + if (int Rewrites = rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT, + ReplaceExitValue, DeadInsts)) { + NumReplaced += Rewrites; + Changed = true; + } + } // Eliminate redundant IV cycles. NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); @@ -3039,7 +2716,7 @@ bool IndVarSimplify::run(Loop *L) { // Given we've changed exit counts, notify SCEV SE->forgetLoop(L); } - + // Try to form loop invariant tests for loop exits by changing how many // iterations of the loop run when that is unobservable. if (predicateLoopExits(L, Rewriter)) { @@ -3049,8 +2726,11 @@ bool IndVarSimplify::run(Loop *L) { } // If we have a trip count expression, rewrite the loop's exit condition - // using it. + // using it. if (!DisableLFTR) { + BasicBlock *PreHeader = L->getLoopPreheader(); + BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); + SmallVector<BasicBlock*, 16> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); for (BasicBlock *ExitingBB : ExitingBlocks) { @@ -3060,10 +2740,10 @@ bool IndVarSimplify::run(Loop *L) { // If our exitting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost - // loop runs before it exits. + // loop runs before it exits. if (LI->getLoopFor(ExitingBB) != L) continue; - + if (!needsLFTR(L, ExitingBB)) continue; @@ -3077,14 +2757,15 @@ bool IndVarSimplify::run(Loop *L) { // until stable to handle cases like this better. if (ExitCount->isZero()) continue; - + PHINode *IndVar = FindLoopCounter(L, ExitingBB, ExitCount, SE, DT); if (!IndVar) continue; - + // Avoid high cost expansions. Note: This heuristic is questionable in - // that our definition of "high cost" is not exactly principled. - if (Rewriter.isHighCostExpansion(ExitCount, L)) + // that our definition of "high cost" is not exactly principled. + if (Rewriter.isHighCostExpansion(ExitCount, L, SCEVCheapExpansionBudget, + TTI, PreHeaderBR)) continue; // Check preconditions for proper SCEVExpander operation. SCEV does not @@ -3092,7 +2773,7 @@ bool IndVarSimplify::run(Loop *L) { // any pass that uses the SCEVExpander must do it. This does not work // well for loop passes because SCEVExpander makes assumptions about // all loops, while LoopPassManager only forces the current loop to be - // simplified. + // simplified. // // FIXME: SCEV expansion has no way to bail out, so the caller must // explicitly check any assumptions made by SCEV. Brittle. @@ -3113,7 +2794,8 @@ bool IndVarSimplify::run(Loop *L) { while (!DeadInsts.empty()) if (Instruction *Inst = dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val())) - Changed |= RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); + Changed |= + RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI, MSSAU.get()); // The Rewriter may not be used from this point on. @@ -3127,7 +2809,7 @@ bool IndVarSimplify::run(Loop *L) { Changed |= rewriteFirstIterationLoopExitValues(L); // Clean up dead instructions. - Changed |= DeleteDeadPHIs(L->getHeader(), TLI); + Changed |= DeleteDeadPHIs(L->getHeader(), TLI, MSSAU.get()); // Check a post-condition. assert(L->isRecursivelyLCSSAForm(*DT, *LI) && @@ -3150,6 +2832,8 @@ bool IndVarSimplify::run(Loop *L) { assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount, NewBECount) && "indvars must preserve SCEV"); } + if (VerifyMemorySSA && MSSAU) + MSSAU->getMemorySSA()->verifyMemorySSA(); #endif return Changed; @@ -3161,12 +2845,14 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, Function *F = L.getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); - IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI); + IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA); if (!IVS.run(&L)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); PA.preserveSet<CFGAnalyses>(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -3191,13 +2877,18 @@ struct IndVarSimplifyLegacyPass : public LoopPass { auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + MemorySSA *MSSA = nullptr; + if (MSSAAnalysis) + MSSA = &MSSAAnalysis->getMSSA(); - IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); + IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA); return IVS.run(L); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); } }; diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 58469749600e..30e4822b6769 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -47,6 +47,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -55,8 +56,8 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -87,6 +88,7 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -242,20 +244,25 @@ public: bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); }; -class IRCELegacyPass : public LoopPass { +class IRCELegacyPass : public FunctionPass { public: static char ID; - IRCELegacyPass() : LoopPass(ID) { + IRCELegacyPass() : FunctionPass(ID) { initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BranchProbabilityInfoWrapperPass>(); - getLoopAnalysisUsage(AU); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); } - bool runOnLoop(Loop *L, LPPassManager &LPM) override; + bool runOnFunction(Function &F) override; }; } // end anonymous namespace @@ -265,7 +272,9 @@ char IRCELegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce", "Inductive range check elimination", false, false) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", false, false) @@ -866,7 +875,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); const SCEV *Step = SE.getSCEV(StepCI); - ConstantInt *One = ConstantInt::get(IndVarTy, 1); + const SCEV *FixedRightSCEV = nullptr; + + // If RightValue resides within loop (but still being loop invariant), + // regenerate it as preheader. + if (auto *I = dyn_cast<Instruction>(RightValue)) + if (L.contains(I->getParent())) + FixedRightSCEV = RightSCEV; + if (IsIncreasing) { bool DecreasedRightValueByOne = false; if (StepCI->isOne()) { @@ -928,10 +944,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, if (LatchBrExitIdx == 0) { // We need to increase the right value unless we have already decreased // it virtually when we replaced EQ with SGT. - if (!DecreasedRightValueByOne) { - IRBuilder<> B(Preheader->getTerminator()); - RightValue = B.CreateAdd(RightValue, One); - } + if (!DecreasedRightValueByOne) + FixedRightSCEV = + SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); } else { assert(!DecreasedRightValueByOne && "Right value can be decreased only for LatchBrExitIdx == 0!"); @@ -995,10 +1010,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, if (LatchBrExitIdx == 0) { // We need to decrease the right value unless we have already increased // it virtually when we replaced EQ with SLT. - if (!IncreasedRightValueByOne) { - IRBuilder<> B(Preheader->getTerminator()); - RightValue = B.CreateSub(RightValue, One); - } + if (!IncreasedRightValueByOne) + FixedRightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); } else { assert(!IncreasedRightValueByOne && "Right value can be increased only for LatchBrExitIdx == 0!"); @@ -1012,9 +1026,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, assert(!L.contains(LatchExit) && "expected an exit block!"); const DataLayout &DL = Preheader->getModule()->getDataLayout(); - Value *IndVarStartV = - SCEVExpander(SE, DL, "irce") - .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); + SCEVExpander Expander(SE, DL, "irce"); + Instruction *Ins = Preheader->getTerminator(); + + if (FixedRightSCEV) + RightValue = + Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); + + Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); IndVarStartV->setName("indvar.start"); LoopStructure Result; @@ -1747,27 +1766,41 @@ IntersectUnsignedRange(ScalarEvolution &SE, return Ret; } -PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &U) { - Function *F = L.getHeader()->getParent(); - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); - auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); - InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI); - auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) { +PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F); + LoopInfo &LI = AM.getResult<LoopAnalysis>(F); + + InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); + + bool Changed = false; + + for (const auto &L : LI) { + Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, + /*PreserveLCSSA=*/false); + Changed |= formLCSSARecursively(*L, DT, &LI, &SE); + } + + SmallPriorityWorklist<Loop *, 4> Worklist; + appendLoopsToWorklist(LI, Worklist); + auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) { if (!IsSubloop) - U.addSiblingLoops(NL); + appendLoopsToWorklist(*NL, Worklist); }; - bool Changed = IRCE.run(&L, LPMAddNewLoop); + + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + Changed |= IRCE.run(L, LPMAddNewLoop); + } + if (!Changed) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); } -bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) +bool IRCELegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) return false; ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); @@ -1776,10 +1809,27 @@ bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); - auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) { - LPM.addLoop(*NL); + + bool Changed = false; + + for (const auto &L : LI) { + Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, + /*PreserveLCSSA=*/false); + Changed |= formLCSSARecursively(*L, DT, &LI, &SE); + } + + SmallPriorityWorklist<Loop *, 4> Worklist; + appendLoopsToWorklist(LI, Worklist); + auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) { + if (!IsSubloop) + appendLoopsToWorklist(*NL, Worklist); }; - return IRCE.run(L, LPMAddNewLoop); + + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + Changed |= IRCE.run(L, LPMAddNewLoop); + } + return Changed; } bool InductiveRangeCheckElimination::run( diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index dfb1b6bfb739..db9cc58bbfc4 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -96,7 +96,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -116,11 +115,13 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> #include <iterator> @@ -132,16 +133,23 @@ using namespace llvm; +static cl::opt<bool> AssumeDefaultIsFlatAddressSpace( + "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden, + cl::desc("The default address space is assumed as the flat address space. " + "This is mainly for test purpose.")); + static const unsigned UninitializedAddressSpace = std::numeric_limits<unsigned>::max(); namespace { using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; +using PostorderStackTy = llvm::SmallVector<PointerIntPair<Value *, 1, bool>, 4>; /// InferAddressSpaces class InferAddressSpaces : public FunctionPass { const TargetTransformInfo *TTI = nullptr; + const DataLayout *DL = nullptr; /// Target specific address space which uses of should be replaced if /// possible. @@ -174,6 +182,11 @@ private: bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + Value *cloneInstructionWithNewAddressSpace( + Instruction *I, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) const; + // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. @@ -182,15 +195,14 @@ private: const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const; void appendsFlatAddressExpressionToPostorderStack( - Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, - DenseSet<Value *> &Visited) const; + Value *V, PostorderStackTy &PostorderStack, + DenseSet<Value *> &Visited) const; bool rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV, Value *NewV) const; - void collectRewritableIntrinsicOperands( - IntrinsicInst *II, - std::vector<std::pair<Value *, bool>> &PostorderStack, - DenseSet<Value *> &Visited) const; + void collectRewritableIntrinsicOperands(IntrinsicInst *II, + PostorderStackTy &PostorderStack, + DenseSet<Value *> &Visited) const; std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const; @@ -214,24 +226,65 @@ void initializeInferAddressSpacesPass(PassRegistry &); INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) +// Check whether that's no-op pointer bicast using a pair of +// `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over +// different address spaces. +static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL, + const TargetTransformInfo *TTI) { + assert(I2P->getOpcode() == Instruction::IntToPtr); + auto *P2I = dyn_cast<Operator>(I2P->getOperand(0)); + if (!P2I || P2I->getOpcode() != Instruction::PtrToInt) + return false; + // Check it's really safe to treat that pair of `ptrtoint`/`inttoptr` as a + // no-op cast. Besides checking both of them are no-op casts, as the + // reinterpreted pointer may be used in other pointer arithmetic, we also + // need to double-check that through the target-specific hook. That ensures + // the underlying target also agrees that's a no-op address space cast and + // pointer bits are preserved. + // The current IR spec doesn't have clear rules on address space casts, + // especially a clear definition for pointer bits in non-default address + // spaces. It would be undefined if that pointer is dereferenced after an + // invalid reinterpret cast. Also, due to the unclearness for the meaning of + // bits in non-default address spaces in the current spec, the pointer + // arithmetic may also be undefined after invalid pointer reinterpret cast. + // However, as we confirm through the target hooks that it's a no-op + // addrspacecast, it doesn't matter since the bits should be the same. + return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()), + I2P->getOperand(0)->getType(), I2P->getType(), + DL) && + CastInst::isNoopCast(Instruction::CastOps(P2I->getOpcode()), + P2I->getOperand(0)->getType(), P2I->getType(), + DL) && + TTI->isNoopAddrSpaceCast( + P2I->getOperand(0)->getType()->getPointerAddressSpace(), + I2P->getType()->getPointerAddressSpace()); +} + // Returns true if V is an address expression. // TODO: Currently, we consider only phi, bitcast, addrspacecast, and // getelementptr operators. -static bool isAddressExpression(const Value &V) { - if (!isa<Operator>(V)) +static bool isAddressExpression(const Value &V, const DataLayout &DL, + const TargetTransformInfo *TTI) { + const Operator *Op = dyn_cast<Operator>(&V); + if (!Op) return false; - const Operator &Op = cast<Operator>(V); - switch (Op.getOpcode()) { + switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op.getType()->isPointerTy()); + assert(Op->getType()->isPointerTy()); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op.getType()->isPointerTy(); + return Op->getType()->isPointerTy(); + case Instruction::Call: { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V); + return II && II->getIntrinsicID() == Intrinsic::ptrmask; + } + case Instruction::IntToPtr: + return isNoopPtrIntCastPair(Op, DL, TTI); default: return false; } @@ -240,7 +293,9 @@ static bool isAddressExpression(const Value &V) { // Returns the pointer operands of V. // // Precondition: V is an address expression. -static SmallVector<Value *, 2> getPointerOperands(const Value &V) { +static SmallVector<Value *, 2> +getPointerOperands(const Value &V, const DataLayout &DL, + const TargetTransformInfo *TTI) { const Operator &Op = cast<Operator>(V); switch (Op.getOpcode()) { case Instruction::PHI: { @@ -254,12 +309,22 @@ static SmallVector<Value *, 2> getPointerOperands(const Value &V) { return {Op.getOperand(0)}; case Instruction::Select: return {Op.getOperand(1), Op.getOperand(2)}; + case Instruction::Call: { + const IntrinsicInst &II = cast<IntrinsicInst>(Op); + assert(II.getIntrinsicID() == Intrinsic::ptrmask && + "unexpected intrinsic call"); + return {II.getArgOperand(0)}; + } + case Instruction::IntToPtr: { + assert(isNoopPtrIntCastPair(&Op, DL, TTI)); + auto *P2I = cast<Operator>(Op.getOperand(0)); + return {P2I->getOperand(0)}; + } default: llvm_unreachable("Unexpected instruction type."); } } -// TODO: Move logic to TTI? bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV, Value *NewV) const { @@ -275,16 +340,26 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, II->setCalledFunction(NewDecl); return true; } - default: - return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); + case Intrinsic::ptrmask: + // This is handled as an address expression, not as a use memory operation. + return false; + default: { + Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); + if (!Rewrite) + return false; + if (Rewrite != II) + II->replaceAllUsesWith(Rewrite); + return true; + } } } void InferAddressSpaces::collectRewritableIntrinsicOperands( - IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack, + IntrinsicInst *II, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const { auto IID = II->getIntrinsicID(); switch (IID) { + case Intrinsic::ptrmask: case Intrinsic::objectsize: appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); @@ -305,7 +380,7 @@ void InferAddressSpaces::collectRewritableIntrinsicOperands( // If V is an unvisited flat address expression, appends V to PostorderStack // and marks it as visited. void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( - Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, + Value *V, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const { assert(V->getType()->isPointerTy()); @@ -313,21 +388,21 @@ void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( // expressions. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { // TODO: Look in non-address parts, like icmp operands. - if (isAddressExpression(*CE) && Visited.insert(CE).second) - PostorderStack.push_back(std::make_pair(CE, false)); + if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second) + PostorderStack.emplace_back(CE, false); return; } - if (isAddressExpression(*V) && + if (isAddressExpression(*V, *DL, TTI) && V->getType()->getPointerAddressSpace() == FlatAddrSpace) { if (Visited.insert(V).second) { - PostorderStack.push_back(std::make_pair(V, false)); + PostorderStack.emplace_back(V, false); Operator *Op = cast<Operator>(V); for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) { if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) { - if (isAddressExpression(*CE) && Visited.insert(CE).second) + if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second) PostorderStack.emplace_back(CE, false); } } @@ -341,7 +416,7 @@ std::vector<WeakTrackingVH> InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { // This function implements a non-recursive postorder traversal of a partial // use-def graph of function F. - std::vector<std::pair<Value *, bool>> PostorderStack; + PostorderStackTy PostorderStack; // The set of visited expressions. DenseSet<Value *> Visited; @@ -383,23 +458,27 @@ InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) { if (!ASC->getType()->isVectorTy()) PushPtrOperand(ASC->getPointerOperand()); + } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) { + if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI)) + PushPtrOperand( + cast<PtrToIntInst>(I2P->getOperand(0))->getPointerOperand()); } } std::vector<WeakTrackingVH> Postorder; // The resultant postorder. while (!PostorderStack.empty()) { - Value *TopVal = PostorderStack.back().first; + Value *TopVal = PostorderStack.back().getPointer(); // If the operands of the expression on the top are already explored, // adds that expression to the resultant postorder. - if (PostorderStack.back().second) { + if (PostorderStack.back().getInt()) { if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace) Postorder.push_back(TopVal); PostorderStack.pop_back(); continue; } // Otherwise, adds its operands to the stack and explores them. - PostorderStack.back().second = true; - for (Value *PtrOperand : getPointerOperands(*TopVal)) { + PostorderStack.back().setInt(true); + for (Value *PtrOperand : getPointerOperands(*TopVal, *DL, TTI)) { appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack, Visited); } @@ -438,10 +517,13 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast // from a pointer whose type already matches. Therefore, this function returns a // Value* instead of an Instruction*. -static Value *cloneInstructionWithNewAddressSpace( +// +// This may also return nullptr in the case the instruction could not be +// rewritten. +Value *InferAddressSpaces::cloneInstructionWithNewAddressSpace( Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, - SmallVectorImpl<const Use *> *UndefUsesToFix) { + SmallVectorImpl<const Use *> *UndefUsesToFix) const { Type *NewPtrType = I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); @@ -456,6 +538,23 @@ static Value *cloneInstructionWithNewAddressSpace( return Src; } + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + // Technically the intrinsic ID is a pointer typed argument, so specially + // handle calls early. + assert(II->getIntrinsicID() == Intrinsic::ptrmask); + Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef( + II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace, + UndefUsesToFix); + Value *Rewrite = + TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr); + if (Rewrite) { + assert(Rewrite != II && "cannot modify this pointer operation in place"); + return Rewrite; + } + + return nullptr; + } + // Computes the converted pointer operands. SmallVector<Value *, 4> NewPointerOperands; for (const Use &OperandUse : I->operands()) { @@ -492,6 +591,14 @@ static Value *cloneInstructionWithNewAddressSpace( assert(I->getType()->isPointerTy()); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); + case Instruction::IntToPtr: { + assert(isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI)); + Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0); + assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + if (Src->getType() != NewPtrType) + return new BitCastInst(Src, NewPtrType); + return Src; + } default: llvm_unreachable("Unexpected opcode"); } @@ -501,8 +608,9 @@ static Value *cloneInstructionWithNewAddressSpace( // constant expression `CE` with its operands replaced as specified in // ValueWithNewAddrSpace. static Value *cloneConstantExprWithNewAddressSpace( - ConstantExpr *CE, unsigned NewAddrSpace, - const ValueToValueMapTy &ValueWithNewAddrSpace) { + ConstantExpr *CE, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, + const TargetTransformInfo *TTI) { Type *TargetType = CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); @@ -533,6 +641,13 @@ static Value *cloneConstantExprWithNewAddressSpace( } } + if (CE->getOpcode() == Instruction::IntToPtr) { + assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI)); + Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0); + assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + return ConstantExpr::getBitCast(Src, TargetType); + } + // Computes the operands of the new constant expression. bool IsNew = false; SmallVector<Constant *, 4> NewOperands; @@ -550,7 +665,7 @@ static Value *cloneConstantExprWithNewAddressSpace( } if (auto CExpr = dyn_cast<ConstantExpr>(Operand)) if (Value *NewOperand = cloneConstantExprWithNewAddressSpace( - CExpr, NewAddrSpace, ValueWithNewAddrSpace)) { + CExpr, NewAddrSpace, ValueWithNewAddrSpace, DL, TTI)) { IsNew = true; NewOperands.push_back(cast<Constant>(NewOperand)); continue; @@ -585,13 +700,13 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace( const ValueToValueMapTy &ValueWithNewAddrSpace, SmallVectorImpl<const Use *> *UndefUsesToFix) const { // All values in Postorder are flat address expressions. - assert(isAddressExpression(*V) && + assert(isAddressExpression(*V, *DL, TTI) && V->getType()->getPointerAddressSpace() == FlatAddrSpace); if (Instruction *I = dyn_cast<Instruction>(V)) { Value *NewV = cloneInstructionWithNewAddressSpace( I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); - if (Instruction *NewI = dyn_cast<Instruction>(NewV)) { + if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) { if (NewI->getParent() == nullptr) { NewI->insertBefore(I); NewI->takeName(I); @@ -601,7 +716,7 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace( } return cloneConstantExprWithNewAddressSpace( - cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace); + cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace, DL, TTI); } // Defines the join operation on the address space lattice (see the file header @@ -625,6 +740,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) { return false; TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + DL = &F.getParent()->getDataLayout(); + + if (AssumeDefaultIsFlatAddressSpace) + FlatAddrSpace = 0; if (FlatAddrSpace == UninitializedAddressSpace) { FlatAddrSpace = TTI->getFlatAddressSpace(); @@ -729,7 +848,7 @@ Optional<unsigned> InferAddressSpaces::updateAddressSpace( else NewAS = joinAddressSpaces(Src0AS, Src1AS); } else { - for (Value *PtrOperand : getPointerOperands(V)) { + for (Value *PtrOperand : getPointerOperands(V, *DL, TTI)) { auto I = InferredAddrSpace.find(PtrOperand); unsigned OperandAS = I != InferredAddrSpace.end() ? I->second : PtrOperand->getType()->getPointerAddressSpace(); @@ -879,8 +998,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( for (Value* V : Postorder) { unsigned NewAddrSpace = InferredAddrSpace.lookup(V); if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { - ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( - V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + Value *New = cloneValueWithNewAddressSpace( + V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + if (New) + ValueWithNewAddrSpace[V] = New; } } @@ -890,7 +1011,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. for (const Use *UndefUse : UndefUsesToFix) { User *V = UndefUse->getUser(); - User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V)); + User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V)); + if (!NewV) + continue; + unsigned OperandNo = UndefUse->getOperandNo(); assert(isa<UndefValue>(NewV->getOperand(OperandNo))); NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); diff --git a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp index e8bbf2936da6..e87b622ab19f 100644 --- a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -40,7 +40,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ, if (!SQ.DT->isReachableFromEntry(&BB)) continue; - SmallVector<Instruction *, 8> DeadInstsInBB; + SmallVector<WeakTrackingVH, 8> DeadInstsInBB; for (Instruction &I : BB) { // The first time through the loop, ToSimplify is empty and we try to // simplify all instructions. On later iterations, ToSimplify is not diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 98c2fcb3dae0..9d0500419a7f 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -13,6 +13,7 @@ #include "llvm/Transforms/Scalar/JumpThreading.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -170,7 +171,7 @@ FunctionPass *llvm::createJumpThreadingPass(int Threshold) { } JumpThreadingPass::JumpThreadingPass(int T) { - BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); + DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); } // Update branch probability information according to conditional @@ -213,11 +214,16 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { if (!CondBr) return; - BranchProbability BP; uint64_t TrueWeight, FalseWeight; if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) return; + if (TrueWeight + FalseWeight == 0) + // Zero branch_weights do not give a hint for getting branch probabilities. + // Technically it would result in division by zero denominator, which is + // TrueWeight + FalseWeight. + return; + // Returns the outgoing edge of the dominating predecessor block // that leads to the PhiNode's incoming block: auto GetPredOutEdge = @@ -252,10 +258,11 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { if (!CI || !CI->getType()->isIntegerTy(1)) continue; - BP = (CI->isOne() ? BranchProbability::getBranchProbability( - TrueWeight, TrueWeight + FalseWeight) - : BranchProbability::getBranchProbability( - FalseWeight, TrueWeight + FalseWeight)); + BranchProbability BP = + (CI->isOne() ? BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight) + : BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB); if (!PredOutEdge.first) @@ -298,8 +305,6 @@ bool JumpThreading::runOnFunction(Function &F) { if (skipFunction(F)) return false; auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - // Get DT analysis before LVI. When LVI is initialized it conditionally adds - // DT if it's available. auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); @@ -316,7 +321,7 @@ bool JumpThreading::runOnFunction(Function &F) { std::move(BFI), std::move(BPI)); if (PrintLVIAfterJumpThreading) { dbgs() << "LVI for function '" << F.getName() << "':\n"; - LVI->printLVI(F, *DT, dbgs()); + LVI->printLVI(F, DTU.getDomTree(), dbgs()); } return Changed; } @@ -324,8 +329,6 @@ bool JumpThreading::runOnFunction(Function &F) { PreservedAnalyses JumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - // Get DT analysis before LVI. When LVI is initialized it conditionally adds - // DT if it's available. auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); @@ -374,6 +377,15 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, BFI = std::move(BFI_); } + // Reduce the number of instructions duplicated when optimizing strictly for + // size. + if (BBDuplicateThreshold.getNumOccurrences()) + BBDupThreshold = BBDuplicateThreshold; + else if (F.hasFnAttribute(Attribute::MinSize)) + BBDupThreshold = 3; + else + BBDupThreshold = DefaultBBDupThreshold; + // JumpThreading must not processes blocks unreachable from entry. It's a // waste of compute time and can potentially lead to hangs. SmallPtrSet<BasicBlock *, 16> Unreachable; @@ -396,6 +408,12 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, continue; while (ProcessBlock(&BB)) // Thread all of the branches we can over BB. Changed = true; + + // Jump threading may have introduced redundant debug values into BB + // which should be removed. + if (Changed) + RemoveRedundantDbgInstrs(&BB); + // Stop processing BB if it's the entry or is now deleted. The following // routines attempt to eliminate BB and locating a suitable replacement // for the entry is non-trivial. @@ -418,26 +436,27 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // ProcessBlock doesn't thread BBs with unconditional TIs. However, if BB // is "almost empty", we attempt to merge BB with its sole successor. auto *BI = dyn_cast<BranchInst>(BB.getTerminator()); - if (BI && BI->isUnconditional() && - // The terminator must be the only non-phi instruction in BB. - BB.getFirstNonPHIOrDbg()->isTerminator() && - // Don't alter Loop headers and latches to ensure another pass can - // detect and transform nested loops later. - !LoopHeaders.count(&BB) && !LoopHeaders.count(BI->getSuccessor(0)) && - TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { - // BB is valid for cleanup here because we passed in DTU. F remains - // BB's parent until a DTU->getDomTree() event. - LVI->eraseBlock(&BB); - Changed = true; + if (BI && BI->isUnconditional()) { + BasicBlock *Succ = BI->getSuccessor(0); + if ( + // The terminator must be the only non-phi instruction in BB. + BB.getFirstNonPHIOrDbg()->isTerminator() && + // Don't alter Loop headers and latches to ensure another pass can + // detect and transform nested loops later. + !LoopHeaders.count(&BB) && !LoopHeaders.count(Succ) && + TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { + RemoveRedundantDbgInstrs(Succ); + // BB is valid for cleanup here because we passed in DTU. F remains + // BB's parent until a DTU->getDomTree() event. + LVI->eraseBlock(&BB); + Changed = true; + } } } EverChanged |= Changed; } while (Changed); LoopHeaders.clear(); - // Flush only the Dominator Tree. - DTU->getDomTree(); - LVI->enableDT(); return EverChanged; } @@ -592,20 +611,19 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// This returns true if there were any known values. bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference, - DenseSet<std::pair<Value *, BasicBlock *>> &RecursionSet, + ConstantPreference Preference, DenseSet<Value *> &RecursionSet, Instruction *CxtI) { // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To // prevent this, keep track of what (value, block) pairs we've already visited // and terminate the search if we loop back to them - if (!RecursionSet.insert(std::make_pair(V, BB)).second) + if (!RecursionSet.insert(V).second) return false; // If V is a constant, then it is known in all predecessors. if (Constant *KC = getKnownConstant(V, Preference)) { for (BasicBlock *Pred : predecessors(BB)) - Result.push_back(std::make_pair(KC, Pred)); + Result.emplace_back(KC, Pred); return !Result.empty(); } @@ -627,17 +645,12 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( // able to handle value inequalities better, for example if the compare is // "X < 4" and "X < 3" is known true but "X < 4" itself is not available. // Perhaps getConstantOnEdge should be smart enough to do this? - - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI); if (Constant *KC = getKnownConstant(PredCst, Preference)) - Result.push_back(std::make_pair(KC, P)); + Result.emplace_back(KC, P); } return !Result.empty(); @@ -645,20 +658,16 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( /// If I is a PHI node, then we know the incoming values for any constants. if (PHINode *PN = dyn_cast<PHINode>(I)) { - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { Value *InVal = PN->getIncomingValue(i); if (Constant *KC = getKnownConstant(InVal, Preference)) { - Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); + Result.emplace_back(KC, PN->getIncomingBlock(i)); } else { Constant *CI = LVI->getConstantOnEdge(InVal, PN->getIncomingBlock(i), BB, CxtI); if (Constant *KC = getKnownConstant(CI, Preference)) - Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); + Result.emplace_back(KC, PN->getIncomingBlock(i)); } } @@ -757,7 +766,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI); if (Constant *KC = getKnownConstant(Folded, WantInteger)) - Result.push_back(std::make_pair(KC, LHSVal.second)); + Result.emplace_back(KC, LHSVal.second); } } @@ -779,10 +788,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( const DataLayout &DL = PN->getModule()->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. // See if any do. - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { BasicBlock *PredBB = PN->getIncomingBlock(i); Value *LHS, *RHS; @@ -813,7 +818,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( } if (Constant *KC = getKnownConstant(Res, WantInteger)) - Result.push_back(std::make_pair(KC, PredBB)); + Result.emplace_back(KC, PredBB); } return !Result.empty(); @@ -826,10 +831,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( if (!isa<Instruction>(CmpLHS) || cast<Instruction>(CmpLHS)->getParent() != BB) { - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. @@ -840,7 +841,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( continue; Constant *ResC = ConstantInt::get(CmpType, Res); - Result.push_back(std::make_pair(ResC, P)); + Result.emplace_back(ResC, P); } return !Result.empty(); @@ -858,10 +859,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) { if (!isa<Instruction>(AddLHS) || cast<Instruction>(AddLHS)->getParent() != BB) { - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a ConstantRange in // a predecessor, use that information to try to thread this @@ -883,7 +880,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( else continue; - Result.push_back(std::make_pair(ResC, P)); + Result.emplace_back(ResC, P); } return !Result.empty(); @@ -901,7 +898,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( Constant *V = LHSVal.first; Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst); if (Constant *KC = getKnownConstant(Folded, WantInteger)) - Result.push_back(std::make_pair(KC, LHSVal.second)); + Result.emplace_back(KC, LHSVal.second); } return !Result.empty(); @@ -935,7 +932,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( // See if the select has a known constant value for this predecessor. if (Constant *Val = KnownCond ? TrueVal : FalseVal) - Result.push_back(std::make_pair(Val, C.second)); + Result.emplace_back(Val, C.second); } return !Result.empty(); @@ -943,14 +940,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( } // If all else fails, see if LVI can figure out a constant value for us. - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); Constant *CI = LVI->getConstant(V, BB, CxtI); if (Constant *KC = getKnownConstant(CI, Preference)) { for (BasicBlock *Pred : predecessors(BB)) - Result.push_back(std::make_pair(KC, Pred)); + Result.emplace_back(KC, Pred); } return !Result.empty(); @@ -1106,10 +1099,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // threading is concerned. assert(CondBr->isConditional() && "Threading on unconditional terminator"); - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); LazyValueInfo::Tristate Ret = LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), CondConst, CondBr); @@ -1363,7 +1352,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { // If so, this load is partially redundant. Remember this info so that we // can create a PHI node. - AvailablePreds.push_back(std::make_pair(PredBB, PredAvailable)); + AvailablePreds.emplace_back(PredBB, PredAvailable); } // If the loaded value isn't available in any predecessor, it isn't partially @@ -1430,14 +1419,14 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { "Can't handle critical edge here!"); LoadInst *NewVal = new LoadInst( LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), - LoadI->getName() + ".pr", false, MaybeAlign(LoadI->getAlignment()), + LoadI->getName() + ".pr", false, LoadI->getAlign(), LoadI->getOrdering(), LoadI->getSyncScopeID(), UnavailablePred->getTerminator()); NewVal->setDebugLoc(LoadI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); - AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal)); + AvailablePreds.emplace_back(UnavailablePred, NewVal); } // Now we know that each predecessor of this block has a value in @@ -1496,56 +1485,70 @@ FindMostPopularDest(BasicBlock *BB, // explicitly choose to ignore 'undef' destinations. We prefer to thread // blocks with known and real destinations to threading undef. We'll handle // them later if interesting. - DenseMap<BasicBlock*, unsigned> DestPopularity; + MapVector<BasicBlock *, unsigned> DestPopularity; + + // Populate DestPopularity with the successors in the order they appear in the + // successor list. This way, we ensure determinism by iterating it in the + // same order in std::max_element below. We map nullptr to 0 so that we can + // return nullptr when PredToDestList contains nullptr only. + DestPopularity[nullptr] = 0; + for (auto *SuccBB : successors(BB)) + DestPopularity[SuccBB] = 0; + for (const auto &PredToDest : PredToDestList) if (PredToDest.second) DestPopularity[PredToDest.second]++; - if (DestPopularity.empty()) - return nullptr; - // Find the most popular dest. - DenseMap<BasicBlock*, unsigned>::iterator DPI = DestPopularity.begin(); - BasicBlock *MostPopularDest = DPI->first; - unsigned Popularity = DPI->second; - SmallVector<BasicBlock*, 4> SamePopularity; - - for (++DPI; DPI != DestPopularity.end(); ++DPI) { - // If the popularity of this entry isn't higher than the popularity we've - // seen so far, ignore it. - if (DPI->second < Popularity) - ; // ignore. - else if (DPI->second == Popularity) { - // If it is the same as what we've seen so far, keep track of it. - SamePopularity.push_back(DPI->first); - } else { - // If it is more popular, remember it. - SamePopularity.clear(); - MostPopularDest = DPI->first; - Popularity = DPI->second; - } + using VT = decltype(DestPopularity)::value_type; + auto MostPopular = std::max_element( + DestPopularity.begin(), DestPopularity.end(), + [](const VT &L, const VT &R) { return L.second < R.second; }); + + // Okay, we have finally picked the most popular destination. + return MostPopular->first; +} + +// Try to evaluate the value of V when the control flows from PredPredBB to +// BB->getSinglePredecessor() and then on to BB. +Constant *JumpThreadingPass::EvaluateOnPredecessorEdge(BasicBlock *BB, + BasicBlock *PredPredBB, + Value *V) { + BasicBlock *PredBB = BB->getSinglePredecessor(); + assert(PredBB && "Expected a single predecessor"); + + if (Constant *Cst = dyn_cast<Constant>(V)) { + return Cst; } - // Okay, now we know the most popular destination. If there is more than one - // destination, we need to determine one. This is arbitrary, but we need - // to make a deterministic decision. Pick the first one that appears in the - // successor list. - if (!SamePopularity.empty()) { - SamePopularity.push_back(MostPopularDest); - Instruction *TI = BB->getTerminator(); - for (unsigned i = 0; ; ++i) { - assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); + // Consult LVI if V is not an instruction in BB or PredBB. + Instruction *I = dyn_cast<Instruction>(V); + if (!I || (I->getParent() != BB && I->getParent() != PredBB)) { + return LVI->getConstantOnEdge(V, PredPredBB, PredBB, nullptr); + } - if (!is_contained(SamePopularity, TI->getSuccessor(i))) - continue; + // Look into a PHI argument. + if (PHINode *PHI = dyn_cast<PHINode>(V)) { + if (PHI->getParent() == PredBB) + return dyn_cast<Constant>(PHI->getIncomingValueForBlock(PredPredBB)); + return nullptr; + } - MostPopularDest = TI->getSuccessor(i); - break; + // If we have a CmpInst, try to fold it for each incoming edge into PredBB. + if (CmpInst *CondCmp = dyn_cast<CmpInst>(V)) { + if (CondCmp->getParent() == BB) { + Constant *Op0 = + EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0)); + Constant *Op1 = + EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1)); + if (Op0 && Op1) { + return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1); + } } + return nullptr; } - // Okay, we have finally picked the most popular destination. - return MostPopularDest; + return nullptr; } bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, @@ -1557,8 +1560,12 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, return false; PredValueInfoTy PredValues; - if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI)) - return false; + if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, + CxtI)) { + // We don't have known values in predecessors. See if we can thread through + // BB and its sole predecessor. + return MaybeThreadThroughTwoBasicBlocks(BB, Cond); + } assert(!PredValues.empty() && "ComputeValueKnownInPredecessors returned true with no values"); @@ -1624,7 +1631,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, isa<CallBrInst>(Pred->getTerminator())) continue; - PredToDestList.push_back(std::make_pair(Pred, DestBB)); + PredToDestList.emplace_back(Pred, DestBB); } // If all edges were unthreadable, we fail. @@ -2015,6 +2022,205 @@ JumpThreadingPass::CloneInstructions(BasicBlock::iterator BI, return ValueMapping; } +/// Attempt to thread through two successive basic blocks. +bool JumpThreadingPass::MaybeThreadThroughTwoBasicBlocks(BasicBlock *BB, + Value *Cond) { + // Consider: + // + // PredBB: + // %var = phi i32* [ null, %bb1 ], [ @a, %bb2 ] + // %tobool = icmp eq i32 %cond, 0 + // br i1 %tobool, label %BB, label ... + // + // BB: + // %cmp = icmp eq i32* %var, null + // br i1 %cmp, label ..., label ... + // + // We don't know the value of %var at BB even if we know which incoming edge + // we take to BB. However, once we duplicate PredBB for each of its incoming + // edges (say, PredBB1 and PredBB2), we know the value of %var in each copy of + // PredBB. Then we can thread edges PredBB1->BB and PredBB2->BB through BB. + + // Require that BB end with a Branch for simplicity. + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + if (!CondBr) + return false; + + // BB must have exactly one predecessor. + BasicBlock *PredBB = BB->getSinglePredecessor(); + if (!PredBB) + return false; + + // Require that PredBB end with a conditional Branch. If PredBB ends with an + // unconditional branch, we should be merging PredBB and BB instead. For + // simplicity, we don't deal with a switch. + BranchInst *PredBBBranch = dyn_cast<BranchInst>(PredBB->getTerminator()); + if (!PredBBBranch || PredBBBranch->isUnconditional()) + return false; + + // If PredBB has exactly one incoming edge, we don't gain anything by copying + // PredBB. + if (PredBB->getSinglePredecessor()) + return false; + + // Don't thread through PredBB if it contains a successor edge to itself, in + // which case we would infinite loop. Suppose we are threading an edge from + // PredPredBB through PredBB and BB to SuccBB with PredBB containing a + // successor edge to itself. If we allowed jump threading in this case, we + // could duplicate PredBB and BB as, say, PredBB.thread and BB.thread. Since + // PredBB.thread has a successor edge to PredBB, we would immediately come up + // with another jump threading opportunity from PredBB.thread through PredBB + // and BB to SuccBB. This jump threading would repeatedly occur. That is, we + // would keep peeling one iteration from PredBB. + if (llvm::is_contained(successors(PredBB), PredBB)) + return false; + + // Don't thread across a loop header. + if (LoopHeaders.count(PredBB)) + return false; + + // Avoid complication with duplicating EH pads. + if (PredBB->isEHPad()) + return false; + + // Find a predecessor that we can thread. For simplicity, we only consider a + // successor edge out of BB to which we thread exactly one incoming edge into + // PredBB. + unsigned ZeroCount = 0; + unsigned OneCount = 0; + BasicBlock *ZeroPred = nullptr; + BasicBlock *OnePred = nullptr; + for (BasicBlock *P : predecessors(PredBB)) { + if (ConstantInt *CI = dyn_cast_or_null<ConstantInt>( + EvaluateOnPredecessorEdge(BB, P, Cond))) { + if (CI->isZero()) { + ZeroCount++; + ZeroPred = P; + } else if (CI->isOne()) { + OneCount++; + OnePred = P; + } + } + } + + // Disregard complicated cases where we have to thread multiple edges. + BasicBlock *PredPredBB; + if (ZeroCount == 1) { + PredPredBB = ZeroPred; + } else if (OneCount == 1) { + PredPredBB = OnePred; + } else { + return false; + } + + BasicBlock *SuccBB = CondBr->getSuccessor(PredPredBB == ZeroPred); + + // If threading to the same block as we come from, we would infinite loop. + if (SuccBB == BB) { + LLVM_DEBUG(dbgs() << " Not threading across BB '" << BB->getName() + << "' - would thread to self!\n"); + return false; + } + + // If threading this would thread across a loop header, don't thread the edge. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { + LLVM_DEBUG({ + bool BBIsHeader = LoopHeaders.count(BB); + bool SuccIsHeader = LoopHeaders.count(SuccBB); + dbgs() << " Not threading across " + << (BBIsHeader ? "loop header BB '" : "block BB '") + << BB->getName() << "' to dest " + << (SuccIsHeader ? "loop header BB '" : "block BB '") + << SuccBB->getName() + << "' - it might create an irreducible loop!\n"; + }); + return false; + } + + // Compute the cost of duplicating BB and PredBB. + unsigned BBCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + unsigned PredBBCost = getJumpThreadDuplicationCost( + PredBB, PredBB->getTerminator(), BBDupThreshold); + + // Give up if costs are too high. We need to check BBCost and PredBBCost + // individually before checking their sum because getJumpThreadDuplicationCost + // return (unsigned)~0 for those basic blocks that cannot be duplicated. + if (BBCost > BBDupThreshold || PredBBCost > BBDupThreshold || + BBCost + PredBBCost > BBDupThreshold) { + LLVM_DEBUG(dbgs() << " Not threading BB '" << BB->getName() + << "' - Cost is too high: " << PredBBCost + << " for PredBB, " << BBCost << "for BB\n"); + return false; + } + + // Now we are ready to duplicate PredBB. + ThreadThroughTwoBasicBlocks(PredPredBB, PredBB, BB, SuccBB); + return true; +} + +void JumpThreadingPass::ThreadThroughTwoBasicBlocks(BasicBlock *PredPredBB, + BasicBlock *PredBB, + BasicBlock *BB, + BasicBlock *SuccBB) { + LLVM_DEBUG(dbgs() << " Threading through '" << PredBB->getName() << "' and '" + << BB->getName() << "'\n"); + + BranchInst *CondBr = cast<BranchInst>(BB->getTerminator()); + BranchInst *PredBBBranch = cast<BranchInst>(PredBB->getTerminator()); + + BasicBlock *NewBB = + BasicBlock::Create(PredBB->getContext(), PredBB->getName() + ".thread", + PredBB->getParent(), PredBB); + NewBB->moveAfter(PredBB); + + // Set the block frequency of NewBB. + if (HasProfileData) { + auto NewBBFreq = BFI->getBlockFreq(PredPredBB) * + BPI->getEdgeProbability(PredPredBB, PredBB); + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } + + // We are going to have to map operands from the original BB block to the new + // copy of the block 'NewBB'. If there are PHI nodes in PredBB, evaluate them + // to account for entry from PredPredBB. + DenseMap<Instruction *, Value *> ValueMapping = + CloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB); + + // Update the terminator of PredPredBB to jump to NewBB instead of PredBB. + // This eliminates predecessors from PredPredBB, which requires us to simplify + // any PHI nodes in PredBB. + Instruction *PredPredTerm = PredPredBB->getTerminator(); + for (unsigned i = 0, e = PredPredTerm->getNumSuccessors(); i != e; ++i) + if (PredPredTerm->getSuccessor(i) == PredBB) { + PredBB->removePredecessor(PredPredBB, true); + PredPredTerm->setSuccessor(i, NewBB); + } + + AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(0), PredBB, NewBB, + ValueMapping); + AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(1), PredBB, NewBB, + ValueMapping); + + DTU->applyUpdatesPermissive( + {{DominatorTree::Insert, NewBB, CondBr->getSuccessor(0)}, + {DominatorTree::Insert, NewBB, CondBr->getSuccessor(1)}, + {DominatorTree::Insert, PredPredBB, NewBB}, + {DominatorTree::Delete, PredPredBB, PredBB}}); + + UpdateSSA(PredBB, NewBB, ValueMapping); + + // Clean up things like PHI nodes with single operands, dead instructions, + // etc. + SimplifyInstructionsInBlock(NewBB, TLI); + SimplifyInstructionsInBlock(PredBB, TLI); + + SmallVector<BasicBlock *, 1> PredsToFactor; + PredsToFactor.push_back(NewBB); + ThreadEdge(BB, PredsToFactor, SuccBB); +} + /// TryThreadEdge - Thread an edge if it's safe and profitable to do so. bool JumpThreadingPass::TryThreadEdge( BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &PredBBs, @@ -2078,10 +2284,6 @@ void JumpThreadingPass::ThreadEdge(BasicBlock *BB, << "' to '" << SuccBB->getName() << ", across block:\n " << *BB << "\n"); - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); LVI->threadEdge(PredBB, BB, SuccBB); BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), @@ -2246,8 +2448,7 @@ void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, } // Update edge probabilities in BPI. - for (int I = 0, E = BBSuccProbs.size(); I < E; I++) - BPI->setEdgeProbability(BB, I, BBSuccProbs[I]); + BPI->setEdgeProbability(BB, BBSuccProbs); // Update the profile metadata as well. // @@ -2524,10 +2725,6 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // Now check if one of the select values would allow us to constant fold the // terminator in BB. We don't do the transform if both sides fold, those // cases will be threaded in any case. - if (DTU->hasPendingDomTreeUpdates()) - LVI->disableDT(); - else - LVI->enableDT(); LazyValueInfo::Tristate LHSFolds = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1), CondRHS, Pred, BB, CondCmp); @@ -2565,6 +2762,16 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { /// select is not jump-threaded, it will be folded again in the later /// optimizations. bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { + // This transform can introduce a UB (a conditional branch that depends on a + // poison value) that was not present in the original program. See + // @TryToUnfoldSelectInCurrBB test in test/Transforms/JumpThreading/select.ll. + // Disable this transform under MemorySanitizer. + // FIXME: either delete it or replace with a valid transform. This issue is + // not limited to MemorySanitizer (but has only been observed as an MSan false + // positive in practice so far). + if (BB->getParent()->hasFnAttribute(Attribute::SanitizeMemory)) + return false; + // If threading this would thread across a loop header, don't thread the edge. // See the comments above FindLoopHeaders for justifications and caveats. if (LoopHeaders.count(BB)) diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index 8c33045c2380..1a22edaf8726 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -46,6 +46,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" @@ -69,6 +70,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -151,11 +153,11 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst, const Instruction *CtxI = nullptr); static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA); + AAResults *AA); static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, SinkAndHoistLICMFlags &Flags); -static Instruction *CloneInstructionInExitBlock( +static Instruction *cloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); @@ -168,27 +170,24 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest, namespace { struct LoopInvariantCodeMotion { - using ASTrackerMapTy = DenseMap<Loop *, std::unique_ptr<AliasSetTracker>>; - bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + bool runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, MemorySSA *MSSA, - OptimizationRemarkEmitter *ORE, bool DeleteAST); + OptimizationRemarkEmitter *ORE); - ASTrackerMapTy &getLoopToAliasSetMap() { return LoopToAliasSetMap; } LoopInvariantCodeMotion(unsigned LicmMssaOptCap, unsigned LicmMssaNoAccForPromotionCap) : LicmMssaOptCap(LicmMssaOptCap), LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap) {} private: - ASTrackerMapTy LoopToAliasSetMap; unsigned LicmMssaOptCap; unsigned LicmMssaNoAccForPromotionCap; std::unique_ptr<AliasSetTracker> - collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA); + collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AAResults *AA); std::unique_ptr<AliasSetTracker> - collectAliasInfoForLoopWithMSSA(Loop *L, AliasAnalysis *AA, + collectAliasInfoForLoopWithMSSA(Loop *L, AAResults *AA, MemorySSAUpdater *MSSAU); }; @@ -202,13 +201,8 @@ struct LegacyLICMPass : public LoopPass { } bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) { - // If we have run LICM on a previous loop but now we are skipping - // (because we've hit the opt-bisect limit), we need to clear the - // loop alias information. - LICM.getLoopToAliasSetMap().clear(); + if (skipLoop(L)) return false; - } auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); MemorySSA *MSSA = EnableMSSALoopDependency @@ -226,7 +220,7 @@ struct LegacyLICMPass : public LoopPass { *L->getHeader()->getParent()), &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *L->getHeader()->getParent()), - SE ? &SE->getSE() : nullptr, MSSA, &ORE, false); + SE ? &SE->getSE() : nullptr, MSSA, &ORE); } /// This transformation requires natural loop information & requires that @@ -244,53 +238,21 @@ struct LegacyLICMPass : public LoopPass { getLoopAnalysisUsage(AU); } - using llvm::Pass::doFinalization; - - bool doFinalization() override { - auto &AliasSetMap = LICM.getLoopToAliasSetMap(); - // All loops in the AliasSetMap should be cleaned up already. The only case - // where we fail to do so is if an outer loop gets deleted before LICM - // visits it. - assert(all_of(AliasSetMap, - [](LoopInvariantCodeMotion::ASTrackerMapTy::value_type &KV) { - return !KV.first->getParentLoop(); - }) && - "Didn't free loop alias sets"); - AliasSetMap.clear(); - return false; - } - private: LoopInvariantCodeMotion LICM; - - /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. - void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, - Loop *L) override; - - /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias - /// set. - void deleteAnalysisValue(Value *V, Loop *L) override; - - /// Simple Analysis hook. Delete loop L from alias set map. - void deleteAnalysisLoop(Loop *L) override; }; } // namespace PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); - Function *F = L.getHeader()->getParent(); - - auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); - // FIXME: This should probably be optional rather than required. - if (!ORE) - report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not " - "cached at a higher level"); + // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap); if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE, - AR.MSSA, ORE, true)) + AR.MSSA, &ORE)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); @@ -322,13 +284,10 @@ Pass *llvm::createLICMPass(unsigned LicmMssaOptCap, /// Hoist expressions out of the specified loop. Note, alias info for inner /// loop is not preserved so it is not a good idea to run LICM multiple /// times on one loop. -/// We should delete AST for inner loops in the new pass manager to avoid -/// memory leak. -/// bool LoopInvariantCodeMotion::runOnLoop( - Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, - MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST) { + MemorySSA *MSSA, OptimizationRemarkEmitter *ORE) { bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); @@ -372,7 +331,7 @@ bool LoopInvariantCodeMotion::runOnLoop( BasicBlock *Preheader = L->getLoopPreheader(); // Compute loop safety information. - ICFLoopSafetyInfo SafetyInfo(DT); + ICFLoopSafetyInfo SafetyInfo; SafetyInfo.computeLoopSafetyInfo(L); // We want to visit all of the instructions in this loop... that are not parts @@ -476,11 +435,6 @@ bool LoopInvariantCodeMotion::runOnLoop( assert((!L->getParentLoop() || L->getParentLoop()->isLCSSAForm(*DT)) && "Parent loop not left in LCSSA form after LICM!"); - // If this loop is nested inside of another one, save the alias information - // for when we process the outer loop. - if (!MSSAU.get() && CurAST.get() && L->getParentLoop() && !DeleteAST) - LoopToAliasSetMap[L] = std::move(CurAST); - if (MSSAU.get() && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -494,7 +448,7 @@ bool LoopInvariantCodeMotion::runOnLoop( /// first order w.r.t the DominatorTree. This allows us to visit uses before /// definitions, allowing us to sink a loop body in one pass without iteration. /// -bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, +bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, @@ -529,6 +483,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // used in the loop, instead, just delete it. if (isInstructionTriviallyDead(&I, TLI)) { LLVM_DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); + salvageKnowledge(&I); salvageDebugInfo(I); ++II; eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); @@ -542,13 +497,14 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // operands of the instruction are loop invariant. // bool FreeInLoop = false; - if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && + if (!I.mayHaveSideEffects() && + isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, - ORE) && - !I.mayHaveSideEffects()) { + ORE)) { if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) { if (!FreeInLoop) { ++II; + salvageDebugInfo(I); eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); } Changed = true; @@ -790,47 +746,12 @@ public: }; } // namespace - -/// Return true if we know how to rewrite all uses of the given alloca after -/// hoisting it out of the loop. The main concerns are a) potential captures -/// and b) invariant.start markers which don't capture, but are no longer -/// valid w/o a corresponding invariant.end. -static bool canRewriteUsesOfAlloca(AllocaInst &AI) { - // TODO: This looks a lot like capture tracking, but we need to remove any - // invariant starts if we extend the lifetime of the alloca by hoisting it. - // We should probably refactor capture tracking into a form which allows us - // to reuse the relevant bits and remove the duplicated logic here. - - SmallVector<Use *, 16> Worklist; - for (Use &U : AI.uses()) - Worklist.push_back(&U); - - unsigned NumUsesExplored = 0; - while (!Worklist.empty()) { - Use *U = Worklist.pop_back_val(); - Instruction *I = cast<Instruction>(U->getUser()); - NumUsesExplored++; - if (NumUsesExplored > DefaultMaxUsesToExplore) - return false; - // Non capturing, terminating uses - if (isa<LoadInst>(I) || - (isa<StoreInst>(I) && U->getOperandNo() == 1)) - continue; - // Non capturing, non-terminating - if (!isa<BitCastInst>(I) && !isa<GetElementPtrInst>(I)) - return false; - for (Use &U : I->uses()) - Worklist.push_back(&U); - } - return true; -} - /// Walk the specified region of the CFG (defined by all blocks dominated by /// the specified block, and that are in the current loop) in depth first /// order w.r.t the DominatorTree. This allows us to visit definitions before /// uses, allowing us to hoist a loop body in one pass without iteration. /// -bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, +bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo, @@ -901,9 +822,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // Attempt to remove floating point division out of the loop by // converting it to a reciprocal multiplication. - if (I.getOpcode() == Instruction::FDiv && - CurLoop->isLoopInvariant(I.getOperand(1)) && - I.hasAllowReciprocal()) { + if (I.getOpcode() == Instruction::FDiv && I.hasAllowReciprocal() && + CurLoop->isLoopInvariant(I.getOperand(1))) { auto Divisor = I.getOperand(1); auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); @@ -945,16 +865,6 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, continue; } - if (isa<AllocaInst>(&I) && - SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) && - canRewriteUsesOfAlloca(cast<AllocaInst>(I))) { - hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, - MSSAU, SE, ORE); - HoistedInstructions.push_back(&I); - Changed = true; - continue; - } - if (PHINode *PN = dyn_cast<PHINode>(&I)) { if (CFH.canHoistPHI(PN)) { // Redirect incoming blocks first to ensure that we create hoisted @@ -1081,12 +991,12 @@ namespace { bool isHoistableAndSinkableInst(Instruction &I) { // Only these instructions are hoistable/sinkable. return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) || - isa<FenceInst>(I) || isa<CastInst>(I) || - isa<UnaryOperator>(I) || isa<BinaryOperator>(I) || - isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || + isa<FenceInst>(I) || isa<CastInst>(I) || isa<UnaryOperator>(I) || + isa<BinaryOperator>(I) || isa<SelectInst>(I) || + isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || - isa<InsertValueInst>(I)); + isa<InsertValueInst>(I) || isa<FreezeInst>(I)); } /// Return true if all of the alias sets within this AST are known not to /// contain a Mod, or if MSSA knows thare are no MemoryDefs in the loop. @@ -1198,11 +1108,11 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI); if (Behavior == FMRB_DoesNotAccessMemory) return true; - if (AliasAnalysis::onlyReadsMemory(Behavior)) { + if (AAResults::onlyReadsMemory(Behavior)) { // A readonly argmemonly function only reads from memory pointed to by // it's arguments with arbitrary offsets. If we can prove there are no // writes to this memory in the loop, we can hoist or sink. - if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) { + if (AAResults::onlyAccessesArgPointees(Behavior)) { // TODO: expand to writeable arguments for (Value *Op : CI->arg_operands()) if (Op->getType()->isPointerTy()) { @@ -1351,7 +1261,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, const TargetTransformInfo *TTI) { if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) { - if (TTI->getUserCost(GEP) != TargetTransformInfo::TCC_Free) + if (TTI->getUserCost(GEP, TargetTransformInfo::TCK_SizeAndLatency) != + TargetTransformInfo::TCC_Free) return false; // For a GEP, we cannot simply use getUserCost because currently it // optimistically assume that a GEP will fold into addressing mode @@ -1366,7 +1277,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, } return true; } else - return TTI->getUserCost(&I) == TargetTransformInfo::TCC_Free; + return TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency) == + TargetTransformInfo::TCC_Free; } /// Return true if the only users of this instruction are outside of @@ -1407,7 +1319,7 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, return true; } -static Instruction *CloneInstructionInExitBlock( +static Instruction *cloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU) { Instruction *New; @@ -1520,7 +1432,7 @@ static Instruction *sinkThroughTriviallyReplaceablePHI( if (It != SunkCopies.end()) New = It->second; else - New = SunkCopies[ExitBlock] = CloneInstructionInExitBlock( + New = SunkCopies[ExitBlock] = cloneInstructionInExitBlock( *I, *ExitBlock, *TPN, LI, SafetyInfo, MSSAU); return New; } @@ -1537,7 +1449,8 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { return false; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *BBPred = *PI; - if (isa<IndirectBrInst>(BBPred->getTerminator())) + if (isa<IndirectBrInst>(BBPred->getTerminator()) || + isa<CallBrInst>(BBPred->getTerminator())) return false; } return true; @@ -1857,7 +1770,7 @@ public: StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); if (UnorderedAtomic) NewSI->setOrdering(AtomicOrdering::Unordered); - NewSI->setAlignment(MaybeAlign(Alignment)); + NewSI->setAlignment(Align(Alignment)); NewSI->setDebugLoc(DL); if (AATags) NewSI->setAAMetadata(AATags); @@ -1981,7 +1894,7 @@ bool llvm::promoteLoopAccessesToScalars( // We start with an alignment of one and try to find instructions that allow // us to prove better alignment. - unsigned Alignment = 1; + Align Alignment; // Keep track of which types of access we see bool SawUnorderedAtomic = false; bool SawNotAtomic = false; @@ -2029,10 +1942,7 @@ bool llvm::promoteLoopAccessesToScalars( SawUnorderedAtomic |= Load->isAtomic(); SawNotAtomic |= !Load->isAtomic(); - unsigned InstAlignment = Load->getAlignment(); - if (!InstAlignment) - InstAlignment = - MDL.getABITypeAlignment(Load->getType()); + Align InstAlignment = Load->getAlign(); // Note that proving a load safe to speculate requires proving // sufficient alignment at the target location. Proving it guaranteed @@ -2060,10 +1970,7 @@ bool llvm::promoteLoopAccessesToScalars( // already know that promotion is safe, since it may have higher // alignment than any other guaranteed stores, in which case we can // raise the alignment on the promoted store. - unsigned InstAlignment = Store->getAlignment(); - if (!InstAlignment) - InstAlignment = - MDL.getABITypeAlignment(Store->getValueOperand()->getType()); + Align InstAlignment = Store->getAlign(); if (!DereferenceableInPH || !SafeToInsertStore || (InstAlignment > Alignment)) { @@ -2090,8 +1997,7 @@ bool llvm::promoteLoopAccessesToScalars( if (!DereferenceableInPH) { DereferenceableInPH = isDereferenceableAndAlignedPointer( Store->getPointerOperand(), Store->getValueOperand()->getType(), - MaybeAlign(Store->getAlignment()), MDL, - Preheader->getTerminator(), DT); + Store->getAlign(), MDL, Preheader->getTerminator(), DT); } } else return false; // Not a load or store. @@ -2156,18 +2062,19 @@ bool llvm::promoteLoopAccessesToScalars( }); ++NumPromoted; - // Grab a debug location for the inserted loads/stores; given that the - // inserted loads/stores have little relation to the original loads/stores, - // this code just arbitrarily picks a location from one, since any debug - // location is better than none. - DebugLoc DL = LoopUses[0]->getDebugLoc(); + // Look at all the loop uses, and try to merge their locations. + std::vector<const DILocation *> LoopUsesLocs; + for (auto U : LoopUses) + LoopUsesLocs.push_back(U->getDebugLoc().get()); + auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs)); // We use the SSAUpdater interface to insert phi nodes as required. SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, *CurAST, MSSAU, *LI, DL, - Alignment, SawUnorderedAtomic, AATags, *SafetyInfo); + Alignment.value(), SawUnorderedAtomic, AATags, + *SafetyInfo); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. @@ -2176,8 +2083,8 @@ bool llvm::promoteLoopAccessesToScalars( SomePtr->getName() + ".promoted", Preheader->getTerminator()); if (SawUnorderedAtomic) PreheaderLoad->setOrdering(AtomicOrdering::Unordered); - PreheaderLoad->setAlignment(MaybeAlign(Alignment)); - PreheaderLoad->setDebugLoc(DL); + PreheaderLoad->setAlignment(Alignment); + PreheaderLoad->setDebugLoc(DebugLoc()); if (AATags) PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); @@ -2206,41 +2113,13 @@ bool llvm::promoteLoopAccessesToScalars( /// Returns an owning pointer to an alias set which incorporates aliasing info /// from L and all subloops of L. -/// FIXME: In new pass manager, there is no helper function to handle loop -/// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed -/// from scratch for every loop. Hook up with the helper functions when -/// available in the new pass manager to avoid redundant computation. std::unique_ptr<AliasSetTracker> LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, - AliasAnalysis *AA) { - std::unique_ptr<AliasSetTracker> CurAST; - SmallVector<Loop *, 4> RecomputeLoops; - for (Loop *InnerL : L->getSubLoops()) { - auto MapI = LoopToAliasSetMap.find(InnerL); - // If the AST for this inner loop is missing it may have been merged into - // some other loop's AST and then that loop unrolled, and so we need to - // recompute it. - if (MapI == LoopToAliasSetMap.end()) { - RecomputeLoops.push_back(InnerL); - continue; - } - std::unique_ptr<AliasSetTracker> InnerAST = std::move(MapI->second); + AAResults *AA) { + auto CurAST = std::make_unique<AliasSetTracker>(*AA); - if (CurAST) { - // What if InnerLoop was modified by other passes ? - // Once we've incorporated the inner loop's AST into ours, we don't need - // the subloop's anymore. - CurAST->add(*InnerAST); - } else { - CurAST = std::move(InnerAST); - } - LoopToAliasSetMap.erase(MapI); - } - if (!CurAST) - CurAST = std::make_unique<AliasSetTracker>(*AA); - - // Add everything from the sub loops that are no longer directly available. - for (Loop *InnerL : RecomputeLoops) + // Add everything from all the sub loops. + for (Loop *InnerL : L->getSubLoops()) for (BasicBlock *BB : InnerL->blocks()) CurAST->add(*BB); @@ -2254,46 +2133,16 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, std::unique_ptr<AliasSetTracker> LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA( - Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) { + Loop *L, AAResults *AA, MemorySSAUpdater *MSSAU) { auto *MSSA = MSSAU->getMemorySSA(); auto CurAST = std::make_unique<AliasSetTracker>(*AA, MSSA, L); CurAST->addAllInstructionsInLoopUsingMSSA(); return CurAST; } -/// Simple analysis hook. Clone alias set info. -/// -void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, - Loop *L) { - auto ASTIt = LICM.getLoopToAliasSetMap().find(L); - if (ASTIt == LICM.getLoopToAliasSetMap().end()) - return; - - ASTIt->second->copyValue(From, To); -} - -/// Simple Analysis hook. Delete value V from alias set -/// -void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) { - auto ASTIt = LICM.getLoopToAliasSetMap().find(L); - if (ASTIt == LICM.getLoopToAliasSetMap().end()) - return; - - ASTIt->second->deleteValue(V); -} - -/// Simple Analysis hook. Delete value L from alias set map. -/// -void LegacyLICMPass::deleteAnalysisLoop(Loop *L) { - if (!LICM.getLoopToAliasSetMap().count(L)) - return; - - LICM.getLoopToAliasSetMap().erase(L); -} - static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA) { + AAResults *AA) { // First check to see if any of the basic blocks in CurLoop invalidate *V. bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index ab65f56d088f..687e14d6d7d2 100644 --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -21,7 +21,6 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" @@ -32,6 +31,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; @@ -61,10 +61,10 @@ namespace { /// Loop prefetch implementation class. class LoopDataPrefetch { public: - LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, - const TargetTransformInfo *TTI, + LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, + ScalarEvolution *SE, const TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE) - : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} + : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} bool run(); @@ -73,12 +73,16 @@ private: /// Check if the stride of the accesses is large enough to /// warrant a prefetch. - bool isStrideLargeEnough(const SCEVAddRecExpr *AR); + bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride); - unsigned getMinPrefetchStride() { + unsigned getMinPrefetchStride(unsigned NumMemAccesses, + unsigned NumStridedMemAccesses, + unsigned NumPrefetches, + bool HasCall) { if (MinPrefetchStride.getNumOccurrences() > 0) return MinPrefetchStride; - return TTI->getMinPrefetchStride(); + return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, + NumPrefetches, HasCall); } unsigned getPrefetchDistance() { @@ -93,7 +97,14 @@ private: return TTI->getMaxPrefetchIterationsAhead(); } + bool doPrefetchWrites() { + if (PrefetchWrites.getNumOccurrences() > 0) + return PrefetchWrites; + return TTI->enableWritePrefetching(); + } + AssumptionCache *AC; + DominatorTree *DT; LoopInfo *LI; ScalarEvolution *SE; const TargetTransformInfo *TTI; @@ -110,6 +121,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); @@ -138,8 +150,8 @@ FunctionPass *llvm::createLoopDataPrefetchPass() { return new LoopDataPrefetchLegacyPass(); } -bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { - unsigned TargetMinStride = getMinPrefetchStride(); +bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR, + unsigned TargetMinStride) { // No need to check if any stride goes. if (TargetMinStride <= 1) return true; @@ -156,6 +168,7 @@ bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { PreservedAnalyses LoopDataPrefetchPass::run(Function &F, FunctionAnalysisManager &AM) { + DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); @@ -163,7 +176,7 @@ PreservedAnalyses LoopDataPrefetchPass::run(Function &F, &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); - LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE); bool Changed = LDP.run(); if (Changed) { @@ -180,6 +193,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); AssumptionCache *AC = @@ -189,7 +203,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { const TargetTransformInfo *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE); return LDP.run(); } @@ -210,6 +224,49 @@ bool LoopDataPrefetch::run() { return MadeChange; } +/// A record for a potential prefetch made during the initial scan of the +/// loop. This is used to let a single prefetch target multiple memory accesses. +struct Prefetch { + /// The address formula for this prefetch as returned by ScalarEvolution. + const SCEVAddRecExpr *LSCEVAddRec; + /// The point of insertion for the prefetch instruction. + Instruction *InsertPt; + /// True if targeting a write memory access. + bool Writes; + /// The (first seen) prefetched instruction. + Instruction *MemI; + + /// Constructor to create a new Prefetch for \p I. + Prefetch(const SCEVAddRecExpr *L, Instruction *I) + : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) { + addInstruction(I); + }; + + /// Add the instruction \param I to this prefetch. If it's not the first + /// one, 'InsertPt' and 'Writes' will be updated as required. + /// \param PtrDiff the known constant address difference to the first added + /// instruction. + void addInstruction(Instruction *I, DominatorTree *DT = nullptr, + int64_t PtrDiff = 0) { + if (!InsertPt) { + MemI = I; + InsertPt = I; + Writes = isa<StoreInst>(I); + } else { + BasicBlock *PrefBB = InsertPt->getParent(); + BasicBlock *InsBB = I->getParent(); + if (PrefBB != InsBB) { + BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB); + if (DomBB != PrefBB) + InsertPt = DomBB->getTerminator(); + } + + if (isa<StoreInst>(I) && PtrDiff == 0) + Writes = true; + } + } +}; + bool LoopDataPrefetch::runOnLoop(Loop *L) { bool MadeChange = false; @@ -222,15 +279,22 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { // Calculate the number of iterations ahead to prefetch CodeMetrics Metrics; + bool HasCall = false; for (const auto BB : L->blocks()) { // If the loop already has prefetches, then assume that the user knows // what they are doing and don't add any more. - for (auto &I : *BB) - if (CallInst *CI = dyn_cast<CallInst>(&I)) - if (Function *F = CI->getCalledFunction()) + for (auto &I : *BB) { + if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) { + if (const Function *F = cast<CallBase>(I).getCalledFunction()) { if (F->getIntrinsicID() == Intrinsic::prefetch) return MadeChange; - + if (TTI->isLoweredToCall(F)) + HasCall = true; + } else { // indirect call. + HasCall = true; + } + } + } Metrics.analyzeBasicBlock(BB, *TTI, EphValues); } unsigned LoopSize = Metrics.NumInsts; @@ -244,12 +308,14 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (ItersAhead > getMaxPrefetchIterationsAhead()) return MadeChange; - LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead - << " iterations ahead (loop size: " << LoopSize << ") in " - << L->getHeader()->getParent()->getName() << ": " << *L); + unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L); + if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1) + return MadeChange; - SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; - for (const auto BB : L->blocks()) { + unsigned NumMemAccesses = 0; + unsigned NumStridedMemAccesses = 0; + SmallVector<Prefetch, 16> Prefetches; + for (const auto BB : L->blocks()) for (auto &I : *BB) { Value *PtrValue; Instruction *MemI; @@ -258,7 +324,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { MemI = LMemI; PtrValue = LMemI->getPointerOperand(); } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { - if (!PrefetchWrites) continue; + if (!doPrefetchWrites()) continue; MemI = SMemI; PtrValue = SMemI->getPointerOperand(); } else continue; @@ -266,7 +332,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); if (PtrAddrSpace) continue; - + NumMemAccesses++; if (L->isLoopInvariant(PtrValue)) continue; @@ -274,62 +340,79 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); if (!LSCEVAddRec) continue; + NumStridedMemAccesses++; - // Check if the stride of the accesses is large enough to warrant a - // prefetch. - if (!isStrideLargeEnough(LSCEVAddRec)) - continue; - - // We don't want to double prefetch individual cache lines. If this load - // is known to be within one cache line of some other load that has - // already been prefetched, then don't prefetch this one as well. + // We don't want to double prefetch individual cache lines. If this + // access is known to be within one cache line of some other one that + // has already been prefetched, then don't prefetch this one as well. bool DupPref = false; - for (const auto &PrefLoad : PrefLoads) { - const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); + for (auto &Pref : Prefetches) { + const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec); if (const SCEVConstant *ConstPtrDiff = dyn_cast<SCEVConstant>(PtrDiff)) { int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); if (PD < (int64_t) TTI->getCacheLineSize()) { + Pref.addInstruction(MemI, DT, PD); DupPref = true; break; } } } - if (DupPref) - continue; + if (!DupPref) + Prefetches.push_back(Prefetch(LSCEVAddRec, MemI)); + } - const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( - SE->getConstant(LSCEVAddRec->getType(), ItersAhead), - LSCEVAddRec->getStepRecurrence(*SE))); - if (!isSafeToExpand(NextLSCEV, *SE)) - continue; + unsigned TargetMinStride = + getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, + Prefetches.size(), HasCall); - PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); - - Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); - SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); - Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); - - IRBuilder<> Builder(MemI); - Module *M = BB->getParent()->getParent(); - Type *I32 = Type::getInt32Ty(BB->getContext()); - Function *PrefetchFunc = Intrinsic::getDeclaration( - M, Intrinsic::prefetch, PrefPtrValue->getType()); - Builder.CreateCall( - PrefetchFunc, - {PrefPtrValue, - ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), - ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); - ++NumPrefetches; - LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV - << "\n"); - ORE->emit([&]() { - return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) - << "prefetched memory access"; + LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead + << " iterations ahead (loop size: " << LoopSize << ") in " + << L->getHeader()->getParent()->getName() << ": " << *L); + LLVM_DEBUG(dbgs() << "Loop has: " + << NumMemAccesses << " memory accesses, " + << NumStridedMemAccesses << " strided memory accesses, " + << Prefetches.size() << " potential prefetch(es), " + << "a minimum stride of " << TargetMinStride << ", " + << (HasCall ? "calls" : "no calls") << ".\n"); + + for (auto &P : Prefetches) { + // Check if the stride of the accesses is large enough to warrant a + // prefetch. + if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride)) + continue; + + const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr( + SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead), + P.LSCEVAddRec->getStepRecurrence(*SE))); + if (!isSafeToExpand(NextLSCEV, *SE)) + continue; + + BasicBlock *BB = P.InsertPt->getParent(); + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); + SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr"); + Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); + + IRBuilder<> Builder(P.InsertPt); + Module *M = BB->getParent()->getParent(); + Type *I32 = Type::getInt32Ty(BB->getContext()); + Function *PrefetchFunc = Intrinsic::getDeclaration( + M, Intrinsic::prefetch, PrefPtrValue->getType()); + Builder.CreateCall( + PrefetchFunc, + {PrefPtrValue, + ConstantInt::get(I32, P.Writes), + ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); + ++NumPrefetches; + LLVM_DEBUG(dbgs() << " Access: " + << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1) + << ", SCEV: " << *P.LSCEVAddRec << "\n"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI) + << "prefetched memory access"; }); - MadeChange = true; - } + MadeChange = true; } return MadeChange; diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 2451572d6171..be209d34be42 100644 --- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -18,6 +18,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" @@ -134,7 +136,9 @@ static bool isLoopNeverExecuted(Loop *L) { /// is unable to delete it due to hoisting trivially loop invariant /// instructions out of the loop. static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, - ScalarEvolution &SE, LoopInfo &LI) { + ScalarEvolution &SE, LoopInfo &LI, + MemorySSA *MSSA, + OptimizationRemarkEmitter &ORE) { assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); // We can only remove the loop if there is a preheader that we can branch from @@ -164,7 +168,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, std::fill(P.incoming_values().begin(), P.incoming_values().end(), UndefValue::get(P.getType())); } - deleteDeadLoop(L, &DT, &SE, &LI); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "NeverExecutes", L->getStartLoc(), + L->getHeader()) + << "Loop deleted because it never executes"; + }); + deleteDeadLoop(L, &DT, &SE, &LI, MSSA); ++NumDeleted; return LoopDeletionResult::Deleted; } @@ -200,7 +209,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, } LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!"); - deleteDeadLoop(L, &DT, &SE, &LI); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Invariant", L->getStartLoc(), + L->getHeader()) + << "Loop deleted because it is invariant"; + }); + deleteDeadLoop(L, &DT, &SE, &LI, MSSA); ++NumDeleted; return LoopDeletionResult::Deleted; @@ -212,15 +226,22 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: "); LLVM_DEBUG(L.dump()); - std::string LoopName = L.getName(); - auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI); + std::string LoopName = std::string(L.getName()); + // For the new PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); + auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, AR.MSSA, ORE); if (Result == LoopDeletionResult::Unmodified) return PreservedAnalyses::all(); if (Result == LoopDeletionResult::Deleted) Updater.markLoopAsDeleted(L, LoopName); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } namespace { @@ -235,6 +256,7 @@ public: bool runOnLoop(Loop *L, LPPassManager &) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); } }; @@ -255,11 +277,19 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + MemorySSA *MSSA = nullptr; + if (MSSAAnalysis) + MSSA = &MSSAAnalysis->getMSSA(); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: "); LLVM_DEBUG(L->dump()); - LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI); + LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI, MSSA, ORE); if (Result == LoopDeletionResult::Deleted) LPM.markLoopAsDeleted(*L); diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 8e04e6e0ffe8..7867a5468891 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -789,12 +789,6 @@ public: // instructions to partitions. Partitions.setupPartitionIdOnInstructions(); - // To keep things simple have an empty preheader before we version or clone - // the loop. (Also split if this has no predecessor, i.e. entry, because we - // rely on PH having a predecessor.) - if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) - SplitBlock(PH, PH->getTerminator(), DT, LI); - // If we need run-time checks, version the loop now. auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI); const auto *RtPtrChecking = LAI->getRuntimePointerChecking(); @@ -807,6 +801,12 @@ public: "may not insert runtime check with convergent operation"); } + // To keep things simple have an empty preheader before we version or clone + // the loop. (Also split if this has no predecessor, i.e. entry, because we + // rely on PH having a predecessor.) + if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) + SplitBlock(PH, PH->getTerminator(), DT, LI); + if (!Pred.isAlwaysTrue() || !Checks.empty()) { assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning"); @@ -903,15 +903,14 @@ private: /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this /// case we can't safely omit the check. - SmallVector<RuntimePointerChecking::PointerCheck, 4> - includeOnlyCrossPartitionChecks( - const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks, + SmallVector<RuntimePointerCheck, 4> includeOnlyCrossPartitionChecks( + const SmallVectorImpl<RuntimePointerCheck> &AllChecks, const SmallVectorImpl<int> &PtrToPartition, const RuntimePointerChecking *RtPtrChecking) { - SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; + SmallVector<RuntimePointerCheck, 4> Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (unsigned PtrIdx1 : Check.first->Members) for (unsigned PtrIdx2 : Check.second->Members) // Only include this check if there is a pair of pointers diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index e1738f08eb23..20edc8699d79 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -86,11 +86,15 @@ STATISTIC(UnknownTripCount, "Loop has unknown trip count"); STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop"); STATISTIC(NonEqualTripCount, "Loop trip counts are not the same"); STATISTIC(NonAdjacent, "Loops are not adjacent"); -STATISTIC(NonEmptyPreheader, "Loop has a non-empty preheader"); +STATISTIC( + NonEmptyPreheader, + "Loop has a non-empty preheader with instructions that cannot be moved"); STATISTIC(FusionNotBeneficial, "Fusion is not beneficial"); STATISTIC(NonIdenticalGuards, "Candidates have different guards"); -STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block"); -STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block"); +STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block with " + "instructions that cannot be moved"); +STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block with " + "instructions that cannot be moved"); STATISTIC(NotRotated, "Candidate is not rotated"); enum FusionDependenceAnalysisChoice { @@ -738,33 +742,40 @@ private: continue; } - // The following three checks look for empty blocks in FC0 and FC1. If - // any of these blocks are non-empty, we do not fuse. This is done - // because we currently do not have the safety checks to determine if - // it is safe to move the blocks past other blocks in the loop. Once - // these checks are added, these conditions can be relaxed. - if (!isEmptyPreheader(*FC1)) { - LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty " - "preheader. Not fusing.\n"); + if (!isSafeToMoveBefore(*FC1->Preheader, + *FC0->Preheader->getTerminator(), DT, &PDT, + &DI)) { + LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe " + "instructions in preheader. Not fusing.\n"); reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, NonEmptyPreheader); continue; } - if (FC0->GuardBranch && !isEmptyExitBlock(*FC0)) { - LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty exit " - "block. Not fusing.\n"); - reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, - NonEmptyExitBlock); - continue; - } + if (FC0->GuardBranch) { + assert(FC1->GuardBranch && "Expecting valid FC1 guard branch"); + + if (!isSafeToMoveBefore(*FC0->ExitBlock, + *FC1->ExitBlock->getFirstNonPHIOrDbg(), DT, + &PDT, &DI)) { + LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe " + "instructions in exit block. Not fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyExitBlock); + continue; + } - if (FC1->GuardBranch && !isEmptyGuardBlock(*FC1)) { - LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty guard " - "block. Not fusing.\n"); - reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, - NonEmptyGuardBlock); - continue; + if (!isSafeToMoveBefore( + *FC1->GuardBranch->getParent(), + *FC0->GuardBranch->getParent()->getTerminator(), DT, &PDT, + &DI)) { + LLVM_DEBUG(dbgs() + << "Fusion candidate contains unsafe " + "instructions in guard block. Not fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyGuardBlock); + continue; + } } // Check the dependencies across the loops and do not fuse if it would @@ -1075,38 +1086,6 @@ private: return (FC1.GuardBranch->getSuccessor(1) == FC1.Preheader); } - /// Check that the guard for \p FC *only* contains the cmp/branch for the - /// guard. - /// Once we are able to handle intervening code, any code in the guard block - /// for FC1 will need to be treated as intervening code and checked whether - /// it can safely move around the loops. - bool isEmptyGuardBlock(const FusionCandidate &FC) const { - assert(FC.GuardBranch && "Expecting a fusion candidate with guard branch."); - if (auto *CmpInst = dyn_cast<Instruction>(FC.GuardBranch->getCondition())) { - auto *GuardBlock = FC.GuardBranch->getParent(); - // If the generation of the cmp value is in GuardBlock, then the size of - // the guard block should be 2 (cmp + branch). If the generation of the - // cmp value is in a different block, then the size of the guard block - // should only be 1. - if (CmpInst->getParent() == GuardBlock) - return GuardBlock->size() == 2; - else - return GuardBlock->size() == 1; - } - - return false; - } - - bool isEmptyPreheader(const FusionCandidate &FC) const { - assert(FC.Preheader && "Expecting a valid preheader"); - return FC.Preheader->size() == 1; - } - - bool isEmptyExitBlock(const FusionCandidate &FC) const { - assert(FC.ExitBlock && "Expecting a valid exit block"); - return FC.ExitBlock->size() == 1; - } - /// Simplify the condition of the latch branch of \p FC to true, when both of /// its successors are the same. void simplifyLatchBranch(const FusionCandidate &FC) const { @@ -1123,7 +1102,7 @@ private: /// Move instructions from FC0.Latch to FC1.Latch. If FC0.Latch has an unique /// successor, then merge FC0.Latch with its unique successor. void mergeLatch(const FusionCandidate &FC0, const FusionCandidate &FC1) { - moveInstsBottomUp(*FC0.Latch, *FC1.Latch, DT, PDT, DI); + moveInstructionsToTheBeginning(*FC0.Latch, *FC1.Latch, DT, PDT, DI); if (BasicBlock *Succ = FC0.Latch->getUniqueSuccessor()) { MergeBlockIntoPredecessor(Succ, &DTU, &LI); DTU.flush(); @@ -1166,6 +1145,10 @@ private: LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump(); dbgs() << "Fusion Candidate 1: \n"; FC1.dump();); + // Move instructions from the preheader of FC1 to the end of the preheader + // of FC0. + moveInstructionsToTheEnd(*FC1.Preheader, *FC0.Preheader, DT, PDT, DI); + // Fusing guarded loops is handled slightly differently than non-guarded // loops and has been broken out into a separate method instead of trying to // intersperse the logic within a single method. @@ -1382,6 +1365,14 @@ private: BasicBlock *FC0NonLoopBlock = FC0.getNonLoopBlock(); BasicBlock *FC1NonLoopBlock = FC1.getNonLoopBlock(); + // Move instructions from the exit block of FC0 to the beginning of the exit + // block of FC1. + moveInstructionsToTheBeginning(*FC0.ExitBlock, *FC1.ExitBlock, DT, PDT, DI); + + // Move instructions from the guard block of FC1 to the end of the guard + // block of FC0. + moveInstructionsToTheEnd(*FC1GuardBlock, *FC0GuardBlock, DT, PDT, DI); + assert(FC0NonLoopBlock == FC1GuardBlock && "Loops are not adjacent"); SmallVector<DominatorTree::UpdateType, 8> TreeUpdates; @@ -1394,6 +1385,7 @@ private: // Thus, one path from the guard goes to the preheader for FC0 (and thus // executes the new fused loop) and the other path goes to the NonLoopBlock // for FC1 (where FC1 guard would have gone if FC1 was not executed). + FC1NonLoopBlock->replacePhiUsesWith(FC1GuardBlock, FC0GuardBlock); FC0.GuardBranch->replaceUsesOfWith(FC0NonLoopBlock, FC1NonLoopBlock); FC0.ExitBlock->getTerminator()->replaceUsesOfWith(FC1GuardBlock, FC1.Header); @@ -1545,7 +1537,10 @@ private: // Update DT/PDT DTU.applyUpdates(TreeUpdates); + LI.removeBlock(FC1GuardBlock); LI.removeBlock(FC1.Preheader); + LI.removeBlock(FC0.ExitBlock); + DTU.deleteBB(FC1GuardBlock); DTU.deleteBB(FC1.Preheader); DTU.deleteBB(FC0.ExitBlock); DTU.flush(); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index b77843d7cd71..3cb4df12e9b0 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -51,9 +51,11 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -91,6 +93,7 @@ #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -123,15 +126,19 @@ class LoopIdiomRecognize { const DataLayout *DL; OptimizationRemarkEmitter &ORE; bool ApplyCodeSizeHeuristics; + std::unique_ptr<MemorySSAUpdater> MSSAU; public: explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, + const TargetTransformInfo *TTI, MemorySSA *MSSA, const DataLayout *DL, OptimizationRemarkEmitter &ORE) - : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {} + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) { + if (MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); + } bool runOnLoop(Loop *L); @@ -224,13 +231,17 @@ public: &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *L->getHeader()->getParent()); const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + MemorySSA *MSSA = nullptr; + if (MSSAAnalysis) + MSSA = &MSSAAnalysis->getMSSA(); // For the old PM, we can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); - LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE); + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, MSSA, DL, ORE); return LIR.runOnLoop(L); } @@ -239,6 +250,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); } }; @@ -252,23 +264,20 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, LPMUpdater &) { const auto *DL = &L.getHeader()->getModule()->getDataLayout(); - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); - Function *F = L.getHeader()->getParent(); - - auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); - // FIXME: This should probably be optional rather than required. - if (!ORE) - report_fatal_error( - "LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached " - "at a higher level"); + // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); - LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL, - *ORE); + LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, + AR.MSSA, DL, ORE); if (!LIR.runOnLoop(&L)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", @@ -339,14 +348,14 @@ bool LoopIdiomRecognize::runOnCountableLoop() { << "] Countable Loop %" << CurLoop->getHeader()->getName() << "\n"); - bool MadeChange = false; - // The following transforms hoist stores/memsets into the loop pre-header. - // Give up if the loop has instructions may throw. + // Give up if the loop has instructions that may throw. SimpleLoopSafetyInfo SafetyInfo; SafetyInfo.computeLoopSafetyInfo(CurLoop); if (SafetyInfo.anyBlockMayThrow()) - return MadeChange; + return false; + + bool MadeChange = false; // Scan all the blocks in the loop that are not in subloops. for (auto *BB : CurLoop->getBlocks()) { @@ -968,11 +977,17 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); } + NewCall->setDebugLoc(TheStore->getDebugLoc()); + + if (MSSAU) { + MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( + NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator); + MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); + } LLVM_DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n" << " from store to: " << *Ev << " at: " << *TheStore << "\n"); - NewCall->setDebugLoc(TheStore->getDebugLoc()); ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStridedStore", @@ -984,12 +999,40 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, the memset has been formed. Zap the original store and anything that // feeds into it. - for (auto *I : Stores) + for (auto *I : Stores) { + if (MSSAU) + MSSAU->removeMemoryAccess(I, true); deleteDeadInstruction(I); + } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumMemSet; return true; } +class ExpandedValuesCleaner { + SCEVExpander &Expander; + TargetLibraryInfo *TLI; + SmallVector<Value *, 4> ExpandedValues; + bool Commit = false; + +public: + ExpandedValuesCleaner(SCEVExpander &Expander, TargetLibraryInfo *TLI) + : Expander(Expander), TLI(TLI) {} + + void add(Value *V) { ExpandedValues.push_back(V); } + + void commit() { Commit = true; } + + ~ExpandedValuesCleaner() { + if (!Commit) { + Expander.clear(); + for (auto *V : ExpandedValues) + RecursivelyDeleteTriviallyDeadInstructions(V, TLI); + } + } +}; + /// If the stored value is a strided load in the same loop with the same stride /// this may be transformable into a memcpy. This kicks in for stuff like /// for (i) A[i] = B[i]; @@ -1020,6 +1063,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, IRBuilder<> Builder(Preheader->getTerminator()); SCEVExpander Expander(*SE, *DL, "loop-idiom"); + ExpandedValuesCleaner EVC(Expander, TLI); + const SCEV *StrStart = StoreEv->getStart(); unsigned StrAS = SI->getPointerAddressSpace(); Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS)); @@ -1036,16 +1081,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // checking everything. Value *StoreBasePtr = Expander.expandCodeFor( StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); + EVC.add(StoreBasePtr); SmallPtrSet<Instruction *, 1> Stores; Stores.insert(SI); if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores)) { - Expander.clear(); - // If we generated new code for the base pointer, clean up. - RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); + StoreSize, *AA, Stores)) return false; - } const SCEV *LdStart = LoadEv->getStart(); unsigned LdAS = LI->getPointerAddressSpace(); @@ -1058,15 +1100,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // mutated by the loop. Value *LoadBasePtr = Expander.expandCodeFor( LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); + EVC.add(LoadBasePtr); if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, - StoreSize, *AA, Stores)) { - Expander.clear(); - // If we generated new code for the base pointer, clean up. - RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); - RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); + StoreSize, *AA, Stores)) return false; - } if (avoidLIRForMultiBlockLoop()) return false; @@ -1078,6 +1116,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); + EVC.add(NumBytes); CallInst *NewCall = nullptr; // Check whether to generate an unordered atomic memcpy: @@ -1089,8 +1128,9 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, else { // We cannot allow unaligned ops for unordered load/store, so reject // anything where the alignment isn't at least the element size. - unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); - if (Align < StoreSize) + const Align StoreAlign = SI->getAlign(); + const Align LoadAlign = LI->getAlign(); + if (StoreAlign < StoreSize || LoadAlign < StoreSize) return false; // If the element.atomic memcpy is not lowered into explicit @@ -1104,11 +1144,17 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Note that unordered atomic loads/stores are *required* by the spec to // have an alignment but non-atomic loads/stores may not. NewCall = Builder.CreateElementUnorderedAtomicMemCpy( - StoreBasePtr, SI->getAlignment(), LoadBasePtr, LI->getAlignment(), - NumBytes, StoreSize); + StoreBasePtr, StoreAlign, LoadBasePtr, LoadAlign, NumBytes, + StoreSize); } NewCall->setDebugLoc(SI->getDebugLoc()); + if (MSSAU) { + MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( + NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator); + MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); + } + LLVM_DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n" << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" << " from store ptr=" << *StoreEv << " at: " << *SI @@ -1124,8 +1170,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Okay, the memcpy has been formed. Zap the original store and anything that // feeds into it. + if (MSSAU) + MSSAU->removeMemoryAccess(SI, true); deleteDeadInstruction(SI); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumMemCpy; + EVC.commit(); return true; } @@ -1502,18 +1553,20 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() { // %inc = add nsw %i.0, 1 // br i1 %tobool - const Value *Args[] = - {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) - : ConstantInt::getFalse(InitX->getContext())}; + const Value *Args[] = { + InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) + : ConstantInt::getFalse(InitX->getContext())}; // @llvm.dbg doesn't count as they have no semantic effect. auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug(); uint32_t HeaderSize = std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); + IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args); + int Cost = + TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency); if (HeaderSize != IdiomCanonicalSize && - TTI->getIntrinsicCost(IntrinID, InitX->getType(), Args) > - TargetTransformInfo::TCC_Basic) + Cost > TargetTransformInfo::TCC_Basic) return false; transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 901204181a7c..3153a8721193 100644 --- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -68,7 +68,7 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, // While simplifying we may discover dead code or cause code to become dead. // Keep track of all such instructions and we will delete them at the end. - SmallVector<Instruction *, 8> DeadInsts; + SmallVector<WeakTrackingVH, 8> DeadInsts; // First we want to create an RPO traversal of the loop body. By processing in // RPO we can ensure that definitions are processed prior to uses (for non PHI diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 6ce2d06058cf..7787c0bccd4c 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -412,7 +412,6 @@ public: private: bool adjustLoopLinks(); - void adjustLoopPreheaders(); bool adjustLoopBranches(); Loop *OuterLoop; @@ -580,6 +579,12 @@ struct LoopInterchange : public LoopPass { LIT.transform(); LLVM_DEBUG(dbgs() << "Loops interchanged.\n"); LoopsInterchanged++; + + assert(InnerLoop->isLCSSAForm(*DT) && + "Inner loop not left in LCSSA form after loop interchange!"); + assert(OuterLoop->isLCSSAForm(*DT) && + "Outer loop not left in LCSSA form after loop interchange!"); + return true; } }; @@ -689,7 +694,7 @@ bool LoopInterchangeLegality::findInductionAndReductions( // PHIs in inner loops need to be part of a reduction in the outer loop, // discovered when checking the PHIs of the outer loop earlier. if (!InnerLoop) { - if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) { + if (!OuterInnerReductions.count(&PHI)) { LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions " "across the outer loop.\n"); return false; @@ -903,8 +908,8 @@ areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL, return false; if (any_of(PHI.users(), [&Reductions, OuterL](User *U) { PHINode *PN = dyn_cast<PHINode>(U); - return !PN || (Reductions.find(PN) == Reductions.end() && - OuterL->contains(PN->getParent())); + return !PN || + (!Reductions.count(PN) && OuterL->contains(PN->getParent())); })) { return false; } @@ -1319,6 +1324,23 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { FromBB->getTerminator()->getIterator()); } +/// Swap instructions between \p BB1 and \p BB2 but keep terminators intact. +static void swapBBContents(BasicBlock *BB1, BasicBlock *BB2) { + // Save all non-terminator instructions of BB1 into TempInstrs and unlink them + // from BB1 afterwards. + auto Iter = map_range(*BB1, [](Instruction &I) { return &I; }); + SmallVector<Instruction *, 4> TempInstrs(Iter.begin(), std::prev(Iter.end())); + for (Instruction *I : TempInstrs) + I->removeFromParent(); + + // Move instructions from BB2 to BB1. + moveBBContents(BB2, BB1->getTerminator()); + + // Move instructions from TempInstrs to BB2. + for (Instruction *I : TempInstrs) + I->insertBefore(BB2->getTerminator()); +} + // Update BI to jump to NewBB instead of OldBB. Records updates to the // dominator tree in DTUpdates. If \p MustUpdateOnce is true, assert that // \p OldBB is exactly once in BI's successor list. @@ -1560,13 +1582,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() { // outer loop and all the remains to do is and updating the incoming blocks. for (PHINode *PHI : OuterLoopPHIs) { PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); - assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && - "Expected a reduction PHI node"); + assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node"); } for (PHINode *PHI : InnerLoopPHIs) { PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); - assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && - "Expected a reduction PHI node"); + assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node"); } // Update the incoming blocks for moved PHI nodes. @@ -1578,30 +1598,17 @@ bool LoopInterchangeTransform::adjustLoopBranches() { return true; } -void LoopInterchangeTransform::adjustLoopPreheaders() { - // We have interchanged the preheaders so we need to interchange the data in - // the preheader as well. - // This is because the content of inner preheader was previously executed - // inside the outer loop. - BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); - BranchInst *InnerTermBI = - cast<BranchInst>(InnerLoopPreHeader->getTerminator()); - - // These instructions should now be executed inside the loop. - // Move instruction into a new block after outer header. - moveBBContents(InnerLoopPreHeader, OuterLoopHeader->getTerminator()); - // These instructions were not executed previously in the loop so move them to - // the older inner loop preheader. - moveBBContents(OuterLoopPreHeader, InnerTermBI); -} - bool LoopInterchangeTransform::adjustLoopLinks() { // Adjust all branches in the inner and outer loop. bool Changed = adjustLoopBranches(); - if (Changed) - adjustLoopPreheaders(); + if (Changed) { + // We have interchanged the preheaders so we need to interchange the data in + // the preheaders as well. This is because the content of the inner + // preheader was previously executed inside the outer loop. + BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + swapBBContents(OuterLoopPreHeader, InnerLoopPreHeader); + } return Changed; } diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 4e1b4e87ebc9..4412b3079461 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -38,7 +38,6 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -58,6 +57,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SizeOpts.h" #include <algorithm> #include <cassert> @@ -377,7 +377,7 @@ public: /// Determine the pointer alias checks to prove that there are no /// intervening stores. - SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks( + SmallVector<RuntimePointerCheck, 4> collectMemchecks( const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) { SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath = @@ -391,10 +391,10 @@ public: std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); - SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; + SmallVector<RuntimePointerCheck, 4> Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (auto PtrIdx1 : Check.first->Members) for (auto PtrIdx2 : Check.second->Members) if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, @@ -432,12 +432,12 @@ public: Value *Ptr = Cand.Load->getPointerOperand(); auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(Ptr)); auto *PH = L->getLoopPreheader(); + assert(PH && "Preheader should exist!"); Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(), PH->getTerminator()); Value *Initial = new LoadInst( Cand.Load->getType(), InitialPtr, "load_initial", - /* isVolatile */ false, MaybeAlign(Cand.Load->getAlignment()), - PH->getTerminator()); + /* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator()); PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); @@ -520,8 +520,7 @@ public: // Check intervening may-alias stores. These need runtime checks for alias // disambiguation. - SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks = - collectMemchecks(Candidates); + SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates); // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { @@ -535,6 +534,11 @@ public: return false; } + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form"); + return false; + } + if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { if (LAI.hasConvergentOp()) { LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with " @@ -554,11 +558,6 @@ public: return false; } - if (!L->isLoopSimplifyForm()) { - LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form"); - return false; - } - // Point of no-return, start the transformation. First, version the loop // if necessary. @@ -697,8 +696,8 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); - auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; MemorySSA *MSSA = EnableMSSALoopDependency diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index f3bfbd3564ab..98889a9df116 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Support/TimeProfiler.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Analysis/LoopInfo.h" @@ -33,15 +34,19 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, // instrumenting callbacks for the passes later. PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR); for (auto &Pass : Passes) { - if (DebugLogging) - dbgs() << "Running pass: " << Pass->name() << " on " << L; - // Check the PassInstrumentation's BeforePass callbacks before running the // pass, skip its execution completely if asked to (callback returns false). if (!PI.runBeforePass<Loop>(*Pass, L)) continue; - PreservedAnalyses PassPA = Pass->run(L, AM, AR, U); + if (DebugLogging) + dbgs() << "Running pass: " << Pass->name() << " on " << L; + + PreservedAnalyses PassPA; + { + TimeTraceScope TimeScope(Pass->name(), L.getName()); + PassPA = Pass->run(L, AM, AR, U); + } // do not pass deleted Loop into the instrumentation if (U.skipCurrentLoop()) diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 1a42f6b23443..edde22d6708f 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -184,7 +184,6 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" @@ -199,6 +198,7 @@ #include "llvm/Transforms/Utils/GuardUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #define DEBUG_TYPE "loop-predication" @@ -268,7 +268,7 @@ class LoopPredication { /// Return an insertion point suitable for inserting a safe to speculate /// instruction whose only user will be 'User' which has operands 'Ops'. A /// trivial result would be the at the User itself, but we try to return a - /// loop invariant location if possible. + /// loop invariant location if possible. Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); /// Same as above, *except* that this uses the SCEV definition of invariant /// which is that an expression *can be made* invariant via SCEVExpander. @@ -278,7 +278,7 @@ class LoopPredication { /// Return true if the value is known to produce a single fixed value across /// all iterations on which it executes. Note that this does not imply - /// speculation safety. That must be established seperately. + /// speculation safety. That must be established separately. bool isLoopInvariantValue(const SCEV* S); Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, @@ -342,7 +342,7 @@ public: }; char LoopPredicationLegacyPass::ID = 0; -} // end namespace llvm +} // end namespace INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) @@ -358,11 +358,12 @@ Pass *llvm::createLoopPredicationPass() { PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); - auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); - LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, BPI); + // For the new PM, we also can't use BranchProbabilityInfo as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but BPI is not preserved, hence a newly built one is needed. + BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI); + LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -397,7 +398,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) { } Value *LoopPredication::expandCheck(SCEVExpander &Expander, - Instruction *Guard, + Instruction *Guard, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { Type *Ty = LHS->getType(); @@ -521,7 +522,7 @@ Instruction *LoopPredication::findInsertPt(Instruction *Use, return Preheader->getTerminator(); } -bool LoopPredication::isLoopInvariantValue(const SCEV* S) { +bool LoopPredication::isLoopInvariantValue(const SCEV* S) { // Handling expressions which produce invariant results, but *haven't* yet // been removed from the loop serves two important purposes. // 1) Most importantly, it resolves a pass ordering cycle which would @@ -534,12 +535,12 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) { // much more obviously in the IR. Otherwise, the cost modeling for other // transforms would end up needing to duplicate all of this logic to model a // check which becomes predictable based on a modeled peel or unswitch. - // + // // The cost of doing so in the worst case is an extra fill from the stack in // the loop to materialize the loop invariant test value instead of checking // against the original IV which is presumable in a register inside the loop. // Such cases are presumably rare, and hint at missing oppurtunities for - // other passes. + // other passes. if (SE->isLoopInvariant(S, L)) // Note: This the SCEV variant, so the original Value* may be within the @@ -547,7 +548,7 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) { return true; // Handle a particular important case which SCEV doesn't yet know about which - // shows up in range checks on arrays with immutable lengths. + // shows up in range checks on arrays with immutable lengths. // TODO: This should be sunk inside SCEV. if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) @@ -574,7 +575,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( const SCEV *LatchLimit = LatchCheck.Limit; // Subtlety: We need all the values to be *invariant* across all iterations, // but we only need to check expansion safety for those which *aren't* - // already guaranteed to dominate the guard. + // already guaranteed to dominate the guard. if (!isLoopInvariantValue(GuardStart) || !isLoopInvariantValue(GuardLimit) || !isLoopInvariantValue(LatchStart) || @@ -598,7 +599,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); - + auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS); auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, @@ -617,7 +618,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( const SCEV *LatchLimit = LatchCheck.Limit; // Subtlety: We need all the values to be *invariant* across all iterations, // but we only need to check expansion safety for those which *aren't* - // already guaranteed to dominate the guard. + // already guaranteed to dominate the guard. if (!isLoopInvariantValue(GuardStart) || !isLoopInvariantValue(GuardLimit) || !isLoopInvariantValue(LatchStart) || @@ -658,7 +659,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp& RC) { // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the - // ULT/UGE form for ease of handling by our caller. + // ULT/UGE form for ease of handling by our caller. if (ICmpInst::isEquality(RC.Pred) && RC.IV->getStepRecurrence(*SE)->isOne() && SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit)) @@ -1020,17 +1021,6 @@ static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, return SE.getUMinFromMismatchedTypes(ExitCounts); } -/// Return true if we can be fairly sure that executing block BB will probably -/// lead to executing an __llvm_deoptimize. This is a profitability heuristic, -/// not a legality constraint. -static bool isVeryLikelyToDeopt(BasicBlock *BB) { - while (BB->getUniqueSuccessor()) - // Will skip side effects, that's okay - BB = BB->getUniqueSuccessor(); - - return BB->getTerminatingDeoptimizeCall(); -} - /// This implements an analogous, but entirely distinct transform from the main /// loop predication transform. This one is phrased in terms of using a /// widenable branch *outside* the loop to allow us to simplify loop exits in a @@ -1054,7 +1044,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // inserting a branch on the value which can be either poison or undef. In // this case, the branch can legally go either way; we just need to avoid // introducing UB. This is achieved through the use of the freeze - // instruction. + // instruction. SmallVector<BasicBlock *, 16> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); @@ -1082,7 +1072,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // analyzeable after dropping widenability. { bool Invalidate = false; - + for (auto *ExitingBB : ExitingBlocks) { if (LI->getLoopFor(ExitingBB) != L) continue; @@ -1150,10 +1140,13 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1); - if (!isVeryLikelyToDeopt(ExitBB)) - // Profitability: indicator of rarely/never taken exit + if (!ExitBB->getPostdominatingDeoptimizeCall()) continue; + /// Here we can be fairly sure that executing this exit will most likely + /// lead to executing llvm.experimental.deoptimize. + /// This is a profitability heuristic, not a legality constraint. + // If we found a widenable exit condition, do two things: // 1) fold the widened exit test into the widenable condition // 2) fold the branch to untaken - avoids infinite looping diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index da13a342ae12..3542d0a4ee73 100644 --- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -24,7 +24,6 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -55,6 +54,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <cassert> #include <cstddef> #include <cstdint> @@ -880,6 +880,12 @@ bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) { if (DRS.Roots.empty()) return false; + // If the value of the base instruction is used outside the loop, we cannot + // reroll the loop. Check for other root instructions is unnecessary because + // they don't match any base instructions if their values are used outside. + if (hasUsesOutsideLoop(DRS.BaseInst, L)) + return false; + // Consider a DAGRootSet with N-1 roots (so N different values including // BaseInst). // Define d = Roots[0] - BaseInst, which should be the same as @@ -1126,7 +1132,7 @@ static bool isIgnorableInst(const Instruction *I) { case Intrinsic::annotation: case Intrinsic::ptr_annotation: case Intrinsic::var_annotation: - // TODO: the following intrinsics may also be whitelisted: + // TODO: the following intrinsics may also be allowed: // lifetime_start, lifetime_end, invariant_start, invariant_end return true; } diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 0868e742f4ee..f92566ba77ce 100644 --- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -81,10 +81,8 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); + if (EnableMSSALoopDependency) AU.addPreserved<MemorySSAWrapperPass>(); - } getLoopAnalysisUsage(AU); } @@ -101,15 +99,18 @@ public: const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); Optional<MemorySSAUpdater> MSSAU; if (EnableMSSALoopDependency) { - MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = MemorySSAUpdater(MSSA); + // Not requiring MemorySSA and getting it only if available will split + // the loop pass pipeline when LoopRotate is being run first. + auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + if (MSSAA) + MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); } return LoopRotation(L, LI, TTI, AC, &DT, &SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ, false, MaxHeaderSize, false); } }; -} +} // end namespace char LoopRotateLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index b27e65e0adb7..031e5b9c1d2c 100644 --- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -30,6 +31,7 @@ #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" @@ -673,13 +675,13 @@ static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE, MemorySSAUpdater *MSSAU, - bool &isLoopDeleted) { + bool &IsLoopDeleted) { bool Changed = false; // Constant-fold terminators with known constant conditions. - Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, isLoopDeleted); + Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, IsLoopDeleted); - if (isLoopDeleted) + if (IsLoopDeleted) return true; // Eliminate unconditional branches by merging blocks into their predecessors. @@ -752,7 +754,7 @@ public: getLoopAnalysisUsage(AU); } }; -} +} // end namespace char LoopSimplifyCFGLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index e9f368628a08..cf02ef1e83f3 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -65,12 +65,14 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/IVUsers.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -109,6 +111,7 @@ #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -807,9 +810,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI, switch (II->getIntrinsicID()) { case Intrinsic::memset: case Intrinsic::prefetch: + case Intrinsic::masked_load: if (II->getArgOperand(0) == OperandVal) isAddress = true; break; + case Intrinsic::masked_store: + if (II->getArgOperand(1) == OperandVal) + isAddress = true; + break; case Intrinsic::memmove: case Intrinsic::memcpy: if (II->getArgOperand(0) == OperandVal || @@ -859,6 +867,15 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, AccessTy.AddrSpace = OperandVal->getType()->getPointerAddressSpace(); AccessTy.MemTy = OperandVal->getType(); break; + case Intrinsic::masked_load: + AccessTy.AddrSpace = + II->getArgOperand(0)->getType()->getPointerAddressSpace(); + break; + case Intrinsic::masked_store: + AccessTy.MemTy = II->getOperand(0)->getType(); + AccessTy.AddrSpace = + II->getArgOperand(1)->getType()->getPointerAddressSpace(); + break; default: { MemIntrinsicInfo IntrInfo; if (TTI.getTgtMemIntrinsic(II, IntrInfo) && IntrInfo.PtrVal) { @@ -962,33 +979,6 @@ static bool isHighCostExpansion(const SCEV *S, return true; } -/// If any of the instructions in the specified set are trivially dead, delete -/// them and see if this makes any of their operands subsequently dead. -static bool -DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) { - bool Changed = false; - - while (!DeadInsts.empty()) { - Value *V = DeadInsts.pop_back_val(); - Instruction *I = dyn_cast_or_null<Instruction>(V); - - if (!I || !isInstructionTriviallyDead(I)) - continue; - - for (Use &O : I->operands()) - if (Instruction *U = dyn_cast<Instruction>(O)) { - O = nullptr; - if (U->use_empty()) - DeadInsts.emplace_back(U); - } - - I->eraseFromParent(); - Changed = true; - } - - return Changed; -} - namespace { class LSRUse; @@ -1242,7 +1232,7 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg, // for now LSR only handles innermost loops). if (AR->getLoop() != L) { // If the AddRec exists, consider it's register free and leave it alone. - if (isExistingPhi(AR, *SE)) + if (isExistingPhi(AR, *SE) && !TTI->shouldFavorPostInc()) return; // It is bad to allow LSR for current loop to add induction variables @@ -1913,9 +1903,10 @@ class LSRInstance { DominatorTree &DT; LoopInfo &LI; AssumptionCache &AC; - TargetLibraryInfo &LibInfo; + TargetLibraryInfo &TLI; const TargetTransformInfo &TTI; Loop *const L; + MemorySSAUpdater *MSSAU; bool FavorBackedgeIndex = false; bool Changed = false; @@ -2018,6 +2009,7 @@ class LSRInstance { void NarrowSearchSpaceByCollapsingUnrolledCode(); void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); void NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + void NarrowSearchSpaceByFilterPostInc(); void NarrowSearchSpaceByDeletingCostlyFormulas(); void NarrowSearchSpaceByPickingWinnerRegs(); void NarrowSearchSpaceUsingHeuristics(); @@ -2053,7 +2045,7 @@ class LSRInstance { public: LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC, - TargetLibraryInfo &LibInfo); + TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU); bool getChanged() const { return Changed; } @@ -2830,9 +2822,10 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, /// increments can be computed in fewer registers when chained. /// /// TODO: Consider IVInc free if it's already used in another chains. -static bool -isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, - ScalarEvolution &SE) { +static bool isProfitableChain(IVChain &Chain, + SmallPtrSetImpl<Instruction *> &Users, + ScalarEvolution &SE, + const TargetTransformInfo &TTI) { if (StressIVChain) return true; @@ -2861,7 +2854,14 @@ isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, unsigned NumConstIncrements = 0; unsigned NumVarIncrements = 0; unsigned NumReusedIncrements = 0; + + if (TTI.isProfitableLSRChainElement(Chain.Incs[0].UserInst)) + return true; + for (const IVInc &Inc : Chain) { + if (TTI.isProfitableLSRChainElement(Inc.UserInst)) + return true; + if (Inc.IncExpr->isZero()) continue; @@ -3092,7 +3092,7 @@ void LSRInstance::CollectChains() { for (unsigned UsersIdx = 0, NChains = IVChainVec.size(); UsersIdx < NChains; ++UsersIdx) { if (!isProfitableChain(IVChainVec[UsersIdx], - ChainUsersVec[UsersIdx].FarUsers, SE)) + ChainUsersVec[UsersIdx].FarUsers, SE, TTI)) continue; // Preserve the chain at UsesIdx. if (ChainIdx != UsersIdx) @@ -3212,7 +3212,8 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, IVOper = Builder.CreateTruncOrBitCast(IVOper, OperTy, "lsr.chain"); } Inc.UserInst->replaceUsesOfWith(Inc.IVOperand, IVOper); - DeadInsts.emplace_back(Inc.IVOperand); + if (auto *OperandIsInstr = dyn_cast<Instruction>(Inc.IVOperand)) + DeadInsts.emplace_back(OperandIsInstr); } // If LSR created a new, wider phi, we may also replace its postinc. We only // do this if we also found a wide value for the head of the chain. @@ -3240,7 +3241,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, void LSRInstance::CollectFixupsAndInitialFormulae() { BranchInst *ExitBranch = nullptr; - bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &LibInfo); + bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &TLI); for (const IVStrideUse &U : IU) { Instruction *UserInst = U.getUser(); @@ -3553,9 +3554,6 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI, const SCEV *LoopStep = AR->getStepRecurrence(SE); if (!isa<SCEVConstant>(LoopStep)) return false; - if (LU.AccessTy.getType()->getScalarSizeInBits() != - LoopStep->getType()->getScalarSizeInBits()) - return false; // Check if a post-indexed load/store can be used. if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) || TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) { @@ -4673,6 +4671,54 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { }); } +/// If we are over the complexity limit, filter out any post-inc prefering +/// variables to only post-inc values. +void LSRInstance::NarrowSearchSpaceByFilterPostInc() { + if (!TTI.shouldFavorPostInc()) + return; + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + + LLVM_DEBUG(dbgs() << "The search space is too complex.\n" + "Narrowing the search space by choosing the lowest " + "register Formula for PostInc Uses.\n"); + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + + if (LU.Kind != LSRUse::Address) + continue; + if (!TTI.isIndexedLoadLegal(TTI.MIM_PostInc, LU.AccessTy.getType()) && + !TTI.isIndexedStoreLegal(TTI.MIM_PostInc, LU.AccessTy.getType())) + continue; + + size_t MinRegs = std::numeric_limits<size_t>::max(); + for (const Formula &F : LU.Formulae) + MinRegs = std::min(F.getNumRegs(), MinRegs); + + bool Any = false; + for (size_t FIdx = 0, NumForms = LU.Formulae.size(); FIdx != NumForms; + ++FIdx) { + Formula &F = LU.Formulae[FIdx]; + if (F.getNumRegs() > MinRegs) { + LLVM_DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n"); + LU.DeleteFormula(F); + --FIdx; + --NumForms; + Any = true; + } + } + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + break; + } + + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); +} + /// The function delete formulas with high registers number expectation. /// Assuming we don't know the value of each formula (already delete /// all inefficient), generate probability of not selecting for each @@ -4883,6 +4929,7 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() { NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); if (FilterSameScaledReg) NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + NarrowSearchSpaceByFilterPostInc(); if (LSRExpNarrow) NarrowSearchSpaceByDeletingCostlyFormulas(); else @@ -4923,19 +4970,24 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, // Ignore formulae which may not be ideal in terms of register reuse of // ReqRegs. The formula should use all required registers before // introducing new ones. - int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); - for (const SCEV *Reg : ReqRegs) { - if ((F.ScaledReg && F.ScaledReg == Reg) || - is_contained(F.BaseRegs, Reg)) { - --NumReqRegsToFind; - if (NumReqRegsToFind == 0) - break; + // This can sometimes (notably when trying to favour postinc) lead to + // sub-optimial decisions. There it is best left to the cost modelling to + // get correct. + if (!TTI.shouldFavorPostInc() || LU.Kind != LSRUse::Address) { + int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); + for (const SCEV *Reg : ReqRegs) { + if ((F.ScaledReg && F.ScaledReg == Reg) || + is_contained(F.BaseRegs, Reg)) { + --NumReqRegsToFind; + if (NumReqRegsToFind == 0) + break; + } + } + if (NumReqRegsToFind != 0) { + // If none of the formulae satisfied the required registers, then we could + // clear ReqRegs and try again. Currently, we simply give up in this case. + continue; } - } - if (NumReqRegsToFind != 0) { - // If none of the formulae satisfied the required registers, then we could - // clear ReqRegs and try again. Currently, we simply give up in this case. - continue; } // Evaluate the cost of the current formula. If it's already worse than @@ -5268,7 +5320,8 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, // form, update the ICmp's other operand. if (LU.Kind == LSRUse::ICmpZero) { ICmpInst *CI = cast<ICmpInst>(LF.UserInst); - DeadInsts.emplace_back(CI->getOperand(1)); + if (auto *OperandIsInstr = dyn_cast<Instruction>(CI->getOperand(1))) + DeadInsts.emplace_back(OperandIsInstr); assert(!F.BaseGV && "ICmp does not support folding a global value and " "a scale at the same time!"); if (F.Scale == -1) { @@ -5449,7 +5502,8 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, LF.UserInst->replaceUsesOfWith(LF.OperandValToReplace, FullV); } - DeadInsts.emplace_back(LF.OperandValToReplace); + if (auto *OperandIsInstr = dyn_cast<Instruction>(LF.OperandValToReplace)) + DeadInsts.emplace_back(OperandIsInstr); } /// Rewrite all the fixup locations with new values, following the chosen @@ -5490,16 +5544,17 @@ void LSRInstance::ImplementSolution( // instructions. Rewriter.clear(); - Changed |= DeleteTriviallyDeadInstructions(DeadInsts); + Changed |= RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, + &TLI, MSSAU); } LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC, - TargetLibraryInfo &LibInfo) - : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), LibInfo(LibInfo), TTI(TTI), L(L), - FavorBackedgeIndex(EnableBackedgeIndexing && - TTI.shouldFavorBackedgeIndex(L)) { + TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU) + : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), TLI(TLI), TTI(TTI), L(L), + MSSAU(MSSAU), FavorBackedgeIndex(EnableBackedgeIndexing && + TTI.shouldFavorBackedgeIndex(L)) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) return; @@ -5702,21 +5757,26 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<IVUsersWrapperPass>(); AU.addPreserved<IVUsersWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, - AssumptionCache &AC, - TargetLibraryInfo &LibInfo) { + AssumptionCache &AC, TargetLibraryInfo &TLI, + MemorySSA *MSSA) { bool Changed = false; + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); // Run the main LSR transformation. - Changed |= LSRInstance(L, IU, SE, DT, LI, TTI, AC, LibInfo).getChanged(); + Changed |= + LSRInstance(L, IU, SE, DT, LI, TTI, AC, TLI, MSSAU.get()).getChanged(); // Remove any extra phis created by processing inner loops. - Changed |= DeleteDeadPHIs(L->getHeader()); + Changed |= DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); if (EnablePhiElim && L->isLoopSimplifyForm()) { SmallVector<WeakTrackingVH, 16> DeadInsts; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); @@ -5727,8 +5787,9 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI); if (numFolded) { Changed = true; - DeleteTriviallyDeadInstructions(DeadInsts); - DeleteDeadPHIs(L->getHeader()); + RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI, + MSSAU.get()); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); } } return Changed; @@ -5746,19 +5807,26 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { *L->getHeader()->getParent()); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( *L->getHeader()->getParent()); - auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( *L->getHeader()->getParent()); - return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo); + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + MemorySSA *MSSA = nullptr; + if (MSSAAnalysis) + MSSA = &MSSAAnalysis->getMSSA(); + return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, TLI, MSSA); } PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE, - AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI)) + AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI, AR.MSSA)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } char LoopStrengthReduce::ID = 0; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 92ad8dafa5ab..285cba6ee205 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -11,8 +11,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" @@ -20,37 +22,36 @@ #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" -#include <algorithm> #include <cassert> #include <cstdint> -#include <string> +#include <vector> + +namespace llvm { +class Instruction; +class Value; +} // namespace llvm using namespace llvm; @@ -91,7 +92,7 @@ static cl::opt<unsigned> PragmaUnrollAndJamThreshold( // Returns the loop hint metadata node with the given name (for example, // "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is // returned. -static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { +static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) { if (MDNode *LoopID = L->getLoopID()) return GetUnrollMetadata(LoopID, Name); return nullptr; @@ -99,14 +100,14 @@ static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { // Returns true if the loop has any metadata starting with Prefix. For example a // Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata. -static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) { +static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) { if (MDNode *LoopID = L->getLoopID()) { // First operand should refer to the loop id itself. assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); - for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(I)); if (!MD) continue; @@ -122,14 +123,14 @@ static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) { } // Returns true if the loop has an unroll_and_jam(enable) pragma. -static bool HasUnrollAndJamEnablePragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable"); +static bool hasUnrollAndJamEnablePragma(const Loop *L) { + return getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable"); } // If loop has an unroll_and_jam_count pragma return the (necessarily // positive) value from the pragma. Otherwise return 0. -static unsigned UnrollAndJamCountPragmaValue(const Loop *L) { - MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count"); +static unsigned unrollAndJamCountPragmaValue(const Loop *L) { + MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count"); if (MD) { assert(MD->getNumOperands() == 2 && "Unroll count hint metadata should have two operands."); @@ -157,7 +158,8 @@ static bool computeUnrollAndJamCount( const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned OuterTripCount, unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, - unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) { + unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP, + TargetTransformInfo::PeelingPreferences &PP) { // First up use computeUnrollCount from the loop unroller to get a count // for unrolling the outer loop, plus any loops requiring explicit // unrolling we leave to the unroller. This uses UP.Threshold / @@ -167,7 +169,8 @@ static bool computeUnrollAndJamCount( bool UseUpperBound = false; bool ExplicitUnroll = computeUnrollCount( L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, - /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); + /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP, + UseUpperBound); if (ExplicitUnroll || UseUpperBound) { // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it // for the unroller instead. @@ -190,7 +193,7 @@ static bool computeUnrollAndJamCount( } // Check for unroll_and_jam pragmas - unsigned PragmaCount = UnrollAndJamCountPragmaValue(L); + unsigned PragmaCount = unrollAndJamCountPragmaValue(L); if (PragmaCount > 0) { UP.Count = PragmaCount; UP.Runtime = true; @@ -202,7 +205,7 @@ static bool computeUnrollAndJamCount( return true; } - bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L); + bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L); bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount; bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount; @@ -279,24 +282,11 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, DependenceInfo &DI, OptimizationRemarkEmitter &ORE, int OptLevel) { - // Quick checks of the correct loop form - if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1) - return LoopUnrollResult::Unmodified; - Loop *SubLoop = L->getSubLoops()[0]; - if (!SubLoop->isLoopSimplifyForm()) - return LoopUnrollResult::Unmodified; - - BasicBlock *Latch = L->getLoopLatch(); - BasicBlock *Exit = L->getExitingBlock(); - BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); - BasicBlock *SubLoopExit = SubLoop->getExitingBlock(); - - if (Latch != Exit || SubLoopLatch != SubLoopExit) - return LoopUnrollResult::Unmodified; - TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None, - None, None, None, None, None, None, None); + None, None, None, None, None); + TargetTransformInfo::PeelingPreferences PP = + gatherPeelingPreferences(L, SE, TTI, None, None); if (AllowUnrollAndJam.getNumOccurrences() > 0) UP.UnrollAndJam = AllowUnrollAndJam; if (UnrollAndJamThreshold.getNumOccurrences() > 0) @@ -317,13 +307,13 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // the unroller, so long as it does not explicitly have unroll_and_jam // metadata. This means #pragma nounroll will disable unroll and jam as well // as unrolling - if (HasAnyUnrollPragma(L, "llvm.loop.unroll.") && - !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) { + if (hasAnyUnrollPragma(L, "llvm.loop.unroll.") && + !hasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) { LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n"); return LoopUnrollResult::Unmodified; } - if (!isSafeToUnrollAndJam(L, SE, DT, DI)) { + if (!isSafeToUnrollAndJam(L, SE, DT, DI, *LI)) { LLVM_DEBUG(dbgs() << " Disabled due to not being safe.\n"); return LoopUnrollResult::Unmodified; } @@ -334,6 +324,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, bool Convergent; SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, &AC, EphValues); + Loop *SubLoop = L->getSubLoops()[0]; unsigned InnerLoopSize = ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable, Convergent, TTI, EphValues, UP.BEInsns); @@ -371,6 +362,8 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, SubLoop->setLoopID(NewInnerEpilogueLoopID.getValue()); // Find trip count and trip multiple + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch); unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch); unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch); @@ -378,7 +371,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Decide if, and by how much, to unroll bool IsCountSetExplicitly = computeUnrollAndJamCount( L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount, - OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP); + OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP); if (UP.Count <= 1) return LoopUnrollResult::Unmodified; // Unroll factor (Count) must be less or equal to TripCount. @@ -388,7 +381,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, Loop *EpilogueOuterLoop = nullptr; LoopUnrollResult UnrollResult = UnrollAndJamLoop( L, UP.Count, OuterTripCount, OuterTripMultiple, UP.UnrollRemainder, LI, - &SE, &DT, &AC, &ORE, &EpilogueOuterLoop); + &SE, &DT, &AC, &TTI, &ORE, &EpilogueOuterLoop); // Assign new loop attributes. if (EpilogueOuterLoop) { @@ -435,22 +428,23 @@ static bool tryToUnrollAndJamLoop(Function &F, DominatorTree &DT, LoopInfo &LI, int OptLevel) { bool DidSomething = false; - // The loop unroll and jam pass requires loops to be in simplified form, and also needs LCSSA. - // Since simplification may add new inner loops, it has to run before the - // legality and profitability checks. This means running the loop unroll and jam pass - // will simplify all loops, regardless of whether anything end up being - // unroll and jammed. + // The loop unroll and jam pass requires loops to be in simplified form, and + // also needs LCSSA. Since simplification may add new inner loops, it has to + // run before the legality and profitability checks. This means running the + // loop unroll and jam pass will simplify all loops, regardless of whether + // anything end up being unroll and jammed. for (auto &L : LI) { DidSomething |= simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */); DidSomething |= formLCSSARecursively(*L, DT, &LI, &SE); } + // Add the loop nests in the reverse order of LoopInfo. See method + // declaration. SmallPriorityWorklist<Loop *, 4> Worklist; - internal::appendLoopsToWorklist(reverse(LI), Worklist); + appendLoopsToWorklist(LI, Worklist); while (!Worklist.empty()) { Loop *L = Worklist.pop_back_val(); - formLCSSA(*L, DT, &LI, &SE); LoopUnrollResult Result = tryToUnrollAndJamLoop(L, DT, &LI, SE, TTI, AC, DI, ORE, OptLevel); if (Result != LoopUnrollResult::Unmodified) diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 4c2b079c6bb5..87f40bb7ba85 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -154,6 +154,10 @@ static cl::opt<bool> cl::desc("Allows loops to be peeled when the dynamic " "trip count is known to be low.")); +static cl::opt<bool> UnrollAllowLoopNestsPeeling( + "unroll-allow-loop-nests-peeling", cl::init(false), cl::Hidden, + cl::desc("Allows loop nests to be peeled.")); + static cl::opt<bool> UnrollUnrollRemainder( "unroll-remainder", cl::Hidden, cl::desc("Allow the loop remainder to be unrolled.")); @@ -167,6 +171,16 @@ static cl::opt<bool> UnrollRevisitChildLoops( "This shouldn't typically be needed as child loops (or their " "clones) were already visited.")); +static cl::opt<unsigned> UnrollThresholdAggressive( + "unroll-threshold-aggressive", cl::init(300), cl::Hidden, + cl::desc("Threshold (max size of unrolled loop) to use in aggressive (O3) " + "optimizations")); +static cl::opt<unsigned> + UnrollThresholdDefault("unroll-threshold-default", cl::init(150), + cl::Hidden, + cl::desc("Default threshold (max size of unrolled " + "loop), used in all but O3 optimizations")); + /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much /// code expansion would result. @@ -179,19 +193,17 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, - Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling, - Optional<bool> UserAllowProfileBasedPeeling, - Optional<unsigned> UserFullUnrollMaxCount) { + Optional<bool> UserUpperBound, Optional<unsigned> UserFullUnrollMaxCount) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults - UP.Threshold = OptLevel > 2 ? 300 : 150; + UP.Threshold = + OptLevel > 2 ? UnrollThresholdAggressive : UnrollThresholdDefault; UP.MaxPercentThresholdBoost = 400; UP.OptSizeThreshold = 0; UP.PartialThreshold = 150; UP.PartialOptSizeThreshold = 0; UP.Count = 0; - UP.PeelCount = 0; UP.DefaultUnrollRuntimeCount = 8; UP.MaxCount = std::numeric_limits<unsigned>::max(); UP.FullUnrollMaxCount = std::numeric_limits<unsigned>::max(); @@ -203,10 +215,9 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.AllowExpensiveTripCount = false; UP.Force = false; UP.UpperBound = false; - UP.AllowPeeling = true; UP.UnrollAndJam = false; - UP.PeelProfiledIterations = true; UP.UnrollAndJamInnerLoopThreshold = 60; + UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze; // Override with any target specific settings TTI.getUnrollingPreferences(L, SE, UP); @@ -232,8 +243,6 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.MaxCount = UnrollMaxCount; if (UnrollFullMaxCount.getNumOccurrences() > 0) UP.FullUnrollMaxCount = UnrollFullMaxCount; - if (UnrollPeelCount.getNumOccurrences() > 0) - UP.PeelCount = UnrollPeelCount; if (UnrollAllowPartial.getNumOccurrences() > 0) UP.Partial = UnrollAllowPartial; if (UnrollAllowRemainder.getNumOccurrences() > 0) @@ -242,10 +251,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.Runtime = UnrollRuntime; if (UnrollMaxUpperBound == 0) UP.UpperBound = false; - if (UnrollAllowPeeling.getNumOccurrences() > 0) - UP.AllowPeeling = UnrollAllowPeeling; if (UnrollUnrollRemainder.getNumOccurrences() > 0) UP.UnrollRemainder = UnrollUnrollRemainder; + if (UnrollMaxIterationsCountToAnalyze.getNumOccurrences() > 0) + UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze; // Apply user values provided by argument if (UserThreshold.hasValue()) { @@ -260,16 +269,45 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.Runtime = *UserRuntime; if (UserUpperBound.hasValue()) UP.UpperBound = *UserUpperBound; - if (UserAllowPeeling.hasValue()) - UP.AllowPeeling = *UserAllowPeeling; - if (UserAllowProfileBasedPeeling.hasValue()) - UP.PeelProfiledIterations = *UserAllowProfileBasedPeeling; if (UserFullUnrollMaxCount.hasValue()) UP.FullUnrollMaxCount = *UserFullUnrollMaxCount; return UP; } +TargetTransformInfo::PeelingPreferences +llvm::gatherPeelingPreferences(Loop *L, ScalarEvolution &SE, + const TargetTransformInfo &TTI, + Optional<bool> UserAllowPeeling, + Optional<bool> UserAllowProfileBasedPeeling) { + TargetTransformInfo::PeelingPreferences PP; + + // Default values + PP.PeelCount = 0; + PP.AllowPeeling = true; + PP.AllowLoopNestsPeeling = false; + PP.PeelProfiledIterations = true; + + // Get Target Specifc Values + TTI.getPeelingPreferences(L, SE, PP); + + // User Specified Values using cl::opt + if (UnrollPeelCount.getNumOccurrences() > 0) + PP.PeelCount = UnrollPeelCount; + if (UnrollAllowPeeling.getNumOccurrences() > 0) + PP.AllowPeeling = UnrollAllowPeeling; + if (UnrollAllowLoopNestsPeeling.getNumOccurrences() > 0) + PP.AllowLoopNestsPeeling = UnrollAllowLoopNestsPeeling; + + // User Specifed values provided by argument + if (UserAllowPeeling.hasValue()) + PP.AllowPeeling = *UserAllowPeeling; + if (UserAllowProfileBasedPeeling.hasValue()) + PP.PeelProfiledIterations = *UserAllowProfileBasedPeeling; + + return PP; +} + namespace { /// A struct to densely store the state of an instruction after unrolling at @@ -335,11 +373,12 @@ struct EstimatedUnrollCost { static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, - const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize) { + const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize, + unsigned MaxIterationsCountToAnalyze) { // We want to be able to scale offsets by the trip count and add more offsets // to them without checking for overflows, and we already don't want to // analyze *massive* trip counts, so we force the max to be reasonably small. - assert(UnrollMaxIterationsCountToAnalyze < + assert(MaxIterationsCountToAnalyze < (unsigned)(std::numeric_limits<int>::max() / 2) && "The unroll iterations max is too large!"); @@ -349,8 +388,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( return None; // Don't simulate loops with a big or unknown tripcount - if (!UnrollMaxIterationsCountToAnalyze || !TripCount || - TripCount > UnrollMaxIterationsCountToAnalyze) + if (!TripCount || TripCount > MaxIterationsCountToAnalyze) return None; SmallSetVector<BasicBlock *, 16> BBWorklist; @@ -428,7 +466,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // First accumulate the cost of this instruction. if (!Cost.IsFree) { - UnrolledCost += TTI.getUserCost(I); + UnrolledCost += TTI.getUserCost(I, TargetTransformInfo::TCK_CodeSize); LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration << "): "); LLVM_DEBUG(I->dump()); @@ -521,7 +559,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // Track this instruction's expected baseline cost when executing the // rolled loop form. - RolledDynamicCost += TTI.getUserCost(&I); + RolledDynamicCost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize); // Visit the instruction to analyze its loop cost after unrolling, // and if the visitor returns true, mark the instruction as free after @@ -665,32 +703,32 @@ unsigned llvm::ApproximateLoopSize( // Returns the loop hint metadata node with the given name (for example, // "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is // returned. -static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { +static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) { if (MDNode *LoopID = L->getLoopID()) return GetUnrollMetadata(LoopID, Name); return nullptr; } // Returns true if the loop has an unroll(full) pragma. -static bool HasUnrollFullPragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.full"); +static bool hasUnrollFullPragma(const Loop *L) { + return getUnrollMetadataForLoop(L, "llvm.loop.unroll.full"); } // Returns true if the loop has an unroll(enable) pragma. This metadata is used // for both "#pragma unroll" and "#pragma clang loop unroll(enable)" directives. -static bool HasUnrollEnablePragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.enable"); +static bool hasUnrollEnablePragma(const Loop *L) { + return getUnrollMetadataForLoop(L, "llvm.loop.unroll.enable"); } // Returns true if the loop has an runtime unroll(disable) pragma. -static bool HasRuntimeUnrollDisablePragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable"); +static bool hasRuntimeUnrollDisablePragma(const Loop *L) { + return getUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable"); } // If loop has an unroll_count pragma return the (necessarily // positive) value from the pragma. Otherwise return 0. -static unsigned UnrollCountPragmaValue(const Loop *L) { - MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll.count"); +static unsigned unrollCountPragmaValue(const Loop *L) { + MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll.count"); if (MD) { assert(MD->getNumOperands() == 2 && "Unroll count hint metadata should have two operands."); @@ -740,7 +778,8 @@ bool llvm::computeUnrollCount( ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount, bool MaxOrZero, unsigned &TripMultiple, unsigned LoopSize, - TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { + TargetTransformInfo::UnrollingPreferences &UP, + TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) { // Check for explicit Count. // 1st priority is unroll count set by "unroll-count" option. @@ -754,7 +793,7 @@ bool llvm::computeUnrollCount( } // 2nd priority is unroll count set by pragma. - unsigned PragmaCount = UnrollCountPragmaValue(L); + unsigned PragmaCount = unrollCountPragmaValue(L); if (PragmaCount > 0) { UP.Count = PragmaCount; UP.Runtime = true; @@ -764,14 +803,14 @@ bool llvm::computeUnrollCount( getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) return true; } - bool PragmaFullUnroll = HasUnrollFullPragma(L); + bool PragmaFullUnroll = hasUnrollFullPragma(L); if (PragmaFullUnroll && TripCount != 0) { UP.Count = TripCount; if (getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) return false; } - bool PragmaEnableUnroll = HasUnrollEnablePragma(L); + bool PragmaEnableUnroll = hasUnrollEnablePragma(L); bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || PragmaEnableUnroll || UserUnrollCount; @@ -827,7 +866,8 @@ bool llvm::computeUnrollCount( // To check that, run additional analysis on the loop. if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( L, FullUnrollTripCount, DT, SE, EphValues, TTI, - UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { + UP.Threshold * UP.MaxPercentThresholdBoost / 100, + UP.MaxIterationsCountToAnalyze)) { unsigned Boost = getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { @@ -841,8 +881,8 @@ bool llvm::computeUnrollCount( } // 4th priority is loop peeling. - computePeelCount(L, LoopSize, UP, TripCount, SE); - if (UP.PeelCount) { + computePeelCount(L, LoopSize, UP, PP, TripCount, SE); + if (PP.PeelCount) { UP.Runtime = false; UP.Count = 1; return ExplicitUnroll; @@ -925,7 +965,7 @@ bool llvm::computeUnrollCount( // 6th priority is runtime unrolling. // Don't unroll a runtime trip count loop when it is disabled. - if (HasRuntimeUnrollDisablePragma(L)) { + if (hasRuntimeUnrollDisablePragma(L)) { UP.Count = 0; return false; } @@ -1045,8 +1085,9 @@ static LoopUnrollResult tryToUnrollLoop( TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, - ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling, ProvidedFullUnrollMaxCount); + TargetTransformInfo::PeelingPreferences PP = gatherPeelingPreferences( + L, SE, TTI, ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling); // Exit early if unrolling is disabled. For OptForSize, we pick the loop size // as threshold later on. @@ -1120,7 +1161,7 @@ static LoopUnrollResult tryToUnrollLoop( bool UseUpperBound = false; bool IsCountSetExplicitly = computeUnrollCount( L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, - TripMultiple, LoopSize, UP, UseUpperBound); + TripMultiple, LoopSize, UP, PP, UseUpperBound); if (!UP.Count) return LoopUnrollResult::Unmodified; // Unroll factor (Count) must be less or equal to TripCount. @@ -1135,9 +1176,9 @@ static LoopUnrollResult tryToUnrollLoop( LoopUnrollResult UnrollResult = UnrollLoop( L, {UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, - UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, + UseUpperBound, MaxOrZero, TripMultiple, PP.PeelCount, UP.UnrollRemainder, ForgetAllSCEV}, - LI, &SE, &DT, &AC, &ORE, PreserveLCSSA, &RemainderLoop); + LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop); if (UnrollResult == LoopUnrollResult::Unmodified) return LoopUnrollResult::Unmodified; @@ -1167,7 +1208,7 @@ static LoopUnrollResult tryToUnrollLoop( // If the loop was peeled, we already "used up" the profile information // we had, so we don't want to unroll or peel again. if (UnrollResult != LoopUnrollResult::FullyUnrolled && - (IsCountSetExplicitly || (UP.PeelProfiledIterations && UP.PeelCount))) + (IsCountSetExplicitly || (PP.PeelProfiledIterations && PP.PeelCount))) L->setLoopAlreadyUnrolled(); return UnrollResult; @@ -1296,16 +1337,10 @@ Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced, PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &Updater) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); - Function *F = L.getHeader()->getParent(); - - auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); - // FIXME: This should probably be optional rather than required. - if (!ORE) - report_fatal_error( - "LoopFullUnrollPass: OptimizationRemarkEmitterAnalysis not " - "cached at a higher level"); + // For the new PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); // Keep track of the previous loop structure so we can identify new loops // created by unrolling. @@ -1316,9 +1351,9 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, else OldLoops.insert(AR.LI.begin(), AR.LI.end()); - std::string LoopName = L.getName(); + std::string LoopName = std::string(L.getName()); - bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE, /*BFI*/ nullptr, /*PSI*/ nullptr, /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, ForgetSCEV, /*Count*/ None, @@ -1384,30 +1419,6 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, return getLoopPassPreservedAnalyses(); } -template <typename RangeT> -static SmallVector<Loop *, 8> appendLoopsToWorklist(RangeT &&Loops) { - SmallVector<Loop *, 8> Worklist; - // We use an internal worklist to build up the preorder traversal without - // recursion. - SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; - - for (Loop *RootL : Loops) { - assert(PreOrderLoops.empty() && "Must start with an empty preorder walk."); - assert(PreOrderWorklist.empty() && - "Must start with an empty preorder walk worklist."); - PreOrderWorklist.push_back(RootL); - do { - Loop *L = PreOrderWorklist.pop_back_val(); - PreOrderWorklist.append(L->begin(), L->end()); - PreOrderLoops.push_back(L); - } while (!PreOrderWorklist.empty()); - - Worklist.append(PreOrderLoops.begin(), PreOrderLoops.end()); - PreOrderLoops.clear(); - } - return Worklist; -} - PreservedAnalyses LoopUnrollPass::run(Function &F, FunctionAnalysisManager &AM) { auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); @@ -1421,10 +1432,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F)) LAM = &LAMProxy->getManager(); - const ModuleAnalysisManager &MAM = - AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); ProfileSummaryInfo *PSI = - MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; @@ -1441,7 +1451,10 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, Changed |= formLCSSARecursively(*L, DT, &LI, &SE); } - SmallVector<Loop *, 8> Worklist = appendLoopsToWorklist(LI); + // Add the loop nests in the reverse order of LoopInfo. See method + // declaration. + SmallPriorityWorklist<Loop *, 4> Worklist; + appendLoopsToWorklist(LI, Worklist); while (!Worklist.empty()) { // Because the LoopInfo stores the loops in RPO, we walk the worklist @@ -1459,7 +1472,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, Optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling; if (PSI && PSI->hasHugeWorkingSetSize()) LocalAllowPeeling = false; - std::string LoopName = L.getName(); + std::string LoopName = std::string(L.getName()); // The API here is quite complex to call and we allow to select some // flavors of unrolling during construction time (by setting UnrollOpts). LoopUnrollResult Result = tryToUnrollLoop( diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 915e053704b2..645a89bbd0ff 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -38,11 +38,11 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" @@ -158,7 +158,7 @@ namespace { // Returns true if another unswitching could be done within the cost // threshold. - bool CostAllowsUnswitching(); + bool costAllowsUnswitching(); // Clone all loop-unswitch related loop properties. // Redistribute unswitching quotas. @@ -173,20 +173,20 @@ namespace { AssumptionCache *AC; // Used to check if second loop needs processing after - // RewriteLoopBodyWithConditionConstant rewrites first loop. + // rewriteLoopBodyWithConditionConstant rewrites first loop. std::vector<Loop*> LoopProcessWorklist; LUAnalysisCache BranchesInfo; bool OptimizeForSize; - bool redoLoop = false; + bool RedoLoop = false; - Loop *currentLoop = nullptr; + Loop *CurrentLoop = nullptr; DominatorTree *DT = nullptr; MemorySSA *MSSA = nullptr; std::unique_ptr<MemorySSAUpdater> MSSAU; - BasicBlock *loopHeader = nullptr; - BasicBlock *loopPreheader = nullptr; + BasicBlock *LoopHeader = nullptr; + BasicBlock *LoopPreheader = nullptr; bool SanitizeMemory; SimpleLoopSafetyInfo SafetyInfo; @@ -198,15 +198,15 @@ namespace { // NewBlocks contained cloned copy of basic blocks from LoopBlocks. std::vector<BasicBlock*> NewBlocks; - bool hasBranchDivergence; + bool HasBranchDivergence; public: static char ID; // Pass ID, replacement for typeid - explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) + explicit LoopUnswitch(bool Os = false, bool HasBranchDivergence = false) : LoopPass(ID), OptimizeForSize(Os), - hasBranchDivergence(hasBranchDivergence) { - initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); + HasBranchDivergence(HasBranchDivergence) { + initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); } bool runOnLoop(Loop *L, LPPassManager &LPM) override; @@ -223,48 +223,46 @@ namespace { AU.addRequired<MemorySSAWrapperPass>(); AU.addPreserved<MemorySSAWrapperPass>(); } - if (hasBranchDivergence) + if (HasBranchDivergence) AU.addRequired<LegacyDivergenceAnalysis>(); getLoopAnalysisUsage(AU); } private: - void releaseMemory() override { - BranchesInfo.forgetLoop(currentLoop); - } + void releaseMemory() override { BranchesInfo.forgetLoop(CurrentLoop); } void initLoopData() { - loopHeader = currentLoop->getHeader(); - loopPreheader = currentLoop->getLoopPreheader(); + LoopHeader = CurrentLoop->getHeader(); + LoopPreheader = CurrentLoop->getLoopPreheader(); } /// Split all of the edges from inside the loop to their exit blocks. /// Update the appropriate Phi nodes as we do so. - void SplitExitEdges(Loop *L, + void splitExitEdges(Loop *L, const SmallVectorImpl<BasicBlock *> &ExitBlocks); - bool TryTrivialLoopUnswitch(bool &Changed); + bool tryTrivialLoopUnswitch(bool &Changed); - bool UnswitchIfProfitable(Value *LoopCond, Constant *Val, + bool unswitchIfProfitable(Value *LoopCond, Constant *Val, Instruction *TI = nullptr); - void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, + void unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, BasicBlock *ExitBlock, Instruction *TI); - void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, + void unswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, Instruction *TI); - void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, - Constant *Val, bool isEqual); + void rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, bool IsEqual); - void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + void emitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, BranchInst *OldBranch, Instruction *TI); - void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + void simplifyCode(std::vector<Instruction *> &Worklist, Loop *L); /// Given that the Invariant is not equal to Val. Simplify instructions /// in the loop. - Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, + Value *simplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, Constant *Val); }; @@ -347,7 +345,7 @@ bool LUAnalysisCache::isUnswitched(const SwitchInst *SI, const Value *V) { return (*CurLoopInstructions)[SI].count(V); } -bool LUAnalysisCache::CostAllowsUnswitching() { +bool LUAnalysisCache::costAllowsUnswitching() { return CurrentLoopProperties->CanBeUnswitchedCount > 0; } @@ -396,8 +394,8 @@ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) -Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { - return new LoopUnswitch(Os, hasBranchDivergence); +Pass *llvm::createLoopUnswitchPass(bool Os, bool HasBranchDivergence) { + return new LoopUnswitch(Os, HasBranchDivergence); } /// Operator chain lattice. @@ -411,15 +409,15 @@ enum OperatorChain { /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. // -/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a -/// mixed operator chain, as we can not reliably find a value which will simplify -/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 -/// to simplify the chain. +/// NOTE: findLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will +/// simplify the operator chain. If the chain is AND-only or OR-only, we can use +/// 0 or ~0 to simplify the chain. /// /// NOTE: In case a partial LIV and a mixed operator chain, we may be able to /// simplify the condition itself to a loop variant condition, but at the /// cost of creating an entirely new loop. -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, +static Value *findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, OperatorChain &ParentChain, DenseMap<Value *, Value *> &Cache, MemorySSAUpdater *MSSAU) { @@ -479,7 +477,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, // If either the left or right side is invariant, we can unswitch on this, // which will cause the branch to go away in one loop and the condition to // simplify in the other one. - if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + if (Value *LHS = findLIVLoopCondition(BO->getOperand(0), L, Changed, ParentChain, Cache, MSSAU)) { Cache[Cond] = LHS; return LHS; @@ -487,7 +485,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, // We did not manage to find a partial LIV in operand(0). Backtrack and try // operand(1). ParentChain = NewChain; - if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + if (Value *RHS = findLIVLoopCondition(BO->getOperand(1), L, Changed, ParentChain, Cache, MSSAU)) { Cache[Cond] = RHS; return RHS; @@ -503,11 +501,11 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, /// an invariant piece, return the invariant along with the operator chain type. /// Otherwise, return null. static std::pair<Value *, OperatorChain> -FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, +findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, MemorySSAUpdater *MSSAU) { DenseMap<Value *, Value *> Cache; OperatorChain OpChain = OC_OpChainNone; - Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU); + Value *FCond = findLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU); // In case we do find a LIV, it can not be obtained by walking up a mixed // operator chain. @@ -516,22 +514,22 @@ FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return {FCond, OpChain}; } -bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { +bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) { if (skipLoop(L)) return false; AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( *L->getHeader()->getParent()); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LPM = &LPM_Ref; + LPM = &LPMRef; DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); if (EnableMSSALoopDependency) { MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); assert(DT && "Cannot update MemorySSA without a valid DomTree."); } - currentLoop = L; - Function *F = currentLoop->getHeader()->getParent(); + CurrentLoop = L; + Function *F = CurrentLoop->getHeader()->getParent(); SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory); if (SanitizeMemory) @@ -542,12 +540,12 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { bool Changed = false; do { - assert(currentLoop->isLCSSAForm(*DT)); + assert(CurrentLoop->isLCSSAForm(*DT)); if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); - redoLoop = false; + RedoLoop = false; Changed |= processCurrentLoop(); - } while(redoLoop); + } while (RedoLoop); if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); @@ -560,7 +558,7 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { auto *Node = DT->getNode(BB)->getIDom(); BasicBlock *DomBB = Node->getBlock(); - while (currentLoop->contains(DomBB)) { + while (CurrentLoop->contains(DomBB)) { BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator()); Node = DT->getNode(DomBB)->getIDom(); @@ -591,7 +589,7 @@ bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { /// causing problems. Detail could be found in PR31652. Note if the /// func returns true, it is unsafe. But if it is false, it doesn't mean /// it is necessarily safe. -static bool EqualityPropUnSafe(Value &LoopCond) { +static bool equalityPropUnSafe(Value &LoopCond) { ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond); if (!CI || !CI->isEquality()) return false; @@ -601,7 +599,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) { if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return true; - auto hasUndefInPHI = [](PHINode &PN) { + auto HasUndefInPHI = [](PHINode &PN) { for (Value *Opd : PN.incoming_values()) { if (isa<UndefValue>(Opd)) return true; @@ -610,10 +608,10 @@ static bool EqualityPropUnSafe(Value &LoopCond) { }; PHINode *LPHI = dyn_cast<PHINode>(LHS); PHINode *RPHI = dyn_cast<PHINode>(RHS); - if ((LPHI && hasUndefInPHI(*LPHI)) || (RPHI && hasUndefInPHI(*RPHI))) + if ((LPHI && HasUndefInPHI(*LPHI)) || (RPHI && HasUndefInPHI(*RPHI))) return true; - auto hasUndefInSelect = [](SelectInst &SI) { + auto HasUndefInSelect = [](SelectInst &SI) { if (isa<UndefValue>(SI.getTrueValue()) || isa<UndefValue>(SI.getFalseValue())) return true; @@ -621,7 +619,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) { }; SelectInst *LSI = dyn_cast<SelectInst>(LHS); SelectInst *RSI = dyn_cast<SelectInst>(RHS); - if ((LSI && hasUndefInSelect(*LSI)) || (RSI && hasUndefInSelect(*RSI))) + if ((LSI && HasUndefInSelect(*LSI)) || (RSI && HasUndefInSelect(*RSI))) return true; return false; } @@ -633,35 +631,36 @@ bool LoopUnswitch::processCurrentLoop() { initLoopData(); // If LoopSimplify was unable to form a preheader, don't do any unswitching. - if (!loopPreheader) + if (!LoopPreheader) return false; // Loops with indirectbr cannot be cloned. - if (!currentLoop->isSafeToClone()) + if (!CurrentLoop->isSafeToClone()) return false; // Without dedicated exits, splitting the exit edge may fail. - if (!currentLoop->hasDedicatedExits()) + if (!CurrentLoop->hasDedicatedExits()) return false; - LLVMContext &Context = loopHeader->getContext(); + LLVMContext &Context = LoopHeader->getContext(); // Analyze loop cost, and stop unswitching if loop content can not be duplicated. if (!BranchesInfo.countLoop( - currentLoop, getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *currentLoop->getHeader()->getParent()), + CurrentLoop, + getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *CurrentLoop->getHeader()->getParent()), AC)) return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (TryTrivialLoopUnswitch(Changed)) { + if (tryTrivialLoopUnswitch(Changed)) { return true; } // Do not do non-trivial unswitch while optimizing for size. // FIXME: Use Function::hasOptSize(). if (OptimizeForSize || - loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) + LoopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) return false; // Run through the instructions in the loop, keeping track of three things: @@ -680,11 +679,12 @@ bool LoopUnswitch::processCurrentLoop() { SmallVector<IntrinsicInst *, 4> Guards; - for (const auto BB : currentLoop->blocks()) { + for (const auto BB : CurrentLoop->blocks()) { for (auto &I : *BB) { - auto CS = CallSite(&I); - if (!CS) continue; - if (CS.isConvergent()) + auto *CB = dyn_cast<CallBase>(&I); + if (!CB) + continue; + if (CB->isConvergent()) return false; if (auto *II = dyn_cast<InvokeInst>(&I)) if (!II->getUnwindDest()->canSplitPredecessors()) @@ -696,11 +696,11 @@ bool LoopUnswitch::processCurrentLoop() { } for (IntrinsicInst *Guard : Guards) { - Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop, + Value *LoopCond = findLIVLoopCondition(Guard->getOperand(0), CurrentLoop, Changed, MSSAU.get()) .first; if (LoopCond && - UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { + unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the // instructions in Guards leaving dangling pointers there. This is fine // because we're returning now, and won't look at Guards again. @@ -712,8 +712,9 @@ bool LoopUnswitch::processCurrentLoop() { // Loop over all of the basic blocks in the loop. If we find an interior // block that is branching on a loop-invariant condition, we can unswitch this // loop. - for (Loop::block_iterator I = currentLoop->block_begin(), - E = currentLoop->block_end(); I != E; ++I) { + for (Loop::block_iterator I = CurrentLoop->block_begin(), + E = CurrentLoop->block_end(); + I != E; ++I) { Instruction *TI = (*I)->getTerminator(); // Unswitching on a potentially uninitialized predicate is not @@ -723,7 +724,7 @@ bool LoopUnswitch::processCurrentLoop() { // This is a workaround for the discrepancy between LLVM IR and MSan // semantics. See PR28054 for more details. if (SanitizeMemory && - !SafetyInfo.isGuaranteedToExecute(*TI, DT, currentLoop)) + !SafetyInfo.isGuaranteedToExecute(*TI, DT, CurrentLoop)) continue; if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { @@ -738,11 +739,11 @@ bool LoopUnswitch::processCurrentLoop() { if (BI->isConditional()) { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. - Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, + Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop, Changed, MSSAU.get()) .first; - if (LoopCond && !EqualityPropUnSafe(*LoopCond) && - UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { + if (LoopCond && !equalityPropUnSafe(*LoopCond) && + unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; return true; } @@ -752,7 +753,7 @@ bool LoopUnswitch::processCurrentLoop() { Value *LoopCond; OperatorChain OpChain; std::tie(LoopCond, OpChain) = - FindLIVLoopCondition(SC, currentLoop, Changed, MSSAU.get()); + findLIVLoopCondition(SC, CurrentLoop, Changed, MSSAU.get()); unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { @@ -796,7 +797,7 @@ bool LoopUnswitch::processCurrentLoop() { if (!UnswitchVal) continue; - if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { + if (unswitchIfProfitable(LoopCond, UnswitchVal)) { ++NumSwitches; // In case of a full LIV, UnswitchVal is the value we unswitched out. // In case of a partial LIV, we only unswitch when its an AND-chain @@ -812,11 +813,11 @@ bool LoopUnswitch::processCurrentLoop() { for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); BBI != E; ++BBI) if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, + Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop, Changed, MSSAU.get()) .first; - if (LoopCond && UnswitchIfProfitable(LoopCond, - ConstantInt::getTrue(Context))) { + if (LoopCond && + unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; return true; } @@ -875,62 +876,38 @@ static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { return nullptr; } -/// We have found that we can unswitch currentLoop when LoopCond == Val to +/// We have found that we can unswitch CurrentLoop when LoopCond == Val to /// simplify the loop. If we decide that this is profitable, /// unswitch the loop, reprocess the pieces, then return true. -bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, +bool LoopUnswitch::unswitchIfProfitable(Value *LoopCond, Constant *Val, Instruction *TI) { // Check to see if it would be profitable to unswitch current loop. - if (!BranchesInfo.CostAllowsUnswitching()) { + if (!BranchesInfo.costAllowsUnswitching()) { LLVM_DEBUG(dbgs() << "NOT unswitching loop %" - << currentLoop->getHeader()->getName() + << CurrentLoop->getHeader()->getName() << " at non-trivial condition '" << *Val << "' == " << *LoopCond << "\n" << ". Cost too high.\n"); return false; } - if (hasBranchDivergence && + if (HasBranchDivergence && getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) { LLVM_DEBUG(dbgs() << "NOT unswitching loop %" - << currentLoop->getHeader()->getName() + << CurrentLoop->getHeader()->getName() << " at non-trivial condition '" << *Val << "' == " << *LoopCond << "\n" << ". Condition is divergent.\n"); return false; } - UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); + unswitchNontrivialCondition(LoopCond, Val, CurrentLoop, TI); return true; } -/// Recursively clone the specified loop and all of its children, -/// mapping the blocks with the specified map. -static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, - LoopInfo *LI, LPPassManager *LPM) { - Loop &New = *LI->AllocateLoop(); - if (PL) - PL->addChildLoop(&New); - else - LI->addTopLevelLoop(&New); - LPM->addLoop(New); - - // Add all of the blocks in L to the new loop. - for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); - I != E; ++I) - if (LI->getLoopFor(*I) == L) - New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI); - - // Add all of the subloops to the new loop. - for (Loop *I : *L) - CloneLoop(I, &New, VM, LI, LPM); - - return &New; -} - /// Emit a conditional branch on two values if LIC == Val, branch to TrueDst, /// otherwise branch to FalseDest. Insert the code immediately before OldBranch /// and remove (but not erase!) it from the function. -void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, +void LoopUnswitch::emitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, BranchInst *OldBranch, @@ -997,11 +974,11 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, /// that doesn't execute its body has no side-effects), unswitch it. This /// doesn't involve any code duplication, just moving the conditional branch /// outside of the loop and updating loop info. -void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, +void LoopUnswitch::unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, BasicBlock *ExitBlock, Instruction *TI) { LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" - << loopHeader->getName() << " [" << L->getBlocks().size() + << LoopHeader->getName() << " [" << L->getBlocks().size() << " blocks] in Function " << L->getHeader()->getParent()->getName() << " on cond: " << *Val << " == " << *Cond << "\n"); @@ -1011,9 +988,9 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, SEWP->getSE().forgetTopmostLoop(L); // First step, split the preheader, so that we know that there is a safe place - // to insert the conditional branch. We will change loopPreheader to have a + // to insert the conditional branch. We will change LoopPreheader to have a // conditional branch on Cond. - BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); + BasicBlock *NewPH = SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get()); // Now that we have a place to insert the conditional branch, create a place // to branch to: this is the exit block out of the loop that we should @@ -1029,22 +1006,21 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, // Okay, now we have a position to branch from and a position to branch to, // insert the new conditional branch. - auto *OldBranch = dyn_cast<BranchInst>(loopPreheader->getTerminator()); + auto *OldBranch = dyn_cast<BranchInst>(LoopPreheader->getTerminator()); assert(OldBranch && "Failed to split the preheader"); - EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); - LPM->deleteSimpleAnalysisValue(OldBranch, L); + emitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); - // EmitPreheaderBranchOnCondition removed the OldBranch from the function. + // emitPreheaderBranchOnCondition removed the OldBranch from the function. // Delete it, as it is no longer needed. delete OldBranch; // We need to reprocess this loop, it could be unswitched again. - redoLoop = true; + RedoLoop = true; // Now that we know that the loop is never entered when this condition is a // particular value, rewrite the loop with this info. We know that this will // at least eliminate the old branch. - RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); + rewriteLoopBodyWithConditionConstant(L, Cond, Val, /*IsEqual=*/false); ++NumTrivial; } @@ -1055,8 +1031,8 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, /// produces no code duplications (equivalently, it produces a simpler loop and /// a new empty loop, which gets deleted). Therefore always unswitch trivial /// condition. -bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { - BasicBlock *CurrentBB = currentLoop->getHeader(); +bool LoopUnswitch::tryTrivialLoopUnswitch(bool &Changed) { + BasicBlock *CurrentBB = CurrentLoop->getHeader(); Instruction *CurrentTerm = CurrentBB->getTerminator(); LLVMContext &Context = CurrentBB->getContext(); @@ -1081,7 +1057,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // we can not reach any trivial condition candidates (unfoldable // branch instructions or switch instructions) and no unswitch // can happen. Exit and return false. - if (!currentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second) + if (!CurrentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second) return false; // Check if this loop will execute any side-effecting instructions (e.g. @@ -1128,7 +1104,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (!BI->isConditional()) return false; - Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, + Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop, Changed, MSSAU.get()) .first; @@ -1141,11 +1117,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // exit through a unique exit block without having any // side-effects. If so, determine the value of Cond that causes // it to do this. - if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, - BI->getSuccessor(0)))) { + if ((LoopExitBB = + isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(0)))) { CondVal = ConstantInt::getTrue(Context); - } else if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, - BI->getSuccessor(1)))) { + } else if ((LoopExitBB = + isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(1)))) { CondVal = ConstantInt::getFalse(Context); } @@ -1154,16 +1130,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) return false; // Can't handle this. - if (EqualityPropUnSafe(*LoopCond)) + if (equalityPropUnSafe(*LoopCond)) return false; - UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, + unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB, CurrentTerm); ++NumBranches; return true; } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, + Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop, Changed, MSSAU.get()) .first; @@ -1181,7 +1157,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { for (auto Case : SI->cases()) { BasicBlock *LoopExitCandidate; if ((LoopExitCandidate = - isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { + isTrivialLoopExitBlock(CurrentLoop, Case.getCaseSuccessor()))) { // Okay, we found a trivial case, remember the value that is trivial. ConstantInt *CaseVal = Case.getCaseValue(); @@ -1200,7 +1176,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) return false; // Can't handle this. - UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, + unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB, nullptr); // We are only unswitching full LIV. @@ -1213,11 +1189,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { /// Split all of the edges from inside the loop to their exit blocks. /// Update the appropriate Phi nodes as we do so. -void LoopUnswitch::SplitExitEdges(Loop *L, - const SmallVectorImpl<BasicBlock *> &ExitBlocks){ +void LoopUnswitch::splitExitEdges( + Loop *L, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { - BasicBlock *ExitBlock = ExitBlocks[i]; + for (unsigned I = 0, E = ExitBlocks.size(); I != E; ++I) { + BasicBlock *ExitBlock = ExitBlocks[I]; SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock), pred_end(ExitBlock)); @@ -1231,11 +1207,11 @@ void LoopUnswitch::SplitExitEdges(Loop *L, /// We determined that the loop is profitable to unswitch when LIC equal Val. /// Split it into loop versions and test the condition outside of either loop. /// Return the loops created as Out1/Out2. -void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, +void LoopUnswitch::unswitchNontrivialCondition(Value *LIC, Constant *Val, Loop *L, Instruction *TI) { - Function *F = loopHeader->getParent(); + Function *F = LoopHeader->getParent(); LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" - << loopHeader->getName() << " [" << L->getBlocks().size() + << LoopHeader->getName() << " [" << L->getBlocks().size() << " blocks] in Function " << F->getName() << " when '" << *Val << "' == " << *LIC << "\n"); @@ -1253,7 +1229,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // First step, split the preheader and exit blocks, and add these blocks to // the LoopBlocks list. BasicBlock *NewPreheader = - SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); + SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get()); LoopBlocks.push_back(NewPreheader); // We want the loop to come after the preheader, but before the exit blocks. @@ -1264,7 +1240,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // Split all of the edges from inside the loop to their exit blocks. Update // the appropriate Phi nodes as we do so. - SplitExitEdges(L, ExitBlocks); + splitExitEdges(L, ExitBlocks); // The exit blocks may have been changed due to edge splitting, recompute. ExitBlocks.clear(); @@ -1278,12 +1254,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // the instructions and blocks. NewBlocks.reserve(LoopBlocks.size()); ValueToValueMapTy VMap; - for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { - BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], VMap, ".us", F); + for (unsigned I = 0, E = LoopBlocks.size(); I != E; ++I) { + BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[I], VMap, ".us", F); NewBlocks.push_back(NewBB); - VMap[LoopBlocks[i]] = NewBB; // Keep the BB mapping. - LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], NewBB, L); + VMap[LoopBlocks[I]] = NewBB; // Keep the BB mapping. } // Splice the newly inserted blocks into the function right before the @@ -1293,7 +1268,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, NewBlocks[0]->getIterator(), F->end()); // Now we create the new Loop object for the versioned loop. - Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM); + Loop *NewLoop = cloneLoop(L, L->getParentLoop(), VMap, LI, LPM); // Recalculate unswitching quota, inherit simplified switches info for NewBB, // Probably clone more loop-unswitch related loop properties. @@ -1306,10 +1281,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI); } - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { - BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[i]]); + for (unsigned EBI = 0, EBE = ExitBlocks.size(); EBI != EBE; ++EBI) { + BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[EBI]]); // The new exit block should be in the same loop as the old one. - if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i])) + if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[EBI])) ExitBBLoop->addBasicBlockToLoop(NewExit, *LI); assert(NewExit->getTerminator()->getNumSuccessors() == 1 && @@ -1319,7 +1294,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // If the successor of the exit block had PHI nodes, add an entry for // NewExit. for (PHINode &PN : ExitSucc->phis()) { - Value *V = PN.getIncomingValueForBlock(ExitBlocks[i]); + Value *V = PN.getIncomingValueForBlock(ExitBlocks[EBI]); ValueToValueMapTy::iterator It = VMap.find(V); if (It != VMap.end()) V = It->second; PN.addIncoming(V, NewExit); @@ -1340,8 +1315,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, } // Rewrite the code to refer to itself. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { - for (Instruction &I : *NewBlocks[i]) { + for (unsigned NBI = 0, NBE = NewBlocks.size(); NBI != NBE; ++NBI) { + for (Instruction &I : *NewBlocks[NBI]) { RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); if (auto *II = dyn_cast<IntrinsicInst>(&I)) @@ -1351,7 +1326,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, } // Rewrite the original preheader to select between versions of the loop. - BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); + BranchInst *OldBR = cast<BranchInst>(LoopPreheader->getTerminator()); assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && "Preheader splitting did not work correctly!"); @@ -1364,9 +1339,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, } // Emit the new branch that selects between the two versions of this loop. - EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, + emitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, TI); - LPM->deleteSimpleAnalysisValue(OldBR, L); if (MSSAU) { // Update MemoryPhis in Exit blocks. MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMap, *DT); @@ -1375,11 +1349,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, } // The OldBr was replaced by a new one and removed (but not erased) by - // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it. + // emitPreheaderBranchOnCondition. It is no longer needed, so delete it. delete OldBR; LoopProcessWorklist.push_back(NewLoop); - redoLoop = true; + RedoLoop = true; // Keep a WeakTrackingVH holding onto LIC. If the first call to // RewriteLoopBody @@ -1390,22 +1364,23 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // Now we rewrite the original code to know that the condition is true and the // new code to know that the condition is false. - RewriteLoopBodyWithConditionConstant(L, LIC, Val, false); + rewriteLoopBodyWithConditionConstant(L, LIC, Val, /*IsEqual=*/false); // It's possible that simplifying one loop could cause the other to be // changed to another value or a constant. If its a constant, don't simplify // it. if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop && LICHandle && !isa<Constant>(LICHandle)) - RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true); + rewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, + /*IsEqual=*/true); if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); } /// Remove all instances of I from the worklist vector specified. -static void RemoveFromWorklist(Instruction *I, - std::vector<Instruction*> &Worklist) { +static void removeFromWorklist(Instruction *I, + std::vector<Instruction *> &Worklist) { Worklist.erase(std::remove(Worklist.begin(), Worklist.end(), I), Worklist.end()); @@ -1413,7 +1388,7 @@ static void RemoveFromWorklist(Instruction *I, /// When we find that I really equals V, remove I from the /// program, replacing all uses with V and update the worklist. -static void ReplaceUsesOfWith(Instruction *I, Value *V, +static void replaceUsesOfWith(Instruction *I, Value *V, std::vector<Instruction *> &Worklist, Loop *L, LPPassManager *LPM, MemorySSAUpdater *MSSAU) { LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); @@ -1426,8 +1401,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V, // Add users to the worklist which may be simplified now. for (User *U : I->users()) Worklist.push_back(cast<Instruction>(U)); - LPM->deleteSimpleAnalysisValue(I, L); - RemoveFromWorklist(I, Worklist); + removeFromWorklist(I, Worklist); I->replaceAllUsesWith(V); if (!I->mayHaveSideEffects()) { if (MSSAU) @@ -1440,7 +1414,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V, /// We know either that the value LIC has the value specified by Val in the /// specified loop, or we know it does NOT have that value. /// Rewrite any uses of LIC or of properties correlated to it. -void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, +void LoopUnswitch::rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, Constant *Val, bool IsEqual) { assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); @@ -1478,7 +1452,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, for (Instruction *UI : Worklist) UI->replaceUsesOfWith(LIC, Replacement); - SimplifyCode(Worklist, L); + simplifyCode(Worklist, L); return; } @@ -1492,7 +1466,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, // At this point, we know LIC is definitely not Val. Try to use some simple // logic to simplify the user w.r.t. to the context. - if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { + if (Value *Replacement = simplifyInstructionWithNotEqual(UI, LIC, Val)) { if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { // This in-loop instruction has been simplified w.r.t. its context, // i.e. LIC != Val, make sure we propagate its replacement value to @@ -1506,7 +1480,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, } } - // This is a LIC user, push it into the worklist so that SimplifyCode can + // This is a LIC user, push it into the worklist so that simplifyCode can // attempt to simplify it. Worklist.push_back(UI); @@ -1568,7 +1542,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, DT->addNewBlock(Abort, NewSISucc); } - SimplifyCode(Worklist, L); + simplifyCode(Worklist, L); } /// Now that we have simplified some instructions in the loop, walk over it and @@ -1579,7 +1553,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, /// FIXME: When the loop optimizer is more mature, separate this out to a new /// pass. /// -void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { +void LoopUnswitch::simplifyCode(std::vector<Instruction *> &Worklist, Loop *L) { const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); while (!Worklist.empty()) { Instruction *I = Worklist.back(); @@ -1593,8 +1567,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) Worklist.push_back(Use); - LPM->deleteSimpleAnalysisValue(I, L); - RemoveFromWorklist(I, Worklist); + removeFromWorklist(I, Worklist); if (MSSAU) MSSAU->removeMemoryAccess(I); I->eraseFromParent(); @@ -1607,7 +1580,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // 'false'. TODO: update the domtree properly so we can pass it here. if (Value *V = SimplifyInstruction(I, DL)) if (LI->replacementPreservesLCSSAForm(I, V)) { - ReplaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get()); + replaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get()); continue; } @@ -1624,9 +1597,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { assert(SinglePred == Pred && "CFG broken"); // Make the LPM and Worklist updates specific to LoopUnswitch. - LPM->deleteSimpleAnalysisValue(BI, L); - RemoveFromWorklist(BI, Worklist); - LPM->deleteSimpleAnalysisValue(Succ, L); + removeFromWorklist(BI, Worklist); auto SuccIt = Succ->begin(); while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) { for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It) @@ -1634,8 +1605,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { Worklist.push_back(Use); for (User *U : PN->users()) Worklist.push_back(cast<Instruction>(U)); - LPM->deleteSimpleAnalysisValue(PN, L); - RemoveFromWorklist(PN, Worklist); + removeFromWorklist(PN, Worklist); ++NumSimplify; } // Merge the block and make the remaining analyses updates. @@ -1652,7 +1622,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { /// Simple simplifications we can do given the information that Cond is /// definitely not equal to Val. -Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, +Value *LoopUnswitch::simplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, Constant *Val) { // icmp eq cond, val -> false diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 7b9af527d444..06b684ef1e70 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -69,7 +69,6 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" diff --git a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp index ab7b85e89e7b..d1f67b355b19 100644 --- a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp @@ -117,18 +117,17 @@ static bool LowerStoreInst(StoreInst *SI) { static bool runOnBasicBlock(BasicBlock &BB) { bool Changed = false; - for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE;) { - Instruction *Inst = &*DI++; - if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + for (Instruction &Inst : make_early_inc_range(BB)) { + if (FenceInst *FI = dyn_cast<FenceInst>(&Inst)) Changed |= LowerFenceInst(FI); - else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst)) + else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(&Inst)) Changed |= LowerAtomicCmpXchgInst(CXI); - else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) + else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&Inst)) Changed |= LowerAtomicRMWInst(RMWI); - else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + else if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) { if (LI->isAtomic()) LowerLoadInst(LI); - } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + } else if (StoreInst *SI = dyn_cast<StoreInst>(&Inst)) { if (SI->isAtomic()) LowerStoreInst(SI); } diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index 21c6c32e8e02..fddf28c281fc 100644 --- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -13,7 +13,9 @@ #include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -135,8 +137,12 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI) { PreservedAnalyses LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { - if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) - return PreservedAnalyses::none(); + if (lowerConstantIntrinsics(F, + AM.getCachedResult<TargetLibraryAnalysis>(F))) { + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; + } return PreservedAnalyses::all(); } @@ -145,7 +151,7 @@ namespace { /// Legacy pass for lowering is.constant intrinsics out of the IR. /// /// When this pass is run over a function it converts is.constant intrinsics -/// into 'true' or 'false'. This is completements the normal constand folding +/// into 'true' or 'false'. This complements the normal constant folding /// to 'true' as part of Instruction Simplify passes. class LowerConstantIntrinsics : public FunctionPass { public: @@ -159,6 +165,10 @@ public: const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; return lowerConstantIntrinsics(F, TLI); } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + } }; } // namespace diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 53671c7bc3d1..0fe7dd9cfb39 100644 --- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -55,13 +55,35 @@ static cl::opt<uint32_t> UnlikelyBranchWeight( "unlikely-branch-weight", cl::Hidden, cl::init(1), cl::desc("Weight of the branch unlikely to be taken (default = 1)")); +static std::tuple<uint32_t, uint32_t> +getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) { + if (IntrinsicID == Intrinsic::expect) { + // __builtin_expect + return std::make_tuple(LikelyBranchWeight.getValue(), + UnlikelyBranchWeight.getValue()); + } else { + // __builtin_expect_with_probability + assert(CI->getNumOperands() >= 3 && + "expect with probability must have 3 arguments"); + ConstantFP *Confidence = dyn_cast<ConstantFP>(CI->getArgOperand(2)); + double TrueProb = Confidence->getValueAPF().convertToDouble(); + assert((TrueProb >= 0.0 && TrueProb <= 1.0) && + "probability value must be in the range [0.0, 1.0]"); + double FalseProb = (1.0 - TrueProb) / (BranchCount - 1); + uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0); + uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0); + return std::make_tuple(LikelyBW, UnlikelyBW); + } +} + static bool handleSwitchExpect(SwitchInst &SI) { CallInst *CI = dyn_cast<CallInst>(SI.getCondition()); if (!CI) return false; Function *Fn = CI->getCalledFunction(); - if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) + if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && + Fn->getIntrinsicID() != Intrinsic::expect_with_probability)) return false; Value *ArgValue = CI->getArgOperand(0); @@ -71,15 +93,19 @@ static bool handleSwitchExpect(SwitchInst &SI) { SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue); unsigned n = SI.getNumCases(); // +1 for default case. - SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); + uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; + std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = + getBranchWeight(Fn->getIntrinsicID(), CI, n + 1); + + SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal); uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1; - Weights[Index] = LikelyBranchWeight; + Weights[Index] = LikelyBranchWeightVal; - SI.setMetadata( - LLVMContext::MD_misexpect, - MDBuilder(CI->getContext()) - .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight)); + SI.setMetadata(LLVMContext::MD_misexpect, + MDBuilder(CI->getContext()) + .createMisExpect(Index, LikelyBranchWeightVal, + UnlikelyBranchWeightVal)); SI.setCondition(ArgValue); misexpect::checkFrontendInstrumentation(SI); @@ -223,15 +249,18 @@ static void handlePhiDef(CallInst *Expect) { return true; return false; }; + uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; + std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight( + Expect->getCalledFunction()->getIntrinsicID(), Expect, 2); if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) - BI->setMetadata( - LLVMContext::MD_prof, - MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight)); + BI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(LikelyBranchWeightVal, + UnlikelyBranchWeightVal)); else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) - BI->setMetadata( - LLVMContext::MD_prof, - MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight)); + BI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(UnlikelyBranchWeightVal, + LikelyBranchWeightVal)); } } @@ -277,7 +306,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { } Function *Fn = CI->getCalledFunction(); - if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) + if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && + Fn->getIntrinsicID() != Intrinsic::expect_with_probability)) return false; Value *ArgValue = CI->getArgOperand(0); @@ -289,13 +319,21 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { MDNode *Node; MDNode *ExpNode; + uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; + std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = + getBranchWeight(Fn->getIntrinsicID(), CI, 2); + if ((ExpectedValue->getZExtValue() == ValueComparedTo) == (Predicate == CmpInst::ICMP_EQ)) { - Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight); - ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight); + Node = + MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal); + ExpNode = + MDB.createMisExpect(0, LikelyBranchWeightVal, UnlikelyBranchWeightVal); } else { - Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); - ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight); + Node = + MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal); + ExpNode = + MDB.createMisExpect(1, LikelyBranchWeightVal, UnlikelyBranchWeightVal); } BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode); @@ -347,7 +385,8 @@ static bool lowerExpectIntrinsic(Function &F) { } Function *Fn = CI->getCalledFunction(); - if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) { + if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect || + Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) { // Before erasing the llvm.expect, walk backward to find // phi that define llvm.expect's first arg, and // infer branch probability: diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 0ff6ee8bcfcc..90314b17b5e2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -9,8 +9,11 @@ // Lower matrix intrinsics to vector operations. // // TODO: -// * Implement multiply & add fusion -// * Add remark, summarizing the available matrix optimization opportunities. +// * Improve fusion: +// * Support more cases, e.g. multiply-add, multiply-sub, operands/results +// transposed. +// * Improve cost-modeling, e.g. choose different number of rows/columns +// columns for tiles, consider cost of copies on alias. // //===----------------------------------------------------------------------===// @@ -18,10 +21,15 @@ #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -29,30 +37,69 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "lower-matrix-intrinsics" -static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", - cl::init(true)); - +static cl::opt<bool> EnableShapePropagation( + "matrix-propagate-shape", cl::init(true), cl::Hidden, + cl::desc("Enable/disable shape propagation from matrix intrinsics to other " + "instructions.")); + +static cl::opt<bool> + FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, + cl::desc("Enable/disable fusing matrix instructions.")); +// TODO: Allow and use non-square tiles. +static cl::opt<unsigned> TileSize( + "fuse-matrix-tile-size", cl::init(4), cl::Hidden, + cl::desc( + "Tile size for matrix instruction fusion using square-shaped tiles.")); +static cl::opt<bool> ForceFusion( + "force-fuse-matrix", cl::init(false), cl::Hidden, + cl::desc("Force matrix instruction fusion even if not profitable.")); static cl::opt<bool> AllowContractEnabled( "matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error.")); +enum class MatrixLayoutTy { ColumnMajor, RowMajor }; + +static cl::opt<MatrixLayoutTy> MatrixLayout( + "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), + cl::desc("Sets the default matrix layout"), + cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", + "Use column-major layout"), + clEnumValN(MatrixLayoutTy::RowMajor, "row-major", + "Use row-major layout"))); + +/// Helper function to either return Scope, if it is a subprogram or the +/// attached subprogram for a local scope. +static DISubprogram *getSubprogram(DIScope *Scope) { + if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) + return Subprogram; + return cast<DILocalScope>(Scope)->getSubprogram(); +} + namespace { -// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute -// the start address of column \p Col with type (\p EltType x \p NumRows) -// assuming \p Stride elements between start two consecutive columns. -// \p Stride must be >= \p NumRows. +// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute +// the start address of vector \p VecIdx with type (\p EltType x \p NumElements) +// assuming \p Stride elements between start two consecutive vectors. +// \p Stride must be >= \p NumElements. +// For column-major matrixes, the function computes the address of a column +// vectors and \p NumElements must be set to the number of elements in a column +// (= number of rows of the matrix). For row-major matrixes, the function +// computes the address of a row vector and \p NumElements must be set to the +// number of elements in a column (= number of columns of the matrix). // -// Consider a 4x4 matrix like below +// Consider a 4x4 matrix in column-mjaor layout like below // // 0 1 2 3 // 0 v_0_0 v_0_1 v_0_2 v_0_3 @@ -62,14 +109,14 @@ namespace { // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, // we need a pointer to the first element of the submatrix as base pointer. -// Then we can use computeColumnAddr to compute the addresses for the columns +// Then we can use computeVectorAddr to compute the addresses for the columns // of the sub-matrix. // -// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) +// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) // -> just returns Base -// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) +// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) // -> returns Base + (1 * 4) -// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) +// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) // -> returns Base + (2 * 4) // // The graphic below illustrates the number of elements in a column (marked @@ -82,30 +129,30 @@ namespace { // v_2_0 |v_2_1 |v_2_2 |v_2_3 // v_3_0 {v_3_1 {v_3_2 v_3_3 // -Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, - unsigned NumRows, Type *EltType, +Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, + unsigned NumElements, Type *EltType, IRBuilder<> &Builder) { assert((!isa<ConstantInt>(Stride) || - cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && - "Stride must be >= the number of rows."); + cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && + "Stride must be >= the number of elements in the result vector."); unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); - // Compute the start of the column with index Col as Col * Stride. - Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); + // Compute the start of the vector with index VecIdx as VecIdx * Stride. + Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); - // Get pointer to the start of the selected column. Skip GEP creation, - // if we select column 0. - if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) - ColumnStart = BasePtr; + // Get pointer to the start of the selected vector. Skip GEP creation, + // if we select vector 0. + if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) + VecStart = BasePtr; else - ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); + VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); - // Cast elementwise column start pointer to a pointer to a column - // (EltType x NumRows)*. - Type *ColumnType = VectorType::get(EltType, NumRows); - Type *ColumnPtrType = PointerType::get(ColumnType, AS); - return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); + // Cast elementwise vector start pointer to a pointer to a vector + // (EltType x NumElements)*. + auto *VecType = FixedVectorType::get(EltType, NumElements); + Type *VecPtrType = PointerType::get(VecType, AS); + return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -113,15 +160,16 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, /// Currently, the lowering for each matrix intrinsic is done as follows: /// 1. Propagate the shape information from intrinsics to connected /// instructions. -/// 2. Lower instructions with shape information. +/// 2. Lower instructions with shape information (assuming column-major layout). +/// The lowering works similarly using row-major layout. /// 2.1. Get column vectors for each argument. If we already lowered the /// definition of an argument, use the produced column vectors directly. /// If not, split the operand vector containing an embedded matrix into /// a set of column vectors, -/// 2.2. Lower the instruction in terms of columnwise operations, which yields -/// a set of column vectors containing result matrix. Note that we lower -/// all instructions that have shape information. Besides the intrinsics, -/// this includes stores for example. +/// 2.2. Lower the instruction in terms of column major operations, which +/// yields a set of column vectors containing result matrix. Note that we +/// lower all instructions that have shape information. Besides the +/// intrinsics, this includes stores for example. /// 2.3. Update uses of the lowered instruction. If we have shape information /// for a user, there is nothing to do, as we will look up the result /// column matrix when lowering the user. For other uses, we embed the @@ -134,42 +182,157 @@ class LowerMatrixIntrinsics { Function &Func; const DataLayout &DL; const TargetTransformInfo &TTI; + AliasAnalysis &AA; + DominatorTree &DT; + LoopInfo &LI; + OptimizationRemarkEmitter &ORE; + + /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. + struct OpInfoTy { + /// Number of stores emitted to generate this matrix. + unsigned NumStores = 0; + /// Number of loads emitted to generate this matrix. + unsigned NumLoads = 0; + /// Number of compute operations emitted to generate this matrix. + unsigned NumComputeOps = 0; + + OpInfoTy &operator+=(const OpInfoTy &RHS) { + NumStores += RHS.NumStores; + NumLoads += RHS.NumLoads; + NumComputeOps += RHS.NumComputeOps; + return *this; + } + }; + + /// Wrapper class representing a matrix as a set of vectors, either in row or + /// column major layout. All vectors must have the same vector type. + class MatrixTy { + SmallVector<Value *, 16> Vectors; + + OpInfoTy OpInfo; - /// Wrapper class representing a matrix as a set of column vectors. - /// All column vectors must have the same vector type. - class ColumnMatrixTy { - SmallVector<Value *, 16> Columns; + bool IsColumnMajor = true; public: - ColumnMatrixTy() : Columns() {} - ColumnMatrixTy(ArrayRef<Value *> Cols) - : Columns(Cols.begin(), Cols.end()) {} + MatrixTy() + : Vectors(), + IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} + MatrixTy(ArrayRef<Value *> Vectors) + : Vectors(Vectors.begin(), Vectors.end()), + IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} + MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) + : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { + + unsigned D = isColumnMajor() ? NumColumns : NumRows; + for (unsigned J = 0; J < D; ++J) + addVector(UndefValue::get(FixedVectorType::get( + EltTy, isColumnMajor() ? NumRows : NumColumns))); + } + + Value *getVector(unsigned i) const { return Vectors[i]; } + Value *getColumn(unsigned i) const { + assert(isColumnMajor() && "only supported for column-major matrixes"); + return Vectors[i]; + } + Value *getRow(unsigned i) const { + assert(!isColumnMajor() && "only supported for row-major matrixes"); + return Vectors[i]; + } - Value *getColumn(unsigned i) const { return Columns[i]; } + void setVector(unsigned i, Value *V) { Vectors[i] = V; } - void setColumn(unsigned i, Value *V) { Columns[i] = V; } + Type *getElementType() { return getVectorTy()->getElementType(); } - size_t getNumColumns() const { return Columns.size(); } - size_t getNumRows() const { - assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); - return cast<VectorType>(Columns[0]->getType())->getNumElements(); + unsigned getNumVectors() const { + if (isColumnMajor()) + return getNumColumns(); + return getNumRows(); } - const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } + unsigned getNumColumns() const { + if (isColumnMajor()) + return Vectors.size(); + else { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); + } + } + unsigned getNumRows() const { + if (isColumnMajor()) { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); + } else + return Vectors.size(); + } - SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } + void addVector(Value *V) { Vectors.push_back(V); } + VectorType *getColumnTy() { + assert(isColumnMajor() && "only supported for column-major matrixes"); + return getVectorTy(); + } - void addColumn(Value *V) { Columns.push_back(V); } + VectorType *getVectorTy() { + return cast<VectorType>(Vectors[0]->getType()); + } iterator_range<SmallVector<Value *, 8>::iterator> columns() { - return make_range(Columns.begin(), Columns.end()); + assert(isColumnMajor() && + "columns() only supported for column-major matrixes"); + return make_range(Vectors.begin(), Vectors.end()); } - /// Embed the columns of the matrix into a flat vector by concatenating + iterator_range<SmallVector<Value *, 8>::iterator> vectors() { + return make_range(Vectors.begin(), Vectors.end()); + } + + /// Embed the vectors of the matrix into a flat vector by concatenating /// them. Value *embedInVector(IRBuilder<> &Builder) const { - return Columns.size() == 1 ? Columns[0] - : concatenateVectors(Builder, Columns); + return Vectors.size() == 1 ? Vectors[0] + : concatenateVectors(Builder, Vectors); + } + + MatrixTy &addNumLoads(unsigned N) { + OpInfo.NumLoads += N; + return *this; + } + + void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } + + MatrixTy &addNumStores(unsigned N) { + OpInfo.NumStores += N; + return *this; + } + + MatrixTy &addNumComputeOps(unsigned N) { + OpInfo.NumComputeOps += N; + return *this; + } + + unsigned getNumStores() const { return OpInfo.NumStores; } + unsigned getNumLoads() const { return OpInfo.NumLoads; } + unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } + + const OpInfoTy &getOpInfo() const { return OpInfo; } + + bool isColumnMajor() const { return IsColumnMajor; } + + unsigned getStride() const { + if (isColumnMajor()) + return getNumRows(); + return getNumColumns(); + } + + /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the + /// matrix is column-major, the result vector is extracted from a column + /// vector, otherwise from a row vector. + Value *extractVector(unsigned I, unsigned J, unsigned NumElts, + IRBuilder<> &Builder) const { + Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); + Value *Undef = UndefValue::get(Vec->getType()); + return Builder.CreateShuffleVector( + Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), + "block"); } }; @@ -177,12 +340,15 @@ class LowerMatrixIntrinsics { unsigned NumRows; unsigned NumColumns; + bool IsColumnMajor; + ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) - : NumRows(NumRows), NumColumns(NumColumns) {} + : NumRows(NumRows), NumColumns(NumColumns), + IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} ShapeInfo(Value *NumRows, Value *NumColumns) - : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), - NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} + : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), + cast<ConstantInt>(NumColumns)->getZExtValue()) {} bool operator==(const ShapeInfo &other) { return NumRows == other.NumRows && NumColumns == other.NumColumns; @@ -195,12 +361,24 @@ class LowerMatrixIntrinsics { assert(NumRows == 0 || NumColumns != 0); return NumRows != 0; } + + unsigned getStride() const { + if (IsColumnMajor) + return NumRows; + return NumColumns; + } + + unsigned getNumVectors() const { + if (IsColumnMajor) + return NumColumns; + return NumRows; + } }; /// Maps instructions to their shape information. The shape information /// describes the shape to be used while lowering. This matches the shape of /// the result value of the instruction, with the only exceptions being store - /// instructions and the matrix_columnwise_store intrinsics. For those, the + /// instructions and the matrix_column_major_store intrinsics. For those, the /// shape information indicates that those instructions should be lowered /// using shape information as well. DenseMap<Value *, ShapeInfo> ShapeMap; @@ -211,31 +389,49 @@ class LowerMatrixIntrinsics { SmallVector<Instruction *, 16> ToRemove; /// Map from instructions to their produced column matrix. - DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; + MapVector<Value *, MatrixTy> Inst2ColumnMatrix; public: - LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) - : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} + LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, + AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, + OptimizationRemarkEmitter &ORE) + : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), + LI(LI), ORE(ORE) {} + + unsigned getNumOps(Type *VT) { + assert(isa<VectorType>(VT) && "Expected vector type"); + return getNumOps(VT->getScalarType(), + cast<FixedVectorType>(VT)->getNumElements()); + } - /// Return the set of column vectors that a matrix value is lowered to. + // + /// Return the estimated number of vector ops required for an operation on + /// \p VT * N. + unsigned getNumOps(Type *ST, unsigned N) { + return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / + double(TTI.getRegisterBitWidth(true))); + } + + /// Return the set of vectors that a matrix value is lowered to. /// - /// If we lowered \p MatrixVal, just return the cache result column matrix. - /// Otherwie split the flat vector \p MatrixVal containing a matrix with - /// shape \p SI into column vectors. - ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, - IRBuilder<> Builder) { + /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise + /// split the flat vector \p MatrixVal containing a matrix with shape \p SI + /// into vectors. + MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, + IRBuilder<> &Builder) { VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType()); assert(VType && "MatrixVal must be a vector type"); - assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && + assert(cast<FixedVectorType>(VType)->getNumElements() == + SI.NumRows * SI.NumColumns && "The vector size must match the number of matrix elements"); // Check if we lowered MatrixVal using shape information. In that case, - // return the existing column matrix, if it matches the requested shape + // return the existing matrix, if it matches the requested shape // information. If there is a mis-match, embed the result in a flat // vector and split it later. auto Found = Inst2ColumnMatrix.find(MatrixVal); if (Found != Inst2ColumnMatrix.end()) { - ColumnMatrixTy &M = Found->second; + MatrixTy &M = Found->second; // Return the found matrix, if its shape matches the requested shape // information if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) @@ -247,10 +443,12 @@ public: // Otherwise split MatrixVal. SmallVector<Value *, 16> SplitVecs; Value *Undef = UndefValue::get(VType); - for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); - MaskStart += SI.NumRows) { - Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); - Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); + for (unsigned MaskStart = 0; + MaskStart < cast<FixedVectorType>(VType)->getNumElements(); + MaskStart += SI.getStride()) { + Value *V = Builder.CreateShuffleVector( + MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), + "split"); SplitVecs.push_back(V); } @@ -308,8 +506,8 @@ public: switch (II->getIntrinsicID()) { case Intrinsic::matrix_multiply: case Intrinsic::matrix_transpose: - case Intrinsic::matrix_columnwise_load: - case Intrinsic::matrix_columnwise_store: + case Intrinsic::matrix_column_major_load: + case Intrinsic::matrix_column_major_store: return true; default: return false; @@ -348,13 +546,13 @@ public: m_Value(MatrixA), m_Value(M), m_Value(N)))) { // Flip dimensions. Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( + } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( m_Value(MatrixA), m_Value(), m_Value(), - m_Value(M), m_Value(N)))) { + m_Value(), m_Value(M), m_Value(N)))) { Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, - m_Intrinsic<Intrinsic::matrix_columnwise_load>( - m_Value(), m_Value(), m_Value(M), m_Value(N)))) { + } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(), m_Value(), m_Value(), m_Value(M), + m_Value(N)))) { Propagate = setShapeInfo(Inst, {M, N}); } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { auto OpShape = ShapeMap.find(MatrixA); @@ -426,14 +624,14 @@ public: // Flip dimensions. if (setShapeInfo(MatrixA, {M, N})) pushInstruction(MatrixA, WorkList); - } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( - m_Value(MatrixA), m_Value(), m_Value(), + } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( + m_Value(MatrixA), m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) { if (setShapeInfo(MatrixA, {M, N})) { pushInstruction(MatrixA, WorkList); } } else if (isa<LoadInst>(V) || - match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { + match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) { // Nothing to do, no matrix input. } else if (isa<StoreInst>(V)) { // Nothing to do. We forward-propagated to this so we would just @@ -472,8 +670,8 @@ public: switch (II->getIntrinsicID()) { case Intrinsic::matrix_multiply: case Intrinsic::matrix_transpose: - case Intrinsic::matrix_columnwise_load: - case Intrinsic::matrix_columnwise_store: + case Intrinsic::matrix_column_major_load: + case Intrinsic::matrix_column_major_store: WorkList.push_back(&Inst); break; default: @@ -487,45 +685,57 @@ public: } } - ReversePostOrderTraversal<Function *> RPOT(&Func); bool Changed = false; - for (auto *BB : RPOT) { - for (Instruction &Inst : make_early_inc_range(*BB)) { - IRBuilder<> Builder(&Inst); - - if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) - Changed |= VisitCallInst(CInst); - - Value *Op1; - Value *Op2; - if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) - Changed |= VisitBinaryOperator(BinOp); - if (match(&Inst, m_Load(m_Value(Op1)))) - Changed |= VisitLoad(&Inst, Op1, Builder); - else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) - Changed |= VisitStore(&Inst, Op1, Op2, Builder); + SmallVector<CallInst *, 16> MaybeFusableInsts; + SmallVector<Instruction *, 16> MatrixInsts; + + // First, collect all instructions with shape information and candidates for + // fusion (currently only matrix multiplies). + ReversePostOrderTraversal<Function *> RPOT(&Func); + for (auto *BB : RPOT) + for (Instruction &I : *BB) { + if (ShapeMap.find(&I) == ShapeMap.end()) + continue; + if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) + MaybeFusableInsts.push_back(cast<CallInst>(&I)); + MatrixInsts.push_back(&I); } + + // Second, try to fuse candidates. + SmallPtrSet<Instruction *, 16> FusedInsts; + for (CallInst *CI : MaybeFusableInsts) + LowerMatrixMultiplyFused(CI, FusedInsts); + Changed = !FusedInsts.empty(); + + // Third, lower remaining instructions with shape information. + for (Instruction *Inst : MatrixInsts) { + if (FusedInsts.count(Inst)) + continue; + + IRBuilder<> Builder(Inst); + + if (CallInst *CInst = dyn_cast<CallInst>(Inst)) + Changed |= VisitCallInst(CInst); + + Value *Op1; + Value *Op2; + if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) + Changed |= VisitBinaryOperator(BinOp); + if (match(Inst, m_Load(m_Value(Op1)))) + Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); + else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) + Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); } + RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); + RemarkGen.emitRemarks(); + for (Instruction *Inst : reverse(ToRemove)) Inst->eraseFromParent(); return Changed; } - LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, - IRBuilder<> Builder) { - unsigned Align = DL.getABITypeAlignment(EltType); - return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); - } - - StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, - Type *EltType, IRBuilder<> Builder) { - unsigned Align = DL.getABITypeAlignment(EltType); - return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); - } - - /// Turns \p BasePtr into an elementwise pointer to \p EltType. Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); @@ -545,11 +755,11 @@ public: case Intrinsic::matrix_transpose: LowerTranspose(Inst); break; - case Intrinsic::matrix_columnwise_load: - LowerColumnwiseLoad(Inst); + case Intrinsic::matrix_column_major_load: + LowerColumnMajorLoad(Inst); break; - case Intrinsic::matrix_columnwise_store: - LowerColumnwiseStore(Inst); + case Intrinsic::matrix_column_major_store: + LowerColumnMajorStore(Inst); break; default: return false; @@ -557,108 +767,200 @@ public: return true; } - void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, - ShapeInfo Shape) { - IRBuilder<> Builder(Inst); - auto VType = cast<VectorType>(Inst->getType()); + /// Compute the alignment for a column/row \p Idx with \p Stride between them. + /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a + /// ConstantInt, reduce the initial alignment based on the byte offset. For + /// non-ConstantInt strides, return the common alignment of the initial + /// alignment and the element size in bytes. + Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, + MaybeAlign A) const { + Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); + if (Idx == 0) + return InitialAlign; + + TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); + if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { + uint64_t StrideInBytes = + ConstStride->getZExtValue() * ElementSizeInBits / 8; + return commonAlignment(InitialAlign, Idx * StrideInBytes); + } + return commonAlignment(InitialAlign, ElementSizeInBits / 8); + } + + /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between + /// vectors. + MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, + bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { + auto VType = cast<VectorType>(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - ColumnMatrixTy Result; - // Distance between start of one column and the start of the next - for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { - Value *GEP = - computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, - VType->getElementType(), Builder); - Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); - Result.addColumn(Column); + MatrixTy Result; + for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { + Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, + Shape.getStride(), VType->getElementType(), + Builder); + Value *Vector = Builder.CreateAlignedLoad( + GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), + IsVolatile, "col.load"); + + Result.addVector(Vector); } + return Result.addNumLoads(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); + } - finalizeLowering(Inst, Result, Builder); + /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, + /// starting at \p MatrixPtr[I][J]. + MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, + ShapeInfo MatrixShape, Value *I, Value *J, + ShapeInfo ResultShape, Type *EltTy, + IRBuilder<> &Builder) { + + Value *Offset = Builder.CreateAdd( + Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + + unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); + Value *EltPtr = + Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); + Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * + ResultShape.NumColumns); + Type *TilePtrTy = PointerType::get(TileTy, AS); + Value *TilePtr = + Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + + return loadMatrix(TileTy, TilePtr, Align, + Builder.getInt64(MatrixShape.getStride()), IsVolatile, + ResultShape, Builder); + } + + /// Lower a load instruction with shape information. + void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, + bool IsVolatile, ShapeInfo Shape) { + IRBuilder<> Builder(Inst); + finalizeLowering(Inst, + loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, + Shape, Builder), + Builder); } - /// Lowers llvm.matrix.columnwise.load. + /// Lowers llvm.matrix.column.major.load. /// /// The intrinsic loads a matrix from memory using a stride between columns. - void LowerColumnwiseLoad(CallInst *Inst) { + void LowerColumnMajorLoad(CallInst *Inst) { + assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && + "Intrinsic only supports column-major layout!"); Value *Ptr = Inst->getArgOperand(0); Value *Stride = Inst->getArgOperand(1); - LowerLoad(Inst, Ptr, Stride, - {Inst->getArgOperand(2), Inst->getArgOperand(3)}); + LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, + cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), + {Inst->getArgOperand(3), Inst->getArgOperand(4)}); } - void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, - ShapeInfo Shape) { - IRBuilder<> Builder(Inst); - auto VType = cast<VectorType>(Matrix->getType()); + /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p + /// MatrixPtr[I][J]. + void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, + MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, + Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { + Value *Offset = Builder.CreateAdd( + Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + + unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); + Value *EltPtr = + Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); + Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * + StoreVal.getNumColumns()); + Type *TilePtrTy = PointerType::get(TileTy, AS); + Value *TilePtr = + Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + + storeMatrix(TileTy, StoreVal, TilePtr, MAlign, + Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); + } + + /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between + /// vectors. + MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, + MaybeAlign MAlign, Value *Stride, bool IsVolatile, + IRBuilder<> &Builder) { + auto VType = cast<VectorType>(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - auto LM = getMatrix(Matrix, Shape, Builder); - for (auto C : enumerate(LM.columns())) { - Value *GEP = - computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, - Shape.NumRows, VType->getElementType(), Builder); - createColumnStore(C.value(), GEP, VType->getElementType(), Builder); + for (auto Vec : enumerate(StoreVal.vectors())) { + Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), + Stride, StoreVal.getStride(), + VType->getElementType(), Builder); + Builder.CreateAlignedStore(Vec.value(), GEP, + getAlignForIndex(Vec.index(), Stride, + VType->getElementType(), + MAlign), + IsVolatile); } + return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * + StoreVal.getNumVectors()); + } - ToRemove.push_back(Inst); + /// Lower a store instruction with shape information. + void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, + Value *Stride, bool IsVolatile, ShapeInfo Shape) { + IRBuilder<> Builder(Inst); + auto StoreVal = getMatrix(Matrix, Shape, Builder); + finalizeLowering(Inst, + storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, + IsVolatile, Builder), + Builder); } - /// Lowers llvm.matrix.columnwise.store. + /// Lowers llvm.matrix.column.major.store. /// /// The intrinsic store a matrix back memory using a stride between columns. - void LowerColumnwiseStore(CallInst *Inst) { + void LowerColumnMajorStore(CallInst *Inst) { + assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && + "Intrinsic only supports column-major layout!"); Value *Matrix = Inst->getArgOperand(0); Value *Ptr = Inst->getArgOperand(1); Value *Stride = Inst->getArgOperand(2); - LowerStore(Inst, Matrix, Ptr, Stride, - {Inst->getArgOperand(3), Inst->getArgOperand(4)}); - } - - /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from - /// the matrix \p LM represented as a vector of column vectors. - Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, - unsigned NumElts, IRBuilder<> Builder) { - Value *Col = LM.getColumn(J); - Value *Undef = UndefValue::get(Col->getType()); - Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); - return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); + LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, + cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), + {Inst->getArgOperand(4), Inst->getArgOperand(5)}); } // Set elements I..I+NumElts-1 to Block Value *insertVector(Value *Col, unsigned I, Value *Block, - IRBuilder<> Builder) { + IRBuilder<> &Builder) { // First, bring Block to the same size as Col unsigned BlockNumElts = - cast<VectorType>(Block->getType())->getNumElements(); - unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); + cast<FixedVectorType>(Block->getType())->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); assert(NumElts >= BlockNumElts && "Too few elements for current block"); - Value *ExtendMask = - createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); Value *Undef = UndefValue::get(Block->getType()); - Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); + Block = Builder.CreateShuffleVector( + Block, Undef, + createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, // 8, 4, 5, 6 - SmallVector<Constant *, 16> Mask; + SmallVector<int, 16> Mask; unsigned i; for (i = 0; i < I; i++) - Mask.push_back(Builder.getInt32(i)); + Mask.push_back(i); - unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); + unsigned VecNumElts = + cast<FixedVectorType>(Col->getType())->getNumElements(); for (; i < I + BlockNumElts; i++) - Mask.push_back(Builder.getInt32(i - I + VecNumElts)); + Mask.push_back(i - I + VecNumElts); for (; i < VecNumElts; i++) - Mask.push_back(Builder.getInt32(i)); - - Value *MaskVal = ConstantVector::get(Mask); + Mask.push_back(i); - return Builder.CreateShuffleVector(Col, Block, MaskVal); + return Builder.CreateShuffleVector(Col, Block, Mask); } Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, - IRBuilder<> &Builder, bool AllowContraction) { - + IRBuilder<> &Builder, bool AllowContraction, + unsigned &NumComputeOps) { + NumComputeOps += getNumOps(A->getType()); if (!Sum) return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); @@ -666,14 +968,16 @@ public: if (AllowContraction) { // Use fmuladd for floating point operations and let the backend decide // if that's profitable. - Value *FMulAdd = Intrinsic::getDeclaration( + Function *FMulAdd = Intrinsic::getDeclaration( Func.getParent(), Intrinsic::fmuladd, A->getType()); return Builder.CreateCall(FMulAdd, {A, B, Sum}); } + NumComputeOps += getNumOps(A->getType()); Value *Mul = Builder.CreateFMul(A, B); return Builder.CreateFAdd(Sum, Mul); } + NumComputeOps += getNumOps(A->getType()); Value *Mul = Builder.CreateMul(A, B); return Builder.CreateAdd(Sum, Mul); } @@ -683,7 +987,7 @@ public: /// cached value when they are lowered. For other users, \p Matrix is /// flattened and the uses are updated to use it. Also marks \p Inst for /// deletion. - void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, + void finalizeLowering(Instruction *Inst, MatrixTy Matrix, IRBuilder<> &Builder) { Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); @@ -699,6 +1003,294 @@ public: } } + /// Compute \p Result += \p A * \p B for input matrices with left-associating + /// addition. + void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, + const MatrixTy &B, bool AllowContraction, + IRBuilder<> &Builder, bool isTiled) { + const unsigned VF = std::max<unsigned>( + TTI.getRegisterBitWidth(true) / + Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), + 1U); + unsigned R = Result.getNumRows(); + unsigned C = Result.getNumColumns(); + unsigned M = A.getNumColumns(); + + bool IsFP = Result.getElementType()->isFloatingPointTy(); + assert(A.isColumnMajor() == B.isColumnMajor() && + Result.isColumnMajor() == A.isColumnMajor() && + "operands must agree on matrix layout"); + unsigned NumComputeOps = 0; + if (A.isColumnMajor()) { + // Multiply columns from the first operand with scalars from the second + // operand. Then move along the K axes and accumulate the columns. With + // this the adds can be vectorized without reassociation. + for (unsigned J = 0; J < C; ++J) { + unsigned BlockSize = VF; + // If Result is zero, we don't need to accumulate in the K==0 iteration. + bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); + + for (unsigned I = 0; I < R; I += BlockSize) { + // Gradually lower the vectorization factor to cover the remainder. + while (I + BlockSize > R) + BlockSize /= 2; + + Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) + : nullptr; + for (unsigned K = 0; K < M; ++K) { + Value *L = A.extractVector(I, K, BlockSize, Builder); + Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); + Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); + Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, + Result.getElementType()->isFloatingPointTy(), + Builder, AllowContraction, NumComputeOps); + } + Result.setVector(J, + insertVector(Result.getVector(J), I, Sum, Builder)); + } + } + } else { + // Multiply rows from the second operand with scalars from the first + // operand. Then move along the K axes and accumulate the rows. With this + // the adds can be vectorized without reassociation. + for (unsigned I = 0; I < R; ++I) { + unsigned BlockSize = VF; + bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); + for (unsigned J = 0; J < C; J += BlockSize) { + // Gradually lower the vectorization factor to cover the remainder. + while (J + BlockSize > C) + BlockSize /= 2; + + Value *Sum = nullptr; + for (unsigned K = 0; K < M; ++K) { + Value *R = B.extractVector(K, J, BlockSize, Builder); + Value *LH = Builder.CreateExtractElement(A.getVector(I), K); + Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); + Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, + IsFP, Builder, AllowContraction, NumComputeOps); + } + Result.setVector(I, + insertVector(Result.getVector(I), J, Sum, Builder)); + } + } + } + Result.addNumComputeOps(NumComputeOps); + } + + /// Ensure that the memory in \p Load does not alias \p Store by potentially + /// copying it to a new location. This new or otherwise the original location + /// is returned. + Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, + CallInst *MatMul) { + MemoryLocation StoreLoc = MemoryLocation::get(Store); + MemoryLocation LoadLoc = MemoryLocation::get(Load); + + AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); + + // If we can statically determine noalias we're good. + if (!LdAliased) + return Load->getPointerOperand(); + + // Create code to check if the memory locations of the Load and Store + // overlap and if they do, copy Load's operand to a new buffer. + + // First, create new blocks for 2n part of the check and the copy. + BasicBlock *Check0 = MatMul->getParent(); + // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a + // DT. Manually collect dominator tree updates, to avoid unnecessary work, + // as we adjust Check0 and Check1's branches. + SmallVector<DominatorTree::UpdateType, 4> DTUpdates; + for (BasicBlock *Succ : successors(Check0)) + DTUpdates.push_back({DT.Delete, Check0, Succ}); + + BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, + nullptr, "alias_cont"); + BasicBlock *Copy = + SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); + BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, + nullptr, "no_alias"); + + // Check if the loaded memory location begins before the end of the store + // location. If the condition holds, they might overlap, otherwise they are + // guaranteed to not overlap. + IRBuilder<> Builder(MatMul); + Check0->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(Check0); + Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); + Value *StoreBegin = Builder.CreatePtrToInt( + const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); + Value *StoreEnd = Builder.CreateAdd( + StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), + "store.end", true, true); + Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), + IntPtrTy, "load.begin"); + Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, + Fusion); + + // Check if the store begins before the end of the load location. If the + // condition holds, they alias, otherwise they are guaranteed to not + // overlap. + Check1->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(Check1, Check1->begin()); + Value *LoadEnd = Builder.CreateAdd( + LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), + "load.end", true, true); + Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, + Fusion); + + // Copy load operand to new alloca. + Builder.SetInsertPoint(Copy, Copy->begin()); + AllocaInst *NewLd = + Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); + Builder.CreateMemCpy(NewLd, NewLd->getAlign(), + Load->getPointerOperand(), Load->getAlign(), + LoadLoc.Size.getValue()); + Builder.SetInsertPoint(Fusion, Fusion->begin()); + PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); + PHI->addIncoming(Load->getPointerOperand(), Check0); + PHI->addIncoming(Load->getPointerOperand(), Check1); + PHI->addIncoming(NewLd, Copy); + + // Adjust DT. + DTUpdates.push_back({DT.Insert, Check0, Check1}); + DTUpdates.push_back({DT.Insert, Check0, Fusion}); + DTUpdates.push_back({DT.Insert, Check1, Copy}); + DTUpdates.push_back({DT.Insert, Check1, Fusion}); + DT.applyUpdates(DTUpdates); + return PHI; + } + + bool isFusionProfitable(CallInst *MatMul) { + if (ForceFusion) + return true; + + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + const unsigned R = LShape.NumRows; + const unsigned C = RShape.NumColumns; + const unsigned M = LShape.NumColumns; + auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + + const unsigned VF = + std::max<unsigned>(TTI.getRegisterBitWidth(true) / + EltType->getPrimitiveSizeInBits().getFixedSize(), + 1U); + + // Cost model for tiling + // + // For tiling to be beneficial, we need reuse either along the R or + // the C axis. We vectorize along the R axis so that means at least + // 3 elements. + // TODO: Also consider cost of copying if operands alias. + if (R <= VF && C == 1) + return false; + // Then we need enough elements to exceed the number of vector + // registers we have. Note that this is an oversimplification since + // fusing also takes some extra loads which may exceed the number of + // reloads necessary. + unsigned Op0Regs = (R + VF - 1) / VF * M; + unsigned Op1Regs = (M + VF - 1) / VF * C; + return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); + } + + MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { + MatrixTy Res; + auto *ColumType = FixedVectorType::get(EltType, R); + for (unsigned I = 0; I < C; ++I) + Res.addVector(ConstantAggregateZero::get(ColumType)); + return Res; + } + + void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, + StoreInst *Store, + SmallPtrSetImpl<Instruction *> &FusedInsts) { + assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && + "Tiling only supported for column-major matrixes at the moment!"); + if (!isFusionProfitable(MatMul)) + return; + + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + const unsigned R = LShape.NumRows; + const unsigned C = RShape.NumColumns; + const unsigned M = LShape.NumColumns; + auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + + Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); + Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); + Value *CPtr = Store->getPointerOperand(); + + bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && + MatMul->hasAllowContract()); + IRBuilder<> Builder(Store); + for (unsigned J = 0; J < C; J += TileSize) + for (unsigned I = 0; I < R; I += TileSize) { + const unsigned TileR = std::min(R - I, unsigned(TileSize)); + const unsigned TileC = std::min(C - J, unsigned(TileSize)); + MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); + + for (unsigned K = 0; K < M; K += TileSize) { + const unsigned TileM = std::min(M - K, unsigned(TileSize)); + MatrixTy A = + loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), + LShape, Builder.getInt64(I), Builder.getInt64(K), + {TileR, TileM}, EltType, Builder); + MatrixTy B = + loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), + RShape, Builder.getInt64(K), Builder.getInt64(J), + {TileM, TileC}, EltType, Builder); + emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); + } + storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, + Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); + } + + // Mark eliminated instructions as fused and remove them. + FusedInsts.insert(Store); + FusedInsts.insert(MatMul); + Store->eraseFromParent(); + MatMul->eraseFromParent(); + if (LoadOp0->hasNUses(0)) { + FusedInsts.insert(LoadOp0); + LoadOp0->eraseFromParent(); + } + if (LoadOp1->hasNUses(0)) { + FusedInsts.insert(LoadOp1); + LoadOp1->eraseFromParent(); + } + } + + /// Try to lower matrix multiply chains by fusing operations. + /// + /// Currently we only lower {ld, ld} -> matmul -> st chains. + // + /// No need to return a MatrixTy object for the result of the operation, since + /// the single store user will be lowered as part of this. Instructions that + /// are completely eliminated by fusion are added to \p FusedInsts. + void LowerMatrixMultiplyFused(CallInst *MatMul, + SmallPtrSetImpl<Instruction *> &FusedInsts) { + if (!FuseMatrix || !MatMul->hasOneUse() || + MatrixLayout != MatrixLayoutTy::ColumnMajor) + return; + + auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); + auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); + auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); + if (LoadOp0 && LoadOp1 && Store) { + // The store address must dominate the MatMul instruction, otherwise + // we create invalid IR. + // FIXME: See if we can hoist the store address computation. + auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); + if (AddrI && (!DT.dominates(AddrI, MatMul))) + return; + + emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); + return; + } + } + /// Lowers llvm.matrix.multiply. void LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); @@ -706,97 +1298,80 @@ public: ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); - const ColumnMatrixTy &Lhs = - getMatrix(MatMul->getArgOperand(0), LShape, Builder); - const ColumnMatrixTy &Rhs = - getMatrix(MatMul->getArgOperand(1), RShape, Builder); + const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); + const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); const unsigned R = LShape.NumRows; - const unsigned M = LShape.NumColumns; const unsigned C = RShape.NumColumns; - assert(M == RShape.NumRows); + assert(LShape.NumColumns == RShape.NumRows); // Initialize the output - ColumnMatrixTy Result; - for (unsigned J = 0; J < C; ++J) - Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); - - const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / - EltType->getPrimitiveSizeInBits(), - uint64_t(1)); + MatrixTy Result(R, C, EltType); bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && MatMul->hasAllowContract()); - // Multiply columns from the first operand with scalars from the second - // operand. Then move along the K axes and accumulate the columns. With - // this the adds can be vectorized without reassociation. - for (unsigned J = 0; J < C; ++J) { - unsigned BlockSize = VF; - for (unsigned I = 0; I < R; I += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (I + BlockSize > R) - BlockSize /= 2; - - Value *Sum = nullptr; - for (unsigned K = 0; K < M; ++K) { - Value *L = extractVector(Lhs, I, K, BlockSize, Builder); - Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); - Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); - Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), - Builder, AllowContract); - } - Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); - } - } + emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); finalizeLowering(MatMul, Result, Builder); } /// Lowers llvm.matrix.transpose. void LowerTranspose(CallInst *Inst) { - ColumnMatrixTy Result; + MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); VectorType *VectorTy = cast<VectorType>(InputVal->getType()); ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); - ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); - - for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { - // Build a single column vector for this row. First initialize it. - Value *ResultColumn = UndefValue::get( - VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); - - // Go through the elements of this row and insert it into the resulting - // column vector. - for (auto C : enumerate(InputMatrix.columns())) { - Value *Elt = Builder.CreateExtractElement(C.value(), Row); - // We insert at index Column since that is the row index after the - // transpose. - ResultColumn = - Builder.CreateInsertElement(ResultColumn, Elt, C.index()); + MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + + const unsigned NewNumVecs = + InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; + const unsigned NewNumElts = + InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; + + for (unsigned I = 0; I < NewNumVecs; ++I) { + // Build a single result vector. First initialize it. + Value *ResultVector = UndefValue::get( + FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); + // Go through the old elements and insert it into the resulting vector. + for (auto J : enumerate(InputMatrix.vectors())) { + Value *Elt = Builder.CreateExtractElement(J.value(), I); + // Row and column indices are transposed. + ResultVector = + Builder.CreateInsertElement(ResultVector, Elt, J.index()); } - Result.addColumn(ResultColumn); + Result.addVector(ResultVector); } - finalizeLowering(Inst, Result, Builder); + // TODO: Improve estimate of operations needed for transposes. Currently we + // just count the insertelement/extractelement instructions, but do not + // account for later simplifications/combines. + finalizeLowering( + Inst, + Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), + Builder); } /// Lower load instructions, if shape information is available. - bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { + bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { auto I = ShapeMap.find(Inst); if (I == ShapeMap.end()) return false; - LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); + LowerLoad(Inst, Ptr, Inst->getAlign(), + Builder.getInt64(I->second.getStride()), Inst->isVolatile(), + I->second); return true; } - bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, + bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, IRBuilder<> &Builder) { auto I = ShapeMap.find(StoredVal); if (I == ShapeMap.end()) return false; - LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); + LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), + Builder.getInt64(I->second.getStride()), Inst->isVolatile(), + I->second); return true; } @@ -812,12 +1387,15 @@ public: IRBuilder<> Builder(Inst); ShapeInfo &Shape = I->second; - ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); - ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); + MatrixTy Result; + MatrixTy A = getMatrix(Lhs, Shape, Builder); + MatrixTy B = getMatrix(Rhs, Shape, Builder); + assert(A.isColumnMajor() == B.isColumnMajor() && + Result.isColumnMajor() == A.isColumnMajor() && + "operands must agree on matrix layout"); - // Add each column and store the result back into the opmapping - ColumnMatrixTy Result; - auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { + // Helper to perform binary op on vectors. + auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) { switch (Inst->getOpcode()) { case Instruction::Add: return Builder.CreateAdd(LHS, RHS); @@ -835,20 +1413,462 @@ public: llvm_unreachable("Unsupported binary operator for matrix"); } }; - for (unsigned C = 0; C < Shape.NumColumns; ++C) - Result.addColumn( - BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); - finalizeLowering(Inst, Result, Builder); + for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); + + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()), + Builder); return true; } + + /// Helper to linearize a matrix expression tree into a string. Currently + /// matrix expressions are linarized by starting at an expression leaf and + /// linearizing bottom up. + struct ExprLinearizer { + unsigned LengthToBreak = 100; + std::string Str; + raw_string_ostream Stream; + unsigned LineLength = 0; + const DataLayout &DL; + + /// Mapping from instructions to matrixes. It is used to identify + /// matrix instructions. + const MapVector<Value *, MatrixTy> &Inst2Matrix; + + /// Mapping from values to the leaves of all expressions that the value is + /// part of. + const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; + + /// Set of matrix expressions in the scope of a given DISubprogram. + const SmallSetVector<Value *, 32> &ExprsInSubprogram; + + /// Leaf node of the expression to linearize. + Value *Leaf; + + /// Used to keep track of sub-expressions that get reused while linearizing + /// the expression. Re-used sub-expressions are marked as (reused). + SmallPtrSet<Value *, 8> ReusedExprs; + + ExprLinearizer(const DataLayout &DL, + const MapVector<Value *, MatrixTy> &Inst2Matrix, + const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, + const SmallSetVector<Value *, 32> &ExprsInSubprogram, + Value *Leaf) + : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), + ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} + + void indent(unsigned N) { + LineLength += N; + for (unsigned i = 0; i < N; i++) + Stream << " "; + } + + void lineBreak() { + Stream << "\n"; + LineLength = 0; + } + + void maybeIndent(unsigned Indent) { + if (LineLength >= LengthToBreak) + lineBreak(); + + if (LineLength == 0) + indent(Indent); + } + + void write(StringRef S) { + LineLength += S.size(); + Stream << S; + } + + Value *getUnderlyingObjectThroughLoads(Value *V) { + if (Value *Ptr = getPointerOperand(V)) + return getUnderlyingObjectThroughLoads(Ptr); + else if (V->getType()->isPointerTy()) + return GetUnderlyingObject(V, DL); + return V; + } + + /// Returns true if \p V is a matrix value in the given subprogram. + bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } + + /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to + /// \p SS. + void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { + auto M = Inst2Matrix.find(V); + if (M == Inst2Matrix.end()) + SS << "unknown"; + else { + SS << M->second.getNumRows(); + SS << "x"; + SS << M->second.getNumColumns(); + } + } + + /// Write the called function name. Handles calls to llvm.matrix.* + /// specially: we write the name, followed by the dimensions of the input + /// matrixes, followed by the scalar type name. + void writeFnName(CallInst *CI) { + if (!CI->getCalledFunction()) + write("<no called fn>"); + else { + StringRef Name = CI->getCalledFunction()->getName(); + if (!Name.startswith("llvm.matrix")) { + write(Name); + return; + } + IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); + write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) + .drop_front(StringRef("llvm.matrix.").size())); + write("."); + std::string Tmp = ""; + raw_string_ostream SS(Tmp); + + switch (II->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + prettyPrintMatrixType(II->getOperand(0), SS); + SS << "."; + prettyPrintMatrixType(II->getOperand(1), SS); + SS << "." << *II->getType()->getScalarType(); + break; + case Intrinsic::matrix_transpose: + prettyPrintMatrixType(II->getOperand(0), SS); + SS << "." << *II->getType()->getScalarType(); + break; + case Intrinsic::matrix_column_major_load: + prettyPrintMatrixType(II, SS); + SS << "." << *II->getType()->getScalarType(); + break; + case Intrinsic::matrix_column_major_store: + prettyPrintMatrixType(II->getOperand(0), SS); + SS << "." << *II->getOperand(0)->getType()->getScalarType(); + break; + default: + llvm_unreachable("Unhandled case"); + } + SS.flush(); + write(Tmp); + } + } + + unsigned getNumShapeArgs(CallInst *CI) const { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { + switch (II->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + return 3; + case Intrinsic::matrix_transpose: + return 2; + case Intrinsic::matrix_column_major_load: + case Intrinsic::matrix_column_major_store: + return 3; + default: + return 0; + } + } + return 0; + } + + /// Special printing for values: for pointers, we print if they refer to an + /// (function) external address or a stack address, for other values we + /// either print the constant or "scalar"/"matrix" for other values. + void write(Value *V) { + V = getUnderlyingObjectThroughLoads(V); + if (V->getType()->isPointerTy()) { + if (isa<AllocaInst>(V)) { + Stream << "stack addr"; + LineLength += StringRef("stack addr").size(); + } else { + Stream << "addr"; + LineLength += StringRef("addr").size(); + } + if (!V->getName().empty()) { + Stream << " %" << V->getName() << ""; + LineLength += V->getName().size() + 2; + } + return; + } + + std::string Tmp; + raw_string_ostream TmpStream(Tmp); + + if (auto *CI = dyn_cast<ConstantInt>(V)) + TmpStream << CI->getValue(); + else if (isa<Constant>(V)) + TmpStream << "constant"; + else { + if (isMatrix(V)) + TmpStream << "matrix"; + else + TmpStream << "scalar"; + } + TmpStream.flush(); + Tmp = std::string(StringRef(Tmp).trim()); + LineLength += Tmp.size(); + Stream << Tmp; + } + + /// Linearize expression \p Expr starting at an indentation of \p Indent. + /// Expressions that are re-used multiple times are prefixed with (reused) + /// at the re-used root instruction. + void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, + bool ParentShared) { + auto *I = cast<Instruction>(Expr); + maybeIndent(Indent); + SmallVector<Value *, 8> Ops; + + // Is Expr shared with other expression leaves? + bool ExprShared = false; + + // Deal with shared subtrees. Mark them as shared, if required. + if (!ParentShared) { + auto SI = Shared.find(Expr); + assert(SI != Shared.end() && SI->second.count(Leaf)); + + for (Value *S : SI->second) { + if (S == Leaf) + continue; + DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); + write("shared with remark at line " + std::to_string(DL.getLine()) + + " column " + std::to_string(DL.getCol()) + " ("); + } + ExprShared = SI->second.size() > 1; + } + + bool Reused = !ReusedExprs.insert(Expr).second; + if (Reused && !ParentReused) + write("(reused) "); + + if (auto *CI = dyn_cast<CallInst>(I)) { + writeFnName(CI); + + Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); + } else if (isa<BitCastInst>(Expr)) { + // Special case bitcasts, which are used to materialize matrixes from + // non-matrix ops. + write("matrix"); + return; + } else { + Ops.append(I->value_op_begin(), I->value_op_end()); + write(std::string(I->getOpcodeName())); + } + + write(std::string("(")); + + unsigned NumOpsToBreak = 1; + if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) + NumOpsToBreak = 2; + + for (Value *Op : Ops) { + if (Ops.size() > NumOpsToBreak) + lineBreak(); + + maybeIndent(Indent + 1); + if (isMatrix(Op)) + linearizeExpr(Op, Indent + 1, Reused, ExprShared); + else + write(Op); + if (Op != Ops.back()) + write(", "); + } + + write(")"); + } + + const std::string &getResult() { + Stream.flush(); + return Str; + } + }; + + /// Generate remarks for matrix operations in a function. To generate remarks + /// for matrix expressions, the following approach is used: + /// 1. Use the inlined-at debug information to group matrix operations to the + /// DISubprograms they are contained in. + /// 2. Collect leaves of matrix expressions (done in + /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression + // mapping. Leaves are lowered matrix instructions without other matrix + // users (like stores) in the current subprogram. + /// 3. For each leaf, create a remark containing a linearizied version of the + /// matrix expression. The expression is linearized by a recursive + /// bottom-up traversal of the matrix operands, starting at a leaf. Note + /// that multiple leaves can share sub-expressions. Shared subexpressions + /// are explicitly marked as shared(). + struct RemarkGenerator { + const MapVector<Value *, MatrixTy> &Inst2Matrix; + OptimizationRemarkEmitter &ORE; + Function &Func; + const DataLayout &DL; + + RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, + OptimizationRemarkEmitter &ORE, Function &Func) + : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), + DL(Func.getParent()->getDataLayout()) {} + + /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are + /// instructions in Inst2Matrix returning void or without any users in + /// \p ExprsInSubprogram. Currently that should only include stores. + SmallVector<Value *, 4> + getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { + SmallVector<Value *, 4> Leaves; + for (auto *Expr : ExprsInSubprogram) + if (Expr->getType()->isVoidTy() || + !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { + return ExprsInSubprogram.count(U); + })) + Leaves.push_back(Expr); + return Leaves; + } + + /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf + /// to all visited expressions in \p Shared. Limit the matrix operations to + /// the ones in \p ExprsInSubprogram. + void collectSharedInfo(Value *Leaf, Value *V, + const SmallSetVector<Value *, 32> &ExprsInSubprogram, + DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { + + if (!ExprsInSubprogram.count(V)) + return; + + auto I = Shared.insert({V, {}}); + I.first->second.insert(Leaf); + + for (Value *Op : cast<Instruction>(V)->operand_values()) + collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); + return; + } + + /// Calculate the number of exclusive and shared op counts for expression + /// starting at \p V. Expressions used multiple times are counted once. + /// Limit the matrix operations to the ones in \p ExprsInSubprogram. + std::pair<OpInfoTy, OpInfoTy> + sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, + const SmallSetVector<Value *, 32> &ExprsInSubprogram, + DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { + if (!ExprsInSubprogram.count(Root)) + return {}; + + // Already counted this expression. Stop. + if (!ReusedExprs.insert(Root).second) + return {}; + + OpInfoTy SharedCount; + OpInfoTy Count; + + auto I = Shared.find(Root); + auto CM = Inst2Matrix.find(Root); + if (I->second.size() == 1) + Count = CM->second.getOpInfo(); + else + SharedCount = CM->second.getOpInfo(); + + for (Value *Op : cast<Instruction>(Root)->operand_values()) { + auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); + Count += C.first; + SharedCount += C.second; + } + return {Count, SharedCount}; + } + + void emitRemarks() { + if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) + return; + + // Map matrix operations to their containting subprograms, by traversing + // the inlinedAt chain. If the function does not have a DISubprogram, we + // only map them to the containing function. + MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; + for (auto &KV : Inst2Matrix) { + if (Func.getSubprogram()) { + auto *I = cast<Instruction>(KV.first); + DILocation *Context = I->getDebugLoc(); + while (Context) { + auto I = + Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); + I.first->second.push_back(KV.first); + Context = DebugLoc(Context).getInlinedAt(); + } + } else { + auto I = Subprog2Exprs.insert({nullptr, {}}); + I.first->second.push_back(KV.first); + } + } + for (auto &KV : Subprog2Exprs) { + SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), + KV.second.end()); + auto Leaves = getExpressionLeaves(ExprsInSubprogram); + + DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; + for (Value *Leaf : Leaves) + collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); + + // Generate remarks for each leaf. + for (auto *L : Leaves) { + + DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); + DILocation *Context = cast<Instruction>(L)->getDebugLoc(); + while (Context) { + if (getSubprogram(Context->getScope()) == KV.first) { + Loc = Context; + break; + } + Context = DebugLoc(Context).getInlinedAt(); + } + + SmallPtrSet<Value *, 8> ReusedExprs; + OpInfoTy Counts, SharedCounts; + std::tie(Counts, SharedCounts) = + sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); + + OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, + cast<Instruction>(L)->getParent()); + + Rem << "Lowered with "; + Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " + << ore::NV("NumLoads", Counts.NumLoads) << " loads, " + << ore::NV("NumComputeOps", Counts.NumComputeOps) + << " compute ops"; + + if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || + SharedCounts.NumComputeOps > 0) { + Rem << ",\nadditionally " + << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " + << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " + << ore::NV("NumFPOps", SharedCounts.NumComputeOps) + << " compute ops" + << " are shared with other expressions"; + } + + Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); + ORE.emit(Rem); + } + } + } + + std::string + linearize(Value *L, + const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, + const SmallSetVector<Value *, 32> &ExprsInSubprogram, + const DataLayout &DL) { + ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); + Lin.linearizeExpr(L, 0, false, false); + return Lin.getResult(); + } + }; }; } // namespace PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); - LowerMatrixIntrinsics LMT(F, TTI); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + + LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); if (LMT.Visit()) { PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); @@ -869,15 +1889,24 @@ public: } bool runOnFunction(Function &F) override { - auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - LowerMatrixIntrinsics LMT(F, *TTI); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); bool C = LMT.Visit(); return C; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.setPreservesCFG(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); } }; } // namespace @@ -886,6 +1915,10 @@ static const char pass_name[] = "Lower the matrix intrinsics"; char LowerMatrixIntrinsicsLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, false, false) diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index c24fa40860eb..4b4196edc12b 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -27,7 +27,6 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -173,8 +172,8 @@ public: void addStore(int64_t OffsetFromFirst, StoreInst *SI) { int64_t StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType()); - addRange(OffsetFromFirst, StoreSize, - SI->getPointerOperand(), SI->getAlignment(), SI); + addRange(OffsetFromFirst, StoreSize, SI->getPointerOperand(), + SI->getAlign().value(), SI); } void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { @@ -387,13 +386,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // Get the starting pointer of the block. StartPtr = Range.StartPtr; - // Determine alignment - const Align Alignment = DL.getValueOrABITypeAlignment( - MaybeAlign(Range.Alignment), - cast<PointerType>(StartPtr->getType())->getElementType()); - AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start, - Alignment); + MaybeAlign(Range.Alignment)); LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI : Range.TheStores) dbgs() << *SI << '\n'; @@ -413,23 +407,6 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, return AMemSet; } -static Align findStoreAlignment(const DataLayout &DL, const StoreInst *SI) { - return DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()), - SI->getOperand(0)->getType()); -} - -static Align findLoadAlignment(const DataLayout &DL, const LoadInst *LI) { - return DL.getValueOrABITypeAlignment(MaybeAlign(LI->getAlignment()), - LI->getType()); -} - -static Align findCommonAlignment(const DataLayout &DL, const StoreInst *SI, - const LoadInst *LI) { - Align StoreAlign = findStoreAlignment(DL, SI); - Align LoadAlign = findLoadAlignment(DL, LI); - return commonAlignment(StoreAlign, LoadAlign); -} - // This method try to lift a store instruction before position P. // It will lift the store and its argument + that anything that // may alias with these. @@ -585,12 +562,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { Instruction *M; if (UseMemMove) M = Builder.CreateMemMove( - SI->getPointerOperand(), findStoreAlignment(DL, SI), - LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); + SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), Size); else M = Builder.CreateMemCpy( - SI->getPointerOperand(), findStoreAlignment(DL, SI), - LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); + SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), Size); LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M << "\n"); @@ -642,7 +619,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { LI, SI->getPointerOperand()->stripPointerCasts(), LI->getPointerOperand()->stripPointerCasts(), DL.getTypeStoreSize(SI->getOperand(0)->getType()), - findCommonAlignment(DL, SI, LI).value(), C); + commonAlignment(SI->getAlign(), LI->getAlign()), C); if (changed) { MD->removeInstruction(SI); SI->eraseFromParent(); @@ -675,11 +652,9 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { auto *T = V->getType(); if (T->isAggregateType()) { uint64_t Size = DL.getTypeStoreSize(T); - const Align MA = - DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()), T); IRBuilder<> Builder(SI); - auto *M = - Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, MA); + auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, + SI->getAlign()); LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); @@ -713,7 +688,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { /// the call write its result directly into the destination of the memcpy. bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, Value *cpySrc, uint64_t cpyLen, - unsigned cpyAlign, CallInst *C) { + Align cpyAlign, CallInst *C) { // The general transformation to keep in mind is // // call @func(..., src, ...) @@ -733,10 +708,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) return false; - // Deliberately get the source and destination with bitcasts stripped away, - // because we'll need to do type comparisons based on the underlying type. - CallSite CS(C); - // Require that src be an alloca. This simplifies the reasoning considerably. AllocaInst *srcAlloca = dyn_cast<AllocaInst>(cpySrc); if (!srcAlloca) @@ -795,9 +766,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, } // Check that dest points to memory that is at least as aligned as src. - unsigned srcAlign = srcAlloca->getAlignment(); - if (!srcAlign) - srcAlign = DL.getABITypeAlignment(srcAlloca->getAllocatedType()); + Align srcAlign = srcAlloca->getAlign(); bool isDestSufficientlyAligned = srcAlign <= cpyAlign; // If dest is not aligned enough and we can't increase its alignment then // bail out. @@ -836,8 +805,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // Check that src isn't captured by the called function since the // transformation can cause aliasing issues in that case. - for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) - if (CS.getArgument(i) == cpySrc && !CS.doesNotCapture(i)) + for (unsigned ArgI = 0, E = C->arg_size(); ArgI != E; ++ArgI) + if (C->getArgOperand(ArgI) == cpySrc && !C->doesNotCapture(ArgI)) return false; // Since we're changing the parameter to the callsite, we need to make sure @@ -864,25 +833,26 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (cpySrc->getType()->getPointerAddressSpace() != cpyDest->getType()->getPointerAddressSpace()) return false; - for (unsigned i = 0; i < CS.arg_size(); ++i) - if (CS.getArgument(i)->stripPointerCasts() == cpySrc && + for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) + if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc && cpySrc->getType()->getPointerAddressSpace() != - CS.getArgument(i)->getType()->getPointerAddressSpace()) + C->getArgOperand(ArgI)->getType()->getPointerAddressSpace()) return false; // All the checks have passed, so do the transformation. bool changedArgument = false; - for (unsigned i = 0; i < CS.arg_size(); ++i) - if (CS.getArgument(i)->stripPointerCasts() == cpySrc) { + for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) + if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc) { Value *Dest = cpySrc->getType() == cpyDest->getType() ? cpyDest : CastInst::CreatePointerCast(cpyDest, cpySrc->getType(), cpyDest->getName(), C); changedArgument = true; - if (CS.getArgument(i)->getType() == Dest->getType()) - CS.setArgument(i, Dest); + if (C->getArgOperand(ArgI)->getType() == Dest->getType()) + C->setArgOperand(ArgI, Dest); else - CS.setArgument(i, CastInst::CreatePointerCast(Dest, - CS.getArgument(i)->getType(), Dest->getName(), C)); + C->setArgOperand(ArgI, CastInst::CreatePointerCast( + Dest, C->getArgOperand(ArgI)->getType(), + Dest->getName(), C)); } if (!changedArgument) @@ -891,7 +861,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // If the destination wasn't sufficiently aligned then increase its alignment. if (!isDestSufficientlyAligned) { assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!"); - cast<AllocaInst>(cpyDest)->setAlignment(MaybeAlign(srcAlign)); + cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); } // Drop any cached information about the call, because we may have changed @@ -1127,15 +1097,16 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, /// B to be a memcpy from X to Z (or potentially a memmove, depending on /// circumstances). This allows later passes to remove the first memcpy /// altogether. -bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { +bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { // We can only optimize non-volatile memcpy's. if (M->isVolatile()) return false; // If the source and destination of the memcpy are the same, then zap it. if (M->getSource() == M->getDest()) { + ++BBI; MD->removeInstruction(M); M->eraseFromParent(); - return false; + return true; } // If copying from a constant, try to turn the memcpy into a memset. @@ -1176,10 +1147,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) { // FIXME: Can we pass in either of dest/src alignment here instead // of conservatively taking the minimum? - unsigned Align = MinAlign(M->getDestAlignment(), M->getSourceAlignment()); + Align Alignment = std::min(M->getDestAlign().valueOrOne(), + M->getSourceAlign().valueOrOne()); if (performCallSlotOptzn(M, M->getDest(), M->getSource(), - CopySize->getZExtValue(), Align, - C)) { + CopySize->getZExtValue(), Alignment, C)) { MD->removeInstruction(M); M->eraseFromParent(); return true; @@ -1247,15 +1218,15 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) { } /// This is called on every byval argument in call sites. -bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { - const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout(); +bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { + const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout(); // Find out what feeds this byval argument. - Value *ByValArg = CS.getArgument(ArgNo); + Value *ByValArg = CB.getArgOperand(ArgNo); Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType(); uint64_t ByValSize = DL.getTypeAllocSize(ByValTy); MemDepResult DepInfo = MD->getPointerDependencyFrom( MemoryLocation(ByValArg, LocationSize::precise(ByValSize)), true, - CS.getInstruction()->getIterator(), CS.getInstruction()->getParent()); + CB.getIterator(), CB.getParent()); if (!DepInfo.isClobber()) return false; @@ -1274,16 +1245,17 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { // Get the alignment of the byval. If the call doesn't specify the alignment, // then it is some target specific value that we can't know. - unsigned ByValAlign = CS.getParamAlignment(ArgNo); - if (ByValAlign == 0) return false; + MaybeAlign ByValAlign = CB.getParamAlign(ArgNo); + if (!ByValAlign) return false; // If it is greater than the memcpy, then we check to see if we can force the // source of the memcpy to the alignment we need. If we fail, we bail out. AssumptionCache &AC = LookupAssumptionCache(); DominatorTree &DT = LookupDomTree(); - if (MDep->getSourceAlignment() < ByValAlign && - getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, - CS.getInstruction(), &AC, &DT) < ByValAlign) + MaybeAlign MemDepAlign = MDep->getSourceAlign(); + if ((!MemDepAlign || *MemDepAlign < *ByValAlign) && + getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, &CB, &AC, + &DT) < *ByValAlign) return false; // The address space of the memcpy source must match the byval argument @@ -1302,21 +1274,25 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { // not just the defining memcpy. MemDepResult SourceDep = MD->getPointerDependencyFrom( MemoryLocation::getForSource(MDep), false, - CS.getInstruction()->getIterator(), MDep->getParent()); + CB.getIterator(), MDep->getParent()); if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) return false; Value *TmpCast = MDep->getSource(); - if (MDep->getSource()->getType() != ByValArg->getType()) - TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), - "tmpcast", CS.getInstruction()); + if (MDep->getSource()->getType() != ByValArg->getType()) { + BitCastInst *TmpBitCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), + "tmpcast", &CB); + // Set the tmpcast's DebugLoc to MDep's + TmpBitCast->setDebugLoc(MDep->getDebugLoc()); + TmpCast = TmpBitCast; + } LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" << " " << *MDep << "\n" - << " " << *CS.getInstruction() << "\n"); + << " " << CB << "\n"); // Otherwise we're good! Update the byval argument. - CS.setArgument(ArgNo, TmpCast); + CB.setArgOperand(ArgNo, TmpCast); ++NumMemCpyInstr; return true; } @@ -1347,13 +1323,13 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) { else if (MemSetInst *M = dyn_cast<MemSetInst>(I)) RepeatInstruction = processMemSet(M, BI); else if (MemCpyInst *M = dyn_cast<MemCpyInst>(I)) - RepeatInstruction = processMemCpy(M); + RepeatInstruction = processMemCpy(M, BI); else if (MemMoveInst *M = dyn_cast<MemMoveInst>(I)) RepeatInstruction = processMemMove(M); - else if (auto CS = CallSite(I)) { - for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) - if (CS.isByValArgument(i)) - MadeChange |= processByValArgument(CS, i); + else if (auto *CB = dyn_cast<CallBase>(I)) { + for (unsigned i = 0, e = CB->arg_size(); i != e; ++i) + if (CB->isByValArgument(i)) + MadeChange |= processByValArgument(*CB, i); } // Reprocess the instruction if desired. diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6b0d0202d9bb..69aa0cebe170 100644 --- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -354,15 +354,11 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) { // optimization opportunities. // This loop doesn't care about newly inserted/split blocks // since they never will be diamond heads. - for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) { - BasicBlock *BB = &*FI++; - + for (BasicBlock &BB : make_early_inc_range(F)) // Hoist equivalent loads and sink stores // outside diamonds when possible - if (isDiamondHead(BB)) { - Changed |= mergeStores(BB); - } - } + if (isDiamondHead(&BB)) + Changed |= mergeStores(&BB); return Changed; } diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index bba9082e31b2..4e010f8704d0 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -213,7 +213,7 @@ bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, return Changed; } -// Whitelist the instruction types NaryReassociate handles for now. +// Explicitly list the instruction types NaryReassociate handles for now. static bool isPotentiallyNaryReassociable(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 6a643480f312..0ed1773373a7 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -106,6 +106,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PredicateInfo.h" #include "llvm/Transforms/Utils/VNCoercion.h" @@ -495,6 +496,7 @@ class NewGVN { AliasAnalysis *AA = nullptr; MemorySSA *MSSA = nullptr; MemorySSAWalker *MSSAWalker = nullptr; + AssumptionCache *AC = nullptr; const DataLayout &DL; std::unique_ptr<PredicateInfo> PredInfo; @@ -658,7 +660,7 @@ public: NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, const DataLayout &DL) - : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), + : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), AC(AC), DL(DL), PredInfo(std::make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {} @@ -898,7 +900,7 @@ bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const { #ifndef NDEBUG static std::string getBlockName(const BasicBlock *B) { - return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr); + return DOTGraphTraits<DOTFuncInfo *>::getSimpleNodeLabel(B, nullptr); } #endif @@ -1334,8 +1336,6 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, // Give store and loads same opcode so they value number together. E->setOpcode(0); E->op_push_back(PointerOp); - if (LI) - E->setAlignment(MaybeAlign(LI->getAlignment())); // TODO: Value number heap versions. We may be able to discover // things alias analysis can't on it's own (IE that a store and a @@ -1470,7 +1470,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, // undef value. This can happen when loading for a fresh allocation with no // intervening stores, for example. Note that this is only true in the case // that the result of the allocation is pointer equal to the load ptr. - if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { + if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || + isAlignedAllocLikeFn(DepInst, TLI)) { return createConstantExpression(UndefValue::get(LoadType)); } // If this load occurs either right after a lifetime begin, @@ -2030,10 +2031,12 @@ NewGVN::performSymbolicEvaluation(Value *V, case Instruction::Select: case Instruction::ExtractElement: case Instruction::InsertElement: - case Instruction::ShuffleVector: case Instruction::GetElementPtr: E = createExpression(I); break; + case Instruction::ShuffleVector: + // FIXME: Add support for shufflevector to createExpression. + return nullptr; default: return nullptr; } @@ -3433,7 +3436,7 @@ bool NewGVN::runGVN() { // Sort dominator tree children arrays into RPO. for (auto &B : RPOT) { auto *Node = DT->getNode(B); - if (Node->getChildren().size() > 1) + if (Node->getNumChildren() > 1) llvm::sort(Node->begin(), Node->end(), [&](const DomTreeNode *A, const DomTreeNode *B) { return RPOOrdering[A] < RPOOrdering[B]; @@ -3693,6 +3696,7 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { Inst.replaceAllUsesWith(UndefValue::get(Inst.getType())); if (isa<LandingPadInst>(Inst)) continue; + salvageKnowledge(&Inst, AC); Inst.eraseFromParent(); ++NumGVNInstrDeleted; diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index 5c4a89977c38..4553b23532f2 100644 --- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -189,7 +189,8 @@ static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) { return false; } - return !(isStatepoint(Call) || isGCRelocate(Call) || isGCResult(Call)); + return !(isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || + isa<GCResultInst>(Call)); } /// Returns true if this loop is known to contain a call safepoint which @@ -650,7 +651,7 @@ InsertSafepointPoll(Instruction *InsertBefore, // Do the actual inlining InlineFunctionInfo IFI; - bool InlineStatus = InlineFunction(PollCall, IFI); + bool InlineStatus = InlineFunction(*PollCall, IFI).isSuccess(); assert(InlineStatus && "inline must succeed"); (void)InlineStatus; // suppress warning in release-asserts diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 41940e980faa..ba7f367267fe 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -254,15 +255,15 @@ static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name, } } -static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, - Instruction *InsertBefore, Value *FlagsOp) { +static Instruction *CreateNeg(Value *S1, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { if (S1->getType()->isIntOrIntVectorTy()) return BinaryOperator::CreateNeg(S1, Name, InsertBefore); - else { - BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore); - Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); - return Res; - } + + if (auto *FMFSource = dyn_cast<Instruction>(FlagsOp)) + return UnaryOperator::CreateFNegFMF(S1, FMFSource, Name, InsertBefore); + + return UnaryOperator::CreateFNeg(S1, Name, InsertBefore); } /// Replace 0-X with X*-1. @@ -914,7 +915,7 @@ static Value *NegateValue(Value *V, Instruction *BI, // Insert a 'neg' instruction that subtracts the value from zero to get the // negation. - BinaryOperator *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI); + Instruction *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI); ToRedo.insert(NewNeg); return NewNeg; } @@ -975,7 +976,8 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub, /// this into a multiply by a constant to assist with further reassociation. static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { Constant *MulCst = ConstantInt::get(Shl->getType(), 1); - MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1))); + auto *SA = cast<ConstantInt>(Shl->getOperand(1)); + MulCst = ConstantExpr::getShl(MulCst, SA); BinaryOperator *Mul = BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); @@ -988,10 +990,12 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { // We can safely preserve the nuw flag in all cases. It's also safe to turn a // nuw nsw shl into a nuw nsw mul. However, nsw in isolation requires special - // handling. + // handling. It can be preserved as long as we're not left shifting by + // bitwidth - 1. bool NSW = cast<BinaryOperator>(Shl)->hasNoSignedWrap(); bool NUW = cast<BinaryOperator>(Shl)->hasNoUnsignedWrap(); - if (NSW && NUW) + unsigned BitWidth = Shl->getType()->getIntegerBitWidth(); + if (NSW && (NUW || SA->getValue().ult(BitWidth - 1))) Mul->setHasNoSignedWrap(true); Mul->setHasNoUnsignedWrap(NUW); return Mul; @@ -1076,7 +1080,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { const APFloat &F1 = FC1->getValueAPF(); APFloat F2(FC2->getValueAPF()); F2.changeSign(); - if (F1.compare(F2) == APFloat::cmpEqual) { + if (F1 == F2) { FoundFactor = NeedsNegate = true; Factors.erase(Factors.begin() + i); break; @@ -1721,7 +1725,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, } /// Build a tree of multiplies, computing the product of Ops. -static Value *buildMultiplyTree(IRBuilder<> &Builder, +static Value *buildMultiplyTree(IRBuilderBase &Builder, SmallVectorImpl<Value*> &Ops) { if (Ops.size() == 1) return Ops.back(); @@ -1744,7 +1748,7 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder, /// DAG of multiplies to compute the final product, and return that product /// value. Value * -ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, +ReassociatePass::buildMinimalMultiplyDAG(IRBuilderBase &Builder, SmallVectorImpl<Factor> &Factors) { assert(Factors[0].Power); SmallVector<Value *, 4> OuterProduct; @@ -1899,7 +1903,7 @@ void ReassociatePass::RecursivelyEraseDeadInsts(Instruction *I, ValueRankMap.erase(I); Insts.remove(I); RedoInsts.remove(I); - llvm::salvageDebugInfoOrMarkUndef(*I); + llvm::salvageDebugInfo(*I); I->eraseFromParent(); for (auto Op : Ops) if (Instruction *OpInst = dyn_cast<Instruction>(Op)) @@ -1916,7 +1920,7 @@ void ReassociatePass::EraseInst(Instruction *I) { // Erase the dead instruction. ValueRankMap.erase(I); RedoInsts.remove(I); - llvm::salvageDebugInfoOrMarkUndef(*I); + llvm::salvageDebugInfo(*I); I->eraseFromParent(); // Optimize its operands. SmallPtrSet<Instruction *, 8> Visited; // Detect self-referential nodes. @@ -2457,6 +2461,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { if (MadeChange) { PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<BasicAA>(); PA.preserve<GlobalsAA>(); return PA; } @@ -2487,6 +2493,8 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } }; diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index b242f100faff..dc2ad14ae61e 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -271,7 +271,7 @@ struct PartiallyConstructedSafepointRecord { /// The *new* gc.statepoint instruction itself. This produces the token /// that normal path gc.relocates and the gc.result are tied to. - Instruction *StatepointToken; + GCStatepointInst *StatepointToken; /// Instruction to which exceptional gc relocates are attached /// Makes it easier to iterate through them during relocationViaAlloca. @@ -381,14 +381,19 @@ static void analyzeParsePointLiveness( dbgs() << " " << V->getName() << " " << *V << "\n"; } if (PrintLiveSetSize) { - dbgs() << "Safepoint For: " << Call->getCalledValue()->getName() << "\n"; + dbgs() << "Safepoint For: " << Call->getCalledOperand()->getName() << "\n"; dbgs() << "Number live values: " << LiveSet.size() << "\n"; } Result.LiveSet = LiveSet; } +// Returns true is V is a knownBaseResult. static bool isKnownBaseResult(Value *V); +// Returns true if V is a BaseResult that already exists in the IR, i.e. it is +// not created by the findBasePointers algorithm. +static bool isOriginalBaseResult(Value *V); + namespace { /// A single base defining value - An immediate base defining value for an @@ -633,15 +638,20 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) { return Def; } +/// This value is a base pointer that is not generated by RS4GC, i.e. it already +/// exists in the code. +static bool isOriginalBaseResult(Value *V) { + // no recursion possible + return !isa<PHINode>(V) && !isa<SelectInst>(V) && + !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && + !isa<ShuffleVectorInst>(V); +} + /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, /// is it known to be a base pointer? Or do we need to continue searching. static bool isKnownBaseResult(Value *V) { - if (!isa<PHINode>(V) && !isa<SelectInst>(V) && - !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && - !isa<ShuffleVectorInst>(V)) { - // no recursion possible + if (isOriginalBaseResult(V)) return true; - } if (isa<Instruction>(V) && cast<Instruction>(V)->getMetadata("is_base_value")) { // This is a previously inserted base phi or select. We know @@ -653,6 +663,12 @@ static bool isKnownBaseResult(Value *V) { return false; } +// Returns true if First and Second values are both scalar or both vector. +static bool areBothVectorOrScalar(Value *First, Value *Second) { + return isa<VectorType>(First->getType()) == + isa<VectorType>(Second->getType()); +} + namespace { /// Models the state of a single base defining value in the findBasePointer @@ -762,7 +778,7 @@ static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) { static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Value *Def = findBaseOrBDV(I, Cache); - if (isKnownBaseResult(Def)) + if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I)) return Def; // Here's the rough algorithm: @@ -810,13 +826,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { States.insert({Def, BDVState()}); while (!Worklist.empty()) { Value *Current = Worklist.pop_back_val(); - assert(!isKnownBaseResult(Current) && "why did it get added?"); + assert(!isOriginalBaseResult(Current) && "why did it get added?"); auto visitIncomingValue = [&](Value *InVal) { Value *Base = findBaseOrBDV(InVal, Cache); - if (isKnownBaseResult(Base)) + if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal)) // Known bases won't need new instructions introduced and can be - // ignored safely + // ignored safely. However, this can only be done when InVal and Base + // are both scalar or both vector. Otherwise, we need to find a + // correct BDV for InVal, by creating an entry in the lattice + // (States). return; assert(isExpectedBDVType(Base) && "the only non-base values " "we see should be base defining values"); @@ -853,10 +872,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Return a phi state for a base defining value. We'll generate a new // base state for known bases and expect to find a cached state otherwise. - auto getStateForBDV = [&](Value *baseValue) { - if (isKnownBaseResult(baseValue)) - return BDVState(baseValue); - auto I = States.find(baseValue); + auto GetStateForBDV = [&](Value *BaseValue, Value *Input) { + if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input)) + return BDVState(BaseValue); + auto I = States.find(BaseValue); assert(I != States.end() && "lookup failed!"); return I->second; }; @@ -873,13 +892,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // much faster. for (auto Pair : States) { Value *BDV = Pair.first; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || + !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) && + "why did it get added?"); // Given an input value for the current instruction, return a BDVState // instance which represents the BDV of that value. auto getStateForInput = [&](Value *V) mutable { Value *BDV = findBaseOrBDV(V, Cache); - return getStateForBDV(BDV); + return GetStateForBDV(BDV, V); }; BDVState NewState; @@ -926,20 +950,26 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { } #endif - // Insert Phis for all conflicts - // TODO: adjust naming patterns to avoid this order of iteration dependency + // Handle all instructions that have a vector BDV, but the instruction itself + // is of scalar type. for (auto Pair : States) { Instruction *I = cast<Instruction>(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(I) && "why did it get added?"); + auto *BaseValue = State.getBaseValue(); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + if (!State.isBase() || !isa<VectorType>(BaseValue->getType())) + continue; // extractelement instructions are a bit special in that we may need to // insert an extract even when we know an exact base for the instruction. // The problem is that we need to convert from a vector base to a scalar // base for the particular indice we're interested in. - if (State.isBase() && isa<ExtractElementInst>(I) && - isa<VectorType>(State.getBaseValue()->getType())) { + if (isa<ExtractElementInst>(I)) { auto *EE = cast<ExtractElementInst>(I); // TODO: In many cases, the new instruction is just EE itself. We should // exploit this, but can't do it here since it would break the invariant @@ -948,7 +978,27 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); States[I] = BDVState(BDVState::Base, BaseInst); + } else if (!isa<VectorType>(I->getType())) { + // We need to handle cases that have a vector base but the instruction is + // a scalar type (these could be phis or selects or any instruction that + // are of scalar type, but the base can be a vector type). We + // conservatively set this as conflict. Setting the base value for these + // conflicts is handled in the next loop which traverses States. + States[I] = BDVState(BDVState::Conflict); } + } + + // Insert Phis for all conflicts + // TODO: adjust naming patterns to avoid this order of iteration dependency + for (auto Pair : States) { + Instruction *I = cast<Instruction>(Pair.first); + BDVState State = Pair.second; + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) && + "why did it get added?"); + assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); // Since we're joining a vector and scalar base, they can never be the // same. As a result, we should always see insert element having reached @@ -987,7 +1037,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto *SV = cast<ShuffleVectorInst>(I); UndefValue *VecUndef = UndefValue::get(SV->getOperand(0)->getType()); std::string Name = suffixed_name_or(I, ".base", "base_sv"); - return new ShuffleVectorInst(VecUndef, VecUndef, SV->getOperand(2), + return new ShuffleVectorInst(VecUndef, VecUndef, SV->getShuffleMask(), Name, SV); } }; @@ -1008,7 +1058,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { Value *BDV = findBaseOrBDV(Input, Cache); Value *Base = nullptr; - if (isKnownBaseResult(BDV)) { + if (isKnownBaseResult(BDV) && areBothVectorOrScalar(BDV, Input)) { Base = BDV; } else { // Either conflict or base. @@ -1029,7 +1079,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Instruction *BDV = cast<Instruction>(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || + !areBothVectorOrScalar(BDV, State.getBaseValue())) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); if (!State.isConflict()) continue; @@ -1119,7 +1174,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto *BDV = Pair.first; Value *Base = Pair.second.getBaseValue(); assert(BDV && Base); - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) && + "why did it get added?"); LLVM_DEBUG( dbgs() << "Updating base value cache" @@ -1238,7 +1297,8 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent, // Create new attribute set containing only attributes which can be transferred // from original call to the safepoint. -static AttributeList legalizeCallAttributes(AttributeList AL) { +static AttributeList legalizeCallAttributes(LLVMContext &Ctx, + AttributeList AL) { if (AL.isEmpty()) return AL; @@ -1252,7 +1312,6 @@ static AttributeList legalizeCallAttributes(AttributeList AL) { } // Just skip parameter and return attributes for now - LLVMContext &Ctx = AL.getContext(); return AttributeList::get(Ctx, AttributeList::FunctionIndex, AttributeSet::get(Ctx, FnAttrs)); } @@ -1261,16 +1320,14 @@ static AttributeList legalizeCallAttributes(AttributeList AL) { /// statepoint. /// Inputs: /// liveVariables - list of variables to be relocated. -/// liveStart - index of the first live variable. /// basePtrs - base pointers. /// statepointToken - statepoint instruction to which relocates should be /// bound. /// Builder - Llvm IR builder to be used to construct new calls. static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, - const int LiveStart, ArrayRef<Value *> BasePtrs, Instruction *StatepointToken, - IRBuilder<> Builder) { + IRBuilder<> &Builder) { if (LiveVariables.empty()) return; @@ -1295,7 +1352,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, auto AS = Ty->getScalarType()->getPointerAddressSpace(); Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS); if (auto *VT = dyn_cast<VectorType>(Ty)) - NewTy = VectorType::get(NewTy, VT->getNumElements()); + NewTy = FixedVectorType::get(NewTy, + cast<FixedVectorType>(VT)->getNumElements()); return Intrinsic::getDeclaration(M, Intrinsic::experimental_gc_relocate, {NewTy}); }; @@ -1307,9 +1365,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, for (unsigned i = 0; i < LiveVariables.size(); i++) { // Generate the gc.relocate call and save the result - Value *BaseIdx = - Builder.getInt32(LiveStart + FindIndex(LiveVariables, BasePtrs[i])); - Value *LiveIdx = Builder.getInt32(LiveStart + i); + Value *BaseIdx = Builder.getInt32(FindIndex(LiveVariables, BasePtrs[i])); + Value *LiveIdx = Builder.getInt32(i); Type *Ty = LiveVariables[i]->getType(); if (!TypeToDeclMap.count(Ty)) @@ -1431,12 +1488,14 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ uint32_t Flags = uint32_t(StatepointFlags::None); ArrayRef<Use> CallArgs(Call->arg_begin(), Call->arg_end()); - ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(Call); - ArrayRef<Use> TransitionArgs; - if (auto TransitionBundle = - Call->getOperandBundle(LLVMContext::OB_gc_transition)) { + Optional<ArrayRef<Use>> DeoptArgs; + if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_deopt)) + DeoptArgs = Bundle->Inputs; + Optional<ArrayRef<Use>> TransitionArgs; + if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_gc_transition)) { + TransitionArgs = Bundle->Inputs; + // TODO: This flag no longer serves a purpose and can be removed later Flags |= uint32_t(StatepointFlags::GCTransition); - TransitionArgs = TransitionBundle->Inputs; } // Instead of lowering calls to @llvm.experimental.deoptimize as normal calls @@ -1459,7 +1518,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ assert(DeoptLowering.equals("live-through") && "Unsupported value!"); } - Value *CallTarget = Call->getCalledValue(); + Value *CallTarget = Call->getCalledOperand(); if (Function *F = dyn_cast<Function>(CallTarget)) { if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) { // Calls to llvm.experimental.deoptimize are lowered to calls to the @@ -1485,7 +1544,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ } // Create the statepoint given all the arguments - Instruction *Token = nullptr; + GCStatepointInst *Token = nullptr; if (auto *CI = dyn_cast<CallInst>(Call)) { CallInst *SPCall = Builder.CreateGCStatepointCall( StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs, @@ -1498,9 +1557,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPCall->setAttributes(legalizeCallAttributes(CI->getAttributes())); + SPCall->setAttributes( + legalizeCallAttributes(CI->getContext(), CI->getAttributes())); - Token = SPCall; + Token = cast<GCStatepointInst>(SPCall); // Put the following gc_result and gc_relocate calls immediately after the // the old call (which we're about to delete) @@ -1524,9 +1584,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPInvoke->setAttributes(legalizeCallAttributes(II->getAttributes())); + SPInvoke->setAttributes( + legalizeCallAttributes(II->getContext(), II->getAttributes())); - Token = SPInvoke; + Token = cast<GCStatepointInst>(SPInvoke); // Generate gc relocates in exceptional path BasicBlock *UnwindBlock = II->getUnwindDest(); @@ -1541,9 +1602,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst(); Result.UnwindToken = ExceptionalToken; - const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); - CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, ExceptionalToken, - Builder); + CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder); // Generate gc relocates and returns for normal block BasicBlock *NormalDest = II->getNormalDest(); @@ -1589,8 +1648,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ Result.StatepointToken = Token; // Second, create a gc.relocate for every live variable - const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); - CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, Token, Builder); + CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder); } // Replace an existing gc.statepoint with a new one and a set of gc.relocates @@ -1651,8 +1709,8 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, cast<AllocaInst>(Alloca)->getAllocatedType(), suffixed_name_or(Relocate, ".casted", "")); - StoreInst *Store = new StoreInst(CastedRelocatedValue, Alloca); - Store->insertAfter(cast<Instruction>(CastedRelocatedValue)); + new StoreInst(CastedRelocatedValue, Alloca, + cast<Instruction>(CastedRelocatedValue)->getNextNode()); #ifndef NDEBUG VisitedLiveValues.insert(OriginalValue); @@ -1674,8 +1732,8 @@ static void insertRematerializationStores( "Can not find alloca for rematerialized value"); Value *Alloca = AllocaMap[OriginalValue]; - StoreInst *Store = new StoreInst(RematerializedValue, Alloca); - Store->insertAfter(RematerializedValue); + new StoreInst(RematerializedValue, Alloca, + RematerializedValue->getNextNode()); #ifndef NDEBUG VisitedLiveValues.insert(OriginalValue); @@ -1780,8 +1838,7 @@ static void relocationViaAlloca( for (auto *AI : ToClobber) { auto PT = cast<PointerType>(AI->getAllocatedType()); Constant *CPN = ConstantPointerNull::get(PT); - StoreInst *Store = new StoreInst(CPN, AI); - Store->insertBefore(IP); + new StoreInst(CPN, AI, IP); } }; @@ -1843,7 +1900,8 @@ static void relocationViaAlloca( // Emit store for the initial gc value. Store must be inserted after load, // otherwise store will be in alloca's use list and an extra load will be // inserted before it. - StoreInst *Store = new StoreInst(Def, Alloca); + StoreInst *Store = new StoreInst(Def, Alloca, /*volatile*/ false, + DL.getABITypeAlign(Def->getType())); if (Instruction *Inst = dyn_cast<Instruction>(Def)) { if (InvokeInst *Invoke = dyn_cast<InvokeInst>(Inst)) { // InvokeInst is a terminator so the store need to be inserted into its @@ -1966,7 +2024,9 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, "non noop cast is found during rematerialization"); Type *SrcTy = CI->getOperand(0)->getType(); - Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI); + Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, + TargetTransformInfo::TCK_SizeAndLatency, + CI); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) { // Cost of the address calculation @@ -2344,9 +2404,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // That Value* no longer exists and we need to use the new gc_result. // Thankfully, the live set is embedded in the statepoint (and updated), so // we just grab that. - Statepoint Statepoint(Info.StatepointToken); - Live.insert(Live.end(), Statepoint.gc_args_begin(), - Statepoint.gc_args_end()); + Live.insert(Live.end(), Info.StatepointToken->gc_args_begin(), + Info.StatepointToken->gc_args_end()); #ifndef NDEBUG // Do some basic sanity checks on our liveness results before performing // relocation. Relocation can and will turn mistakes in liveness results @@ -2354,7 +2413,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // TODO: It would be nice to test consistency as well assert(DT.isReachableFromEntry(Info.StatepointToken->getParent()) && "statepoint must be reachable or liveness is meaningless"); - for (Value *V : Statepoint.gc_args()) { + for (Value *V : Info.StatepointToken->gc_args()) { if (!isa<Instruction>(V)) // Non-instruction values trivial dominate all possible uses continue; @@ -2523,7 +2582,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, auto NeedsRewrite = [&TLI](Instruction &I) { if (const auto *Call = dyn_cast<CallBase>(&I)) - return !callsGCLeafFunction(Call, TLI) && !isStatepoint(Call); + return !callsGCLeafFunction(Call, TLI) && !isa<GCStatepointInst>(Call); return false; }; @@ -2608,10 +2667,10 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, unsigned VF = 0; for (unsigned i = 0; i < I.getNumOperands(); i++) - if (I.getOperand(i)->getType()->isVectorTy()) { + if (auto *OpndVTy = dyn_cast<VectorType>(I.getOperand(i)->getType())) { assert(VF == 0 || - VF == I.getOperand(i)->getType()->getVectorNumElements()); - VF = I.getOperand(i)->getType()->getVectorNumElements(); + VF == cast<FixedVectorType>(OpndVTy)->getNumElements()); + VF = cast<FixedVectorType>(OpndVTy)->getNumElements(); } // It's the vector to scalar traversal through the pointer operand which diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp index e696ea83a300..5ebd3b71fe78 100644 --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -27,12 +27,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -67,123 +68,44 @@ using namespace llvm; STATISTIC(NumInstRemoved, "Number of instructions removed"); STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); +STATISTIC(NumInstReplaced, + "Number of instructions replaced with (simpler) instruction"); STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); - +STATISTIC( + IPNumInstReplaced, + "Number of instructions replaced with (simpler) instruction by IPSCCP"); + +// The maximum number of range extensions allowed for operations requiring +// widening. +static const unsigned MaxNumRangeExtensions = 10; + +/// Returns MergeOptions with MaxWidenSteps set to MaxNumRangeExtensions. +static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { + return ValueLatticeElement::MergeOptions().setMaxWidenSteps( + MaxNumRangeExtensions); +} namespace { -/// LatticeVal class - This class represents the different lattice values that -/// an LLVM value may occupy. It is a simple class with value semantics. -/// -class LatticeVal { - enum LatticeValueTy { - /// unknown - This LLVM Value has no known value yet. - unknown, - - /// constant - This LLVM Value has a specific constant value. - constant, - - /// forcedconstant - This LLVM Value was thought to be undef until - /// ResolvedUndefsIn. This is treated just like 'constant', but if merged - /// with another (different) constant, it goes to overdefined, instead of - /// asserting. - forcedconstant, - - /// overdefined - This instruction is not known to be constant, and we know - /// it has a value. - overdefined - }; - - /// Val: This stores the current lattice value along with the Constant* for - /// the constant if this is a 'constant' or 'forcedconstant' value. - PointerIntPair<Constant *, 2, LatticeValueTy> Val; - - LatticeValueTy getLatticeValue() const { - return Val.getInt(); - } - -public: - LatticeVal() : Val(nullptr, unknown) {} - - bool isUnknown() const { return getLatticeValue() == unknown; } - - bool isConstant() const { - return getLatticeValue() == constant || getLatticeValue() == forcedconstant; - } - - bool isOverdefined() const { return getLatticeValue() == overdefined; } - - Constant *getConstant() const { - assert(isConstant() && "Cannot get the constant of a non-constant!"); - return Val.getPointer(); - } - - /// markOverdefined - Return true if this is a change in status. - bool markOverdefined() { - if (isOverdefined()) - return false; - - Val.setInt(overdefined); - return true; - } - - /// markConstant - Return true if this is a change in status. - bool markConstant(Constant *V) { - if (getLatticeValue() == constant) { // Constant but not forcedconstant. - assert(getConstant() == V && "Marking constant with different value"); - return false; - } - - if (isUnknown()) { - Val.setInt(constant); - assert(V && "Marking constant with NULL"); - Val.setPointer(V); - } else { - assert(getLatticeValue() == forcedconstant && - "Cannot move from overdefined to constant!"); - // Stay at forcedconstant if the constant is the same. - if (V == getConstant()) return false; - - // Otherwise, we go to overdefined. Assumptions made based on the - // forced value are possibly wrong. Assuming this is another constant - // could expose a contradiction. - Val.setInt(overdefined); - } - return true; - } - - /// getConstantInt - If this is a constant with a ConstantInt value, return it - /// otherwise return null. - ConstantInt *getConstantInt() const { - if (isConstant()) - return dyn_cast<ConstantInt>(getConstant()); - return nullptr; - } - - /// getBlockAddress - If this is a constant with a BlockAddress value, return - /// it, otherwise return null. - BlockAddress *getBlockAddress() const { - if (isConstant()) - return dyn_cast<BlockAddress>(getConstant()); - return nullptr; - } - - void markForcedConstant(Constant *V) { - assert(isUnknown() && "Can't force a defined value!"); - Val.setInt(forcedconstant); - Val.setPointer(V); - } +// Helper to check if \p LV is either a constant or a constant +// range with a single element. This should cover exactly the same cases as the +// old ValueLatticeElement::isConstant() and is intended to be used in the +// transition to ValueLatticeElement. +bool isConstant(const ValueLatticeElement &LV) { + return LV.isConstant() || + (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); +} - ValueLatticeElement toValueLattice() const { - if (isOverdefined()) - return ValueLatticeElement::getOverdefined(); - if (isConstant()) - return ValueLatticeElement::get(getConstant()); - return ValueLatticeElement(); - } -}; +// Helper to check if \p LV is either overdefined or a constant range with more +// than a single element. This should cover exactly the same cases as the old +// ValueLatticeElement::isOverdefined() and is intended to be used in the +// transition to ValueLatticeElement. +bool isOverdefined(const ValueLatticeElement &LV) { + return LV.isOverdefined() || + (LV.isConstantRange() && !LV.getConstantRange().isSingleElement()); +} //===----------------------------------------------------------------------===// // @@ -194,28 +116,28 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { const DataLayout &DL; std::function<const TargetLibraryInfo &(Function &)> GetTLI; SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable. - DenseMap<Value *, LatticeVal> ValueState; // The state each value is in. - // The state each parameter is in. - DenseMap<Value *, ValueLatticeElement> ParamState; + DenseMap<Value *, ValueLatticeElement> + ValueState; // The state each value is in. /// StructValueState - This maintains ValueState for values that have /// StructType, for example for formal arguments, calls, insertelement, etc. - DenseMap<std::pair<Value *, unsigned>, LatticeVal> StructValueState; + DenseMap<std::pair<Value *, unsigned>, ValueLatticeElement> StructValueState; /// GlobalValue - If we are tracking any values for the contents of a global /// variable, we keep a mapping from the constant accessor to the element of /// the global, to the currently known value. If the value becomes /// overdefined, it's entry is simply removed from this map. - DenseMap<GlobalVariable *, LatticeVal> TrackedGlobals; + DenseMap<GlobalVariable *, ValueLatticeElement> TrackedGlobals; /// TrackedRetVals - If we are tracking arguments into and the return /// value out of a function, it will have an entry in this map, indicating /// what the known return value for the function is. - MapVector<Function *, LatticeVal> TrackedRetVals; + MapVector<Function *, ValueLatticeElement> TrackedRetVals; /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions /// that return multiple values. - MapVector<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; + MapVector<std::pair<Function *, unsigned>, ValueLatticeElement> + TrackedMultipleRetVals; /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. @@ -251,6 +173,8 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { DenseMap<Function *, AnalysisResultsForFn> AnalysisResults; DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers; + LLVMContext &Ctx; + public: void addAnalysis(Function &F, AnalysisResultsForFn A) { AnalysisResults.insert({&F, std::move(A)}); @@ -270,8 +194,9 @@ public: } SCCPSolver(const DataLayout &DL, - std::function<const TargetLibraryInfo &(Function &)> GetTLI) - : DL(DL), GetTLI(std::move(GetTLI)) {} + std::function<const TargetLibraryInfo &(Function &)> GetTLI, + LLVMContext &Ctx) + : DL(DL), GetTLI(std::move(GetTLI)), Ctx(Ctx) {} /// MarkBlockExecutable - This method can be used by clients to mark all of /// the blocks that are known to be intrinsically live in the processed unit. @@ -292,7 +217,7 @@ public: void TrackValueOfGlobalVariable(GlobalVariable *GV) { // We only track the contents of scalar globals. if (GV->getValueType()->isSingleValueType()) { - LatticeVal &IV = TrackedGlobals[GV]; + ValueLatticeElement &IV = TrackedGlobals[GV]; if (!isa<UndefValue>(GV->getInitializer())) IV.markConstant(GV->getInitializer()); } @@ -306,10 +231,10 @@ public: if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { MRVFunctionsTracked.insert(F); for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) - TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i), - LatticeVal())); + TrackedMultipleRetVals.insert( + std::make_pair(std::make_pair(F, i), ValueLatticeElement())); } else - TrackedRetVals.insert(std::make_pair(F, LatticeVal())); + TrackedRetVals.insert(std::make_pair(F, ValueLatticeElement())); } /// AddMustTailCallee - If the SCCP solver finds that this function is called @@ -352,8 +277,8 @@ public: // block to the 'To' basic block is currently feasible. bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); - std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { - std::vector<LatticeVal> StructValues; + std::vector<ValueLatticeElement> getStructLatticeValueFor(Value *V) const { + std::vector<ValueLatticeElement> StructValues; auto *STy = dyn_cast<StructType>(V->getType()); assert(STy && "getStructLatticeValueFor() can be called only on structs"); for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { @@ -364,23 +289,26 @@ public: return StructValues; } - const LatticeVal &getLatticeValueFor(Value *V) const { + void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); - DenseMap<Value *, LatticeVal>::const_iterator I = ValueState.find(V); + DenseMap<Value *, ValueLatticeElement>::const_iterator I = + ValueState.find(V); assert(I != ValueState.end() && "V not found in ValueState nor Paramstate map!"); return I->second; } /// getTrackedRetVals - Get the inferred return value map. - const MapVector<Function*, LatticeVal> &getTrackedRetVals() { + const MapVector<Function *, ValueLatticeElement> &getTrackedRetVals() { return TrackedRetVals; } /// getTrackedGlobals - Get and return the set of inferred initializers for /// global variables. - const DenseMap<GlobalVariable*, LatticeVal> &getTrackedGlobals() { + const DenseMap<GlobalVariable *, ValueLatticeElement> &getTrackedGlobals() { return TrackedGlobals; } @@ -407,32 +335,59 @@ public: } // isStructLatticeConstant - Return true if all the lattice values - // corresponding to elements of the structure are not overdefined, + // corresponding to elements of the structure are constants, // false otherwise. bool isStructLatticeConstant(Function *F, StructType *STy) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); assert(It != TrackedMultipleRetVals.end()); - LatticeVal LV = It->second; - if (LV.isOverdefined()) + ValueLatticeElement LV = It->second; + if (!isConstant(LV)) return false; } return true; } + /// Helper to return a Constant if \p LV is either a constant or a constant + /// range with a single element. + Constant *getConstant(const ValueLatticeElement &LV) const { + if (LV.isConstant()) + return LV.getConstant(); + + if (LV.isConstantRange()) { + auto &CR = LV.getConstantRange(); + if (CR.getSingleElement()) + return ConstantInt::get(Ctx, *CR.getSingleElement()); + } + return nullptr; + } + private: - // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined - void pushToWorkList(LatticeVal &IV, Value *V) { + ConstantInt *getConstantInt(const ValueLatticeElement &IV) const { + return dyn_cast_or_null<ConstantInt>(getConstant(IV)); + } + + // pushToWorkList - Helper for markConstant/markOverdefined + void pushToWorkList(ValueLatticeElement &IV, Value *V) { if (IV.isOverdefined()) return OverdefinedInstWorkList.push_back(V); InstWorkList.push_back(V); } + // Helper to push \p V to the worklist, after updating it to \p IV. Also + // prints a debug message with the updated value. + void pushToWorkListMsg(ValueLatticeElement &IV, Value *V) { + LLVM_DEBUG(dbgs() << "updated " << IV << ": " << *V << '\n'); + pushToWorkList(IV, V); + } + // markConstant - Make a value be marked as "constant". If the value // is not already a constant, add it to the instruction work list so that // the users of the instruction are updated later. - bool markConstant(LatticeVal &IV, Value *V, Constant *C) { - if (!IV.markConstant(C)) return false; + bool markConstant(ValueLatticeElement &IV, Value *V, Constant *C, + bool MayIncludeUndef = false) { + if (!IV.markConstant(C, MayIncludeUndef)) + return false; LLVM_DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); pushToWorkList(IV, V); return true; @@ -443,18 +398,10 @@ private: return markConstant(ValueState[V], V, C); } - void markForcedConstant(Value *V, Constant *C) { - assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); - LatticeVal &IV = ValueState[V]; - IV.markForcedConstant(C); - LLVM_DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); - pushToWorkList(IV, V); - } - // markOverdefined - Make a value be marked as "overdefined". If the // value is not already overdefined, add it to the overdefined instruction // work list so that the users of the instruction are updated later. - bool markOverdefined(LatticeVal &IV, Value *V) { + bool markOverdefined(ValueLatticeElement &IV, Value *V) { if (!IV.markOverdefined()) return false; LLVM_DEBUG(dbgs() << "markOverdefined: "; @@ -466,71 +413,59 @@ private: return true; } - bool mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { - if (IV.isOverdefined() || MergeWithV.isUnknown()) - return false; // Noop. - if (MergeWithV.isOverdefined()) - return markOverdefined(IV, V); - if (IV.isUnknown()) - return markConstant(IV, V, MergeWithV.getConstant()); - if (IV.getConstant() != MergeWithV.getConstant()) - return markOverdefined(IV, V); + /// Merge \p MergeWithV into \p IV and push \p V to the worklist, if \p IV + /// changes. + bool mergeInValue(ValueLatticeElement &IV, Value *V, + ValueLatticeElement MergeWithV, + ValueLatticeElement::MergeOptions Opts = { + /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) { + if (IV.mergeIn(MergeWithV, Opts)) { + pushToWorkList(IV, V); + LLVM_DEBUG(dbgs() << "Merged " << MergeWithV << " into " << *V << " : " + << IV << "\n"); + return true; + } return false; } - bool mergeInValue(Value *V, LatticeVal MergeWithV) { + bool mergeInValue(Value *V, ValueLatticeElement MergeWithV, + ValueLatticeElement::MergeOptions Opts = { + /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) { assert(!V->getType()->isStructTy() && "non-structs should use markConstant"); - return mergeInValue(ValueState[V], V, MergeWithV); + return mergeInValue(ValueState[V], V, MergeWithV, Opts); } - /// getValueState - Return the LatticeVal object that corresponds to the - /// value. This function handles the case when the value hasn't been seen yet - /// by properly seeding constants etc. - LatticeVal &getValueState(Value *V) { + /// getValueState - Return the ValueLatticeElement object that corresponds to + /// the value. This function handles the case when the value hasn't been seen + /// yet by properly seeding constants etc. + ValueLatticeElement &getValueState(Value *V) { assert(!V->getType()->isStructTy() && "Should use getStructValueState"); - std::pair<DenseMap<Value*, LatticeVal>::iterator, bool> I = - ValueState.insert(std::make_pair(V, LatticeVal())); - LatticeVal &LV = I.first->second; + auto I = ValueState.insert(std::make_pair(V, ValueLatticeElement())); + ValueLatticeElement &LV = I.first->second; if (!I.second) return LV; // Common case, already in the map. - if (auto *C = dyn_cast<Constant>(V)) { - // Undef values remain unknown. - if (!isa<UndefValue>(V)) - LV.markConstant(C); // Constants are constant - } - - // All others are underdefined by default. - return LV; - } - - ValueLatticeElement &getParamState(Value *V) { - assert(!V->getType()->isStructTy() && "Should use getStructValueState"); - - std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> - PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); - ValueLatticeElement &LV = PI.first->second; - if (PI.second) - LV = getValueState(V).toValueLattice(); + if (auto *C = dyn_cast<Constant>(V)) + LV.markConstant(C); // Constants are constant + // All others are unknown by default. return LV; } - /// getStructValueState - Return the LatticeVal object that corresponds to the - /// value/field pair. This function handles the case when the value hasn't - /// been seen yet by properly seeding constants etc. - LatticeVal &getStructValueState(Value *V, unsigned i) { + /// getStructValueState - Return the ValueLatticeElement object that + /// corresponds to the value/field pair. This function handles the case when + /// the value hasn't been seen yet by properly seeding constants etc. + ValueLatticeElement &getStructValueState(Value *V, unsigned i) { assert(V->getType()->isStructTy() && "Should use getValueState"); assert(i < cast<StructType>(V->getType())->getNumElements() && "Invalid element #"); - std::pair<DenseMap<std::pair<Value*, unsigned>, LatticeVal>::iterator, - bool> I = StructValueState.insert( - std::make_pair(std::make_pair(V, i), LatticeVal())); - LatticeVal &LV = I.first->second; + auto I = StructValueState.insert( + std::make_pair(std::make_pair(V, i), ValueLatticeElement())); + ValueLatticeElement &LV = I.first->second; if (!I.second) return LV; // Common case, already in the map. @@ -589,9 +524,20 @@ private: // Mark I's users as changed, including AdditionalUsers. void markUsersAsChanged(Value *I) { - for (User *U : I->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - OperandChangedState(UI); + // Functions include their arguments in the use-list. Changed function + // values mean that the result of the function changed. We only need to + // update the call sites with the new function result and do not have to + // propagate the call arguments. + if (isa<Function>(I)) { + for (User *U : I->users()) { + if (auto *CB = dyn_cast<CallBase>(U)) + handleCallResult(*CB); + } + } else { + for (User *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + OperandChangedState(UI); + } auto Iter = AdditionalUsers.find(I); if (Iter != AdditionalUsers.end()) { @@ -600,6 +546,9 @@ private: OperandChangedState(UI); } } + void handleCallOverdefined(CallBase &CB); + void handleCallResult(CallBase &CB); + void handleCallArguments(CallBase &CB); private: friend class InstVisitor<SCCPSolver>; @@ -634,20 +583,20 @@ private: void visitGetElementPtrInst(GetElementPtrInst &I); void visitCallInst (CallInst &I) { - visitCallSite(&I); + visitCallBase(I); } void visitInvokeInst (InvokeInst &II) { - visitCallSite(&II); + visitCallBase(II); visitTerminator(II); } void visitCallBrInst (CallBrInst &CBI) { - visitCallSite(&CBI); + visitCallBase(CBI); visitTerminator(CBI); } - void visitCallSite (CallSite CS); + void visitCallBase (CallBase &CB); void visitResumeInst (ResumeInst &I) { /*returns void*/ } void visitUnreachableInst(UnreachableInst &I) { /*returns void*/ } void visitFenceInst (FenceInst &I) { /*returns void*/ } @@ -673,12 +622,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI, return; } - LatticeVal BCValue = getValueState(BI->getCondition()); - ConstantInt *CI = BCValue.getConstantInt(); + ValueLatticeElement BCValue = getValueState(BI->getCondition()); + ConstantInt *CI = getConstantInt(BCValue); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant // conditions, mean the branch could go either way. - if (!BCValue.isUnknown()) + if (!BCValue.isUnknownOrUndef()) Succs[0] = Succs[1] = true; return; } @@ -699,12 +648,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI, Succs[0] = true; return; } - LatticeVal SCValue = getValueState(SI->getCondition()); - ConstantInt *CI = SCValue.getConstantInt(); + ValueLatticeElement SCValue = getValueState(SI->getCondition()); + ConstantInt *CI = getConstantInt(SCValue); if (!CI) { // Overdefined or unknown condition? // All destinations are executable! - if (!SCValue.isUnknown()) + if (!SCValue.isUnknownOrUndef()) Succs.assign(TI.getNumSuccessors(), true); return; } @@ -717,11 +666,11 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI, // the target as executable. if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { // Casts are folded by visitCastInst. - LatticeVal IBRValue = getValueState(IBR->getAddress()); - BlockAddress *Addr = IBRValue.getBlockAddress(); + ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); + BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue)); if (!Addr) { // Overdefined or unknown condition? // All destinations are executable! - if (!IBRValue.isUnknown()) + if (!IBRValue.isUnknownOrUndef()) Succs.assign(TI.getNumSuccessors(), true); return; } @@ -786,50 +735,43 @@ void SCCPSolver::visitPHINode(PHINode &PN) { return (void)markOverdefined(&PN); if (getValueState(&PN).isOverdefined()) - return; // Quick exit + return; // Quick exit // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) return (void)markOverdefined(&PN); + unsigned NumActiveIncoming = 0; + // Look at all of the executable operands of the PHI node. If any of them // are overdefined, the PHI becomes overdefined as well. If they are all // constant, and they agree with each other, the PHI becomes the identical - // constant. If they are constant and don't agree, the PHI is overdefined. - // If there are no executable operands, the PHI remains unknown. - Constant *OperandVal = nullptr; + // constant. If they are constant and don't agree, the PHI is a constant + // range. If there are no executable operands, the PHI remains unknown. + ValueLatticeElement PhiState = getValueState(&PN); for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { - LatticeVal IV = getValueState(PN.getIncomingValue(i)); - if (IV.isUnknown()) continue; // Doesn't influence PHI node. - if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) continue; - if (IV.isOverdefined()) // PHI node becomes overdefined! - return (void)markOverdefined(&PN); - - if (!OperandVal) { // Grab the first value. - OperandVal = IV.getConstant(); - continue; - } - - // There is already a reachable operand. If we conflict with it, - // then the PHI node becomes overdefined. If we agree with it, we - // can continue on. - - // Check to see if there are two different constants merging, if so, the PHI - // node is overdefined. - if (IV.getConstant() != OperandVal) - return (void)markOverdefined(&PN); - } - - // If we exited the loop, this means that the PHI node only has constant - // arguments that agree with each other(and OperandVal is the constant) or - // OperandVal is null because there are no defined incoming arguments. If - // this is the case, the PHI remains unknown. - if (OperandVal) - markConstant(&PN, OperandVal); // Acquire operand value + ValueLatticeElement IV = getValueState(PN.getIncomingValue(i)); + PhiState.mergeIn(IV); + NumActiveIncoming++; + if (PhiState.isOverdefined()) + break; + } + + // We allow up to 1 range extension per active incoming value and one + // additional extension. Note that we manually adjust the number of range + // extensions to match the number of active incoming values. This helps to + // limit multiple extensions caused by the same incoming value, if other + // incoming values are equal. + mergeInValue(&PN, PhiState, + ValueLatticeElement::MergeOptions().setMaxWidenSteps( + NumActiveIncoming + 1)); + ValueLatticeElement &PhiStateRef = getValueState(&PN); + PhiStateRef.setNumRangeExtensions( + std::max(NumActiveIncoming, PhiStateRef.getNumRangeExtensions())); } void SCCPSolver::visitReturnInst(ReturnInst &I) { @@ -840,8 +782,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) { // If we are tracking the return value of this function, merge it in. if (!TrackedRetVals.empty() && !ResultOp->getType()->isStructTy()) { - MapVector<Function*, LatticeVal>::iterator TFRVI = - TrackedRetVals.find(F); + auto TFRVI = TrackedRetVals.find(F); if (TFRVI != TrackedRetVals.end()) { mergeInValue(TFRVI->second, F, getValueState(ResultOp)); return; @@ -871,18 +812,28 @@ void SCCPSolver::visitTerminator(Instruction &TI) { } void SCCPSolver::visitCastInst(CastInst &I) { - LatticeVal OpSt = getValueState(I.getOperand(0)); - if (OpSt.isOverdefined()) // Inherit overdefinedness of operand - markOverdefined(&I); - else if (OpSt.isConstant()) { + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (ValueState[&I].isOverdefined()) + return; + + ValueLatticeElement OpSt = getValueState(I.getOperand(0)); + if (Constant *OpC = getConstant(OpSt)) { // Fold the constant as we build. - Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpSt.getConstant(), - I.getType(), DL); + Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); if (isa<UndefValue>(C)) return; // Propagate constant value markConstant(&I, C); - } + } else if (OpSt.isConstantRange() && I.getDestTy()->isIntegerTy()) { + auto &LV = getValueState(&I); + ConstantRange OpRange = OpSt.getConstantRange(); + Type *DestTy = I.getDestTy(); + ConstantRange Res = + OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy)); + mergeInValue(LV, &I, ValueLatticeElement::getRange(Res)); + } else if (!OpSt.isUnknownOrUndef()) + markOverdefined(&I); } void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { @@ -891,6 +842,11 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { if (EVI.getType()->isStructTy()) return (void)markOverdefined(&EVI); + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (ValueState[&EVI].isOverdefined()) + return (void)markOverdefined(&EVI); + // If this is extracting from more than one level of struct, we don't know. if (EVI.getNumIndices() != 1) return (void)markOverdefined(&EVI); @@ -898,7 +854,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { Value *AggVal = EVI.getAggregateOperand(); if (AggVal->getType()->isStructTy()) { unsigned i = *EVI.idx_begin(); - LatticeVal EltVal = getStructValueState(AggVal, i); + ValueLatticeElement EltVal = getStructValueState(AggVal, i); mergeInValue(getValueState(&EVI), &EVI, EltVal); } else { // Otherwise, must be extracting from an array. @@ -911,6 +867,11 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { if (!STy) return (void)markOverdefined(&IVI); + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (isOverdefined(ValueState[&IVI])) + return (void)markOverdefined(&IVI); + // If this has more than one index, we can't handle it, drive all results to // undef. if (IVI.getNumIndices() != 1) @@ -923,7 +884,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { // This passes through all values that aren't the inserted element. if (i != Idx) { - LatticeVal EltVal = getStructValueState(Aggr, i); + ValueLatticeElement EltVal = getStructValueState(Aggr, i); mergeInValue(getStructValueState(&IVI, i), &IVI, EltVal); continue; } @@ -933,7 +894,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { // We don't track structs in structs. markOverdefined(getStructValueState(&IVI, i), &IVI); else { - LatticeVal InVal = getValueState(Val); + ValueLatticeElement InVal = getValueState(Val); mergeInValue(getStructValueState(&IVI, i), &IVI, InVal); } } @@ -945,11 +906,16 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { if (I.getType()->isStructTy()) return (void)markOverdefined(&I); - LatticeVal CondValue = getValueState(I.getCondition()); - if (CondValue.isUnknown()) + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (ValueState[&I].isOverdefined()) + return (void)markOverdefined(&I); + + ValueLatticeElement CondValue = getValueState(I.getCondition()); + if (CondValue.isUnknownOrUndef()) return; - if (ConstantInt *CondCB = CondValue.getConstantInt()) { + if (ConstantInt *CondCB = getConstantInt(CondValue)) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); mergeInValue(&I, getValueState(OpVal)); return; @@ -958,30 +924,27 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { // Otherwise, the condition is overdefined or a constant we can't evaluate. // See if we can produce something better than overdefined based on the T/F // value. - LatticeVal TVal = getValueState(I.getTrueValue()); - LatticeVal FVal = getValueState(I.getFalseValue()); - - // select ?, C, C -> C. - if (TVal.isConstant() && FVal.isConstant() && - TVal.getConstant() == FVal.getConstant()) - return (void)markConstant(&I, FVal.getConstant()); - - if (TVal.isUnknown()) // select ?, undef, X -> X. - return (void)mergeInValue(&I, FVal); - if (FVal.isUnknown()) // select ?, X, undef -> X. - return (void)mergeInValue(&I, TVal); - markOverdefined(&I); + ValueLatticeElement TVal = getValueState(I.getTrueValue()); + ValueLatticeElement FVal = getValueState(I.getFalseValue()); + + bool Changed = ValueState[&I].mergeIn(TVal); + Changed |= ValueState[&I].mergeIn(FVal); + if (Changed) + pushToWorkListMsg(ValueState[&I], &I); } // Handle Unary Operators. void SCCPSolver::visitUnaryOperator(Instruction &I) { - LatticeVal V0State = getValueState(I.getOperand(0)); + ValueLatticeElement V0State = getValueState(I.getOperand(0)); - LatticeVal &IV = ValueState[&I]; - if (IV.isOverdefined()) return; + ValueLatticeElement &IV = ValueState[&I]; + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (isOverdefined(IV)) + return (void)markOverdefined(&I); - if (V0State.isConstant()) { - Constant *C = ConstantExpr::get(I.getOpcode(), V0State.getConstant()); + if (isConstant(V0State)) { + Constant *C = ConstantExpr::get(I.getOpcode(), getConstant(V0State)); // op Y -> undef. if (isa<UndefValue>(C)) @@ -990,7 +953,7 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) { } // If something is undef, wait for it to resolve. - if (!V0State.isOverdefined()) + if (!isOverdefined(V0State)) return; markOverdefined(&I); @@ -998,101 +961,90 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) { // Handle Binary Operators. void SCCPSolver::visitBinaryOperator(Instruction &I) { - LatticeVal V1State = getValueState(I.getOperand(0)); - LatticeVal V2State = getValueState(I.getOperand(1)); + ValueLatticeElement V1State = getValueState(I.getOperand(0)); + ValueLatticeElement V2State = getValueState(I.getOperand(1)); - LatticeVal &IV = ValueState[&I]; - if (IV.isOverdefined()) return; - - if (V1State.isConstant() && V2State.isConstant()) { - Constant *C = ConstantExpr::get(I.getOpcode(), V1State.getConstant(), - V2State.getConstant()); - // X op Y -> undef. - if (isa<UndefValue>(C)) - return; - return (void)markConstant(IV, &I, C); - } + ValueLatticeElement &IV = ValueState[&I]; + if (IV.isOverdefined()) + return; // If something is undef, wait for it to resolve. - if (!V1State.isOverdefined() && !V2State.isOverdefined()) + if (V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef()) return; - // Otherwise, one of our operands is overdefined. Try to produce something - // better than overdefined with some tricks. - // If this is 0 / Y, it doesn't matter that the second operand is - // overdefined, and we can replace it with zero. - if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) - if (V1State.isConstant() && V1State.getConstant()->isNullValue()) - return (void)markConstant(IV, &I, V1State.getConstant()); - - // If this is: - // -> AND/MUL with 0 - // -> OR with -1 - // it doesn't matter that the other operand is overdefined. - if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || - I.getOpcode() == Instruction::Or) { - LatticeVal *NonOverdefVal = nullptr; - if (!V1State.isOverdefined()) - NonOverdefVal = &V1State; - else if (!V2State.isOverdefined()) - NonOverdefVal = &V2State; - - if (NonOverdefVal) { - if (NonOverdefVal->isUnknown()) - return; + if (V1State.isOverdefined() && V2State.isOverdefined()) + return (void)markOverdefined(&I); - if (I.getOpcode() == Instruction::And || - I.getOpcode() == Instruction::Mul) { - // X and 0 = 0 - // X * 0 = 0 - if (NonOverdefVal->getConstant()->isNullValue()) - return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); - } else { - // X or -1 = -1 - if (ConstantInt *CI = NonOverdefVal->getConstantInt()) - if (CI->isMinusOne()) - return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); - } + // If either of the operands is a constant, try to fold it to a constant. + // TODO: Use information from notconstant better. + if ((V1State.isConstant() || V2State.isConstant())) { + Value *V1 = isConstant(V1State) ? getConstant(V1State) : I.getOperand(0); + Value *V2 = isConstant(V2State) ? getConstant(V2State) : I.getOperand(1); + Value *R = SimplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); + auto *C = dyn_cast_or_null<Constant>(R); + if (C) { + // X op Y -> undef. + if (isa<UndefValue>(C)) + return; + // Conservatively assume that the result may be based on operands that may + // be undef. Note that we use mergeInValue to combine the constant with + // the existing lattice value for I, as different constants might be found + // after one of the operands go to overdefined, e.g. due to one operand + // being a special floating value. + ValueLatticeElement NewV; + NewV.markConstant(C, /*MayIncludeUndef=*/true); + return (void)mergeInValue(&I, NewV); } } - markOverdefined(&I); + // Only use ranges for binary operators on integers. + if (!I.getType()->isIntegerTy()) + return markOverdefined(&I); + + // Try to simplify to a constant range. + ConstantRange A = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); + ConstantRange B = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); + if (V1State.isConstantRange()) + A = V1State.getConstantRange(); + if (V2State.isConstantRange()) + B = V2State.getConstantRange(); + + ConstantRange R = A.binaryOp(cast<BinaryOperator>(&I)->getOpcode(), B); + mergeInValue(&I, ValueLatticeElement::getRange(R)); + + // TODO: Currently we do not exploit special values that produce something + // better than overdefined with an overdefined operand for vector or floating + // point types, like and <4 x i32> overdefined, zeroinitializer. } // Handle ICmpInst instruction. void SCCPSolver::visitCmpInst(CmpInst &I) { // Do not cache this lookup, getValueState calls later in the function might // invalidate the reference. - if (ValueState[&I].isOverdefined()) return; + if (isOverdefined(ValueState[&I])) + return (void)markOverdefined(&I); Value *Op1 = I.getOperand(0); Value *Op2 = I.getOperand(1); // For parameters, use ParamState which includes constant range info if // available. - auto V1Param = ParamState.find(Op1); - ValueLatticeElement V1State = (V1Param != ParamState.end()) - ? V1Param->second - : getValueState(Op1).toValueLattice(); - - auto V2Param = ParamState.find(Op2); - ValueLatticeElement V2State = V2Param != ParamState.end() - ? V2Param->second - : getValueState(Op2).toValueLattice(); + auto V1State = getValueState(Op1); + auto V2State = getValueState(Op2); Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State); if (C) { if (isa<UndefValue>(C)) return; - LatticeVal CV; + ValueLatticeElement CV; CV.markConstant(C); mergeInValue(&I, CV); return; } // If operands are still unknown, wait for it to resolve. - if (!V1State.isOverdefined() && !V2State.isOverdefined() && - !ValueState[&I].isConstant()) + if ((V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef()) && + !isConstant(ValueState[&I])) return; markOverdefined(&I); @@ -1101,21 +1053,26 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { // Handle getelementptr instructions. If all operands are constants then we // can turn this into a getelementptr ConstantExpr. void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { - if (ValueState[&I].isOverdefined()) return; + if (isOverdefined(ValueState[&I])) + return (void)markOverdefined(&I); SmallVector<Constant*, 8> Operands; Operands.reserve(I.getNumOperands()); for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { - LatticeVal State = getValueState(I.getOperand(i)); - if (State.isUnknown()) + ValueLatticeElement State = getValueState(I.getOperand(i)); + if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. - if (State.isOverdefined()) + if (isOverdefined(State)) return (void)markOverdefined(&I); - assert(State.isConstant() && "Unknown state!"); - Operands.push_back(State.getConstant()); + if (Constant *C = getConstant(State)) { + Operands.push_back(C); + continue; + } + + return (void)markOverdefined(&I); } Constant *Ptr = Operands[0]; @@ -1136,230 +1093,297 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) { return; GlobalVariable *GV = cast<GlobalVariable>(SI.getOperand(1)); - DenseMap<GlobalVariable*, LatticeVal>::iterator I = TrackedGlobals.find(GV); - if (I == TrackedGlobals.end() || I->second.isOverdefined()) return; + auto I = TrackedGlobals.find(GV); + if (I == TrackedGlobals.end()) + return; // Get the value we are storing into the global, then merge it. - mergeInValue(I->second, GV, getValueState(SI.getOperand(0))); + mergeInValue(I->second, GV, getValueState(SI.getOperand(0)), + ValueLatticeElement::MergeOptions().setCheckWiden(false)); if (I->second.isOverdefined()) TrackedGlobals.erase(I); // No need to keep tracking this! } +static ValueLatticeElement getValueFromMetadata(const Instruction *I) { + if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) + if (I->getType()->isIntegerTy()) + return ValueLatticeElement::getRange( + getConstantRangeFromMetadata(*Ranges)); + // TODO: Also handle MD_nonnull. + return ValueLatticeElement::getOverdefined(); +} + // Handle load instructions. If the operand is a constant pointer to a constant // global, we can replace the load with the loaded constant value! void SCCPSolver::visitLoadInst(LoadInst &I) { - // If this load is of a struct, just mark the result overdefined. - if (I.getType()->isStructTy()) + // If this load is of a struct or the load is volatile, just mark the result + // as overdefined. + if (I.getType()->isStructTy() || I.isVolatile()) return (void)markOverdefined(&I); - LatticeVal PtrVal = getValueState(I.getOperand(0)); - if (PtrVal.isUnknown()) return; // The pointer is not resolved yet! - - LatticeVal &IV = ValueState[&I]; - if (IV.isOverdefined()) return; + // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (ValueState[&I].isOverdefined()) + return (void)markOverdefined(&I); - if (!PtrVal.isConstant() || I.isVolatile()) - return (void)markOverdefined(IV, &I); + ValueLatticeElement PtrVal = getValueState(I.getOperand(0)); + if (PtrVal.isUnknownOrUndef()) + return; // The pointer is not resolved yet! - Constant *Ptr = PtrVal.getConstant(); + ValueLatticeElement &IV = ValueState[&I]; - // load null is undefined. - if (isa<ConstantPointerNull>(Ptr)) { - if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace())) - return (void)markOverdefined(IV, &I); - else - return; - } + if (isConstant(PtrVal)) { + Constant *Ptr = getConstant(PtrVal); - // Transform load (constant global) into the value loaded. - if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { - if (!TrackedGlobals.empty()) { - // If we are tracking this global, merge in the known value for it. - DenseMap<GlobalVariable*, LatticeVal>::iterator It = - TrackedGlobals.find(GV); - if (It != TrackedGlobals.end()) { - mergeInValue(IV, &I, It->second); + // load null is undefined. + if (isa<ConstantPointerNull>(Ptr)) { + if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace())) + return (void)markOverdefined(IV, &I); + else return; + } + + // Transform load (constant global) into the value loaded. + if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { + if (!TrackedGlobals.empty()) { + // If we are tracking this global, merge in the known value for it. + auto It = TrackedGlobals.find(GV); + if (It != TrackedGlobals.end()) { + mergeInValue(IV, &I, It->second, getMaxWidenStepsOpts()); + return; + } } } - } - // Transform load from a constant into a constant if possible. - if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { - if (isa<UndefValue>(C)) - return; - return (void)markConstant(IV, &I, C); + // Transform load from a constant into a constant if possible. + if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { + if (isa<UndefValue>(C)) + return; + return (void)markConstant(IV, &I, C); + } } - // Otherwise we cannot say for certain what value this load will produce. - // Bail out. - markOverdefined(IV, &I); + // Fall back to metadata. + mergeInValue(&I, getValueFromMetadata(&I)); } -void SCCPSolver::visitCallSite(CallSite CS) { - Function *F = CS.getCalledFunction(); - Instruction *I = CS.getInstruction(); +void SCCPSolver::visitCallBase(CallBase &CB) { + handleCallResult(CB); + handleCallArguments(CB); +} - if (auto *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::ssa_copy) { - if (ValueState[I].isOverdefined()) +void SCCPSolver::handleCallOverdefined(CallBase &CB) { + Function *F = CB.getCalledFunction(); + + // Void return and not tracking callee, just bail. + if (CB.getType()->isVoidTy()) + return; + + // Always mark struct return as overdefined. + if (CB.getType()->isStructTy()) + return (void)markOverdefined(&CB); + + // Otherwise, if we have a single return value case, and if the function is + // a declaration, maybe we can constant fold it. + if (F && F->isDeclaration() && canConstantFoldCallTo(&CB, F)) { + SmallVector<Constant *, 8> Operands; + for (auto AI = CB.arg_begin(), E = CB.arg_end(); AI != E; ++AI) { + if (AI->get()->getType()->isStructTy()) + return markOverdefined(&CB); // Can't handle struct args. + ValueLatticeElement State = getValueState(*AI); + + if (State.isUnknownOrUndef()) + return; // Operands are not resolved yet. + if (isOverdefined(State)) + return (void)markOverdefined(&CB); + assert(isConstant(State) && "Unknown state!"); + Operands.push_back(getConstant(State)); + } + + if (isOverdefined(getValueState(&CB))) + return (void)markOverdefined(&CB); + + // If we can constant fold this, mark the result of the call as a + // constant. + if (Constant *C = ConstantFoldCall(&CB, F, Operands, &GetTLI(*F))) { + // call -> undef. + if (isa<UndefValue>(C)) return; + return (void)markConstant(&CB, C); + } + } + + // Fall back to metadata. + mergeInValue(&CB, getValueFromMetadata(&CB)); +} + +void SCCPSolver::handleCallArguments(CallBase &CB) { + Function *F = CB.getCalledFunction(); + // If this is a local function that doesn't have its address taken, mark its + // entry block executable and merge in the actual arguments to the call into + // the formal arguments of the function. + if (!TrackingIncomingArguments.empty() && + TrackingIncomingArguments.count(F)) { + MarkBlockExecutable(&F->front()); + + // Propagate information from this call site into the callee. + auto CAI = CB.arg_begin(); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; + ++AI, ++CAI) { + // If this argument is byval, and if the function is not readonly, there + // will be an implicit copy formed of the input aggregate. + if (AI->hasByValAttr() && !F->onlyReadsMemory()) { + markOverdefined(&*AI); + continue; + } + + if (auto *STy = dyn_cast<StructType>(AI->getType())) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + ValueLatticeElement CallArg = getStructValueState(*CAI, i); + mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg, + getMaxWidenStepsOpts()); + } + } else + mergeInValue(&*AI, getValueState(*CAI), getMaxWidenStepsOpts()); + } + } +} + +void SCCPSolver::handleCallResult(CallBase &CB) { + Function *F = CB.getCalledFunction(); - auto *PI = getPredicateInfoFor(I); - if (!PI) + if (auto *II = dyn_cast<IntrinsicInst>(&CB)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + if (ValueState[&CB].isOverdefined()) return; - Value *CopyOf = I->getOperand(0); - auto *PBranch = dyn_cast<PredicateBranch>(PI); - if (!PBranch) { - mergeInValue(ValueState[I], I, getValueState(CopyOf)); + Value *CopyOf = CB.getOperand(0); + ValueLatticeElement CopyOfVal = getValueState(CopyOf); + auto *PI = getPredicateInfoFor(&CB); + assert(PI && "Missing predicate info for ssa.copy"); + + CmpInst *Cmp; + bool TrueEdge; + if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) { + Cmp = dyn_cast<CmpInst>(PBranch->Condition); + TrueEdge = PBranch->TrueEdge; + } else if (auto *PAssume = dyn_cast<PredicateAssume>(PI)) { + Cmp = dyn_cast<CmpInst>(PAssume->Condition); + TrueEdge = true; + } else { + mergeInValue(ValueState[&CB], &CB, CopyOfVal); return; } - Value *Cond = PBranch->Condition; - // Everything below relies on the condition being a comparison. - auto *Cmp = dyn_cast<CmpInst>(Cond); if (!Cmp) { - mergeInValue(ValueState[I], I, getValueState(CopyOf)); + mergeInValue(ValueState[&CB], &CB, CopyOfVal); return; } + Value *RenamedOp = PI->RenamedOp; Value *CmpOp0 = Cmp->getOperand(0); Value *CmpOp1 = Cmp->getOperand(1); - if (CopyOf != CmpOp0 && CopyOf != CmpOp1) { - mergeInValue(ValueState[I], I, getValueState(CopyOf)); + // Bail out if neither of the operands matches RenamedOp. + if (CmpOp0 != RenamedOp && CmpOp1 != RenamedOp) { + mergeInValue(ValueState[&CB], &CB, getValueState(CopyOf)); return; } - if (CmpOp0 != CopyOf) + auto Pred = Cmp->getPredicate(); + if (CmpOp1 == RenamedOp) { std::swap(CmpOp0, CmpOp1); + Pred = Cmp->getSwappedPredicate(); + } - LatticeVal OriginalVal = getValueState(CopyOf); - LatticeVal EqVal = getValueState(CmpOp1); - LatticeVal &IV = ValueState[I]; - if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) { - addAdditionalUser(CmpOp1, I); - if (OriginalVal.isConstant()) - mergeInValue(IV, I, OriginalVal); - else - mergeInValue(IV, I, EqVal); + // Wait until CmpOp1 is resolved. + if (getValueState(CmpOp1).isUnknown()) { + addAdditionalUser(CmpOp1, &CB); return; } - if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) { - addAdditionalUser(CmpOp1, I); - if (OriginalVal.isConstant()) - mergeInValue(IV, I, OriginalVal); - else - mergeInValue(IV, I, EqVal); + + // The code below relies on PredicateInfo only inserting copies for the + // true branch when the branch condition is an AND and only inserting + // copies for the false branch when the branch condition is an OR. This + // ensures we can intersect the range from the condition with the range of + // CopyOf. + if (!TrueEdge) + Pred = CmpInst::getInversePredicate(Pred); + + ValueLatticeElement CondVal = getValueState(CmpOp1); + ValueLatticeElement &IV = ValueState[&CB]; + if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) { + auto ImposedCR = + ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType())); + + // Get the range imposed by the condition. + if (CondVal.isConstantRange()) + ImposedCR = ConstantRange::makeAllowedICmpRegion( + Pred, CondVal.getConstantRange()); + + // Combine range info for the original value with the new range from the + // condition. + auto CopyOfCR = CopyOfVal.isConstantRange() + ? CopyOfVal.getConstantRange() + : ConstantRange::getFull( + DL.getTypeSizeInBits(CopyOf->getType())); + auto NewCR = ImposedCR.intersectWith(CopyOfCR); + // If the existing information is != x, do not use the information from + // a chained predicate, as the != x information is more likely to be + // helpful in practice. + if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement()) + NewCR = CopyOfCR; + + addAdditionalUser(CmpOp1, &CB); + // TODO: Actually filp MayIncludeUndef for the created range to false, + // once most places in the optimizer respect the branches on + // undef/poison are UB rule. The reason why the new range cannot be + // undef is as follows below: + // The new range is based on a branch condition. That guarantees that + // neither of the compare operands can be undef in the branch targets, + // unless we have conditions that are always true/false (e.g. icmp ule + // i32, %a, i32_max). For the latter overdefined/empty range will be + // inferred, but the branch will get folded accordingly anyways. + mergeInValue( + IV, &CB, + ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef=*/true)); + return; + } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) { + // For non-integer values or integer constant expressions, only + // propagate equal constants. + addAdditionalUser(CmpOp1, &CB); + mergeInValue(IV, &CB, CondVal); return; } - return (void)mergeInValue(IV, I, getValueState(CopyOf)); + return (void)mergeInValue(IV, &CB, CopyOfVal); } } // The common case is that we aren't tracking the callee, either because we // are not doing interprocedural analysis or the callee is indirect, or is // external. Handle these cases first. - if (!F || F->isDeclaration()) { -CallOverdefined: - // Void return and not tracking callee, just bail. - if (I->getType()->isVoidTy()) return; - - // Otherwise, if we have a single return value case, and if the function is - // a declaration, maybe we can constant fold it. - if (F && F->isDeclaration() && !I->getType()->isStructTy() && - canConstantFoldCallTo(cast<CallBase>(CS.getInstruction()), F)) { - SmallVector<Constant*, 8> Operands; - for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); - AI != E; ++AI) { - if (AI->get()->getType()->isStructTy()) - return markOverdefined(I); // Can't handle struct args. - LatticeVal State = getValueState(*AI); - - if (State.isUnknown()) - return; // Operands are not resolved yet. - if (State.isOverdefined()) - return (void)markOverdefined(I); - assert(State.isConstant() && "Unknown state!"); - Operands.push_back(State.getConstant()); - } - - if (getValueState(I).isOverdefined()) - return; - - // If we can constant fold this, mark the result of the call as a - // constant. - if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F, - Operands, &GetTLI(*F))) { - // call -> undef. - if (isa<UndefValue>(C)) - return; - return (void)markConstant(I, C); - } - } - - // Otherwise, we don't know anything about this call, mark it overdefined. - return (void)markOverdefined(I); - } - - // If this is a local function that doesn't have its address taken, mark its - // entry block executable and merge in the actual arguments to the call into - // the formal arguments of the function. - if (!TrackingIncomingArguments.empty() && TrackingIncomingArguments.count(F)){ - MarkBlockExecutable(&F->front()); - - // Propagate information from this call site into the callee. - CallSite::arg_iterator CAI = CS.arg_begin(); - for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); - AI != E; ++AI, ++CAI) { - // If this argument is byval, and if the function is not readonly, there - // will be an implicit copy formed of the input aggregate. - if (AI->hasByValAttr() && !F->onlyReadsMemory()) { - markOverdefined(&*AI); - continue; - } - - if (auto *STy = dyn_cast<StructType>(AI->getType())) { - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - LatticeVal CallArg = getStructValueState(*CAI, i); - mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg); - } - } else { - // Most other parts of the Solver still only use the simpler value - // lattice, so we propagate changes for parameters to both lattices. - LatticeVal ConcreteArgument = getValueState(*CAI); - bool ParamChanged = - getParamState(&*AI).mergeIn(ConcreteArgument.toValueLattice(), DL); - bool ValueChanged = mergeInValue(&*AI, ConcreteArgument); - // Add argument to work list, if the state of a parameter changes but - // ValueState does not change (because it is already overdefined there), - // We have to take changes in ParamState into account, as it is used - // when evaluating Cmp instructions. - if (!ValueChanged && ParamChanged) - pushToWorkList(ValueState[&*AI], &*AI); - } - } - } + if (!F || F->isDeclaration()) + return handleCallOverdefined(CB); // If this is a single/zero retval case, see if we're tracking the function. if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { if (!MRVFunctionsTracked.count(F)) - goto CallOverdefined; // Not tracking this callee. + return handleCallOverdefined(CB); // Not tracking this callee. // If we are tracking this callee, propagate the result of the function // into this call site. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) - mergeInValue(getStructValueState(I, i), I, - TrackedMultipleRetVals[std::make_pair(F, i)]); + mergeInValue(getStructValueState(&CB, i), &CB, + TrackedMultipleRetVals[std::make_pair(F, i)], + getMaxWidenStepsOpts()); } else { - MapVector<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); + auto TFRVI = TrackedRetVals.find(F); if (TFRVI == TrackedRetVals.end()) - goto CallOverdefined; // Not tracking this callee. + return handleCallOverdefined(CB); // Not tracking this callee. // If so, propagate the return value of the callee into this call result. - mergeInValue(I, TFRVI->second); + mergeInValue(&CB, TFRVI->second, getMaxWidenStepsOpts()); } } @@ -1429,10 +1453,8 @@ void SCCPSolver::Solve() { /// constraints on the condition of the branch, as that would impact other users /// of the value. /// -/// This scan also checks for values that use undefs, whose results are actually -/// defined. For example, 'zext i8 undef to i32' should produce all zeros -/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, -/// even if X isn't defined. +/// This scan also checks for values that use undefs. It conservatively marks +/// them as overdefined. bool SCCPSolver::ResolvedUndefsIn(Function &F) { for (BasicBlock &BB : F) { if (!BBExecutable.count(&BB)) @@ -1446,8 +1468,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // Only a few things that can be structs matter for undef. // Tracked calls must never be marked overdefined in ResolvedUndefsIn. - if (CallSite CS = CallSite(&I)) - if (Function *F = CS.getCalledFunction()) + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) if (MRVFunctionsTracked.count(F)) continue; @@ -1455,19 +1477,18 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // tracked as precisely as their operands. if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) continue; - // Send the results of everything else to overdefined. We could be // more precise than this but it isn't worth bothering. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - LatticeVal &LV = getStructValueState(&I, i); - if (LV.isUnknown()) + ValueLatticeElement &LV = getStructValueState(&I, i); + if (LV.isUnknownOrUndef()) markOverdefined(LV, &I); } continue; } - LatticeVal &LV = getValueState(&I); - if (!LV.isUnknown()) + ValueLatticeElement &LV = getValueState(&I); + if (!LV.isUnknownOrUndef()) continue; // There are two reasons a call can have an undef result @@ -1475,195 +1496,20 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // 2. It could be constant-foldable. // Because of the way we solve return values, tracked calls must // never be marked overdefined in ResolvedUndefsIn. - if (CallSite CS = CallSite(&I)) { - if (Function *F = CS.getCalledFunction()) + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) if (TrackedRetVals.count(F)) continue; - // If the call is constant-foldable, we mark it overdefined because - // we do not know what return values are valid. - markOverdefined(&I); - return true; - } - - // extractvalue is safe; check here because the argument is a struct. - if (isa<ExtractValueInst>(I)) - continue; - - // Compute the operand LatticeVals, for convenience below. - // Anything taking a struct is conservatively assumed to require - // overdefined markings. - if (I.getOperand(0)->getType()->isStructTy()) { - markOverdefined(&I); - return true; - } - LatticeVal Op0LV = getValueState(I.getOperand(0)); - LatticeVal Op1LV; - if (I.getNumOperands() == 2) { - if (I.getOperand(1)->getType()->isStructTy()) { - markOverdefined(&I); - return true; - } - - Op1LV = getValueState(I.getOperand(1)); - } - // If this is an instructions whose result is defined even if the input is - // not fully defined, propagate the information. - Type *ITy = I.getType(); - switch (I.getOpcode()) { - case Instruction::Add: - case Instruction::Sub: - case Instruction::Trunc: - case Instruction::FPTrunc: - case Instruction::BitCast: - break; // Any undef -> undef - case Instruction::FSub: - case Instruction::FAdd: - case Instruction::FMul: - case Instruction::FDiv: - case Instruction::FRem: - // Floating-point binary operation: be conservative. - if (Op0LV.isUnknown() && Op1LV.isUnknown()) - markForcedConstant(&I, Constant::getNullValue(ITy)); - else - markOverdefined(&I); - return true; - case Instruction::FNeg: - break; // fneg undef -> undef - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::SIToFP: - case Instruction::UIToFP: - // undef -> 0; some outputs are impossible - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - case Instruction::Mul: - case Instruction::And: - // Both operands undef -> undef - if (Op0LV.isUnknown() && Op1LV.isUnknown()) - break; - // undef * X -> 0. X could be zero. - // undef & X -> 0. X could be zero. - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - case Instruction::Or: - // Both operands undef -> undef - if (Op0LV.isUnknown() && Op1LV.isUnknown()) - break; - // undef | X -> -1. X could be -1. - markForcedConstant(&I, Constant::getAllOnesValue(ITy)); - return true; - case Instruction::Xor: - // undef ^ undef -> 0; strictly speaking, this is not strictly - // necessary, but we try to be nice to people who expect this - // behavior in simple cases - if (Op0LV.isUnknown() && Op1LV.isUnknown()) { - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - } - // undef ^ X -> undef - break; - case Instruction::SDiv: - case Instruction::UDiv: - case Instruction::SRem: - case Instruction::URem: - // X / undef -> undef. No change. - // X % undef -> undef. No change. - if (Op1LV.isUnknown()) break; - - // X / 0 -> undef. No change. - // X % 0 -> undef. No change. - if (Op1LV.isConstant() && Op1LV.getConstant()->isZeroValue()) - break; - - // undef / X -> 0. X could be maxint. - // undef % X -> 0. X could be 1. - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - case Instruction::AShr: - // X >>a undef -> undef. - if (Op1LV.isUnknown()) break; - - // Shifting by the bitwidth or more is undefined. - if (Op1LV.isConstant()) { - if (auto *ShiftAmt = Op1LV.getConstantInt()) - if (ShiftAmt->getLimitedValue() >= - ShiftAmt->getType()->getScalarSizeInBits()) - break; - } - - // undef >>a X -> 0 - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - case Instruction::LShr: - case Instruction::Shl: - // X << undef -> undef. - // X >> undef -> undef. - if (Op1LV.isUnknown()) break; - - // Shifting by the bitwidth or more is undefined. - if (Op1LV.isConstant()) { - if (auto *ShiftAmt = Op1LV.getConstantInt()) - if (ShiftAmt->getLimitedValue() >= - ShiftAmt->getType()->getScalarSizeInBits()) - break; - } - - // undef << X -> 0 - // undef >> X -> 0 - markForcedConstant(&I, Constant::getNullValue(ITy)); - return true; - case Instruction::Select: - Op1LV = getValueState(I.getOperand(1)); - // undef ? X : Y -> X or Y. There could be commonality between X/Y. - if (Op0LV.isUnknown()) { - if (!Op1LV.isConstant()) // Pick the constant one if there is any. - Op1LV = getValueState(I.getOperand(2)); - } else if (Op1LV.isUnknown()) { - // c ? undef : undef -> undef. No change. - Op1LV = getValueState(I.getOperand(2)); - if (Op1LV.isUnknown()) - break; - // Otherwise, c ? undef : x -> x. - } else { - // Leave Op1LV as Operand(1)'s LatticeValue. - } - - if (Op1LV.isConstant()) - markForcedConstant(&I, Op1LV.getConstant()); - else - markOverdefined(&I); - return true; - case Instruction::Load: + if (isa<LoadInst>(I)) { // A load here means one of two things: a load of undef from a global, // a load from an unknown pointer. Either way, having it return undef // is okay. - break; - case Instruction::ICmp: - // X == undef -> undef. Other comparisons get more complicated. - Op0LV = getValueState(I.getOperand(0)); - Op1LV = getValueState(I.getOperand(1)); - - if ((Op0LV.isUnknown() || Op1LV.isUnknown()) && - cast<ICmpInst>(&I)->isEquality()) - break; - markOverdefined(&I); - return true; - case Instruction::Call: - case Instruction::Invoke: - case Instruction::CallBr: - llvm_unreachable("Call-like instructions should have be handled early"); - default: - // If we don't know what should happen here, conservatively mark it - // overdefined. - markOverdefined(&I); - return true; + continue; } + + markOverdefined(&I); + return true; } // Check to see if we have a branch or switch on an undefined value. If so @@ -1672,7 +1518,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { Instruction *TI = BB.getTerminator(); if (auto *BI = dyn_cast<BranchInst>(TI)) { if (!BI->isConditional()) continue; - if (!getValueState(BI->getCondition()).isUnknown()) + if (!getValueState(BI->getCondition()).isUnknownOrUndef()) continue; // If the input to SCCP is actually branch on undef, fix the undef to @@ -1700,7 +1546,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { if (IBR->getNumSuccessors() < 1) continue; - if (!getValueState(IBR->getAddress()).isUnknown()) + if (!getValueState(IBR->getAddress()).isUnknownOrUndef()) continue; // If the input to SCCP is actually branch on undef, fix the undef to @@ -1724,7 +1570,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } if (auto *SI = dyn_cast<SwitchInst>(TI)) { - if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) + if (!SI->getNumCases() || + !getValueState(SI->getCondition()).isUnknownOrUndef()) continue; // If the input to SCCP is actually switch on undef, fix the undef to @@ -1753,25 +1600,26 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { - std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); - if (llvm::any_of(IVs, - [](const LatticeVal &LV) { return LV.isOverdefined(); })) + std::vector<ValueLatticeElement> IVs = Solver.getStructLatticeValueFor(V); + if (any_of(IVs, + [](const ValueLatticeElement &LV) { return isOverdefined(LV); })) return false; std::vector<Constant *> ConstVals; auto *ST = cast<StructType>(V->getType()); for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - LatticeVal V = IVs[i]; - ConstVals.push_back(V.isConstant() - ? V.getConstant() + ValueLatticeElement V = IVs[i]; + ConstVals.push_back(isConstant(V) + ? Solver.getConstant(V) : UndefValue::get(ST->getElementType(i))); } Const = ConstantStruct::get(ST, ConstVals); } else { - const LatticeVal &IV = Solver.getLatticeValueFor(V); - if (IV.isOverdefined()) + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + if (isOverdefined(IV)) return false; - Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); + Const = + isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); } assert(Const && "Constant is nullptr here!"); @@ -1779,8 +1627,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { // unless the call itself can be removed CallInst *CI = dyn_cast<CallInst>(V); if (CI && CI->isMustTailCall() && !CI->isSafeToRemove()) { - CallSite CS(CI); - Function *F = CS.getCalledFunction(); + Function *F = CI->getCalledFunction(); // Don't zap returns of the callee if (F) @@ -1798,13 +1645,49 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { return true; } +static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, + SmallPtrSetImpl<Value *> &InsertedValues, + Statistic &InstRemovedStat, + Statistic &InstReplacedStat) { + bool MadeChanges = false; + for (Instruction &Inst : make_early_inc_range(BB)) { + if (Inst.getType()->isVoidTy()) + continue; + if (tryToReplaceWithConstant(Solver, &Inst)) { + if (Inst.isSafeToRemove()) + Inst.eraseFromParent(); + // Hey, we just changed something! + MadeChanges = true; + ++InstRemovedStat; + } else if (isa<SExtInst>(&Inst)) { + Value *ExtOp = Inst.getOperand(0); + if (isa<Constant>(ExtOp) || InsertedValues.count(ExtOp)) + continue; + const ValueLatticeElement &IV = Solver.getLatticeValueFor(ExtOp); + if (!IV.isConstantRange(/*UndefAllowed=*/false)) + continue; + if (IV.getConstantRange().isAllNonNegative()) { + auto *ZExt = new ZExtInst(ExtOp, Inst.getType(), "", &Inst); + InsertedValues.insert(ZExt); + Inst.replaceAllUsesWith(ZExt); + Solver.removeLatticeValueFor(&Inst); + Inst.eraseFromParent(); + InstReplacedStat++; + MadeChanges = true; + } + } + } + return MadeChanges; +} + // runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, // and return true if the function was modified. static bool runSCCP(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI) { LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); SCCPSolver Solver( - DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }); + DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }, + F.getContext()); // Mark the first block of the function as being executable. Solver.MarkBlockExecutable(&F.front()); @@ -1827,6 +1710,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, // delete their contents now. Note that we cannot actually delete the blocks, // as we cannot modify the CFG of the function. + SmallPtrSet<Value *, 32> InsertedValues; for (BasicBlock &BB : F) { if (!Solver.isBlockExecutable(&BB)) { LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); @@ -1838,21 +1722,8 @@ static bool runSCCP(Function &F, const DataLayout &DL, continue; } - // Iterate over all of the instructions in a function, replacing them with - // constants if we have found them to be of constant values. - for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { - Instruction *Inst = &*BI++; - if (Inst->getType()->isVoidTy() || Inst->isTerminator()) - continue; - - if (tryToReplaceWithConstant(Solver, Inst)) { - if (isInstructionTriviallyDead(Inst)) - Inst->eraseFromParent(); - // Hey, we just changed something! - MadeChanges = true; - ++NumInstRemoved; - } - } + MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, + NumInstRemoved, NumInstReplaced); } return MadeChanges; @@ -1942,14 +1813,15 @@ static void findReturnsToZap(Function &F, // uses (like blockaddresses) could stuck around, without being // used in the underlying IR, meaning we do not have lattice // values for them. - if (!CallSite(U)) + if (!isa<CallBase>(U)) return true; if (U->getType()->isStructTy()) { - return all_of( - Solver.getStructLatticeValueFor(U), - [](const LatticeVal &LV) { return !LV.isOverdefined(); }); + return all_of(Solver.getStructLatticeValueFor(U), + [](const ValueLatticeElement &LV) { + return !isOverdefined(LV); + }); } - return !Solver.getLatticeValueFor(U).isOverdefined(); + return !isOverdefined(Solver.getLatticeValueFor(U)); }) && "We can only zap functions where all live users have a concrete value"); @@ -2006,7 +1878,7 @@ bool llvm::runIPSCCP( Module &M, const DataLayout &DL, std::function<const TargetLibraryInfo &(Function &)> GetTLI, function_ref<AnalysisResultsForFn(Function &)> getAnalysis) { - SCCPSolver Solver(DL, GetTLI); + SCCPSolver Solver(DL, GetTLI, M.getContext()); // Loop over all functions, marking arguments to those with their addresses // taken or that are external as overdefined. @@ -2080,30 +1952,21 @@ bool llvm::runIPSCCP( } } - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (!Solver.isBlockExecutable(&*BB)) { - LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << *BB); + SmallPtrSet<Value *, 32> InsertedValues; + for (BasicBlock &BB : F) { + if (!Solver.isBlockExecutable(&BB)) { + LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); ++NumDeadBlocks; MadeChanges = true; - if (&*BB != &F.front()) - BlocksToErase.push_back(&*BB); + if (&BB != &F.front()) + BlocksToErase.push_back(&BB); continue; } - for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { - Instruction *Inst = &*BI++; - if (Inst->getType()->isVoidTy()) - continue; - if (tryToReplaceWithConstant(Solver, Inst)) { - if (Inst->isSafeToRemove()) - Inst->eraseFromParent(); - // Hey, we just changed something! - MadeChanges = true; - ++IPNumInstRemoved; - } - } + MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, + IPNumInstRemoved, IPNumInstReplaced); } DomTreeUpdater DTU = Solver.getDTU(F); @@ -2189,10 +2052,9 @@ bool llvm::runIPSCCP( // whether other functions are optimizable. SmallVector<ReturnInst*, 8> ReturnsToZap; - const MapVector<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); - for (const auto &I : RV) { + for (const auto &I : Solver.getTrackedRetVals()) { Function *F = I.first; - if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) + if (isOverdefined(I.second) || F->getReturnType()->isVoidTy()) continue; findReturnsToZap(*F, ReturnsToZap, Solver); } @@ -2213,17 +2075,16 @@ bool llvm::runIPSCCP( // If we inferred constant or undef values for globals variables, we can // delete the global and any stores that remain to it. - const DenseMap<GlobalVariable*, LatticeVal> &TG = Solver.getTrackedGlobals(); - for (DenseMap<GlobalVariable*, LatticeVal>::const_iterator I = TG.begin(), - E = TG.end(); I != E; ++I) { - GlobalVariable *GV = I->first; - assert(!I->second.isOverdefined() && - "Overdefined values should have been taken out of the map!"); + for (auto &I : make_early_inc_range(Solver.getTrackedGlobals())) { + GlobalVariable *GV = I.first; + if (isOverdefined(I.second)) + continue; LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName() << "' is constant!\n"); while (!GV->use_empty()) { StoreInst *SI = cast<StoreInst>(GV->user_back()); SI->eraseFromParent(); + MadeChanges = true; } M.getGlobalList().erase(GV); ++IPNumGlobalConst; diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 89916e43fce2..89f324deef9f 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -94,11 +94,6 @@ #include <utility> #include <vector> -#ifndef NDEBUG -// We only use this for a debug check. -#include <random> -#endif - using namespace llvm; using namespace llvm::sroa; @@ -115,11 +110,6 @@ STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion"); STATISTIC(NumDeleted, "Number of instructions deleted"); STATISTIC(NumVectorized, "Number of vectorized aggregates"); -/// Hidden option to enable randomly shuffling the slices to help uncover -/// instability in their order. -static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", - cl::init(false), cl::Hidden); - /// Hidden option to experiment with completely strict handling of inbounds /// GEPs. static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), @@ -129,7 +119,7 @@ namespace { /// A custom IRBuilder inserter which prefixes all names, but only in /// Assert builds. -class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { +class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter { std::string Prefix; const Twine getNameWithPrefix(const Twine &Name) const { @@ -139,9 +129,8 @@ class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { public: void SetNamePrefix(const Twine &P) { Prefix = P.str(); } -protected: void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, - BasicBlock::iterator InsertPt) const { + BasicBlock::iterator InsertPt) const override { IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB, InsertPt); } @@ -663,7 +652,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { public: SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) : PtrUseVisitor<SliceBuilder>(DL), - AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {} + AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize()), + AS(AS) {} private: void markAsDead(Instruction &I) { @@ -752,8 +742,10 @@ private: // For array or vector indices, scale the index by the size of the // type. APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); - GEPOffset += Index * APInt(Offset.getBitWidth(), - DL.getTypeAllocSize(GTI.getIndexedType())); + GEPOffset += + Index * + APInt(Offset.getBitWidth(), + DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize()); } // If this index has computed an intermediate pointer which is not @@ -788,7 +780,7 @@ private: LI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) return PI.setAborted(&LI); - uint64_t Size = DL.getTypeStoreSize(LI.getType()); + uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedSize(); return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); } @@ -803,7 +795,7 @@ private: SI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) return PI.setAborted(&SI); - uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); + uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedSize(); // If this memory access can be shown to *statically* extend outside the // bounds of the allocation, it's behavior is undefined, so simply @@ -1069,17 +1061,9 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) llvm::remove_if(Slices, [](const Slice &S) { return S.isDead(); }), Slices.end()); -#ifndef NDEBUG - if (SROARandomShuffleSlices) { - std::mt19937 MT(static_cast<unsigned>( - std::chrono::system_clock::now().time_since_epoch().count())); - std::shuffle(Slices.begin(), Slices.end(), MT); - } -#endif - // Sort the uses. This arranges for the offsets to be in ascending order, // and the sizes to be in descending order. - llvm::sort(Slices); + std::stable_sort(Slices.begin(), Slices.end()); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1200,7 +1184,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // TODO: Allow recursive phi users. // TODO: Allow stores. BasicBlock *BB = PN.getParent(); - MaybeAlign MaxAlign; + Align MaxAlign; uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType()); APInt MaxSize(APWidth, 0); bool HaveLoad = false; @@ -1221,8 +1205,8 @@ static bool isSafePHIToSpeculate(PHINode &PN) { if (BBI->mayWriteToMemory()) return false; - uint64_t Size = DL.getTypeStoreSize(LI->getType()); - MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment())); + uint64_t Size = DL.getTypeStoreSize(LI->getType()).getFixedSize(); + MaxAlign = std::max(MaxAlign, LI->getAlign()); MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize; HaveLoad = true; } @@ -1273,7 +1257,7 @@ static void speculatePHINodeLoads(PHINode &PN) { // matter which one we get and if any differ. AAMDNodes AATags; SomeLoad->getAAMetadata(AATags); - const MaybeAlign Align = MaybeAlign(SomeLoad->getAlignment()); + Align Alignment = SomeLoad->getAlign(); // Rewrite all loads of the PN to use the new PHI. while (!PN.use_empty()) { @@ -1300,11 +1284,10 @@ static void speculatePHINodeLoads(PHINode &PN) { Instruction *TI = Pred->getTerminator(); IRBuilderTy PredBuilder(TI); - LoadInst *Load = PredBuilder.CreateLoad( - LoadTy, InVal, + LoadInst *Load = PredBuilder.CreateAlignedLoad( + LoadTy, InVal, Alignment, (PN.getName() + ".sroa.speculate.load." + Pred->getName())); ++NumLoadsSpeculated; - Load->setAlignment(Align); if (AATags) Load->setAAMetadata(AATags); NewPN->addIncoming(Load, Pred); @@ -1342,10 +1325,10 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { // absolutely (e.g. allocas) or at this point because we can see other // accesses to it. if (!isSafeToLoadUnconditionally(TValue, LI->getType(), - MaybeAlign(LI->getAlignment()), DL, LI)) + LI->getAlign(), DL, LI)) return false; if (!isSafeToLoadUnconditionally(FValue, LI->getType(), - MaybeAlign(LI->getAlignment()), DL, LI)) + LI->getAlign(), DL, LI)) return false; } @@ -1371,8 +1354,8 @@ static void speculateSelectInstLoads(SelectInst &SI) { NumLoadsSpeculated += 2; // Transfer alignment and AA info if present. - TL->setAlignment(MaybeAlign(LI->getAlignment())); - FL->setAlignment(MaybeAlign(LI->getAlignment())); + TL->setAlignment(LI->getAlign()); + FL->setAlignment(LI->getAlign()); AAMDNodes Tags; LI->getAAMetadata(Tags); @@ -1479,14 +1462,15 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, // extremely poorly defined currently. The long-term goal is to remove GEPing // over a vector from the IR completely. if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) { - unsigned ElementSizeInBits = DL.getTypeSizeInBits(VecTy->getScalarType()); + unsigned ElementSizeInBits = + DL.getTypeSizeInBits(VecTy->getScalarType()).getFixedSize(); if (ElementSizeInBits % 8 != 0) { // GEPs over non-multiple of 8 size vector elements are invalid. return nullptr; } APInt ElementSize(Offset.getBitWidth(), ElementSizeInBits / 8); APInt NumSkippedElements = Offset.sdiv(ElementSize); - if (NumSkippedElements.ugt(VecTy->getNumElements())) + if (NumSkippedElements.ugt(cast<FixedVectorType>(VecTy)->getNumElements())) return nullptr; Offset -= NumSkippedElements * ElementSize; Indices.push_back(IRB.getInt(NumSkippedElements)); @@ -1496,7 +1480,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { Type *ElementTy = ArrTy->getElementType(); - APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + APInt ElementSize(Offset.getBitWidth(), + DL.getTypeAllocSize(ElementTy).getFixedSize()); APInt NumSkippedElements = Offset.sdiv(ElementSize); if (NumSkippedElements.ugt(ArrTy->getNumElements())) return nullptr; @@ -1518,7 +1503,7 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, unsigned Index = SL->getElementContainingOffset(StructOffset); Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index)); Type *ElementTy = STy->getElementType(Index); - if (Offset.uge(DL.getTypeAllocSize(ElementTy))) + if (Offset.uge(DL.getTypeAllocSize(ElementTy).getFixedSize())) return nullptr; // The offset points into alignment padding. Indices.push_back(IRB.getInt32(Index)); @@ -1550,7 +1535,8 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, Type *ElementTy = Ty->getElementType(); if (!ElementTy->isSized()) return nullptr; // We can't GEP through an unsized element. - APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + APInt ElementSize(Offset.getBitWidth(), + DL.getTypeAllocSize(ElementTy).getFixedSize()); if (ElementSize == 0) return nullptr; // Zero-length arrays can't help us build a natural GEP. APInt NumSkippedElements = Offset.sdiv(ElementSize); @@ -1681,20 +1667,8 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, } /// Compute the adjusted alignment for a load or store from an offset. -static Align getAdjustedAlignment(Instruction *I, uint64_t Offset, - const DataLayout &DL) { - MaybeAlign Alignment; - Type *Ty; - if (auto *LI = dyn_cast<LoadInst>(I)) { - Alignment = MaybeAlign(LI->getAlignment()); - Ty = LI->getType(); - } else if (auto *SI = dyn_cast<StoreInst>(I)) { - Alignment = MaybeAlign(SI->getAlignment()); - Ty = SI->getValueOperand()->getType(); - } else { - llvm_unreachable("Only loads and stores are allowed!"); - } - return commonAlignment(DL.getValueOrABITypeAlignment(Alignment, Ty), Offset); +static Align getAdjustedAlignment(Instruction *I, uint64_t Offset) { + return commonAlignment(getLoadStoreAlignment(I), Offset); } /// Test whether we can convert a value from the old to the new type. @@ -1717,7 +1691,8 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { return false; } - if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy)) + if (DL.getTypeSizeInBits(NewTy).getFixedSize() != + DL.getTypeSizeInBits(OldTy).getFixedSize()) return false; if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType()) return false; @@ -1728,8 +1703,15 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { NewTy = NewTy->getScalarType(); if (NewTy->isPointerTy() || OldTy->isPointerTy()) { if (NewTy->isPointerTy() && OldTy->isPointerTy()) { - return cast<PointerType>(NewTy)->getPointerAddressSpace() == - cast<PointerType>(OldTy)->getPointerAddressSpace(); + unsigned OldAS = OldTy->getPointerAddressSpace(); + unsigned NewAS = NewTy->getPointerAddressSpace(); + // Convert pointers if they are pointers from the same address space or + // different integral (not non-integral) address spaces with the same + // pointer size. + return OldAS == NewAS || + (!DL.isNonIntegralAddressSpace(OldAS) && + !DL.isNonIntegralAddressSpace(NewAS) && + DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS)); } // We can convert integers to integral pointers, but not to non-integral @@ -1765,36 +1747,40 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) && "Integer types must be the exact same to convert."); - // See if we need inttoptr for this type pair. A cast involving both scalars - // and vectors requires and additional bitcast. + // See if we need inttoptr for this type pair. May require additional bitcast. if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) { // Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8* - if (OldTy->isVectorTy() && !NewTy->isVectorTy()) - return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), - NewTy); - // Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*> - if (!OldTy->isVectorTy() && NewTy->isVectorTy()) - return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), - NewTy); - - return IRB.CreateIntToPtr(V, NewTy); + // Expand <4 x i32> to <2 x i8*> --> <4 x i32> to <2 x i64> to <2 x i8*> + // Directly handle i64 to i8* + return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), + NewTy); } - // See if we need ptrtoint for this type pair. A cast involving both scalars - // and vectors requires and additional bitcast. + // See if we need ptrtoint for this type pair. May require additional bitcast. if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) { // Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128 - if (OldTy->isVectorTy() && !NewTy->isVectorTy()) - return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), - NewTy); - // Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32> - if (!OldTy->isVectorTy() && NewTy->isVectorTy()) - return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), - NewTy); + // Expand <2 x i8*> to <4 x i32> --> <2 x i8*> to <2 x i64> to <4 x i32> + // Expand i8* to i64 --> i8* to i64 to i64 + return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), + NewTy); + } - return IRB.CreatePtrToInt(V, NewTy); + if (OldTy->isPtrOrPtrVectorTy() && NewTy->isPtrOrPtrVectorTy()) { + unsigned OldAS = OldTy->getPointerAddressSpace(); + unsigned NewAS = NewTy->getPointerAddressSpace(); + // To convert pointers with different address spaces (they are already + // checked convertible, i.e. they have the same pointer size), so far we + // cannot use `bitcast` (which has restrict on the same address space) or + // `addrspacecast` (which is not always no-op casting). Instead, use a pair + // of no-op `ptrtoint`/`inttoptr` casts through an integer with the same bit + // size. + if (OldAS != NewAS) { + assert(DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS)); + return IRB.CreateIntToPtr(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), + NewTy); + } } return IRB.CreateBitCast(V, NewTy); @@ -1813,19 +1799,20 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset(); uint64_t BeginIndex = BeginOffset / ElementSize; if (BeginIndex * ElementSize != BeginOffset || - BeginIndex >= Ty->getNumElements()) + BeginIndex >= cast<FixedVectorType>(Ty)->getNumElements()) return false; uint64_t EndOffset = std::min(S.endOffset(), P.endOffset()) - P.beginOffset(); uint64_t EndIndex = EndOffset / ElementSize; - if (EndIndex * ElementSize != EndOffset || EndIndex > Ty->getNumElements()) + if (EndIndex * ElementSize != EndOffset || + EndIndex > cast<FixedVectorType>(Ty)->getNumElements()) return false; assert(EndIndex > BeginIndex && "Empty vector!"); uint64_t NumElements = EndIndex - BeginIndex; Type *SliceTy = (NumElements == 1) ? Ty->getElementType() - : VectorType::get(Ty->getElementType(), NumElements); + : FixedVectorType::get(Ty->getElementType(), NumElements); Type *SplitIntTy = Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8); @@ -1890,7 +1877,8 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // Return if bitcast to vectors is different for total size in bits. if (!CandidateTys.empty()) { VectorType *V = CandidateTys[0]; - if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) { + if (DL.getTypeSizeInBits(VTy).getFixedSize() != + DL.getTypeSizeInBits(V).getFixedSize()) { CandidateTys.clear(); return; } @@ -1936,13 +1924,15 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // they're all integer vectors. We sort by ascending number of elements. auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { (void)DL; - assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) && + assert(DL.getTypeSizeInBits(RHSTy).getFixedSize() == + DL.getTypeSizeInBits(LHSTy).getFixedSize() && "Cannot have vector types of different sizes!"); assert(RHSTy->getElementType()->isIntegerTy() && "All non-integer types eliminated!"); assert(LHSTy->getElementType()->isIntegerTy() && "All non-integer types eliminated!"); - return RHSTy->getNumElements() < LHSTy->getNumElements(); + return cast<FixedVectorType>(RHSTy)->getNumElements() < + cast<FixedVectorType>(LHSTy)->getNumElements(); }; llvm::sort(CandidateTys, RankVectorTypes); CandidateTys.erase( @@ -1964,13 +1954,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // Try each vector type, and return the one which works. auto CheckVectorTypeForPromotion = [&](VectorType *VTy) { - uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType()); + uint64_t ElementSize = + DL.getTypeSizeInBits(VTy->getElementType()).getFixedSize(); // While the definition of LLVM vectors is bitpacked, we don't support sizes // that aren't byte sized. if (ElementSize % 8) return false; - assert((DL.getTypeSizeInBits(VTy) % 8) == 0 && + assert((DL.getTypeSizeInBits(VTy).getFixedSize() % 8) == 0 && "vector size not a multiple of element size?"); ElementSize /= 8; @@ -2000,7 +1991,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, Type *AllocaTy, const DataLayout &DL, bool &WholeAllocaOp) { - uint64_t Size = DL.getTypeStoreSize(AllocaTy); + uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedSize(); uint64_t RelBegin = S.beginOffset() - AllocBeginOffset; uint64_t RelEnd = S.endOffset() - AllocBeginOffset; @@ -2016,7 +2007,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (LI->isVolatile()) return false; // We can't handle loads that extend past the allocated memory. - if (DL.getTypeStoreSize(LI->getType()) > Size) + if (DL.getTypeStoreSize(LI->getType()).getFixedSize() > Size) return false; // So far, AllocaSliceRewriter does not support widening split slice tails // in rewriteIntegerLoad. @@ -2028,7 +2019,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) { - if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize()) return false; } else if (RelBegin != 0 || RelEnd != Size || !canConvertValue(DL, AllocaTy, LI->getType())) { @@ -2041,7 +2032,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (SI->isVolatile()) return false; // We can't handle stores that extend past the allocated memory. - if (DL.getTypeStoreSize(ValueTy) > Size) + if (DL.getTypeStoreSize(ValueTy).getFixedSize() > Size) return false; // So far, AllocaSliceRewriter does not support widening split slice tails // in rewriteIntegerStore. @@ -2053,7 +2044,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) { - if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize()) return false; } else if (RelBegin != 0 || RelEnd != Size || !canConvertValue(DL, ValueTy, AllocaTy)) { @@ -2084,13 +2075,13 @@ static bool isIntegerWideningViableForSlice(const Slice &S, /// promote the resulting alloca. static bool isIntegerWideningViable(Partition &P, Type *AllocaTy, const DataLayout &DL) { - uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy); + uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedSize(); // Don't create integer types larger than the maximum bitwidth. if (SizeInBits > IntegerType::MAX_INT_BITS) return false; // Don't try to handle allocas with bit-padding. - if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy)) + if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedSize()) return false; // We need to ensure that an integer type with the appropriate bitwidth can @@ -2129,11 +2120,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, const Twine &Name) { LLVM_DEBUG(dbgs() << " start: " << *V << "\n"); IntegerType *IntTy = cast<IntegerType>(V->getType()); - assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && + assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= + DL.getTypeStoreSize(IntTy).getFixedSize() && "Element extends past full value"); uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - + DL.getTypeStoreSize(Ty).getFixedSize() - Offset); if (ShAmt) { V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -2158,11 +2151,13 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, V = IRB.CreateZExt(V, IntTy, Name + ".ext"); LLVM_DEBUG(dbgs() << " extended: " << *V << "\n"); } - assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && + assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= + DL.getTypeStoreSize(IntTy).getFixedSize() && "Element store outside of alloca store"); uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - + DL.getTypeStoreSize(Ty).getFixedSize() - Offset); if (ShAmt) { V = IRB.CreateShl(V, ShAmt, Name + ".shift"); LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -2180,7 +2175,7 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name) { - VectorType *VecTy = cast<VectorType>(V->getType()); + auto *VecTy = cast<FixedVectorType>(V->getType()); unsigned NumElements = EndIndex - BeginIndex; assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); @@ -2194,12 +2189,12 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, return V; } - SmallVector<Constant *, 8> Mask; + SmallVector<int, 8> Mask; Mask.reserve(NumElements); for (unsigned i = BeginIndex; i != EndIndex; ++i) - Mask.push_back(IRB.getInt32(i)); - V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), - ConstantVector::get(Mask), Name + ".extract"); + Mask.push_back(i); + V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), Mask, + Name + ".extract"); LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n"); return V; } @@ -2218,21 +2213,23 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, return V; } - assert(Ty->getNumElements() <= VecTy->getNumElements() && + assert(cast<FixedVectorType>(Ty)->getNumElements() <= + cast<FixedVectorType>(VecTy)->getNumElements() && "Too many elements!"); - if (Ty->getNumElements() == VecTy->getNumElements()) { + if (cast<FixedVectorType>(Ty)->getNumElements() == + cast<FixedVectorType>(VecTy)->getNumElements()) { assert(V->getType() == VecTy && "Vector type mismatch"); return V; } - unsigned EndIndex = BeginIndex + Ty->getNumElements(); + unsigned EndIndex = BeginIndex + cast<FixedVectorType>(Ty)->getNumElements(); // When inserting a smaller vector into the larger to store, we first // use a shuffle vector to widen it with undef elements, and then // a second shuffle vector to select between the loaded vector and the // incoming vector. SmallVector<Constant *, 8> Mask; - Mask.reserve(VecTy->getNumElements()); - for (unsigned i = 0; i != VecTy->getNumElements(); ++i) + Mask.reserve(cast<FixedVectorType>(VecTy)->getNumElements()); + for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i) if (i >= BeginIndex && i < EndIndex) Mask.push_back(IRB.getInt32(i - BeginIndex)); else @@ -2242,7 +2239,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n"); Mask.clear(); - for (unsigned i = 0; i != VecTy->getNumElements(); ++i) + for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i) Mask.push_back(IRB.getInt1(i >= BeginIndex && i < EndIndex)); V = IRB.CreateSelect(ConstantVector::get(Mask), V, Old, Name + "blend"); @@ -2325,18 +2322,20 @@ public: NewAllocaBeginOffset(NewAllocaBeginOffset), NewAllocaEndOffset(NewAllocaEndOffset), NewAllocaTy(NewAI.getAllocatedType()), - IntTy(IsIntegerPromotable - ? Type::getIntNTy( - NewAI.getContext(), - DL.getTypeSizeInBits(NewAI.getAllocatedType())) - : nullptr), + IntTy( + IsIntegerPromotable + ? Type::getIntNTy(NewAI.getContext(), + DL.getTypeSizeInBits(NewAI.getAllocatedType()) + .getFixedSize()) + : nullptr), VecTy(PromotableVecTy), ElementTy(VecTy ? VecTy->getElementType() : nullptr), - ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), + ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8 + : 0), PHIUsers(PHIUsers), SelectUsers(SelectUsers), IRB(NewAI.getContext(), ConstantFolder()) { if (VecTy) { - assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 && + assert((DL.getTypeSizeInBits(ElementTy).getFixedSize() % 8) == 0 && "Only multiple-of-8 sized vector elements are viable"); ++NumVectorized; } @@ -2368,7 +2367,8 @@ public: Instruction *OldUserI = cast<Instruction>(OldUse->getUser()); IRB.SetInsertPoint(OldUserI); IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc()); - IRB.SetNamePrefix(Twine(NewAI.getName()) + "." + Twine(BeginOffset) + "."); + IRB.getInserter().SetNamePrefix( + Twine(NewAI.getName()) + "." + Twine(BeginOffset) + "."); CanSROA &= visit(cast<Instruction>(OldUse->getUser())); if (VecTy || IntTy) @@ -2429,14 +2429,9 @@ private: /// /// You can optionally pass a type to this routine and if that type's ABI /// alignment is itself suitable, this will return zero. - MaybeAlign getSliceAlign(Type *Ty = nullptr) { - const MaybeAlign NewAIAlign = DL.getValueOrABITypeAlignment( - MaybeAlign(NewAI.getAlignment()), NewAI.getAllocatedType()); - const MaybeAlign Align = - commonAlignment(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset); - return (Ty && Align && Align->value() == DL.getABITypeAlignment(Ty)) - ? None - : Align; + Align getSliceAlign() { + return commonAlignment(NewAI.getAlign(), + NewBeginOffset - NewAllocaBeginOffset); } unsigned getIndex(uint64_t Offset) { @@ -2460,7 +2455,7 @@ private: assert(EndIndex > BeginIndex && "Empty vector!"); Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "load"); + NewAI.getAlign(), "load"); return extractVector(IRB, V, BeginIndex, EndIndex, "vec"); } @@ -2468,7 +2463,7 @@ private: assert(IntTy && "We cannot insert an integer to the alloca"); assert(!LI.isVolatile()); Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "load"); + NewAI.getAlign(), "load"); V = convertValue(DL, IRB, V, IntTy); assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; @@ -2500,7 +2495,8 @@ private: Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) : LI.getType(); - const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; + const bool IsLoadPastEnd = + DL.getTypeStoreSize(TargetTy).getFixedSize() > SliceSize; bool IsPtrAdjusted = false; Value *V; if (VecTy) { @@ -2513,12 +2509,14 @@ private: (IsLoadPastEnd && NewAllocaTy->isIntegerTy() && TargetTy->isIntegerTy()))) { LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), - LI.isVolatile(), LI.getName()); + NewAI.getAlign(), LI.isVolatile(), + LI.getName()); if (AATags) NewLI->setAAMetadata(AATags); if (LI.isVolatile()) NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + if (NewLI->isAtomic()) + NewLI->setAlignment(LI.getAlign()); // Any !nonnull metadata or !range metadata on the old load is also valid // on the new load. This is even true in some cases even when the loads @@ -2549,9 +2547,9 @@ private: } } else { Type *LTy = TargetTy->getPointerTo(AS); - LoadInst *NewLI = IRB.CreateAlignedLoad( - TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy), - LI.isVolatile(), LI.getName()); + LoadInst *NewLI = + IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy), + getSliceAlign(), LI.isVolatile(), LI.getName()); if (AATags) NewLI->setAAMetadata(AATags); if (LI.isVolatile()) @@ -2566,7 +2564,7 @@ private: assert(!LI.isVolatile()); assert(LI.getType()->isIntegerTy() && "Only integer type loads and stores are split"); - assert(SliceSize < DL.getTypeStoreSize(LI.getType()) && + assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedSize() && "Split load isn't smaller than original load"); assert(DL.typeSizeEqualsStoreSize(LI.getType()) && "Non-byte-multiple bit width"); @@ -2577,7 +2575,8 @@ private: // the computed value, and then replace the placeholder with LI, leaving // LI only used for this computation. Value *Placeholder = new LoadInst( - LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS))); + LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS)), "", + false, Align(1)); V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, "insert"); LI.replaceAllUsesWith(V); @@ -2600,19 +2599,20 @@ private: unsigned EndIndex = getIndex(NewEndOffset); assert(EndIndex > BeginIndex && "Empty vector!"); unsigned NumElements = EndIndex - BeginIndex; - assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() && + "Too many elements!"); Type *SliceTy = (NumElements == 1) ? ElementTy - : VectorType::get(ElementTy, NumElements); + : FixedVectorType::get(ElementTy, NumElements); if (V->getType() != SliceTy) V = convertValue(DL, IRB, V, SliceTy); // Mix in the existing elements. Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "load"); + NewAI.getAlign(), "load"); V = insertVector(IRB, Old, V, BeginIndex, "vec"); } - StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign()); if (AATags) Store->setAAMetadata(AATags); Pass.DeadInsts.insert(&SI); @@ -2624,16 +2624,17 @@ private: bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) { assert(IntTy && "We cannot extract an integer from the alloca"); assert(!SI.isVolatile()); - if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { + if (DL.getTypeSizeInBits(V->getType()).getFixedSize() != + IntTy->getBitWidth()) { Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "oldload"); + NewAI.getAlign(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = BeginOffset - NewAllocaBeginOffset; V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert"); } V = convertValue(DL, IRB, V, NewAllocaTy); - StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign()); Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) @@ -2659,7 +2660,7 @@ private: if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets())) Pass.PostPromotionWorklist.insert(AI); - if (SliceSize < DL.getTypeStoreSize(V->getType())) { + if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedSize()) { assert(!SI.isVolatile()); assert(V->getType()->isIntegerTy() && "Only integer type loads and stores are split"); @@ -2675,7 +2676,8 @@ private: if (IntTy && V->getType()->isIntegerTy()) return rewriteIntegerStore(V, SI, AATags); - const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize; + const bool IsStorePastEnd = + DL.getTypeStoreSize(V->getType()).getFixedSize() > SliceSize; StoreInst *NewSI; if (NewBeginOffset == NewAllocaBeginOffset && NewEndOffset == NewAllocaEndOffset && @@ -2695,13 +2697,13 @@ private: } V = convertValue(DL, IRB, V, NewAllocaTy); - NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), - SI.isVolatile()); + NewSI = + IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), SI.isVolatile()); } else { unsigned AS = SI.getPointerAddressSpace(); Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); - NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), - SI.isVolatile()); + NewSI = + IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile()); } NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); @@ -2709,6 +2711,8 @@ private: NewSI->setAAMetadata(AATags); if (SI.isVolatile()) NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); + if (NewSI->isAtomic()) + NewSI->setAlignment(SI.getAlign()); Pass.DeadInsts.insert(&SI); deleteIfTriviallyDead(OldOp); @@ -2786,9 +2790,9 @@ private: return false; const auto Len = C->getZExtValue(); auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext()); - auto *SrcTy = VectorType::get(Int8Ty, Len); + auto *SrcTy = FixedVectorType::get(Int8Ty, Len); return canConvertValue(DL, SrcTy, AllocaTy) && - DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)); + DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedSize()); }(); // If this doesn't map cleanly onto the alloca type, and that type isn't @@ -2820,16 +2824,17 @@ private: unsigned EndIndex = getIndex(NewEndOffset); assert(EndIndex > BeginIndex && "Empty vector!"); unsigned NumElements = EndIndex - BeginIndex; - assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() && + "Too many elements!"); - Value *Splat = - getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8); + Value *Splat = getIntegerSplat( + II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8); Splat = convertValue(DL, IRB, Splat, ElementTy); if (NumElements > 1) Splat = getVectorSplat(Splat, NumElements); Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "oldload"); + NewAI.getAlign(), "oldload"); V = insertVector(IRB, Old, Splat, BeginIndex, "vec"); } else if (IntTy) { // If this is a memset on an alloca where we can widen stores, insert the @@ -2842,7 +2847,7 @@ private: if (IntTy && (BeginOffset != NewAllocaBeginOffset || EndOffset != NewAllocaBeginOffset)) { Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "oldload"); + NewAI.getAlign(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; V = insertInteger(DL, IRB, Old, V, Offset, "insert"); @@ -2856,15 +2861,17 @@ private: assert(NewBeginOffset == NewAllocaBeginOffset); assert(NewEndOffset == NewAllocaEndOffset); - V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8); + V = getIntegerSplat(II.getValue(), + DL.getTypeSizeInBits(ScalarTy).getFixedSize() / 8); if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy)) - V = getVectorSplat(V, AllocaVecTy->getNumElements()); + V = getVectorSplat( + V, cast<FixedVectorType>(AllocaVecTy)->getNumElements()); V = convertValue(DL, IRB, V, AllocaTy); } - StoreInst *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), - II.isVolatile()); + StoreInst *New = + IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), II.isVolatile()); if (AATags) New->setAAMetadata(AATags); LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); @@ -2919,7 +2926,8 @@ private: bool EmitMemCpy = !VecTy && !IntTy && (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || - SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) || + SliceSize != + DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedSize() || !NewAI.getAllocatedType()->isSingleValueType()); // If we're just going to emit a memcpy, the alloca hasn't changed, and the @@ -2955,7 +2963,7 @@ private: unsigned OffsetWidth = DL.getIndexSizeInBits(OtherAS); APInt OtherOffset(OffsetWidth, NewBeginOffset - BeginOffset); Align OtherAlign = - assumeAligned(IsDest ? II.getSourceAlignment() : II.getDestAlignment()); + (IsDest ? II.getSourceAlign() : II.getDestAlign()).valueOrOne(); OtherAlign = commonAlignment(OtherAlign, OtherOffset.zextOrTrunc(64).getZExtValue()); @@ -3007,7 +3015,7 @@ private: if (NumElements == 1) OtherTy = VecTy->getElementType(); else - OtherTy = VectorType::get(VecTy->getElementType(), NumElements); + OtherTy = FixedVectorType::get(VecTy->getElementType(), NumElements); } else if (IntTy && !IsWholeAlloca) { OtherTy = SubIntTy; } else { @@ -3028,11 +3036,11 @@ private: Value *Src; if (VecTy && !IsWholeAlloca && !IsDest) { Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "load"); + NewAI.getAlign(), "load"); Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec"); } else if (IntTy && !IsWholeAlloca && !IsDest) { Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "load"); + NewAI.getAlign(), "load"); Src = convertValue(DL, IRB, Src, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); @@ -3046,11 +3054,11 @@ private: if (VecTy && !IsWholeAlloca && IsDest) { Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "oldload"); + NewAI.getAlign(), "oldload"); Src = insertVector(IRB, Old, Src, BeginIndex, "vec"); } else if (IntTy && !IsWholeAlloca && IsDest) { Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, - NewAI.getAlignment(), "oldload"); + NewAI.getAlign(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = insertInteger(DL, IRB, Old, Src, Offset, "insert"); @@ -3115,17 +3123,12 @@ private: Instruction *I = Uses.pop_back_val(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - MaybeAlign LoadAlign = DL.getValueOrABITypeAlignment( - MaybeAlign(LI->getAlignment()), LI->getType()); - LI->setAlignment(std::min(LoadAlign, getSliceAlign())); + LI->setAlignment(std::min(LI->getAlign(), getSliceAlign())); continue; } if (StoreInst *SI = dyn_cast<StoreInst>(I)) { - Value *Op = SI->getOperand(0); - MaybeAlign StoreAlign = DL.getValueOrABITypeAlignment( - MaybeAlign(SI->getAlignment()), Op->getType()); - SI->setAlignment(std::min(StoreAlign, getSliceAlign())); - continue; + SI->setAlignment(std::min(SI->getAlign(), getSliceAlign())); + continue; } assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) || @@ -3146,14 +3149,14 @@ private: // as local as possible to the PHI. To do that, we re-use the location of // the old pointer, which necessarily must be in the right position to // dominate the PHI. - IRBuilderTy PtrBuilder(IRB); + IRBuilderBase::InsertPointGuard Guard(IRB); if (isa<PHINode>(OldPtr)) - PtrBuilder.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt()); + IRB.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt()); else - PtrBuilder.SetInsertPoint(OldPtr); - PtrBuilder.SetCurrentDebugLocation(OldPtr->getDebugLoc()); + IRB.SetInsertPoint(OldPtr); + IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc()); - Value *NewPtr = getNewAllocaSlicePtr(PtrBuilder, OldPtr->getType()); + Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); // Replace the operands which were using the old pointer. std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr); @@ -3357,7 +3360,7 @@ private: Value *GEP = IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep"); LoadInst *Load = - IRB.CreateAlignedLoad(Ty, GEP, Alignment.value(), Name + ".load"); + IRB.CreateAlignedLoad(Ty, GEP, Alignment, Name + ".load"); if (AATags) Load->setAAMetadata(AATags); Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); @@ -3375,9 +3378,10 @@ private: AAMDNodes AATags; LI.getAAMetadata(AATags); LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags, - getAdjustedAlignment(&LI, 0, DL), DL); + getAdjustedAlignment(&LI, 0), DL); Value *V = UndefValue::get(LI.getType()); Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); + Visited.erase(&LI); LI.replaceAllUsesWith(V); LI.eraseFromParent(); return true; @@ -3403,7 +3407,7 @@ private: Value *InBoundsGEP = IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep"); StoreInst *Store = - IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment.value()); + IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment); if (AATags) Store->setAAMetadata(AATags); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); @@ -3422,8 +3426,9 @@ private: AAMDNodes AATags; SI.getAAMetadata(AATags); StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags, - getAdjustedAlignment(&SI, 0, DL), DL); + getAdjustedAlignment(&SI, 0), DL); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); + Visited.erase(&SI); SI.eraseFromParent(); return true; } @@ -3438,7 +3443,110 @@ private: return false; } + // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2) + bool foldGEPSelect(GetElementPtrInst &GEPI) { + if (!GEPI.hasAllConstantIndices()) + return false; + + SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand()); + + LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):" + << "\n original: " << *Sel + << "\n " << GEPI); + + IRBuilderTy Builder(&GEPI); + SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end()); + bool IsInBounds = GEPI.isInBounds(); + + Value *True = Sel->getTrueValue(); + Value *NTrue = + IsInBounds + ? Builder.CreateInBoundsGEP(True, Index, + True->getName() + ".sroa.gep") + : Builder.CreateGEP(True, Index, True->getName() + ".sroa.gep"); + + Value *False = Sel->getFalseValue(); + + Value *NFalse = + IsInBounds + ? Builder.CreateInBoundsGEP(False, Index, + False->getName() + ".sroa.gep") + : Builder.CreateGEP(False, Index, False->getName() + ".sroa.gep"); + + Value *NSel = Builder.CreateSelect(Sel->getCondition(), NTrue, NFalse, + Sel->getName() + ".sroa.sel"); + Visited.erase(&GEPI); + GEPI.replaceAllUsesWith(NSel); + GEPI.eraseFromParent(); + Instruction *NSelI = cast<Instruction>(NSel); + Visited.insert(NSelI); + enqueueUsers(*NSelI); + + LLVM_DEBUG(dbgs() << "\n to: " << *NTrue + << "\n " << *NFalse + << "\n " << *NSel << '\n'); + + return true; + } + + // Fold gep (phi ptr1, ptr2) => phi gep(ptr1), gep(ptr2) + bool foldGEPPhi(GetElementPtrInst &GEPI) { + if (!GEPI.hasAllConstantIndices()) + return false; + + PHINode *PHI = cast<PHINode>(GEPI.getPointerOperand()); + if (GEPI.getParent() != PHI->getParent() || + llvm::any_of(PHI->incoming_values(), [](Value *In) + { Instruction *I = dyn_cast<Instruction>(In); + return !I || isa<GetElementPtrInst>(I) || isa<PHINode>(I) || + succ_empty(I->getParent()) || + !I->getParent()->isLegalToHoistInto(); + })) + return false; + + LLVM_DEBUG(dbgs() << " Rewriting gep(phi) -> phi(gep):" + << "\n original: " << *PHI + << "\n " << GEPI + << "\n to: "); + + SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end()); + bool IsInBounds = GEPI.isInBounds(); + IRBuilderTy PHIBuilder(GEPI.getParent()->getFirstNonPHI()); + PHINode *NewPN = PHIBuilder.CreatePHI(GEPI.getType(), + PHI->getNumIncomingValues(), + PHI->getName() + ".sroa.phi"); + for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { + Instruction *In = cast<Instruction>(PHI->getIncomingValue(I)); + + IRBuilderTy B(In->getParent(), std::next(In->getIterator())); + Value *NewVal = IsInBounds + ? B.CreateInBoundsGEP(In, Index, In->getName() + ".sroa.gep") + : B.CreateGEP(In, Index, In->getName() + ".sroa.gep"); + NewPN->addIncoming(NewVal, PHI->getIncomingBlock(I)); + } + + Visited.erase(&GEPI); + GEPI.replaceAllUsesWith(NewPN); + GEPI.eraseFromParent(); + Visited.insert(NewPN); + enqueueUsers(*NewPN); + + LLVM_DEBUG(for (Value *In : NewPN->incoming_values()) + dbgs() << "\n " << *In; + dbgs() << "\n " << *NewPN << '\n'); + + return true; + } + bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { + if (isa<SelectInst>(GEPI.getPointerOperand()) && + foldGEPSelect(GEPI)) + return true; + + if (isa<PHINode>(GEPI.getPointerOperand()) && + foldGEPPhi(GEPI)) + return true; + enqueueUsers(GEPI); return false; } @@ -3465,8 +3573,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { if (Ty->isSingleValueType()) return Ty; - uint64_t AllocSize = DL.getTypeAllocSize(Ty); - uint64_t TypeSize = DL.getTypeSizeInBits(Ty); + uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedSize(); + uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedSize(); Type *InnerTy; if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { @@ -3479,8 +3587,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { return Ty; } - if (AllocSize > DL.getTypeAllocSize(InnerTy) || - TypeSize > DL.getTypeSizeInBits(InnerTy)) + if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedSize() || + TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedSize()) return Ty; return stripAggregateTypeWrapping(DL, InnerTy); @@ -3501,17 +3609,28 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { /// return a type if necessary. static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, uint64_t Size) { - if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size) + if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedSize() == Size) return stripAggregateTypeWrapping(DL, Ty); - if (Offset > DL.getTypeAllocSize(Ty) || - (DL.getTypeAllocSize(Ty) - Offset) < Size) + if (Offset > DL.getTypeAllocSize(Ty).getFixedSize() || + (DL.getTypeAllocSize(Ty).getFixedSize() - Offset) < Size) return nullptr; - if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) { - Type *ElementTy = SeqTy->getElementType(); - uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); + if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) { + Type *ElementTy; + uint64_t TyNumElements; + if (auto *AT = dyn_cast<ArrayType>(Ty)) { + ElementTy = AT->getElementType(); + TyNumElements = AT->getNumElements(); + } else { + // FIXME: This isn't right for vectors with non-byte-sized or + // non-power-of-two sized elements. + auto *VT = cast<FixedVectorType>(Ty); + ElementTy = VT->getElementType(); + TyNumElements = VT->getNumElements(); + } + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize(); uint64_t NumSkippedElements = Offset / ElementSize; - if (NumSkippedElements >= SeqTy->getNumElements()) + if (NumSkippedElements >= TyNumElements) return nullptr; Offset -= NumSkippedElements * ElementSize; @@ -3549,7 +3668,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, Offset -= SL->getElementOffset(Index); Type *ElementTy = STy->getElementType(Index); - uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize(); if (Offset >= ElementSize) return nullptr; // The offset points into alignment padding. @@ -3860,7 +3979,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { getAdjustedPtr(IRB, DL, BasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, BasePtr->getName() + "."), - getAdjustedAlignment(LI, PartOffset, DL).value(), + getAdjustedAlignment(LI, PartOffset), /*IsVolatile*/ false, LI->getName()); PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); @@ -3918,7 +4037,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { getAdjustedPtr(IRB, DL, StoreBasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, StoreBasePtr->getName() + "."), - getAdjustedAlignment(SI, PartOffset, DL).value(), + getAdjustedAlignment(SI, PartOffset), /*IsVolatile*/ false); PStore->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); @@ -4003,7 +4122,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { getAdjustedPtr(IRB, DL, LoadBasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), LoadPartPtrTy, LoadBasePtr->getName() + "."), - getAdjustedAlignment(LI, PartOffset, DL).value(), + getAdjustedAlignment(LI, PartOffset), /*IsVolatile*/ false, LI->getName()); } @@ -4015,7 +4134,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { getAdjustedPtr(IRB, DL, StoreBasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), StorePartPtrTy, StoreBasePtr->getName() + "."), - getAdjustedAlignment(SI, PartOffset, DL).value(), + getAdjustedAlignment(SI, PartOffset), /*IsVolatile*/ false); // Now build a new slice for the alloca. @@ -4117,7 +4236,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, Type *SliceTy = nullptr; const DataLayout &DL = AI.getModule()->getDataLayout(); if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset())) - if (DL.getTypeAllocSize(CommonUseTy) >= P.size()) + if (DL.getTypeAllocSize(CommonUseTy).getFixedSize() >= P.size()) SliceTy = CommonUseTy; if (!SliceTy) if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), @@ -4129,7 +4248,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, SliceTy = Type::getIntNTy(*C, P.size() * 8); if (!SliceTy) SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); - assert(DL.getTypeAllocSize(SliceTy) >= P.size()); + assert(DL.getTypeAllocSize(SliceTy).getFixedSize() >= P.size()); bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL); @@ -4151,19 +4270,14 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // FIXME: We might want to defer PHI speculation until after here. // FIXME: return nullptr; } else { - // If alignment is unspecified we fallback on the one required by the ABI - // for this type. We also make sure the alignment is compatible with - // P.beginOffset(). - const Align Alignment = commonAlignment( - DL.getValueOrABITypeAlignment(MaybeAlign(AI.getAlignment()), - AI.getAllocatedType()), - P.beginOffset()); + // Make sure the alignment is compatible with P.beginOffset(). + const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset()); // If we will get at least this much alignment from the type alone, leave // the alloca's alignment unconstrained. - const bool IsUnconstrained = Alignment <= DL.getABITypeAlignment(SliceTy); + const bool IsUnconstrained = Alignment <= DL.getABITypeAlign(SliceTy); NewAI = new AllocaInst( SliceTy, AI.getType()->getAddressSpace(), nullptr, - IsUnconstrained ? MaybeAlign() : Alignment, + IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment, AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); // Copy the old AI debug location over to the new one. NewAI->setDebugLoc(AI.getDebugLoc()); @@ -4270,7 +4384,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // to be rewritten into a partition. bool IsSorted = true; - uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); + uint64_t AllocaSize = + DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize(); const uint64_t MaxBitVectorSize = 1024; if (AllocaSize <= MaxBitVectorSize) { // If a byte boundary is included in any load or store, a slice starting or @@ -4334,7 +4449,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { Changed = true; if (NewAI != &AI) { uint64_t SizeOfByte = 8; - uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType()); + uint64_t AllocaSize = + DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedSize(); // Don't include any padding. uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); @@ -4354,7 +4470,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { auto *Expr = DbgDeclares.front()->getExpression(); auto VarSize = Var->getSizeInBits(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); - uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); + uint64_t AllocaSize = + DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedSize(); for (auto Fragment : Fragments) { // Create a fragment expression describing the new partition or reuse AI's // expression if there is only one partition. @@ -4442,8 +4559,9 @@ bool SROA::runOnAlloca(AllocaInst &AI) { const DataLayout &DL = AI.getModule()->getDataLayout(); // Skip alloca forms that this analysis can't handle. - if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() || - DL.getTypeAllocSize(AI.getAllocatedType()) == 0) + auto *AT = AI.getAllocatedType(); + if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) || + DL.getTypeAllocSize(AT).getFixedSize() == 0) return false; bool Changed = false; @@ -4563,8 +4681,14 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); I != E; ++I) { - if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) - Worklist.insert(AI); + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { + if (isa<ScalableVectorType>(AI->getAllocatedType())) { + if (isAllocaPromotable(AI)) + PromotableAllocas.push_back(AI); + } else { + Worklist.insert(AI); + } + } } bool Changed = false; diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index c25c6c632b8f..851bd79cd6d8 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -22,8 +22,8 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -41,6 +41,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <iterator> @@ -51,6 +52,11 @@ using namespace llvm; #define DEBUG_TYPE "scalarizer" +static cl::opt<bool> ScalarizeVariableInsertExtract( + "scalarize-variable-insert-extract", cl::init(true), cl::Hidden, + cl::desc("Allow the scalarizer pass to scalarize " + "insertelement/extractelement with variable index")); + // This is disabled by default because having separate loads and stores // makes it more likely that the -combiner-alias-analysis limits will be // reached. @@ -156,8 +162,8 @@ struct VectorLayout { VectorLayout() = default; // Return the alignment of element I. - uint64_t getElemAlign(unsigned I) { - return MinAlign(VecAlign, I * ElemSize); + Align getElemAlign(unsigned I) { + return commonAlignment(VecAlign, I * ElemSize); } // The type of the vector. @@ -167,7 +173,7 @@ struct VectorLayout { Type *ElemTy = nullptr; // The alignment of the vector. - uint64_t VecAlign = 0; + Align VecAlign; // The size of each element. uint64_t ElemSize = 0; @@ -192,6 +198,8 @@ public: bool visitGetElementPtrInst(GetElementPtrInst &GEPI); bool visitCastInst(CastInst &CI); bool visitBitCastInst(BitCastInst &BCI); + bool visitInsertElementInst(InsertElementInst &IEI); + bool visitExtractElementInst(ExtractElementInst &EEI); bool visitShuffleVectorInst(ShuffleVectorInst &SVI); bool visitPHINode(PHINode &PHI); bool visitLoadInst(LoadInst &LI); @@ -203,8 +211,8 @@ private: void gather(Instruction *Op, const ValueVector &CV); bool canTransferMetadata(unsigned Kind); void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); - bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, - const DataLayout &DL); + Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, + const DataLayout &DL); bool finish(); template<typename T> bool splitUnary(Instruction &, const T &); @@ -215,6 +223,8 @@ private: ScatterMap Scattered; GatherList Gathered; + SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; + unsigned ParallelLoopAccessMDKind; DominatorTree *DT; @@ -252,7 +262,7 @@ Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, PtrTy = dyn_cast<PointerType>(Ty); if (PtrTy) Ty = PtrTy->getElementType(); - Size = Ty->getVectorNumElements(); + Size = cast<FixedVectorType>(Ty)->getNumElements(); if (!CachePtr) Tmp.resize(Size, nullptr); else if (CachePtr->empty()) @@ -269,7 +279,7 @@ Value *Scatterer::operator[](unsigned I) { return CV[I]; IRBuilder<> Builder(BB, BBI); if (PtrTy) { - Type *ElTy = PtrTy->getElementType()->getVectorElementType(); + Type *ElTy = cast<VectorType>(PtrTy->getElementType())->getElementType(); if (!CV[0]) { Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace()); CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); @@ -376,11 +386,6 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { // so that we can avoid creating the gathered form if all uses of Op are // replaced with uses of CV. void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { - // Since we're not deleting Op yet, stub out its operands, so that it - // doesn't make anything live unnecessarily. - for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) - Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType())); - transferMetadataAndIRFlags(Op, CV); // If we already have a scattered form of Op (created from ExtractElements @@ -389,13 +394,13 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { if (!SV.empty()) { for (unsigned I = 0, E = SV.size(); I != E; ++I) { Value *V = SV[I]; - if (V == nullptr) + if (V == nullptr || SV[I] == CV[I]) continue; Instruction *Old = cast<Instruction>(V); CV[I]->takeName(Old); Old->replaceAllUsesWith(CV[I]); - Old->eraseFromParent(); + PotentiallyDeadInstrs.emplace_back(Old); } } SV = CV; @@ -434,25 +439,22 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, } // Try to fill in Layout from Ty, returning true on success. Alignment is -// the alignment of the vector, or 0 if the ABI default should be used. -bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment, - VectorLayout &Layout, const DataLayout &DL) { +// the alignment of the vector, or None if the ABI default should be used. +Optional<VectorLayout> +ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, + const DataLayout &DL) { + VectorLayout Layout; // Make sure we're dealing with a vector. Layout.VecTy = dyn_cast<VectorType>(Ty); if (!Layout.VecTy) - return false; - + return None; // Check that we're dealing with full-byte elements. Layout.ElemTy = Layout.VecTy->getElementType(); if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy)) - return false; - - if (Alignment) - Layout.VecAlign = Alignment; - else - Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy); + return None; + Layout.VecAlign = Alignment; Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy); - return true; + return Layout; } // Scalarize one-operand instruction I, using Split(Builder, X, Name) @@ -463,7 +465,7 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&I); Scatterer Op = scatter(&I, I.getOperand(0)); assert(Op.size() == NumElems && "Mismatched unary operation"); @@ -483,17 +485,19 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&I); - Scatterer Op0 = scatter(&I, I.getOperand(0)); - Scatterer Op1 = scatter(&I, I.getOperand(1)); - assert(Op0.size() == NumElems && "Mismatched binary operation"); - assert(Op1.size() == NumElems && "Mismatched binary operation"); + Scatterer VOp0 = scatter(&I, I.getOperand(0)); + Scatterer VOp1 = scatter(&I, I.getOperand(1)); + assert(VOp0.size() == NumElems && "Mismatched binary operation"); + assert(VOp1.size() == NumElems && "Mismatched binary operation"); ValueVector Res; Res.resize(NumElems); - for (unsigned Elem = 0; Elem < NumElems; ++Elem) - Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem], - I.getName() + ".i" + Twine(Elem)); + for (unsigned Elem = 0; Elem < NumElems; ++Elem) { + Value *Op0 = VOp0[Elem]; + Value *Op1 = VOp1[Elem]; + Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem)); + } gather(&I, Res); return true; } @@ -524,7 +528,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); unsigned NumArgs = CI.getNumArgOperands(); ValueVector ScalarOperands(NumArgs); @@ -574,26 +578,33 @@ bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&SI); - Scatterer Op1 = scatter(&SI, SI.getOperand(1)); - Scatterer Op2 = scatter(&SI, SI.getOperand(2)); - assert(Op1.size() == NumElems && "Mismatched select"); - assert(Op2.size() == NumElems && "Mismatched select"); + Scatterer VOp1 = scatter(&SI, SI.getOperand(1)); + Scatterer VOp2 = scatter(&SI, SI.getOperand(2)); + assert(VOp1.size() == NumElems && "Mismatched select"); + assert(VOp2.size() == NumElems && "Mismatched select"); ValueVector Res; Res.resize(NumElems); if (SI.getOperand(0)->getType()->isVectorTy()) { - Scatterer Op0 = scatter(&SI, SI.getOperand(0)); - assert(Op0.size() == NumElems && "Mismatched select"); - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I], + Scatterer VOp0 = scatter(&SI, SI.getOperand(0)); + assert(VOp0.size() == NumElems && "Mismatched select"); + for (unsigned I = 0; I < NumElems; ++I) { + Value *Op0 = VOp0[I]; + Value *Op1 = VOp1[I]; + Value *Op2 = VOp2[I]; + Res[I] = Builder.CreateSelect(Op0, Op1, Op2, SI.getName() + ".i" + Twine(I)); + } } else { Value *Op0 = SI.getOperand(0); - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I], + for (unsigned I = 0; I < NumElems; ++I) { + Value *Op1 = VOp1[I]; + Value *Op2 = VOp2[I]; + Res[I] = Builder.CreateSelect(Op0, Op1, Op2, SI.getName() + ".i" + Twine(I)); + } } gather(&SI, Res); return true; @@ -621,7 +632,7 @@ bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { return false; IRBuilder<> Builder(&GEPI); - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); unsigned NumIndices = GEPI.getNumIndices(); // The base pointer might be scalar even if it's a vector GEP. In those cases, @@ -666,7 +677,7 @@ bool ScalarizerVisitor::visitCastInst(CastInst &CI) { if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&CI); Scatterer Op0 = scatter(&CI, CI.getOperand(0)); assert(Op0.size() == NumElems && "Mismatched cast"); @@ -685,8 +696,8 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { if (!DstVT || !SrcVT) return false; - unsigned DstNumElems = DstVT->getNumElements(); - unsigned SrcNumElems = SrcVT->getNumElements(); + unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements(); + unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements(); IRBuilder<> Builder(&BCI); Scatterer Op0 = scatter(&BCI, BCI.getOperand(0)); ValueVector Res; @@ -700,7 +711,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the // individual elements to the destination. unsigned FanOut = DstNumElems / SrcNumElems; - Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut); + auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut); unsigned ResI = 0; for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) { Value *V = Op0[Op0I]; @@ -718,7 +729,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { } else { // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2. unsigned FanIn = SrcNumElems / DstNumElems; - Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn); + auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn); unsigned Op0I = 0; for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) { Value *V = UndefValue::get(MidTy); @@ -734,12 +745,79 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { return true; } +bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { + VectorType *VT = dyn_cast<VectorType>(IEI.getType()); + if (!VT) + return false; + + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + IRBuilder<> Builder(&IEI); + Scatterer Op0 = scatter(&IEI, IEI.getOperand(0)); + Value *NewElt = IEI.getOperand(1); + Value *InsIdx = IEI.getOperand(2); + + ValueVector Res; + Res.resize(NumElems); + + if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) { + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I]; + } else { + if (!ScalarizeVariableInsertExtract) + return false; + + for (unsigned I = 0; I < NumElems; ++I) { + Value *ShouldReplace = + Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I), + InsIdx->getName() + ".is." + Twine(I)); + Value *OldElt = Op0[I]; + Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt, + IEI.getName() + ".i" + Twine(I)); + } + } + + gather(&IEI, Res); + return true; +} + +bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { + VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType()); + if (!VT) + return false; + + unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements(); + IRBuilder<> Builder(&EEI); + Scatterer Op0 = scatter(&EEI, EEI.getOperand(0)); + Value *ExtIdx = EEI.getOperand(1); + + if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { + Value *Res = Op0[CI->getValue().getZExtValue()]; + gather(&EEI, {Res}); + return true; + } + + if (!ScalarizeVariableInsertExtract) + return false; + + Value *Res = UndefValue::get(VT->getElementType()); + for (unsigned I = 0; I < NumSrcElems; ++I) { + Value *ShouldExtract = + Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), + ExtIdx->getName() + ".is." + Twine(I)); + Value *Elt = Op0[I]; + Res = Builder.CreateSelect(ShouldExtract, Elt, Res, + EEI.getName() + ".upto" + Twine(I)); + } + gather(&EEI, {Res}); + return true; +} + bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { VectorType *VT = dyn_cast<VectorType>(SVI.getType()); if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); Scatterer Op0 = scatter(&SVI, SVI.getOperand(0)); Scatterer Op1 = scatter(&SVI, SVI.getOperand(1)); ValueVector Res; @@ -763,7 +841,7 @@ bool ScalarizerVisitor::visitPHINode(PHINode &PHI) { if (!VT) return false; - unsigned NumElems = VT->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&PHI); ValueVector Res; Res.resize(NumElems); @@ -789,20 +867,20 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { if (!LI.isSimple()) return false; - VectorLayout Layout; - if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout, - LI.getModule()->getDataLayout())) + Optional<VectorLayout> Layout = getVectorLayout( + LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout()); + if (!Layout) return false; - unsigned NumElems = Layout.VecTy->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&LI); Scatterer Ptr = scatter(&LI, LI.getPointerOperand()); ValueVector Res; Res.resize(NumElems); for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateAlignedLoad(Layout.VecTy->getElementType(), Ptr[I], - Layout.getElemAlign(I), + Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I], + Align(Layout->getElemAlign(I)), LI.getName() + ".i" + Twine(I)); gather(&LI, Res); return true; @@ -814,22 +892,23 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { if (!SI.isSimple()) return false; - VectorLayout Layout; Value *FullValue = SI.getValueOperand(); - if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout, - SI.getModule()->getDataLayout())) + Optional<VectorLayout> Layout = getVectorLayout( + FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout()); + if (!Layout) return false; - unsigned NumElems = Layout.VecTy->getNumElements(); + unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&SI); - Scatterer Ptr = scatter(&SI, SI.getPointerOperand()); - Scatterer Val = scatter(&SI, FullValue); + Scatterer VPtr = scatter(&SI, SI.getPointerOperand()); + Scatterer VVal = scatter(&SI, FullValue); ValueVector Stores; Stores.resize(NumElems); for (unsigned I = 0; I < NumElems; ++I) { - unsigned Align = Layout.getElemAlign(I); - Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align); + Value *Val = VVal[I]; + Value *Ptr = VPtr[I]; + Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I)); } transferMetadataAndIRFlags(&SI, Stores); return true; @@ -852,23 +931,32 @@ bool ScalarizerVisitor::finish() { if (!Op->use_empty()) { // The value is still needed, so recreate it using a series of // InsertElements. - Type *Ty = Op->getType(); - Value *Res = UndefValue::get(Ty); - BasicBlock *BB = Op->getParent(); - unsigned Count = Ty->getVectorNumElements(); - IRBuilder<> Builder(Op); - if (isa<PHINode>(Op)) - Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); - for (unsigned I = 0; I < Count; ++I) - Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), - Op->getName() + ".upto" + Twine(I)); + Value *Res = UndefValue::get(Op->getType()); + if (auto *Ty = dyn_cast<VectorType>(Op->getType())) { + BasicBlock *BB = Op->getParent(); + unsigned Count = cast<FixedVectorType>(Ty)->getNumElements(); + IRBuilder<> Builder(Op); + if (isa<PHINode>(Op)) + Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); + for (unsigned I = 0; I < Count; ++I) + Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), + Op->getName() + ".upto" + Twine(I)); + } else { + assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); + Res = CV[0]; + if (Op == Res) + continue; + } Res->takeName(Op); Op->replaceAllUsesWith(Res); } - Op->eraseFromParent(); + PotentiallyDeadInstrs.emplace_back(Op); } Gathered.clear(); Scattered.clear(); + + RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); + return true; } diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 2a1a040bf83e..f1d2e3c1ecfa 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -431,8 +431,10 @@ private: bool reuniteExts(Instruction *I); /// Find the closest dominator of <Dominatee> that is equivalent to <Key>. - Instruction *findClosestMatchingDominator(const SCEV *Key, - Instruction *Dominatee); + Instruction *findClosestMatchingDominator( + const SCEV *Key, Instruction *Dominatee, + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs); + /// Verify F is free of dead code. void verifyNoDeadCode(Function &F); @@ -456,7 +458,8 @@ private: /// multiple GEPs with a single index. bool LowerGEP; - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingExprs; + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingAdds; + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingSubs; }; } // end anonymous namespace @@ -519,7 +522,7 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, // sext(a + b) = sext(a) + sext(b) // even if the addition is not marked nsw. // - // Leveraging this invarient, we can trace into an sext'ed inbound GEP + // Leveraging this invariant, we can trace into an sext'ed inbound GEP // index if the constant offset is non-negative. // // Verified in @sext_add in split-gep.ll. @@ -549,6 +552,9 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO, bool SignExtended, bool ZeroExtended) { + // Save off the current height of the chain, in case we need to restore it. + size_t ChainLength = UserChain.size(); + // BO being non-negative does not shed light on whether its operands are // non-negative. Clear the NonNegative flag here. APInt ConstantOffset = find(BO->getOperand(0), SignExtended, ZeroExtended, @@ -559,12 +565,22 @@ APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO, // However, such cases are probably already handled by -instcombine, // given this pass runs after the standard optimizations. if (ConstantOffset != 0) return ConstantOffset; + + // Reset the chain back to where it was when we started exploring this node, + // since visiting the LHS didn't pan out. + UserChain.resize(ChainLength); + ConstantOffset = find(BO->getOperand(1), SignExtended, ZeroExtended, /* NonNegative */ false); // If U is a sub operator, negate the constant offset found in the right // operand. if (BO->getOpcode() == Instruction::Sub) ConstantOffset = -ConstantOffset; + + // If RHS wasn't a suitable candidate either, reset the chain again. + if (ConstantOffset == 0) + UserChain.resize(ChainLength); + return ConstantOffset; } @@ -688,7 +704,7 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { } BinaryOperator *BO = cast<BinaryOperator>(UserChain[ChainIndex]); - assert(BO->getNumUses() <= 1 && + assert((BO->use_empty() || BO->hasOneUse()) && "distributeExtsAndCloneChain clones each BinaryOperator in " "UserChain, so no one should be used more than " "once"); @@ -1141,7 +1157,8 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) { } Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator( - const SCEV *Key, Instruction *Dominatee) { + const SCEV *Key, Instruction *Dominatee, + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs) { auto Pos = DominatingExprs.find(Key); if (Pos == DominatingExprs.end()) return nullptr; @@ -1169,12 +1186,23 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { // If Dom can't sign overflow and Dom dominates I, optimize I to sext(Dom). // TODO: handle zext Value *LHS = nullptr, *RHS = nullptr; - if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS)))) || - match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { + if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { if (LHS->getType() == RHS->getType()) { const SCEV *Key = SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); - if (auto *Dom = findClosestMatchingDominator(Key, I)) { + if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) { + Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); + NewSExt->takeName(I); + I->replaceAllUsesWith(NewSExt); + RecursivelyDeleteTriviallyDeadInstructions(I); + return true; + } + } + } else if (match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { + if (LHS->getType() == RHS->getType()) { + const SCEV *Key = + SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingSubs)) { Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); NewSExt->takeName(I); I->replaceAllUsesWith(NewSExt); @@ -1185,12 +1213,17 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { } // Add I to DominatingExprs if it's an add/sub that can't sign overflow. - if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS))) || - match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { - if (programUndefinedIfFullPoison(I)) { + if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS)))) { + if (programUndefinedIfPoison(I)) { + const SCEV *Key = + SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + DominatingAdds[Key].push_back(I); + } + } else if (match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { + if (programUndefinedIfPoison(I)) { const SCEV *Key = SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); - DominatingExprs[Key].push_back(I); + DominatingSubs[Key].push_back(I); } } return false; @@ -1198,7 +1231,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) { bool Changed = false; - DominatingExprs.clear(); + DominatingAdds.clear(); + DominatingSubs.clear(); for (const auto Node : depth_first(DT)) { BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ) { diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index d7a34acb4318..6c6d6ca9cf65 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -26,7 +26,6 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -36,6 +35,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -182,7 +182,7 @@ static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc) { IRBuilder<> IRB(&BB); - + Value *Cond = Direction ? IRB.CreateOr(Invariants) : IRB.CreateAnd(Invariants); IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, @@ -598,19 +598,36 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, auto *ParentBB = SI.getParent(); + // The same check must be used both for the default and the exit cases. We + // should never leave edges from the switch instruction to a basic block that + // we are unswitching, hence the condition used to determine the default case + // needs to also be used to populate ExitCaseIndices, which is then used to + // remove cases from the switch. + auto IsTriviallyUnswitchableExitBlock = [&](BasicBlock &BBToCheck) { + // BBToCheck is not an exit block if it is inside loop L. + if (L.contains(&BBToCheck)) + return false; + // BBToCheck is not trivial to unswitch if its phis aren't loop invariant. + if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, BBToCheck)) + return false; + // We do not unswitch a block that only has an unreachable statement, as + // it's possible this is a previously unswitched block. Only unswitch if + // either the terminator is not unreachable, or, if it is, it's not the only + // instruction in the block. + auto *TI = BBToCheck.getTerminator(); + bool isUnreachable = isa<UnreachableInst>(TI); + return !isUnreachable || + (isUnreachable && (BBToCheck.getFirstNonPHIOrDbg() != TI)); + }; + SmallVector<int, 4> ExitCaseIndices; - for (auto Case : SI.cases()) { - auto *SuccBB = Case.getCaseSuccessor(); - if (!L.contains(SuccBB) && - areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) + for (auto Case : SI.cases()) + if (IsTriviallyUnswitchableExitBlock(*Case.getCaseSuccessor())) ExitCaseIndices.push_back(Case.getCaseIndex()); - } BasicBlock *DefaultExitBB = nullptr; SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight = SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0); - if (!L.contains(SI.getDefaultDest()) && - areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && - !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) { + if (IsTriviallyUnswitchableExitBlock(*SI.getDefaultDest())) { DefaultExitBB = SI.getDefaultDest(); } else if (ExitCaseIndices.empty()) return false; @@ -1557,6 +1574,11 @@ static void deleteDeadBlocksFromLoop(Loop &L, // Check that the dominator tree has already been updated. assert(!DT.getNode(BB) && "Should already have cleared domtree!"); LI.changeLoopFor(BB, nullptr); + // Drop all uses of the instructions to make sure we won't have dangling + // uses in other blocks. + for (auto &I : *BB) + if (!I.use_empty()) + I.replaceAllUsesWith(UndefValue::get(I.getType())); BB->dropAllReferences(); } @@ -2465,7 +2487,7 @@ turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, /// unswitch candidates, making adequate predictions instead of wild guesses. /// That requires knowing not just the number of "remaining" candidates but /// also costs of unswitching for each of these candidates. -static int calculateUnswitchCostMultiplier( +static int CalculateUnswitchCostMultiplier( Instruction &TI, Loop &L, LoopInfo &LI, DominatorTree &DT, ArrayRef<std::pair<Instruction *, TinyPtrVector<Value *>>> UnswitchCandidates) { @@ -2656,11 +2678,11 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) return false; - if (auto CS = CallSite(&I)) - if (CS.isConvergent() || CS.cannotDuplicate()) + if (auto *CB = dyn_cast<CallBase>(&I)) + if (CB->isConvergent() || CB->cannotDuplicate()) return false; - Cost += TTI.getUserCost(&I); + Cost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize); } assert(Cost >= 0 && "Must not have negative costs!"); LoopCost += Cost; @@ -2754,7 +2776,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, // exponential behavior of loop-unswitch. if (EnableUnswitchCostMultiplier) { int CostMultiplier = - calculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates); + CalculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates); assert( (CostMultiplier > 0 && CostMultiplier <= UnswitchThreshold) && "cost multiplier needs to be in the range of 1..UnswitchThreshold"); @@ -2868,7 +2890,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, // Save the current loop name in a variable so that we can report it even // after it has been deleted. - std::string LoopName = L.getName(); + std::string LoopName = std::string(L.getName()); auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, ArrayRef<Loop *> NewLoops) { @@ -2983,10 +3005,6 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); - // If anything was unswitched, also clear any cached information about this - // loop. - LPM.deleteSimpleAnalysisLoop(L); - // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. assert(DT.verify(DominatorTree::VerificationLevel::Fast)); diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 623a8b711ed8..2e459c9a64d4 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -104,6 +104,21 @@ static bool mergeEmptyReturnBlocks(Function &F) { continue; } + // Skip merging if this would result in a CallBr instruction with a + // duplicate destination. FIXME: See note in CodeGenPrepare.cpp. + bool SkipCallBr = false; + for (pred_iterator PI = pred_begin(&BB), E = pred_end(&BB); + PI != E && !SkipCallBr; ++PI) { + if (auto *CBI = dyn_cast<CallBrInst>((*PI)->getTerminator())) + for (unsigned i = 0, e = CBI->getNumSuccessors(); i != e; ++i) + if (RetBlock == CBI->getSuccessor(i)) { + SkipCallBr = true; + break; + } + } + if (SkipCallBr) + continue; + // Otherwise, we found a duplicate return block. Merge the two. Changed = true; @@ -266,6 +281,14 @@ struct CFGSimplifyPass : public FunctionPass { return false; Options.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + if (F.hasFnAttribute(Attribute::OptForFuzzing)) { + Options.setSimplifyCondBranch(false) + .setFoldTwoEntryPHINode(false); + } else { + Options.setSimplifyCondBranch(true) + .setFoldTwoEntryPHINode(true); + } + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); return simplifyFunctionCFG(F, TTI, Options); } diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp index 677d86f8c7b4..48f289c8f17d 100644 --- a/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/llvm/lib/Transforms/Scalar/Sink.cpp @@ -166,8 +166,8 @@ static bool SinkInstruction(Instruction *Inst, // dominated by one of the successors. // Look at all the dominated blocks and see if we can sink it in one. DomTreeNode *DTN = DT.getNode(Inst->getParent()); - for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end(); - I != E && SuccToSinkTo == nullptr; ++I) { + for (auto I = DTN->begin(), E = DTN->end(); I != E && SuccToSinkTo == nullptr; + ++I) { BasicBlock *Candidate = (*I)->getBlock(); // A node always immediate-dominates its children on the dominator // tree. diff --git a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index cd7bfb2f20dc..8258b92a716d 100644 --- a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -67,8 +67,8 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, return false; } - if (auto CS = ImmutableCallSite(UI)) { - if (CS.isConvergent() || CS.cannotDuplicate()) { + if (const auto *CS = dyn_cast<CallBase>(UI)) { + if (CS->isConvergent() || CS->cannotDuplicate()) { LLVM_DEBUG(dbgs() << " Unsafe: convergent " "callsite cannot de duplicated: " << *UI << '\n'); return false; @@ -232,7 +232,8 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( continue; int &MatCost = InsertResult.first->second.MatCost; - MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType()); + MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType(), + TargetTransformInfo::TCK_SizeAndLatency); NonFreeMat |= MatCost != TTI.TCC_Free; } if (!NonFreeMat) { @@ -283,12 +284,15 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( int MatCost = IncomingConstantAndCostsAndCount.second.MatCost; int &FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost; if (IID) - FoldedCost += TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(), - IncomingC->getType()); + FoldedCost += + TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(), + IncomingC->getType(), + TargetTransformInfo::TCK_SizeAndLatency); else FoldedCost += TTI.getIntImmCostInst(UserI->getOpcode(), Idx, - IncomingC->getValue(), IncomingC->getType()); + IncomingC->getValue(), IncomingC->getType(), + TargetTransformInfo::TCK_SizeAndLatency); // If we accumulate more folded cost for this incoming constant than // materialized cost, then we'll regress any edge with this constant so @@ -465,7 +469,7 @@ findProfitablePHIs(ArrayRef<PHINode *> PNs, if (CostMapIt != SpecCostMap.end()) Cost += CostMapIt->second; } - Cost += TTI.getUserCost(I); + Cost += TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency); bool Inserted = SpecCostMap.insert({I, Cost}).second; (void)Inserted; assert(Inserted && "Must not re-insert a cost during the DFS!"); diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index c8d899bb4871..f82a2936c762 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -65,6 +65,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" @@ -244,19 +245,35 @@ static unsigned ComputeSpeculationCost(const Instruction *I, case Instruction::FNeg: case Instruction::ICmp: case Instruction::FCmp: - return TTI.getUserCost(I); + return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency); default: - return UINT_MAX; // Disallow anything not whitelisted. + return UINT_MAX; // Disallow anything not explicitly listed. } } bool SpeculativeExecutionPass::considerHoistingFromTo( BasicBlock &FromBlock, BasicBlock &ToBlock) { SmallPtrSet<const Instruction *, 8> NotHoisted; - const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { - for (Value* V : U->operand_values()) { - if (Instruction *I = dyn_cast<Instruction>(V)) { + const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](const User *U) { + // Debug variable has special operand to check it's not hoisted. + if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(U)) { + if (const auto *I = + dyn_cast_or_null<Instruction>(DVI->getVariableLocation())) + if (NotHoisted.count(I) == 0) + return true; + return false; + } + + // Usially debug label instrinsic corresponds to label in LLVM IR. In these + // cases we should not move it here. + // TODO: Possible special processing needed to detect it is related to a + // hoisted instruction. + if (isa<DbgLabelInst>(U)) + return false; + + for (const Value *V : U->operand_values()) { + if (const Instruction *I = dyn_cast<Instruction>(V)) { if (NotHoisted.count(I) > 0) return false; } @@ -265,7 +282,8 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( }; unsigned TotalSpeculationCost = 0; - for (auto& I : FromBlock) { + unsigned NotHoistedInstCount = 0; + for (const auto &I : FromBlock) { const unsigned Cost = ComputeSpeculationCost(&I, *TTI); if (Cost != UINT_MAX && isSafeToSpeculativelyExecute(&I) && AllPrecedingUsesFromBlockHoisted(&I)) { @@ -273,15 +291,15 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( if (TotalSpeculationCost > SpecExecMaxSpeculationCost) return false; // too much to hoist } else { - NotHoisted.insert(&I); - if (NotHoisted.size() > SpecExecMaxNotHoisted) + // Debug info instrinsics should not be counted for threshold. + if (!isa<DbgInfoIntrinsic>(I)) + NotHoistedInstCount++; + if (NotHoistedInstCount > SpecExecMaxNotHoisted) return false; // too much left behind + NotHoisted.insert(&I); } } - if (TotalSpeculationCost == 0) - return false; // nothing to hoist - for (auto I = FromBlock.begin(); I != FromBlock.end();) { // We have to increment I before moving Current as moving Current // changes the list that I is iterating through. diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 4ce4ce46f67a..c20e57b02c1a 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -8,13 +8,12 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LegacyDivergenceAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Analysis/RegionPass.h" @@ -34,6 +33,7 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" @@ -43,6 +43,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> #include <cassert> @@ -88,6 +89,59 @@ using BBPredicates = DenseMap<BasicBlock *, Value *>; using PredMap = DenseMap<BasicBlock *, BBPredicates>; using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>; +// A traits type that is intended to be used in graph algorithms. The graph +// traits starts at an entry node, and traverses the RegionNodes that are in +// the Nodes set. +struct SubGraphTraits { + using NodeRef = std::pair<RegionNode *, SmallDenseSet<RegionNode *> *>; + using BaseSuccIterator = GraphTraits<RegionNode *>::ChildIteratorType; + + // This wraps a set of Nodes into the iterator, so we know which edges to + // filter out. + class WrappedSuccIterator + : public iterator_adaptor_base< + WrappedSuccIterator, BaseSuccIterator, + typename std::iterator_traits<BaseSuccIterator>::iterator_category, + NodeRef, std::ptrdiff_t, NodeRef *, NodeRef> { + SmallDenseSet<RegionNode *> *Nodes; + + public: + WrappedSuccIterator(BaseSuccIterator It, SmallDenseSet<RegionNode *> *Nodes) + : iterator_adaptor_base(It), Nodes(Nodes) {} + + NodeRef operator*() const { return {*I, Nodes}; } + }; + + static bool filterAll(const NodeRef &N) { return true; } + static bool filterSet(const NodeRef &N) { return N.second->count(N.first); } + + using ChildIteratorType = + filter_iterator<WrappedSuccIterator, bool (*)(const NodeRef &)>; + + static NodeRef getEntryNode(Region *R) { + return {GraphTraits<Region *>::getEntryNode(R), nullptr}; + } + + static NodeRef getEntryNode(NodeRef N) { return N; } + + static iterator_range<ChildIteratorType> children(const NodeRef &N) { + auto *filter = N.second ? &filterSet : &filterAll; + return make_filter_range( + make_range<WrappedSuccIterator>( + {GraphTraits<RegionNode *>::child_begin(N.first), N.second}, + {GraphTraits<RegionNode *>::child_end(N.first), N.second}), + filter); + } + + static ChildIteratorType child_begin(const NodeRef &N) { + return children(N).begin(); + } + + static ChildIteratorType child_end(const NodeRef &N) { + return children(N).end(); + } +}; + /// Finds the nearest common dominator of a set of BasicBlocks. /// /// For every BB you add to the set, you can specify whether we "remember" the @@ -192,11 +246,11 @@ class StructurizeCFG : public RegionPass { LegacyDivergenceAnalysis *DA; DominatorTree *DT; - LoopInfo *LI; SmallVector<RegionNode *, 8> Order; BBSet Visited; + SmallVector<WeakVH, 8> AffectedPhis; BBPhiMap DeletedPhis; BB2BBVecMap AddedPhis; @@ -211,13 +265,8 @@ class StructurizeCFG : public RegionPass { void orderNodes(); - Loop *getAdjustedLoop(RegionNode *RN); - unsigned getAdjustedLoopDepth(RegionNode *RN); - void analyzeLoops(RegionNode *N); - Value *invert(Value *Condition); - Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert); void gatherPredicates(RegionNode *N); @@ -232,6 +281,8 @@ class StructurizeCFG : public RegionPass { void setPhiValues(); + void simplifyAffectedPhis(); + void killTerminator(BasicBlock *BB); void changeExit(RegionNode *Node, BasicBlock *NewExit, @@ -279,7 +330,6 @@ public: AU.addRequired<LegacyDivergenceAnalysis>(); AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); RegionPass::getAnalysisUsage(AU); @@ -311,75 +361,60 @@ bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) { return false; } -/// Use the exit block to determine the loop if RN is a SubRegion. -Loop *StructurizeCFG::getAdjustedLoop(RegionNode *RN) { - if (RN->isSubRegion()) { - Region *SubRegion = RN->getNodeAs<Region>(); - return LI->getLoopFor(SubRegion->getExit()); - } - - return LI->getLoopFor(RN->getEntry()); -} - -/// Use the exit block to determine the loop depth if RN is a SubRegion. -unsigned StructurizeCFG::getAdjustedLoopDepth(RegionNode *RN) { - if (RN->isSubRegion()) { - Region *SubR = RN->getNodeAs<Region>(); - return LI->getLoopDepth(SubR->getExit()); - } - - return LI->getLoopDepth(RN->getEntry()); -} - -/// Build up the general order of nodes +/// Build up the general order of nodes, by performing a topological sort of the +/// parent region's nodes, while ensuring that there is no outer cycle node +/// between any two inner cycle nodes. void StructurizeCFG::orderNodes() { - ReversePostOrderTraversal<Region*> RPOT(ParentRegion); - SmallDenseMap<Loop*, unsigned, 8> LoopBlocks; - - // The reverse post-order traversal of the list gives us an ordering close - // to what we want. The only problem with it is that sometimes backedges - // for outer loops will be visited before backedges for inner loops. - for (RegionNode *RN : RPOT) { - Loop *Loop = getAdjustedLoop(RN); - ++LoopBlocks[Loop]; - } - - unsigned CurrentLoopDepth = 0; - Loop *CurrentLoop = nullptr; - for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { - RegionNode *RN = cast<RegionNode>(*I); - unsigned LoopDepth = getAdjustedLoopDepth(RN); - - if (is_contained(Order, *I)) - continue; - - if (LoopDepth < CurrentLoopDepth) { - // Make sure we have visited all blocks in this loop before moving back to - // the outer loop. + Order.resize(std::distance(GraphTraits<Region *>::nodes_begin(ParentRegion), + GraphTraits<Region *>::nodes_end(ParentRegion))); + if (Order.empty()) + return; - auto LoopI = I; - while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { - LoopI++; - if (getAdjustedLoop(cast<RegionNode>(*LoopI)) == CurrentLoop) { - --BlockCount; - Order.push_back(*LoopI); - } + SmallDenseSet<RegionNode *> Nodes; + auto EntryNode = SubGraphTraits::getEntryNode(ParentRegion); + + // A list of range indices of SCCs in Order, to be processed. + SmallVector<std::pair<unsigned, unsigned>, 8> WorkList; + unsigned I = 0, E = Order.size(); + while (true) { + // Run through all the SCCs in the subgraph starting with Entry. + for (auto SCCI = + scc_iterator<SubGraphTraits::NodeRef, SubGraphTraits>::begin( + EntryNode); + !SCCI.isAtEnd(); ++SCCI) { + auto &SCC = *SCCI; + + // An SCC up to the size of 2, can be reduced to an entry (the last node), + // and a possible additional node. Therefore, it is already in order, and + // there is no need to add it to the work-list. + unsigned Size = SCC.size(); + if (Size > 2) + WorkList.emplace_back(I, I + Size); + + // Add the SCC nodes to the Order array. + for (auto &N : SCC) { + assert(I < E && "SCC size mismatch!"); + Order[I++] = N.first; } } + assert(I == E && "SCC size mismatch!"); - CurrentLoop = getAdjustedLoop(RN); - if (CurrentLoop) - LoopBlocks[CurrentLoop]--; + // If there are no more SCCs to order, then we are done. + if (WorkList.empty()) + break; - CurrentLoopDepth = LoopDepth; - Order.push_back(*I); - } + std::tie(I, E) = WorkList.pop_back_val(); + + // Collect the set of nodes in the SCC's subgraph. These are only the + // possible child nodes; we do not add the entry (last node) otherwise we + // will have the same exact SCC all over again. + Nodes.clear(); + Nodes.insert(Order.begin() + I, Order.begin() + E - 1); - // This pass originally used a post-order traversal and then operated on - // the list in reverse. Now that we are using a reverse post-order traversal - // rather than re-working the whole pass to operate on the list in order, - // we just reverse the list and continue to operate on it in reverse. - std::reverse(Order.begin(), Order.end()); + // Update the entry node. + EntryNode.first = Order[E - 1]; + EntryNode.second = &Nodes; + } } /// Determine the end of the loops @@ -401,39 +436,6 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { } } -/// Invert the given condition -Value *StructurizeCFG::invert(Value *Condition) { - // First: Check if it's a constant - if (Constant *C = dyn_cast<Constant>(Condition)) - return ConstantExpr::getNot(C); - - // Second: If the condition is already inverted, return the original value - Value *NotCondition; - if (match(Condition, m_Not(m_Value(NotCondition)))) - return NotCondition; - - if (Instruction *Inst = dyn_cast<Instruction>(Condition)) { - // Third: Check all the users for an invert - BasicBlock *Parent = Inst->getParent(); - for (User *U : Condition->users()) - if (Instruction *I = dyn_cast<Instruction>(U)) - if (I->getParent() == Parent && match(I, m_Not(m_Specific(Condition)))) - return I; - - // Last option: Create a new instruction - return BinaryOperator::CreateNot(Condition, "", Parent->getTerminator()); - } - - if (Argument *Arg = dyn_cast<Argument>(Condition)) { - BasicBlock &EntryBlock = Arg->getParent()->getEntryBlock(); - return BinaryOperator::CreateNot(Condition, - Arg->getName() + ".inv", - EntryBlock.getTerminator()); - } - - llvm_unreachable("Unhandled condition to invert"); -} - /// Build the condition for one edge Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx, bool Invert) { @@ -442,7 +444,7 @@ Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx, Cond = Term->getCondition(); if (Idx != (unsigned)Invert) - Cond = invert(Cond); + Cond = invertCondition(Cond); } return Cond; } @@ -520,8 +522,7 @@ void StructurizeCFG::collectInfos() { for (RegionNode *RN : reverse(Order)) { LLVM_DEBUG(dbgs() << "Visiting: " << (RN->isSubRegion() ? "SubRegion with entry: " : "") - << RN->getEntry()->getName() << " Loop Depth: " - << LI->getLoopDepth(RN->getEntry()) << "\n"); + << RN->getEntry()->getName() << "\n"); // Analyze all the conditions leading to a node gatherPredicates(RN); @@ -585,9 +586,14 @@ void StructurizeCFG::insertConditions(bool Loops) { void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { PhiMap &Map = DeletedPhis[To]; for (PHINode &Phi : To->phis()) { + bool Recorded = false; while (Phi.getBasicBlockIndex(From) != -1) { Value *Deleted = Phi.removeIncomingValue(From, false); Map[&Phi].push_back(std::make_pair(From, Deleted)); + if (!Recorded) { + AffectedPhis.push_back(&Phi); + Recorded = true; + } } } } @@ -632,28 +638,29 @@ void StructurizeCFG::setPhiValues() { for (BasicBlock *FI : From) Phi->setIncomingValueForBlock(FI, Updater.GetValueAtEndOfBlock(FI)); + AffectedPhis.push_back(Phi); } DeletedPhis.erase(To); } assert(DeletedPhis.empty()); - // Simplify any phis inserted by the SSAUpdater if possible + AffectedPhis.append(InsertedPhis.begin(), InsertedPhis.end()); +} + +void StructurizeCFG::simplifyAffectedPhis() { bool Changed; do { Changed = false; - SimplifyQuery Q(Func->getParent()->getDataLayout()); Q.DT = DT; - for (size_t i = 0; i < InsertedPhis.size(); ++i) { - PHINode *Phi = InsertedPhis[i]; - if (Value *V = SimplifyInstruction(Phi, Q)) { - Phi->replaceAllUsesWith(V); - Phi->eraseFromParent(); - InsertedPhis[i] = InsertedPhis.back(); - InsertedPhis.pop_back(); - i--; - Changed = true; + for (WeakVH VH : AffectedPhis) { + if (auto Phi = dyn_cast_or_null<PHINode>(VH)) { + if (auto NewValue = SimplifyInstruction(Phi, Q)) { + Phi->replaceAllUsesWith(NewValue); + Phi->eraseFromParent(); + Changed = true; + } } } } while (Changed); @@ -886,6 +893,7 @@ void StructurizeCFG::createFlow() { BasicBlock *Exit = ParentRegion->getExit(); bool EntryDominatesExit = DT->dominates(ParentRegion->getEntry(), Exit); + AffectedPhis.clear(); DeletedPhis.clear(); AddedPhis.clear(); Conditions.clear(); @@ -1036,7 +1044,6 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { ParentRegion = R; DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); orderNodes(); collectInfos(); @@ -1044,6 +1051,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { insertConditions(false); insertConditions(true); setPhiValues(); + simplifyAffectedPhis(); rebuildSSA(); // Cleanup diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 9f0ab9103d42..5bb1d54d7d12 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -64,7 +64,6 @@ #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -126,16 +125,16 @@ struct AllocaDerivedValueTracker { switch (I->getOpcode()) { case Instruction::Call: case Instruction::Invoke: { - CallSite CS(I); + auto &CB = cast<CallBase>(*I); // If the alloca-derived argument is passed byval it is not an escape // point, or a use of an alloca. Calling with byval copies the contents // of the alloca into argument registers or stack slots, which exist // beyond the lifetime of the current frame. - if (CS.isArgOperand(U) && CS.isByValArgument(CS.getArgumentNo(U))) + if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U))) continue; bool IsNocapture = - CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U)); - callUsesLocalStack(CS, IsNocapture); + CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U)); + callUsesLocalStack(CB, IsNocapture); if (IsNocapture) { // If the alloca-derived argument is passed in as nocapture, then it // can't propagate to the call's return. That would be capturing. @@ -168,17 +167,17 @@ struct AllocaDerivedValueTracker { } } - void callUsesLocalStack(CallSite CS, bool IsNocapture) { + void callUsesLocalStack(CallBase &CB, bool IsNocapture) { // Add it to the list of alloca users. - AllocaUsers.insert(CS.getInstruction()); + AllocaUsers.insert(&CB); // If it's nocapture then it can't capture this alloca. if (IsNocapture) return; // If it can write to memory, it can leak the alloca value. - if (!CS.onlyReadsMemory()) - EscapePoints.insert(CS.getInstruction()); + if (!CB.onlyReadsMemory()) + EscapePoints.insert(&CB); } SmallPtrSet<Instruction *, 32> AllocaUsers; @@ -342,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { const DataLayout &DL = L->getModule()->getDataLayout(); if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), - MaybeAlign(L->getAlignment()), DL, L)) + L->getAlign(), DL, L)) return false; } } @@ -355,89 +354,23 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { return !is_contained(I->operands(), CI); } -/// Return true if the specified value is the same when the return would exit -/// as it was when the initial iteration of the recursive function was executed. -/// -/// We currently handle static constants and arguments that are not modified as -/// part of the recursion. -static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { - if (isa<Constant>(V)) return true; // Static constants are always dyn consts - - // Check to see if this is an immutable argument, if so, the value - // will be available to initialize the accumulator. - if (Argument *Arg = dyn_cast<Argument>(V)) { - // Figure out which argument number this is... - unsigned ArgNo = 0; - Function *F = CI->getParent()->getParent(); - for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) - ++ArgNo; - - // If we are passing this argument into call as the corresponding - // argument operand, then the argument is dynamically constant. - // Otherwise, we cannot transform this function safely. - if (CI->getArgOperand(ArgNo) == Arg) - return true; - } - - // Switch cases are always constant integers. If the value is being switched - // on and the return is only reachable from one of its cases, it's - // effectively constant. - if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) - if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator())) - if (SI->getCondition() == V) - return SI->getDefaultDest() != RI->getParent(); - - // Not a constant or immutable argument, we can't safely transform. - return false; -} - -/// Check to see if the function containing the specified tail call consistently -/// returns the same runtime-constant value at all exit points except for -/// IgnoreRI. If so, return the returned value. -static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { - Function *F = CI->getParent()->getParent(); - Value *ReturnedValue = nullptr; - - for (BasicBlock &BBI : *F) { - ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()); - if (RI == nullptr || RI == IgnoreRI) continue; - - // We can only perform this transformation if the value returned is - // evaluatable at the start of the initial invocation of the function, - // instead of at the end of the evaluation. - // - Value *RetOp = RI->getOperand(0); - if (!isDynamicConstant(RetOp, CI, RI)) - return nullptr; - - if (ReturnedValue && RetOp != ReturnedValue) - return nullptr; // Cannot transform if differing values are returned. - ReturnedValue = RetOp; - } - return ReturnedValue; -} +static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { + if (!I->isAssociative() || !I->isCommutative()) + return false; -/// If the specified instruction can be transformed using accumulator recursion -/// elimination, return the constant which is the start of the accumulator -/// value. Otherwise return null. -static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { - if (!I->isAssociative() || !I->isCommutative()) return nullptr; assert(I->getNumOperands() == 2 && "Associative/commutative operations should have 2 args!"); // Exactly one operand should be the result of the call instruction. if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || (I->getOperand(0) != CI && I->getOperand(1) != CI)) - return nullptr; + return false; // The only user of this instruction we allow is a single return instruction. if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back())) - return nullptr; + return false; - // Ok, now we have to check all of the other return instructions in this - // function. If they return non-constants or differing values, then we cannot - // transform the function safely. - return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI); + return true; } static Instruction *firstNonDbg(BasicBlock::iterator I) { @@ -446,11 +379,73 @@ static Instruction *firstNonDbg(BasicBlock::iterator I) { return &*I; } -static CallInst *findTRECandidate(Instruction *TI, - bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI) { +namespace { +class TailRecursionEliminator { + Function &F; + const TargetTransformInfo *TTI; + AliasAnalysis *AA; + OptimizationRemarkEmitter *ORE; + DomTreeUpdater &DTU; + + // The below are shared state we want to have available when eliminating any + // calls in the function. There values should be populated by + // createTailRecurseLoopHeader the first time we find a call we can eliminate. + BasicBlock *HeaderBB = nullptr; + SmallVector<PHINode *, 8> ArgumentPHIs; + bool RemovableCallsMustBeMarkedTail = false; + + // PHI node to store our return value. + PHINode *RetPN = nullptr; + + // i1 PHI node to track if we have a valid return value stored in RetPN. + PHINode *RetKnownPN = nullptr; + + // Vector of select instructions we insereted. These selects use RetKnownPN + // to either propagate RetPN or select a new return value. + SmallVector<SelectInst *, 8> RetSelects; + + // The below are shared state needed when performing accumulator recursion. + // There values should be populated by insertAccumulator the first time we + // find an elimination that requires an accumulator. + + // PHI node to store our current accumulated value. + PHINode *AccPN = nullptr; + + // The instruction doing the accumulating. + Instruction *AccumulatorRecursionInstr = nullptr; + + TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) + : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {} + + CallInst *findTRECandidate(Instruction *TI, + bool CannotTailCallElimCallsMarkedTail); + + void createTailRecurseLoopHeader(CallInst *CI); + + void insertAccumulator(Instruction *AccRecInstr); + + bool eliminateCall(CallInst *CI); + + bool foldReturnAndProcessPred(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + bool processReturningBlock(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + void cleanupAndFinalize(); + +public: + static bool eliminate(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU); +}; +} // namespace + +CallInst *TailRecursionEliminator::findTRECandidate( + Instruction *TI, bool CannotTailCallElimCallsMarkedTail) { BasicBlock *BB = TI->getParent(); - Function *F = BB->getParent(); if (&BB->front() == TI) // Make sure there is something before the terminator. return nullptr; @@ -461,7 +456,7 @@ static CallInst *findTRECandidate(Instruction *TI, BasicBlock::iterator BBI(TI); while (true) { CI = dyn_cast<CallInst>(BBI); - if (CI && CI->getCalledFunction() == F) + if (CI && CI->getCalledFunction() == &F) break; if (BBI == BB->begin()) @@ -478,16 +473,14 @@ static CallInst *findTRECandidate(Instruction *TI, // double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call // and disable this xform in this case, because the code generator will // lower the call to fabs into inline code. - if (BB == &F->getEntryBlock() && + if (BB == &F.getEntryBlock() && firstNonDbg(BB->front().getIterator()) == CI && firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && !TTI->isLoweredToCall(CI->getCalledFunction())) { // A single-block function with just a call and a return. Check that // the arguments match. - CallSite::arg_iterator I = CallSite(CI).arg_begin(), - E = CallSite(CI).arg_end(); - Function::arg_iterator FI = F->arg_begin(), - FE = F->arg_end(); + auto I = CI->arg_begin(), E = CI->arg_end(); + Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end(); for (; I != E && FI != FE; ++I, ++FI) if (*I != &*FI) break; if (I == E && FI == FE) @@ -497,27 +490,106 @@ static CallInst *findTRECandidate(Instruction *TI, return CI; } -static bool eliminateRecursiveTailCall( - CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { - // If we are introducing accumulator recursion to eliminate operations after - // the call instruction that are both associative and commutative, the initial - // value for the accumulator is placed in this variable. If this value is set - // then we actually perform accumulator recursion elimination instead of - // simple tail recursion elimination. If the operation is an LLVM instruction - // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then - // we are handling the case when the return instruction returns a constant C - // which is different to the constant returned by other return instructions - // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a - // special case of accumulator recursion, the operation being "return C". - Value *AccumulatorRecursionEliminationInitVal = nullptr; - Instruction *AccumulatorRecursionInstr = nullptr; +void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { + HeaderBB = &F.getEntryBlock(); + BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB); + NewEntry->takeName(HeaderBB); + HeaderBB->setName("tailrecurse"); + BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry); + BI->setDebugLoc(CI->getDebugLoc()); + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. + // We have to choose whether we move allocas in the entry block to the new + // entry block or not, so we can't make a good choice for both. We make this + // decision here based on whether the first call we found to remove is + // marked tail. + // NOTE: We could do slightly better here in the case that the function has + // no entry block allocas. + RemovableCallsMustBeMarkedTail = CI->isTailCall(); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + if (RemovableCallsMustBeMarkedTail) + // Move all fixed sized allocas from HeaderBB to NewEntry. + for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(), + NEBI = NewEntry->begin(); + OEBI != E;) + if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) + if (isa<ConstantInt>(AI->getArraySize())) + AI->moveBefore(&*NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = &HeaderBB->front(); + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { + PHINode *PN = + PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(&*I, NewEntry); + ArgumentPHIs.push_back(PN); + } + + // If the function doen't return void, create the RetPN and RetKnownPN PHI + // nodes to track our return value. We initialize RetPN with undef and + // RetKnownPN with false since we can't know our return value at function + // entry. + Type *RetType = F.getReturnType(); + if (!RetType->isVoidTy()) { + Type *BoolType = Type::getInt1Ty(F.getContext()); + RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); + RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); + + RetPN->addIncoming(UndefValue::get(RetType), NewEntry); + RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); + } + + // The entry block was changed from HeaderBB to NewEntry. + // The forward DominatorTree needs to be recalculated when the EntryBB is + // changed. In this corner-case we recalculate the entire tree. + DTU.recalculate(*NewEntry->getParent()); +} + +void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { + assert(!AccPN && "Trying to insert multiple accumulators"); + + AccumulatorRecursionInstr = AccRecInstr; + + // Start by inserting a new PHI node for the accumulator. + pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); + AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1, + "accumulator.tr", &HeaderBB->front()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the identity constant for + // the accumulation operation. For any other existing branches to this block + // (due to other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to HeaderBB yet, + // it will not show up as a predecessor. + for (pred_iterator PI = PB; PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == &F.getEntryBlock()) { + Constant *Identity = ConstantExpr::getBinOpIdentity( + AccRecInstr->getOpcode(), AccRecInstr->getType()); + AccPN->addIncoming(Identity, P); + } else { + AccPN->addIncoming(AccPN, P); + } + } + + ++NumAccumAdded; +} + +bool TailRecursionEliminator::eliminateCall(CallInst *CI) { + ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator()); // Ok, we found a potential tail call. We can currently only transform the // tail call if all of the instructions between the call and the return are // movable to above the call itself, leaving the call next to the return. // Check that this is the case now. + Instruction *AccRecInstr = nullptr; BasicBlock::iterator BBI(CI); for (++BBI; &*BBI != Ret; ++BBI) { if (canMoveAboveCall(&*BBI, CI, AA)) @@ -526,39 +598,16 @@ static bool eliminateRecursiveTailCall( // If we can't move the instruction above the call, it might be because it // is an associative and commutative operation that could be transformed // using accumulator recursion elimination. Check to see if this is the - // case, and if so, remember the initial accumulator value for later. - if ((AccumulatorRecursionEliminationInitVal = - canTransformAccumulatorRecursion(&*BBI, CI))) { - // Yes, this is accumulator recursion. Remember which instruction - // accumulates. - AccumulatorRecursionInstr = &*BBI; - } else { - return false; // Otherwise, we cannot eliminate the tail recursion! - } - } + // case, and if so, remember which instruction accumulates for later. + if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI)) + return false; // We cannot eliminate the tail recursion! - // We can only transform call/return pairs that either ignore the return value - // of the call and return void, ignore the value of the call and return a - // constant, return the value returned by the tail call, or that are being - // accumulator recursion variable eliminated. - if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && - !isa<UndefValue>(Ret->getReturnValue()) && - AccumulatorRecursionEliminationInitVal == nullptr && - !getCommonReturnValue(nullptr, CI)) { - // One case remains that we are able to handle: the current return - // instruction returns a constant, and all other return instructions - // return a different constant. - if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) - return false; // Current return instruction does not return a constant. - // Check that all other return instructions return a common constant. If - // so, record it in AccumulatorRecursionEliminationInitVal. - AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); - if (!AccumulatorRecursionEliminationInitVal) - return false; + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccRecInstr = &*BBI; } BasicBlock *BB = Ret->getParent(); - Function *F = BB->getParent(); using namespace ore; ORE->emit([&]() { @@ -568,51 +617,10 @@ static bool eliminateRecursiveTailCall( // OK! We can transform this tail call. If this is the first one found, // create the new entry block, allowing us to branch back to the old entry. - if (!OldEntry) { - OldEntry = &F->getEntryBlock(); - BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); - NewEntry->takeName(OldEntry); - OldEntry->setName("tailrecurse"); - BranchInst *BI = BranchInst::Create(OldEntry, NewEntry); - BI->setDebugLoc(CI->getDebugLoc()); - - // If this tail call is marked 'tail' and if there are any allocas in the - // entry block, move them up to the new entry block. - TailCallsAreMarkedTail = CI->isTailCall(); - if (TailCallsAreMarkedTail) - // Move all fixed sized allocas from OldEntry to NewEntry. - for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), - NEBI = NewEntry->begin(); OEBI != E; ) - if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) - if (isa<ConstantInt>(AI->getArraySize())) - AI->moveBefore(&*NEBI); - - // Now that we have created a new block, which jumps to the entry - // block, insert a PHI node for each argument of the function. - // For now, we initialize each PHI to only have the real arguments - // which are passed in. - Instruction *InsertPos = &OldEntry->front(); - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); - I != E; ++I) { - PHINode *PN = PHINode::Create(I->getType(), 2, - I->getName() + ".tr", InsertPos); - I->replaceAllUsesWith(PN); // Everyone use the PHI node now! - PN->addIncoming(&*I, NewEntry); - ArgumentPHIs.push_back(PN); - } - // The entry block was changed from OldEntry to NewEntry. - // The forward DominatorTree needs to be recalculated when the EntryBB is - // changed. In this corner-case we recalculate the entire tree. - DTU.recalculate(*NewEntry->getParent()); - } + if (!HeaderBB) + createTailRecurseLoopHeader(CI); - // If this function has self recursive calls in the tail position where some - // are marked tail and some are not, only transform one flavor or another. We - // have to choose whether we move allocas in the entry block to the new entry - // block or not, so we can't make a good choice for both. NOTE: We could do - // slightly better here in the case that the function has no entry block - // allocas. - if (TailCallsAreMarkedTail && !CI->isTailCall()) + if (RemovableCallsMustBeMarkedTail && !CI->isTailCall()) return false; // Ok, now that we know we have a pseudo-entry block WITH all of the @@ -621,74 +629,53 @@ static bool eliminateRecursiveTailCall( for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB); - // If we are introducing an accumulator variable to eliminate the recursion, - // do so now. Note that we _know_ that no subsequent tail recursion - // eliminations will happen on this function because of the way the - // accumulator recursion predicate is set up. - // - if (AccumulatorRecursionEliminationInitVal) { - Instruction *AccRecInstr = AccumulatorRecursionInstr; - // Start by inserting a new PHI node for the accumulator. - pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry); - PHINode *AccPN = PHINode::Create( - AccumulatorRecursionEliminationInitVal->getType(), - std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front()); - - // Loop over all of the predecessors of the tail recursion block. For the - // real entry into the function we seed the PHI with the initial value, - // computed earlier. For any other existing branches to this block (due to - // other tail recursions eliminated) the accumulator is not modified. - // Because we haven't added the branch in the current block to OldEntry yet, - // it will not show up as a predecessor. - for (pred_iterator PI = PB; PI != PE; ++PI) { - BasicBlock *P = *PI; - if (P == &F->getEntryBlock()) - AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); - else - AccPN->addIncoming(AccPN, P); - } + if (AccRecInstr) { + insertAccumulator(AccRecInstr); - if (AccRecInstr) { - // Add an incoming argument for the current block, which is computed by - // our associative and commutative accumulator instruction. - AccPN->addIncoming(AccRecInstr, BB); + // Rewrite the accumulator recursion instruction so that it does not use + // the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + } - // Next, rewrite the accumulator recursion instruction so that it does not - // use the result of the call anymore, instead, use the PHI node we just - // inserted. - AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + // Update our return value tracking + if (RetPN) { + if (Ret->getReturnValue() == CI || AccRecInstr) { + // Defer selecting a return value + RetPN->addIncoming(RetPN, BB); + RetKnownPN->addIncoming(RetKnownPN, BB); } else { - // Add an incoming argument for the current block, which is just the - // constant returned by the current return instruction. - AccPN->addIncoming(Ret->getReturnValue(), BB); + // We found a return value we want to use, insert a select instruction to + // select it if we don't already know what our return value will be and + // store the result in our return value PHI node. + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret); + RetSelects.push_back(SI); + + RetPN->addIncoming(SI, BB); + RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB); } - // Finally, rewrite any return instructions in the program to return the PHI - // node instead of the "initval" that they do currently. This loop will - // actually rewrite the return value we are destroying, but that's ok. - for (BasicBlock &BBI : *F) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator())) - RI->setOperand(0, AccPN); - ++NumAccumAdded; + if (AccPN) + AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB); } // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. - BranchInst *NewBI = BranchInst::Create(OldEntry, Ret); + BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); NewBI->setDebugLoc(CI->getDebugLoc()); BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. - DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}}); + DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}}); ++NumEliminated; return true; } -static bool foldReturnAndProcessPred( - BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { +bool TailRecursionEliminator::foldReturnAndProcessPred( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + BasicBlock *BB = Ret->getParent(); + bool Change = false; // Make sure this block is a trivial return block. @@ -711,10 +698,11 @@ static bool foldReturnAndProcessPred( while (!UncondBranchPreds.empty()) { BranchInst *BI = UncondBranchPreds.pop_back_val(); BasicBlock *Pred = BI->getParent(); - if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ + if (CallInst *CI = + findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) { LLVM_DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); - ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); + FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); // Cleanup: if all predecessors of BB have been eliminated by // FoldReturnIntoUncondBranch, delete it. It is important to empty it, @@ -723,8 +711,7 @@ static bool foldReturnAndProcessPred( if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) DTU.deleteBB(BB); - eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + eliminateCall(CI); ++NumRetDuped; Change = true; } @@ -733,23 +720,92 @@ static bool foldReturnAndProcessPred( return Change; } -static bool processReturningBlock( - ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { - CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); +bool TailRecursionEliminator::processReturningBlock( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail); if (!CI) return false; - return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + return eliminateCall(CI); +} + +void TailRecursionEliminator::cleanupAndFinalize() { + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + for (PHINode *PN : ArgumentPHIs) { + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + + if (RetPN) { + if (RetSelects.empty()) { + // If we didn't insert any select instructions, then we know we didn't + // store a return value and we can remove the PHI nodes we inserted. + RetPN->dropAllReferences(); + RetPN->eraseFromParent(); + + RetKnownPN->dropAllReferences(); + RetKnownPN->eraseFromParent(); + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // return in the function, and return its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!RI) + continue; + + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + RI->getOperand(0)); + AccRecInstrNew->insertBefore(RI); + RI->setOperand(0, AccRecInstrNew); + } + } + } else { + // We need to insert a select instruction before any return left in the + // function to select our stored return value if we have one. + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!RI) + continue; + + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); + RetSelects.push_back(SI); + RI->setOperand(0, SI); + } + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // of the selects we inserted, and select its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (SelectInst *SI : RetSelects) { + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + SI->getFalseValue()); + AccRecInstrNew->insertBefore(SI); + SI->setFalseValue(AccRecInstrNew); + } + } + } + } } -static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, - AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU) { +bool TailRecursionEliminator::eliminate(Function &F, + const TargetTransformInfo *TTI, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; @@ -762,17 +818,15 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, // If this function is a varargs function, we won't be able to PHI the args // right, so don't even try to convert it... if (F.getFunctionType()->isVarArg()) - return false; - - BasicBlock *OldEntry = nullptr; - bool TailCallsAreMarkedTail = false; - SmallVector<PHINode*, 8> ArgumentPHIs; + return MadeChange; // If false, we cannot perform TRE on tail calls marked with the 'tail' // attribute, because doing so would cause the stack size to increase (real // TRE would deallocate variable sized allocas, TRE doesn't). bool CanTRETailMarkedCall = canTRE(F); + TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU); + // Change any tail recursive calls to loops. // // FIXME: The code generator produces really bad code when an 'escaping @@ -782,29 +836,14 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { - bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, - TTI, AA, ORE, DTU); + bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = foldReturnAndProcessPred( - BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall, TTI, AA, ORE, DTU); + Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall); MadeChange |= Change; } } - // If we eliminated any tail recursions, it's possible that we inserted some - // silly PHI nodes which just merge an initial value (the incoming operand) - // with themselves. Check to see if we did and clean up our mess if so. This - // occurs when a function passes an argument straight through to its tail - // call. - for (PHINode *PN : ArgumentPHIs) { - // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { - PN->replaceAllUsesWith(PNV); - PN->eraseFromParent(); - } - } + TRE.cleanupAndFinalize(); return MadeChange; } @@ -838,7 +877,7 @@ struct TailCallElim : public FunctionPass { // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - return eliminateTailRecursion( + return TailRecursionEliminator::eliminate( F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU); @@ -871,7 +910,7 @@ PreservedAnalyses TailCallElimPass::run(Function &F, // UpdateStrategy based on some test results. It is feasible to switch the // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU); + bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp index c8461fdc1608..7c81e6352dec 100644 --- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/WarnMissedTransforms.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils/LoopUtils.h" |