diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
commit | 145449b1e420787bb99721a429341fa6be3adfb6 (patch) | |
tree | 1d56ae694a6de602e348dd80165cf881a36600ed /llvm/lib/Analysis | |
parent | ecbca9f5fb7d7613d2b94982c4825eb0d33d6842 (diff) | |
download | src-145449b1e420787bb99721a429341fa6be3adfb6.tar.gz src-145449b1e420787bb99721a429341fa6be3adfb6.zip |
Diffstat (limited to 'llvm/lib/Analysis')
104 files changed, 5130 insertions, 3162 deletions
diff --git a/llvm/lib/Analysis/AliasAnalysis.cpp b/llvm/lib/Analysis/AliasAnalysis.cpp index a8132e5abf54..e249c38ecd34 100644 --- a/llvm/lib/Analysis/AliasAnalysis.cpp +++ b/llvm/lib/Analysis/AliasAnalysis.cpp @@ -42,7 +42,6 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -680,7 +679,7 @@ ModRefInfo AAResults::getModRefInfo(const Instruction *I, } } - const MemoryLocation &Loc = OptLoc.getValueOr(MemoryLocation()); + const MemoryLocation &Loc = OptLoc.value_or(MemoryLocation()); switch (I->getOpcode()) { case Instruction::VAArg: @@ -988,6 +987,28 @@ bool llvm::isIdentifiedFunctionLocal(const Value *V) { return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasOrByValArgument(V); } +bool llvm::isEscapeSource(const Value *V) { + if (auto *CB = dyn_cast<CallBase>(V)) + return !isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CB, + true); + + // The load case works because isNonEscapingLocalObject considers all + // stores to be escapes (it passes true for the StoreCaptures argument + // to PointerMayBeCaptured). + if (isa<LoadInst>(V)) + return true; + + // The inttoptr case works because isNonEscapingLocalObject considers all + // means of converting or equating a pointer to an int (ptrtoint, ptr store + // which could be followed by an integer load, ptr<->int compare) as + // escaping, and objects located at well-known addresses via platform-specific + // means cannot be considered non-escaping local objects. + if (isa<IntToPtrInst>(V)) + return true; + + return false; +} + bool llvm::isNotVisibleOnUnwind(const Value *Object, bool &RequiresNoCaptureBeforeUnwind) { RequiresNoCaptureBeforeUnwind = false; diff --git a/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp index 1577f1eb70b1..e3446a1f3130 100644 --- a/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp +++ b/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -9,9 +9,7 @@ #include "llvm/Analysis/AliasAnalysisEvaluator.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -19,7 +17,6 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -41,30 +38,48 @@ static cl::opt<bool> PrintMustModRef("print-mustmodref", cl::ReallyHidden); static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); -static void PrintResults(AliasResult AR, bool P, const Value *V1, - const Value *V2, const Module *M) { +static void PrintResults(AliasResult AR, bool P, + std::pair<const Value *, Type *> Loc1, + std::pair<const Value *, Type *> Loc2, + const Module *M) { if (PrintAll || P) { + Type *Ty1 = Loc1.second, *Ty2 = Loc2.second; + unsigned AS1 = Loc1.first->getType()->getPointerAddressSpace(); + unsigned AS2 = Loc2.first->getType()->getPointerAddressSpace(); std::string o1, o2; { raw_string_ostream os1(o1), os2(o2); - V1->printAsOperand(os1, true, M); - V2->printAsOperand(os2, true, M); + Loc1.first->printAsOperand(os1, false, M); + Loc2.first->printAsOperand(os2, false, M); } if (o2 < o1) { std::swap(o1, o2); + std::swap(Ty1, Ty2); + std::swap(AS1, AS2); // Change offset sign for the local AR, for printing only. AR.swap(); } - errs() << " " << AR << ":\t" << o1 << ", " << o2 << "\n"; + errs() << " " << AR << ":\t"; + Ty1->print(errs(), false, /* NoDetails */ true); + if (AS1 != 0) + errs() << " addrspace(" << AS1 << ")"; + errs() << "* " << o1 << ", "; + Ty2->print(errs(), false, /* NoDetails */ true); + if (AS2 != 0) + errs() << " addrspace(" << AS2 << ")"; + errs() << "* " << o2 << "\n"; } } -static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I, - Value *Ptr, Module *M) { +static inline void PrintModRefResults( + const char *Msg, bool P, Instruction *I, + std::pair<const Value *, Type *> Loc, Module *M) { if (PrintAll || P) { errs() << " " << Msg << ": Ptr: "; - Ptr->printAsOperand(errs(), true, M); + Loc.second->print(errs(), false, /* NoDetails */ true); + errs() << "* "; + Loc.first->printAsOperand(errs(), false, M); errs() << "\t<->" << *I << '\n'; } } @@ -84,11 +99,6 @@ static inline void PrintLoadStoreResults(AliasResult AR, bool P, } } -static inline bool isInterestingPointer(Value *V) { - return V->getType()->isPointerTy() - && !isa<ConstantPointerNull>(V); -} - PreservedAnalyses AAEvaluator::run(Function &F, FunctionAnalysisManager &AM) { runInternal(F, AM.getResult<AAManager>(F)); return PreservedAnalyses::all(); @@ -99,38 +109,21 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { ++FunctionCount; - SetVector<Value *> Pointers; + SetVector<std::pair<const Value *, Type *>> Pointers; SmallSetVector<CallBase *, 16> Calls; SetVector<Value *> Loads; SetVector<Value *> Stores; - for (auto &I : F.args()) - if (I.getType()->isPointerTy()) // Add all pointer arguments. - Pointers.insert(&I); - for (Instruction &Inst : instructions(F)) { - if (Inst.getType()->isPointerTy()) // Add all pointer instructions. - Pointers.insert(&Inst); - if (EvalAAMD && isa<LoadInst>(&Inst)) - Loads.insert(&Inst); - if (EvalAAMD && isa<StoreInst>(&Inst)) - Stores.insert(&Inst); - if (auto *Call = dyn_cast<CallBase>(&Inst)) { - Value *Callee = Call->getCalledOperand(); - // Skip actual functions for direct function calls. - if (!isa<Function>(Callee) && isInterestingPointer(Callee)) - Pointers.insert(Callee); - // Consider formals. - for (Use &DataOp : Call->data_ops()) - if (isInterestingPointer(DataOp)) - Pointers.insert(DataOp); - Calls.insert(Call); - } else { - // Consider all operands. - for (Use &Op : Inst.operands()) - if (isInterestingPointer(Op)) - Pointers.insert(Op); - } + if (auto *LI = dyn_cast<LoadInst>(&Inst)) { + Pointers.insert({LI->getPointerOperand(), LI->getType()}); + Loads.insert(LI); + } else if (auto *SI = dyn_cast<StoreInst>(&Inst)) { + Pointers.insert({SI->getPointerOperand(), + SI->getValueOperand()->getType()}); + Stores.insert(SI); + } else if (auto *CB = dyn_cast<CallBase>(&Inst)) + Calls.insert(CB); } if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias || @@ -139,20 +132,12 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { << " pointers, " << Calls.size() << " call sites\n"; // iterate over the worklist, and run the full (n^2)/2 disambiguations - for (SetVector<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end(); - I1 != E; ++I1) { - auto I1Size = LocationSize::afterPointer(); - Type *I1ElTy = (*I1)->getType()->getPointerElementType(); - if (I1ElTy->isSized()) - I1Size = LocationSize::precise(DL.getTypeStoreSize(I1ElTy)); - - for (SetVector<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) { - auto I2Size = LocationSize::afterPointer(); - Type *I2ElTy = (*I2)->getType()->getPointerElementType(); - if (I2ElTy->isSized()) - I2Size = LocationSize::precise(DL.getTypeStoreSize(I2ElTy)); - - AliasResult AR = AA.alias(*I1, I1Size, *I2, I2Size); + for (auto I1 = Pointers.begin(), E = Pointers.end(); I1 != E; ++I1) { + LocationSize Size1 = LocationSize::precise(DL.getTypeStoreSize(I1->second)); + for (auto I2 = Pointers.begin(); I2 != I1; ++I2) { + LocationSize Size2 = + LocationSize::precise(DL.getTypeStoreSize(I2->second)); + AliasResult AR = AA.alias(I1->first, Size1, I2->first, Size2); switch (AR) { case AliasResult::NoAlias: PrintResults(AR, PrintNoAlias, *I1, *I2, F.getParent()); @@ -231,13 +216,10 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { // Mod/ref alias analysis: compare all pairs of calls and values for (CallBase *Call : Calls) { - for (auto Pointer : Pointers) { - auto Size = LocationSize::afterPointer(); - Type *ElTy = Pointer->getType()->getPointerElementType(); - if (ElTy->isSized()) - Size = LocationSize::precise(DL.getTypeStoreSize(ElTy)); - - switch (AA.getModRefInfo(Call, Pointer, Size)) { + for (const auto &Pointer : Pointers) { + LocationSize Size = + LocationSize::precise(DL.getTypeStoreSize(Pointer.second)); + switch (AA.getModRefInfo(Call, Pointer.first, Size)) { case ModRefInfo::NoModRef: PrintModRefResults("NoModRef", PrintNoModRef, Call, Pointer, F.getParent()); diff --git a/llvm/lib/Analysis/AliasSetTracker.cpp b/llvm/lib/Analysis/AliasSetTracker.cpp index 5dc6c7780a0c..234a73bff6a8 100644 --- a/llvm/lib/Analysis/AliasSetTracker.cpp +++ b/llvm/lib/Analysis/AliasSetTracker.cpp @@ -13,16 +13,12 @@ #include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/GuardUtils.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Config/llvm-config.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" @@ -237,8 +233,8 @@ bool AliasSet::aliasesUnknownInst(const Instruction *Inst, if (AliasAny) return true; - assert(Inst->mayReadOrWriteMemory() && - "Instruction must either read or write memory."); + if (!Inst->mayReadOrWriteMemory()) + return false; for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { if (auto *UnknownInst = getUnknownInst(i)) { @@ -258,31 +254,6 @@ bool AliasSet::aliasesUnknownInst(const Instruction *Inst, return false; } -Instruction* AliasSet::getUniqueInstruction() { - if (AliasAny) - // May have collapses alias set - return nullptr; - if (begin() != end()) { - if (!UnknownInsts.empty()) - // Another instruction found - return nullptr; - if (std::next(begin()) != end()) - // Another instruction found - return nullptr; - Value *Addr = begin()->getValue(); - assert(!Addr->user_empty() && - "where's the instruction which added this pointer?"); - if (std::next(Addr->user_begin()) != Addr->user_end()) - // Another instruction found -- this is really restrictive - // TODO: generalize! - return nullptr; - return cast<Instruction>(*(Addr->user_begin())); - } - if (1 != UnknownInsts.size()) - return nullptr; - return cast<Instruction>(UnknownInsts[0]); -} - void AliasSetTracker::clear() { // Delete all the PointerRec entries. for (auto &I : PointerMap) diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp index 177f38af13d8..460dddceaf17 100644 --- a/llvm/lib/Analysis/Analysis.cpp +++ b/llvm/lib/Analysis/Analysis.cpp @@ -40,14 +40,14 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeDelinearizationPass(Registry); initializeDemandedBitsWrapperPassPass(Registry); initializeDominanceFrontierWrapperPassPass(Registry); - initializeDomViewerPass(Registry); - initializeDomPrinterPass(Registry); - initializeDomOnlyViewerPass(Registry); - initializePostDomViewerPass(Registry); - initializeDomOnlyPrinterPass(Registry); - initializePostDomPrinterPass(Registry); - initializePostDomOnlyViewerPass(Registry); - initializePostDomOnlyPrinterPass(Registry); + initializeDomViewerWrapperPassPass(Registry); + initializeDomPrinterWrapperPassPass(Registry); + initializeDomOnlyViewerWrapperPassPass(Registry); + initializePostDomViewerWrapperPassPass(Registry); + initializeDomOnlyPrinterWrapperPassPass(Registry); + initializePostDomPrinterWrapperPassPass(Registry); + initializePostDomOnlyViewerWrapperPassPass(Registry); + initializePostDomOnlyPrinterWrapperPassPass(Registry); initializeAAResultsWrapperPassPass(Registry); initializeGlobalsAAWrapperPassPass(Registry); initializeIVUsersWrapperPassPass(Registry); diff --git a/llvm/lib/Analysis/AssumeBundleQueries.cpp b/llvm/lib/Analysis/AssumeBundleQueries.cpp index 9d4fe1225b33..7440dbd29ccf 100644 --- a/llvm/lib/Analysis/AssumeBundleQueries.cpp +++ b/llvm/lib/Analysis/AssumeBundleQueries.cpp @@ -10,8 +10,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/DebugCounter.h" diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp index 3e0214e21ecd..e7e476dfb572 100644 --- a/llvm/lib/Analysis/AssumptionCache.cpp +++ b/llvm/lib/Analysis/AssumptionCache.cpp @@ -11,18 +11,17 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" @@ -31,7 +30,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> #include <utility> diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp index 0a0b53796add..c78f822b8bcf 100644 --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -22,7 +22,6 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/PhiValues.h" @@ -45,7 +44,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" @@ -105,29 +103,6 @@ bool BasicAAResult::invalidate(Function &Fn, const PreservedAnalyses &PA, // Useful predicates //===----------------------------------------------------------------------===// -/// Returns true if the pointer is one which would have been considered an -/// escape by isNonEscapingLocalObject. -static bool isEscapeSource(const Value *V) { - if (isa<CallBase>(V)) - return true; - - // The load case works because isNonEscapingLocalObject considers all - // stores to be escapes (it passes true for the StoreCaptures argument - // to PointerMayBeCaptured). - if (isa<LoadInst>(V)) - return true; - - // The inttoptr case works because isNonEscapingLocalObject considers all - // means of converting or equating a pointer to an int (ptrtoint, ptr store - // which could be followed by an integer load, ptr<->int compare) as - // escaping, and objects located at well-known addresses via platform-specific - // means cannot be considered non-escaping local objects. - if (isa<IntToPtrInst>(V)) - return true; - - return false; -} - /// Returns the size of the object specified by V or UnknownSize if unknown. static uint64_t getObjectSize(const Value *V, const DataLayout &DL, const TargetLibraryInfo &TLI, @@ -234,7 +209,7 @@ bool EarliestEscapeInfo::isNotCapturedBeforeOrAt(const Value *Object, if (Iter.second) { Instruction *EarliestCapture = FindEarliestCapture( Object, *const_cast<Function *>(I->getFunction()), - /*ReturnCaptures=*/false, /*StoreCaptures=*/true, DT); + /*ReturnCaptures=*/false, /*StoreCaptures=*/true, DT, EphValues); if (EarliestCapture) { auto Ins = Inst2Obj.insert({EarliestCapture, {}}); Ins.first->second.push_back(Object); @@ -661,8 +636,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL, unsigned TypeSize = DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize(); LE = LE.mul(APInt(IndexSize, TypeSize), GEPOp->isInBounds()); - Decomposed.Offset += LE.Offset.sextOrSelf(MaxIndexSize); - APInt Scale = LE.Scale.sextOrSelf(MaxIndexSize); + Decomposed.Offset += LE.Offset.sext(MaxIndexSize); + APInt Scale = LE.Scale.sext(MaxIndexSize); // If we already had an occurrence of this index variable, merge this // scale into it. For example, we want to handle: @@ -1299,8 +1274,31 @@ AliasResult BasicAAResult::aliasGEP( const VariableGEPIndex &Var = DecompGEP1.VarIndices[0]; if (Var.Val.TruncBits == 0 && isKnownNonZero(Var.Val.V, DL, 0, &AC, Var.CxtI, DT)) { - // If V != 0 then abs(VarIndex) >= abs(Scale). - MinAbsVarIndex = Var.Scale.abs(); + // If V != 0, then abs(VarIndex) > 0. + MinAbsVarIndex = APInt(Var.Scale.getBitWidth(), 1); + + // Check if abs(V*Scale) >= abs(Scale) holds in the presence of + // potentially wrapping math. + auto MultiplyByScaleNoWrap = [](const VariableGEPIndex &Var) { + if (Var.IsNSW) + return true; + + int ValOrigBW = Var.Val.V->getType()->getPrimitiveSizeInBits(); + // If Scale is small enough so that abs(V*Scale) >= abs(Scale) holds. + // The max value of abs(V) is 2^ValOrigBW - 1. Multiplying with a + // constant smaller than 2^(bitwidth(Val) - ValOrigBW) won't wrap. + int MaxScaleValueBW = Var.Val.getBitWidth() - ValOrigBW; + if (MaxScaleValueBW <= 0) + return false; + return Var.Scale.ule( + APInt::getMaxValue(MaxScaleValueBW).zext(Var.Scale.getBitWidth())); + }; + // Refine MinAbsVarIndex, if abs(Scale*V) >= abs(Scale) holds in the + // presence of potentially wrapping math. + if (MultiplyByScaleNoWrap(Var)) { + // If V != 0 then abs(VarIndex) >= abs(Scale). + MinAbsVarIndex = Var.Scale.abs(); + } } } else if (DecompGEP1.VarIndices.size() == 2) { // VarIndex = Scale*V0 + (-Scale)*V1. @@ -1370,15 +1368,15 @@ BasicAAResult::aliasSelect(const SelectInst *SI, LocationSize SISize, // If both arms of the Select node NoAlias or MustAlias V2, then returns // NoAlias / MustAlias. Otherwise, returns MayAlias. - AliasResult Alias = getBestAAResults().alias( - MemoryLocation(V2, V2Size), - MemoryLocation(SI->getTrueValue(), SISize), AAQI); + AliasResult Alias = + getBestAAResults().alias(MemoryLocation(SI->getTrueValue(), SISize), + MemoryLocation(V2, V2Size), AAQI); if (Alias == AliasResult::MayAlias) return AliasResult::MayAlias; - AliasResult ThisAlias = getBestAAResults().alias( - MemoryLocation(V2, V2Size), - MemoryLocation(SI->getFalseValue(), SISize), AAQI); + AliasResult ThisAlias = + getBestAAResults().alias(MemoryLocation(SI->getFalseValue(), SISize), + MemoryLocation(V2, V2Size), AAQI); return MergeAliasResults(ThisAlias, Alias); } @@ -1500,8 +1498,7 @@ AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize, AAQueryInfo *UseAAQI = BlockInserted ? &NewAAQI : &AAQI; AliasResult Alias = getBestAAResults().alias( - MemoryLocation(V2, V2Size), - MemoryLocation(V1Srcs[0], PNSize), *UseAAQI); + MemoryLocation(V1Srcs[0], PNSize), MemoryLocation(V2, V2Size), *UseAAQI); // Early exit if the check of the first PHI source against V2 is MayAlias. // Other results are not possible. @@ -1518,7 +1515,7 @@ AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize, Value *V = V1Srcs[i]; AliasResult ThisAlias = getBestAAResults().alias( - MemoryLocation(V2, V2Size), MemoryLocation(V, PNSize), *UseAAQI); + MemoryLocation(V, PNSize), MemoryLocation(V2, V2Size), *UseAAQI); Alias = MergeAliasResults(ThisAlias, Alias); if (Alias == AliasResult::MayAlias) break; diff --git a/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/llvm/lib/Analysis/BlockFrequencyInfo.cpp index b464071a33e6..436b01764033 100644 --- a/llvm/lib/Analysis/BlockFrequencyInfo.cpp +++ b/llvm/lib/Analysis/BlockFrequencyInfo.cpp @@ -25,7 +25,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> #include <string> diff --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp index 2a5e1f65d731..ec8d318b675b 100644 --- a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -13,7 +13,6 @@ #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/None.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/Config/llvm-config.h" @@ -22,8 +21,8 @@ #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ScaledNumber.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScaledNumber.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> @@ -48,7 +47,7 @@ cl::opt<bool> CheckBFIUnknownBlockQueries( "for debugging missed BFI updates")); cl::opt<bool> UseIterativeBFIInference( - "use-iterative-bfi-inference", cl::init(false), cl::Hidden, cl::ZeroOrMore, + "use-iterative-bfi-inference", cl::Hidden, cl::desc("Apply an iterative post-processing to infer correct BFI counts")); cl::opt<unsigned> IterativeBFIMaxIterationsPerBlock( diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp index ffb80134749a..1d880424e55c 100644 --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -414,8 +414,7 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1)); auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB}); if (EstimatedWeight && - EstimatedWeight.getValue() <= - static_cast<uint32_t>(BlockExecWeight::UNREACHABLE)) + *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE)) UnreachableIdxs.push_back(I - 1); else ReachableIdxs.push_back(I - 1); @@ -688,7 +687,7 @@ Optional<uint32_t> BranchProbabilityInfo::getMaxEstimatedEdgeWeight( if (!Weight) return None; - if (!MaxWeight || MaxWeight.getValue() < Weight.getValue()) + if (!MaxWeight || *MaxWeight < *Weight) MaxWeight = Weight; } @@ -852,8 +851,7 @@ void BranchProbabilityInfo::computeEestimateBlockWeight( if (LoopWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE)) LoopWeight = static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO); - EstimatedLoopWeight.insert( - {LoopBB.getLoopData(), LoopWeight.getValue()}); + EstimatedLoopWeight.insert({LoopBB.getLoopData(), *LoopWeight}); // Add all blocks entering the loop into working list. getLoopEnterBlocks(LoopBB, BlockWorkList); } @@ -875,7 +873,7 @@ void BranchProbabilityInfo::computeEestimateBlockWeight( auto MaxWeight = getMaxEstimatedEdgeWeight(LoopBB, successors(BB)); if (MaxWeight) - propagateEstimatedBlockWeight(LoopBB, DT, PDT, MaxWeight.getValue(), + propagateEstimatedBlockWeight(LoopBB, DT, PDT, *MaxWeight, BlockWorkList, LoopWorkList); } } while (!BlockWorkList.empty() || !LoopWorkList.empty()); @@ -913,7 +911,7 @@ bool BranchProbabilityInfo::calcEstimatedHeuristics(const BasicBlock *BB) { // Scale down loop exiting weight by trip count. Weight = std::max( static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO), - Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) / + Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) / TC); } bool IsUnlikelyEdge = LoopBB.getLoop() && UnlikelyBlocks.contains(SuccBB); @@ -923,15 +921,14 @@ bool BranchProbabilityInfo::calcEstimatedHeuristics(const BasicBlock *BB) { // 'Unlikely' blocks have twice lower weight. Weight = std::max( static_cast<uint32_t>(BlockExecWeight::LOWEST_NON_ZERO), - Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) / - 2); + Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT)) / 2); } if (Weight) FoundEstimatedWeight = true; auto WeightVal = - Weight.getValueOr(static_cast<uint32_t>(BlockExecWeight::DEFAULT)); + Weight.value_or(static_cast<uint32_t>(BlockExecWeight::DEFAULT)); TotalWeight += WeightVal; SuccWeights.push_back(WeightVal); } diff --git a/llvm/lib/Analysis/CFG.cpp b/llvm/lib/Analysis/CFG.cpp index ec25ee161e2c..1902d72f2f89 100644 --- a/llvm/lib/Analysis/CFG.cpp +++ b/llvm/lib/Analysis/CFG.cpp @@ -127,11 +127,7 @@ bool llvm::isCriticalEdge(const Instruction *TI, const BasicBlock *Dest, // the outermost loop in the loop nest that contains BB. static const Loop *getOutermostLoop(const LoopInfo *LI, const BasicBlock *BB) { const Loop *L = LI->getLoopFor(BB); - if (L) { - while (const Loop *Parent = L->getParentLoop()) - L = Parent; - } - return L; + return L ? L->getOutermostLoop() : nullptr; } bool llvm::isPotentiallyReachableFromMany( diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp index 04ccdc590845..f8eba1a00f28 100644 --- a/llvm/lib/Analysis/CFGPrinter.cpp +++ b/llvm/lib/Analysis/CFGPrinter.cpp @@ -23,7 +23,7 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" -#include <algorithm> +#include "llvm/Support/GraphWriter.h" using namespace llvm; diff --git a/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp index 1216d03e448b..602a01867f3b 100644 --- a/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp +++ b/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -831,14 +831,14 @@ CFLAndersAAResult::ensureCached(const Function &Fn) { scan(Fn); Iter = Cache.find(&Fn); assert(Iter != Cache.end()); - assert(Iter->second.hasValue()); + assert(Iter->second); } return Iter->second; } const AliasSummary *CFLAndersAAResult::getAliasSummary(const Function &Fn) { auto &FunInfo = ensureCached(Fn); - if (FunInfo.hasValue()) + if (FunInfo) return &FunInfo->getAliasSummary(); else return nullptr; diff --git a/llvm/lib/Analysis/CFLGraph.h b/llvm/lib/Analysis/CFLGraph.h index 02a13d673f40..60fc8d18678c 100644 --- a/llvm/lib/Analysis/CFLGraph.h +++ b/llvm/lib/Analysis/CFLGraph.h @@ -403,7 +403,7 @@ template <typename CFLAA> class CFLGraphBuilder { auto &RetParamRelations = Summary->RetParamRelations; for (auto &Relation : RetParamRelations) { auto IRelation = instantiateExternalRelation(Relation, Call); - if (IRelation.hasValue()) { + if (IRelation) { Graph.addNode(IRelation->From); Graph.addNode(IRelation->To); Graph.addEdge(IRelation->From, IRelation->To); @@ -413,7 +413,7 @@ template <typename CFLAA> class CFLGraphBuilder { auto &RetParamAttributes = Summary->RetParamAttributes; for (auto &Attribute : RetParamAttributes) { auto IAttr = instantiateExternalAttribute(Attribute, Call); - if (IAttr.hasValue()) + if (IAttr) Graph.addNode(IAttr->IValue, IAttr->Attr); } } diff --git a/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp index 090dccc53b6e..f92869c2ec63 100644 --- a/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp +++ b/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp @@ -165,7 +165,7 @@ CFLSteensAAResult::FunctionInfo::FunctionInfo( assert(RetVal != nullptr); assert(RetVal->getType()->isPointerTy()); auto RetInfo = Sets.find(InstantiatedValue{RetVal, 0}); - if (RetInfo.hasValue()) + if (RetInfo) AddToRetParamRelations(0, RetInfo->Index); } @@ -174,7 +174,7 @@ CFLSteensAAResult::FunctionInfo::FunctionInfo( for (auto &Param : Fn.args()) { if (Param.getType()->isPointerTy()) { auto ParamInfo = Sets.find(InstantiatedValue{&Param, 0}); - if (ParamInfo.hasValue()) + if (ParamInfo) AddToRetParamRelations(I + 1, ParamInfo->Index); } ++I; @@ -250,14 +250,14 @@ CFLSteensAAResult::ensureCached(Function *Fn) { scan(Fn); Iter = Cache.find(Fn); assert(Iter != Cache.end()); - assert(Iter->second.hasValue()); + assert(Iter->second); } return Iter->second; } const AliasSummary *CFLSteensAAResult::getAliasSummary(Function &Fn) { auto &FunInfo = ensureCached(&Fn); - if (FunInfo.hasValue()) + if (FunInfo) return &FunInfo->getAliasSummary(); else return nullptr; @@ -293,15 +293,15 @@ AliasResult CFLSteensAAResult::query(const MemoryLocation &LocA, assert(Fn != nullptr); auto &MaybeInfo = ensureCached(Fn); - assert(MaybeInfo.hasValue()); + assert(MaybeInfo); auto &Sets = MaybeInfo->getStratifiedSets(); auto MaybeA = Sets.find(InstantiatedValue{ValA, 0}); - if (!MaybeA.hasValue()) + if (!MaybeA) return AliasResult::MayAlias; auto MaybeB = Sets.find(InstantiatedValue{ValB, 0}); - if (!MaybeB.hasValue()) + if (!MaybeB) return AliasResult::MayAlias; auto SetA = *MaybeA; diff --git a/llvm/lib/Analysis/CGSCCPassManager.cpp b/llvm/lib/Analysis/CGSCCPassManager.cpp index c60b70ae5b69..b2e7422bbf8b 100644 --- a/llvm/lib/Analysis/CGSCCPassManager.cpp +++ b/llvm/lib/Analysis/CGSCCPassManager.cpp @@ -9,6 +9,7 @@ #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -27,7 +28,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> #include <iterator> @@ -164,9 +164,9 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { InlinedInternalEdges; CGSCCUpdateResult UR = { - RCWorklist, CWorklist, InvalidRefSCCSet, InvalidSCCSet, - nullptr, nullptr, PreservedAnalyses::all(), InlinedInternalEdges, - {}}; + RCWorklist, CWorklist, InvalidRefSCCSet, + InvalidSCCSet, nullptr, PreservedAnalyses::all(), + InlinedInternalEdges, {}}; // Request PassInstrumentation from analysis manager, will use it to run // instrumenting callbacks for the passes later. @@ -174,9 +174,8 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { PreservedAnalyses PA = PreservedAnalyses::all(); CG.buildRefSCCs(); - for (auto RCI = CG.postorder_ref_scc_begin(), - RCE = CG.postorder_ref_scc_end(); - RCI != RCE;) { + for (LazyCallGraph::RefSCC &RC : + llvm::make_early_inc_range(CG.postorder_ref_sccs())) { assert(RCWorklist.empty() && "Should always start with an empty RefSCC worklist"); // The postorder_ref_sccs range we are walking is lazily constructed, so @@ -190,7 +189,7 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { // // We also eagerly increment the iterator to the next position because // the CGSCC passes below may delete the current RefSCC. - RCWorklist.insert(&*RCI++); + RCWorklist.insert(&RC); do { LazyCallGraph::RefSCC *RC = RCWorklist.pop_back_val(); @@ -230,11 +229,15 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { LLVM_DEBUG(dbgs() << "Skipping redundant run on SCC: " << *C << "\n"); continue; } - if (&C->getOuterRefSCC() != RC) { - LLVM_DEBUG(dbgs() << "Skipping an SCC that is now part of some other " - "RefSCC...\n"); - continue; - } + // We used to also check if the current SCC is part of the current + // RefSCC and bail if it wasn't, since it should be in RCWorklist. + // However, this can cause compile time explosions in some cases on + // modules with a huge RefSCC. If a non-trivial amount of SCCs in the + // huge RefSCC can become their own child RefSCC, we create one child + // RefSCC, bail on the current RefSCC, visit the child RefSCC, revisit + // the huge RefSCC, and repeat. By visiting all SCCs in the original + // RefSCC we create all the child RefSCCs in one pass of the RefSCC, + // rather one pass of the RefSCC creating one child RefSCC at a time. // Ensure we can proxy analysis updates from the CGSCC analysis manager // into the the Function analysis manager by getting a proxy here. @@ -264,11 +267,8 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { // Check that we didn't miss any update scenario. assert(!InvalidSCCSet.count(C) && "Processing an invalid SCC!"); assert(C->begin() != C->end() && "Cannot have an empty SCC!"); - assert(&C->getOuterRefSCC() == RC && - "Processing an SCC in a different RefSCC!"); LastUpdatedC = UR.UpdatedC; - UR.UpdatedRC = nullptr; UR.UpdatedC = nullptr; // Check the PassInstrumentation's BeforePass callbacks before @@ -290,7 +290,6 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { // Update the SCC and RefSCC if necessary. C = UR.UpdatedC ? UR.UpdatedC : C; - RC = UR.UpdatedRC ? UR.UpdatedRC : RC; if (UR.UpdatedC) { // If we're updating the SCC, also update the FAM inside the proxy's @@ -1213,10 +1212,8 @@ static LazyCallGraph::SCC &updateCGAndAnalysisManagerForPass( assert(!UR.InvalidatedRefSCCs.count(RC) && "Invalidated the current RefSCC!"); assert(&C->getOuterRefSCC() == RC && "Current SCC not in current RefSCC!"); - // Record the current RefSCC and SCC for higher layers of the CGSCC pass - // manager now that all the updates have been applied. - if (RC != &InitialRC) - UR.UpdatedRC = RC; + // Record the current SCC for higher layers of the CGSCC pass manager now that + // all the updates have been applied. if (C != &InitialC) UR.UpdatedC = C; diff --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp index dfbd29b7d636..f85527122b2a 100644 --- a/llvm/lib/Analysis/CallGraph.cpp +++ b/llvm/lib/Analysis/CallGraph.cpp @@ -21,7 +21,6 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> using namespace llvm; @@ -70,8 +69,7 @@ bool CallGraph::invalidate(Module &, const PreservedAnalyses &PA, // Check whether the analysis, all analyses on functions, or the function's // CFG have been preserved. auto PAC = PA.getChecker<CallGraphAnalysis>(); - return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>() || - PAC.preservedSet<CFGAnalyses>()); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>()); } void CallGraph::addToCallGraph(Function *F) { diff --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp index 930cb13c0cb3..8438f33f4712 100644 --- a/llvm/lib/Analysis/CallGraphSCCPass.cpp +++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -28,7 +28,6 @@ #include "llvm/IR/OptBisect.h" #include "llvm/IR/PassTimingInfo.h" #include "llvm/IR/PrintPasses.h" -#include "llvm/IR/StructuralHash.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -271,7 +270,7 @@ bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, Calls.count(Call) || // If the call edge is not from a call or invoke, or it is a - // instrinsic call, then the function pass RAUW'd a call with + // intrinsic call, then the function pass RAUW'd a call with // another value. This can happen when constant folding happens // of well known functions etc. (Call->getCalledFunction() && @@ -470,7 +469,7 @@ bool CGPassManager::RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, initializeAnalysisImpl(P); #ifdef EXPENSIVE_CHECKS - uint64_t RefHash = StructuralHash(CG.getModule()); + uint64_t RefHash = P->structuralHash(CG.getModule()); #endif // Actually run this pass on the current SCC. @@ -480,7 +479,7 @@ bool CGPassManager::RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, Changed |= LocalChanged; #ifdef EXPENSIVE_CHECKS - if (!LocalChanged && (RefHash != StructuralHash(CG.getModule()))) { + if (!LocalChanged && (RefHash != P->structuralHash(CG.getModule()))) { llvm::errs() << "Pass modifies its input and doesn't report it: " << P->getPassName() << "\n"; llvm_unreachable("Pass modifies its input and doesn't report it"); diff --git a/llvm/lib/Analysis/CallPrinter.cpp b/llvm/lib/Analysis/CallPrinter.cpp index 829532a0fa10..65e3184fad91 100644 --- a/llvm/lib/Analysis/CallPrinter.cpp +++ b/llvm/lib/Analysis/CallPrinter.cpp @@ -14,18 +14,23 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/CallPrinter.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/DOTGraphTraitsPass.h" #include "llvm/Analysis/HeatUtils.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/DOTGraphTraits.h" +#include "llvm/Support/GraphWriter.h" using namespace llvm; +namespace llvm { +template <class GraphType> struct GraphTraits; +} + // This option shows static (relative) call counts. // FIXME: // Need to show real counts when profile data is available @@ -213,6 +218,71 @@ struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits { } // end llvm namespace namespace { +void doCallGraphDOTPrinting( + Module &M, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { + std::string Filename; + if (!CallGraphDotFilenamePrefix.empty()) + Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot"); + else + Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot"); + errs() << "Writing '" << Filename << "'..."; + + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::OF_Text); + + CallGraph CG(M); + CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI); + + if (!EC) + WriteGraph(File, &CFGInfo); + else + errs() << " error opening file for writing!"; + errs() << "\n"; +} + +void viewCallGraph(Module &M, + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { + CallGraph CG(M); + CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI); + + std::string Title = + DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo); + ViewGraph(&CFGInfo, "callgraph", true, Title); +} +} // namespace + +namespace llvm { +PreservedAnalyses CallGraphDOTPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + auto LookupBFI = [&FAM](Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + doCallGraphDOTPrinting(M, LookupBFI); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses CallGraphViewerPass::run(Module &M, + ModuleAnalysisManager &AM) { + + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + auto LookupBFI = [&FAM](Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + viewCallGraph(M, LookupBFI); + + return PreservedAnalyses::all(); +} +} // namespace llvm + +namespace { // Viewer class CallGraphViewer : public ModulePass { public: @@ -234,12 +304,7 @@ bool CallGraphViewer::runOnModule(Module &M) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - CallGraph CG(M); - CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI); - - std::string Title = - DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo); - ViewGraph(&CFGInfo, "callgraph", true, Title); + viewCallGraph(M, LookupBFI); return false; } @@ -266,24 +331,7 @@ bool CallGraphDOTPrinter::runOnModule(Module &M) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - std::string Filename; - if (!CallGraphDotFilenamePrefix.empty()) - Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot"); - else - Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot"); - errs() << "Writing '" << Filename << "'..."; - - std::error_code EC; - raw_fd_ostream File(Filename, EC, sys::fs::OF_Text); - - CallGraph CG(M); - CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI); - - if (!EC) - WriteGraph(File, &CFGInfo); - else - errs() << " error opening file for writing!"; - errs() << "\n"; + doCallGraphDOTPrinting(M, LookupBFI); return false; } diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp index ba8462e659d5..f4fd660ac7e0 100644 --- a/llvm/lib/Analysis/CaptureTracking.cpp +++ b/llvm/lib/Analysis/CaptureTracking.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -44,15 +45,15 @@ STATISTIC(NumNotCapturedBefore, "Number of pointers not captured before"); /// use it where possible. The caching version can use much higher limit or /// don't have this cap at all. static cl::opt<unsigned> -DefaultMaxUsesToExplore("capture-tracking-max-uses-to-explore", cl::Hidden, - cl::desc("Maximal number of uses to explore."), - cl::init(20)); + DefaultMaxUsesToExplore("capture-tracking-max-uses-to-explore", cl::Hidden, + cl::desc("Maximal number of uses to explore."), + cl::init(100)); unsigned llvm::getDefaultMaxUsesToExploreForCaptureTracking() { return DefaultMaxUsesToExplore; } -CaptureTracker::~CaptureTracker() {} +CaptureTracker::~CaptureTracker() = default; bool CaptureTracker::shouldExplore(const Use *U) { return true; } @@ -74,8 +75,10 @@ bool CaptureTracker::isDereferenceableOrNull(Value *O, const DataLayout &DL) { namespace { struct SimpleCaptureTracker : public CaptureTracker { - explicit SimpleCaptureTracker(bool ReturnCaptures) - : ReturnCaptures(ReturnCaptures) {} + explicit SimpleCaptureTracker( + + const SmallPtrSetImpl<const Value *> &EphValues, bool ReturnCaptures) + : EphValues(EphValues), ReturnCaptures(ReturnCaptures) {} void tooManyUses() override { Captured = true; } @@ -83,10 +86,15 @@ namespace { if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) return false; + if (EphValues.contains(U->getUser())) + return false; + Captured = true; return true; } + const SmallPtrSetImpl<const Value *> &EphValues; + bool ReturnCaptures; bool Captured = false; @@ -154,8 +162,9 @@ namespace { // escape are not in a cycle. struct EarliestCaptures : public CaptureTracker { - EarliestCaptures(bool ReturnCaptures, Function &F, const DominatorTree &DT) - : DT(DT), ReturnCaptures(ReturnCaptures), F(F) {} + EarliestCaptures(bool ReturnCaptures, Function &F, const DominatorTree &DT, + const SmallPtrSetImpl<const Value *> &EphValues) + : EphValues(EphValues), DT(DT), ReturnCaptures(ReturnCaptures), F(F) {} void tooManyUses() override { Captured = true; @@ -167,6 +176,9 @@ namespace { if (isa<ReturnInst>(I) && !ReturnCaptures) return false; + if (EphValues.contains(I)) + return false; + if (!EarliestCapture) { EarliestCapture = I; } else if (EarliestCapture->getParent() == I->getParent()) { @@ -193,6 +205,8 @@ namespace { return false; } + const SmallPtrSetImpl<const Value *> &EphValues; + Instruction *EarliestCapture = nullptr; const DominatorTree &DT; @@ -212,8 +226,18 @@ namespace { /// counts as capturing it or not. The boolean StoreCaptures specified whether /// storing the value (or part of it) into memory anywhere automatically /// counts as capturing it or not. -bool llvm::PointerMayBeCaptured(const Value *V, - bool ReturnCaptures, bool StoreCaptures, +bool llvm::PointerMayBeCaptured(const Value *V, bool ReturnCaptures, + bool StoreCaptures, unsigned MaxUsesToExplore) { + SmallPtrSet<const Value *, 1> Empty; + return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures, Empty, + MaxUsesToExplore); +} + +/// Variant of the above function which accepts a set of Values that are +/// ephemeral and cannot cause pointers to escape. +bool llvm::PointerMayBeCaptured(const Value *V, bool ReturnCaptures, + bool StoreCaptures, + const SmallPtrSetImpl<const Value *> &EphValues, unsigned MaxUsesToExplore) { assert(!isa<GlobalValue>(V) && "It doesn't make sense to ask whether a global is captured."); @@ -224,7 +248,7 @@ bool llvm::PointerMayBeCaptured(const Value *V, // take advantage of this. (void)StoreCaptures; - SimpleCaptureTracker SCT(ReturnCaptures); + SimpleCaptureTracker SCT(EphValues, ReturnCaptures); PointerMayBeCaptured(V, &SCT, MaxUsesToExplore); if (SCT.Captured) ++NumCaptured; @@ -266,14 +290,16 @@ bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, return CB.Captured; } -Instruction *llvm::FindEarliestCapture(const Value *V, Function &F, - bool ReturnCaptures, bool StoreCaptures, - const DominatorTree &DT, - unsigned MaxUsesToExplore) { +Instruction * +llvm::FindEarliestCapture(const Value *V, Function &F, bool ReturnCaptures, + bool StoreCaptures, const DominatorTree &DT, + + const SmallPtrSetImpl<const Value *> &EphValues, + unsigned MaxUsesToExplore) { assert(!isa<GlobalValue>(V) && "It doesn't make sense to ask whether a global is captured."); - EarliestCaptures CB(ReturnCaptures, F, DT); + EarliestCaptures CB(ReturnCaptures, F, DT, EphValues); PointerMayBeCaptured(V, &CB, MaxUsesToExplore); if (CB.Captured) ++NumCapturedBefore; @@ -282,6 +308,132 @@ Instruction *llvm::FindEarliestCapture(const Value *V, Function &F, return CB.EarliestCapture; } +UseCaptureKind llvm::DetermineUseCaptureKind( + const Use &U, + function_ref<bool(Value *, const DataLayout &)> IsDereferenceableOrNull) { + Instruction *I = cast<Instruction>(U.getUser()); + + switch (I->getOpcode()) { + case Instruction::Call: + case Instruction::Invoke: { + auto *Call = cast<CallBase>(I); + // Not captured if the callee is readonly, doesn't return a copy through + // its return value and doesn't unwind (a readonly function can leak bits + // by throwing an exception or not depending on the input value). + if (Call->onlyReadsMemory() && Call->doesNotThrow() && + Call->getType()->isVoidTy()) + return UseCaptureKind::NO_CAPTURE; + + // The pointer is not captured if returned pointer is not captured. + // NOTE: CaptureTracking users should not assume that only functions + // marked with nocapture do not capture. This means that places like + // getUnderlyingObject in ValueTracking or DecomposeGEPExpression + // in BasicAA also need to know about this property. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call, true)) + return UseCaptureKind::PASSTHROUGH; + + // Volatile operations effectively capture the memory location that they + // load and store to. + if (auto *MI = dyn_cast<MemIntrinsic>(Call)) + if (MI->isVolatile()) + return UseCaptureKind::MAY_CAPTURE; + + // Calling a function pointer does not in itself cause the pointer to + // be captured. This is a subtle point considering that (for example) + // the callee might return its own address. It is analogous to saying + // that loading a value from a pointer does not cause the pointer to be + // captured, even though the loaded value might be the pointer itself + // (think of self-referential objects). + if (Call->isCallee(&U)) + return UseCaptureKind::NO_CAPTURE; + + // Not captured if only passed via 'nocapture' arguments. + if (Call->isDataOperand(&U) && + !Call->doesNotCapture(Call->getDataOperandNo(&U))) { + // The parameter is not marked 'nocapture' - captured. + return UseCaptureKind::MAY_CAPTURE; + } + return UseCaptureKind::NO_CAPTURE; + } + case Instruction::Load: + // Volatile loads make the address observable. + if (cast<LoadInst>(I)->isVolatile()) + return UseCaptureKind::MAY_CAPTURE; + return UseCaptureKind::NO_CAPTURE; + case Instruction::VAArg: + // "va-arg" from a pointer does not cause it to be captured. + return UseCaptureKind::NO_CAPTURE; + case Instruction::Store: + // Stored the pointer - conservatively assume it may be captured. + // Volatile stores make the address observable. + if (U.getOperandNo() == 0 || cast<StoreInst>(I)->isVolatile()) + return UseCaptureKind::MAY_CAPTURE; + return UseCaptureKind::NO_CAPTURE; + case Instruction::AtomicRMW: { + // atomicrmw conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ARMWI = cast<AtomicRMWInst>(I); + if (U.getOperandNo() == 1 || ARMWI->isVolatile()) + return UseCaptureKind::MAY_CAPTURE; + return UseCaptureKind::NO_CAPTURE; + } + case Instruction::AtomicCmpXchg: { + // cmpxchg conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ACXI = cast<AtomicCmpXchgInst>(I); + if (U.getOperandNo() == 1 || U.getOperandNo() == 2 || ACXI->isVolatile()) + return UseCaptureKind::MAY_CAPTURE; + return UseCaptureKind::NO_CAPTURE; + } + case Instruction::BitCast: + case Instruction::GetElementPtr: + case Instruction::PHI: + case Instruction::Select: + case Instruction::AddrSpaceCast: + // The original value is not captured via this if the new value isn't. + return UseCaptureKind::PASSTHROUGH; + case Instruction::ICmp: { + unsigned Idx = U.getOperandNo(); + unsigned OtherIdx = 1 - Idx; + if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) { + // Don't count comparisons of a no-alias return value against null as + // captures. This allows us to ignore comparisons of malloc results + // with null, for example. + if (CPN->getType()->getAddressSpace() == 0) + if (isNoAliasCall(U.get()->stripPointerCasts())) + return UseCaptureKind::NO_CAPTURE; + if (!I->getFunction()->nullPointerIsDefined()) { + auto *O = I->getOperand(Idx)->stripPointerCastsSameRepresentation(); + // Comparing a dereferenceable_or_null pointer against null cannot + // lead to pointer escapes, because if it is not null it must be a + // valid (in-bounds) pointer. + const DataLayout &DL = I->getModule()->getDataLayout(); + if (IsDereferenceableOrNull && IsDereferenceableOrNull(O, DL)) + return UseCaptureKind::NO_CAPTURE; + } + } + // Comparison against value stored in global variable. Given the pointer + // does not escape, its value cannot be guessed and stored separately in a + // global variable. + auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIdx)); + if (LI && isa<GlobalVariable>(LI->getPointerOperand())) + return UseCaptureKind::NO_CAPTURE; + // Otherwise, be conservative. There are crazy ways to capture pointers + // using comparisons. + return UseCaptureKind::MAY_CAPTURE; + } + default: + // Something else - be conservative and say it is captured. + return UseCaptureKind::MAY_CAPTURE; + } +} + void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker, unsigned MaxUsesToExplore) { assert(V->getType()->isPointerTy() && "Capture is for pointers only!"); @@ -293,11 +445,10 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker, SmallSet<const Use *, 20> Visited; auto AddUses = [&](const Value *V) { - unsigned Count = 0; for (const Use &U : V->uses()) { // If there are lots of uses, conservatively say that the value // is captured to avoid taking too much compile time. - if (Count++ >= MaxUsesToExplore) { + if (Visited.size() >= MaxUsesToExplore) { Tracker->tooManyUses(); return false; } @@ -312,144 +463,22 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker, if (!AddUses(V)) return; + auto IsDereferenceableOrNull = [Tracker](Value *V, const DataLayout &DL) { + return Tracker->isDereferenceableOrNull(V, DL); + }; while (!Worklist.empty()) { const Use *U = Worklist.pop_back_val(); - Instruction *I = cast<Instruction>(U->getUser()); - - switch (I->getOpcode()) { - case Instruction::Call: - case Instruction::Invoke: { - auto *Call = cast<CallBase>(I); - // Not captured if the callee is readonly, doesn't return a copy through - // its return value and doesn't unwind (a readonly function can leak bits - // by throwing an exception or not depending on the input value). - if (Call->onlyReadsMemory() && Call->doesNotThrow() && - Call->getType()->isVoidTy()) - break; - - // The pointer is not captured if returned pointer is not captured. - // NOTE: CaptureTracking users should not assume that only functions - // marked with nocapture do not capture. This means that places like - // getUnderlyingObject in ValueTracking or DecomposeGEPExpression - // in BasicAA also need to know about this property. - if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call, - true)) { - if (!AddUses(Call)) - return; - break; - } - - // Volatile operations effectively capture the memory location that they - // load and store to. - if (auto *MI = dyn_cast<MemIntrinsic>(Call)) - if (MI->isVolatile()) - if (Tracker->captured(U)) - return; - - // Calling a function pointer does not in itself cause the pointer to - // be captured. This is a subtle point considering that (for example) - // the callee might return its own address. It is analogous to saying - // that loading a value from a pointer does not cause the pointer to be - // captured, even though the loaded value might be the pointer itself - // (think of self-referential objects). - if (Call->isCallee(U)) - break; - - // Not captured if only passed via 'nocapture' arguments. - if (Call->isDataOperand(U) && - !Call->doesNotCapture(Call->getDataOperandNo(U))) { - // The parameter is not marked 'nocapture' - captured. - if (Tracker->captured(U)) - return; - } - break; - } - case Instruction::Load: - // Volatile loads make the address observable. - if (cast<LoadInst>(I)->isVolatile()) - if (Tracker->captured(U)) - return; - break; - case Instruction::VAArg: - // "va-arg" from a pointer does not cause it to be captured. - break; - case Instruction::Store: - // Stored the pointer - conservatively assume it may be captured. - // Volatile stores make the address observable. - if (U->getOperandNo() == 0 || cast<StoreInst>(I)->isVolatile()) - if (Tracker->captured(U)) - return; - break; - case Instruction::AtomicRMW: { - // atomicrmw conceptually includes both a load and store from - // the same location. - // As with a store, the location being accessed is not captured, - // but the value being stored is. - // Volatile stores make the address observable. - auto *ARMWI = cast<AtomicRMWInst>(I); - if (U->getOperandNo() == 1 || ARMWI->isVolatile()) - if (Tracker->captured(U)) - return; - break; - } - case Instruction::AtomicCmpXchg: { - // cmpxchg conceptually includes both a load and store from - // the same location. - // As with a store, the location being accessed is not captured, - // but the value being stored is. - // Volatile stores make the address observable. - auto *ACXI = cast<AtomicCmpXchgInst>(I); - if (U->getOperandNo() == 1 || U->getOperandNo() == 2 || - ACXI->isVolatile()) - if (Tracker->captured(U)) - return; - break; - } - case Instruction::BitCast: - case Instruction::GetElementPtr: - case Instruction::PHI: - case Instruction::Select: - case Instruction::AddrSpaceCast: - // The original value is not captured via this if the new value isn't. - if (!AddUses(I)) - return; - break; - case Instruction::ICmp: { - unsigned Idx = U->getOperandNo(); - unsigned OtherIdx = 1 - Idx; - if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) { - // Don't count comparisons of a no-alias return value against null as - // captures. This allows us to ignore comparisons of malloc results - // with null, for example. - if (CPN->getType()->getAddressSpace() == 0) - if (isNoAliasCall(U->get()->stripPointerCasts())) - break; - if (!I->getFunction()->nullPointerIsDefined()) { - auto *O = I->getOperand(Idx)->stripPointerCastsSameRepresentation(); - // Comparing a dereferenceable_or_null pointer against null cannot - // lead to pointer escapes, because if it is not null it must be a - // valid (in-bounds) pointer. - if (Tracker->isDereferenceableOrNull(O, I->getModule()->getDataLayout())) - break; - } - } - // Comparison against value stored in global variable. Given the pointer - // does not escape, its value cannot be guessed and stored separately in a - // global variable. - auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIdx)); - if (LI && isa<GlobalVariable>(LI->getPointerOperand())) - break; - // Otherwise, be conservative. There are crazy ways to capture pointers - // using comparisons. + switch (DetermineUseCaptureKind(*U, IsDereferenceableOrNull)) { + case UseCaptureKind::NO_CAPTURE: + continue; + case UseCaptureKind::MAY_CAPTURE: if (Tracker->captured(U)) return; - break; - } - default: - // Something else - be conservative and say it is captured. - if (Tracker->captured(U)) + continue; + case UseCaptureKind::PASSTHROUGH: + if (!AddUses(U->getUser())) return; - break; + continue; } } diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index 5b951980a0aa..20b1df6e1495 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -18,9 +18,7 @@ using namespace llvm; -unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { - ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate() - : ICI->getPredicate(); +unsigned llvm::getICmpCode(CmpInst::Predicate Pred) { switch (Pred) { // False -> 0 case ICmpInst::ICMP_UGT: return 1; // 001 @@ -63,6 +61,18 @@ bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) { (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1)); } +Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, + CmpInst::Predicate &Pred) { + Pred = static_cast<FCmpInst::Predicate>(Code); + assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && + "Unexpected FCmp predicate!"); + if (Pred == FCmpInst::FCMP_FALSE) + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); + if (Pred == FCmpInst::FCMP_TRUE) + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); + return nullptr; +} + bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, Value *&X, APInt &Mask, bool LookThruTrunc) { diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp index 27c52506352f..6d9084215dee 100644 --- a/llvm/lib/Analysis/CodeMetrics.cpp +++ b/llvm/lib/Analysis/CodeMetrics.cpp @@ -15,7 +15,6 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Function.h" #include "llvm/Support/Debug.h" #include "llvm/Support/InstructionCost.h" @@ -118,13 +117,6 @@ void CodeMetrics::analyzeBasicBlock( const BasicBlock *BB, const TargetTransformInfo &TTI, const SmallPtrSetImpl<const Value *> &EphValues, bool PrepareForLTO) { ++NumBlocks; - // Use a proxy variable for NumInsts of type InstructionCost, so that it can - // use InstructionCost's arithmetic properties such as saturation when this - // feature is added to InstructionCost. - // When storing the value back to NumInsts, we can assume all costs are Valid - // because the IR should not contain any nodes that cannot be costed. If that - // happens the cost-model is broken. - InstructionCost NumInstsProxy = NumInsts; InstructionCost NumInstsBeforeThisBB = NumInsts; for (const Instruction &I : *BB) { // Skip ephemeral values. @@ -184,8 +176,7 @@ void CodeMetrics::analyzeBasicBlock( if (InvI->cannotDuplicate()) notDuplicatable = true; - NumInstsProxy += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize); - NumInsts = *NumInstsProxy.getValue(); + NumInsts += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize); } if (isa<ReturnInst>(BB->getTerminator())) @@ -205,6 +196,6 @@ void CodeMetrics::analyzeBasicBlock( notDuplicatable |= isa<IndirectBrInst>(BB->getTerminator()); // Remember NumInsts for this BB. - InstructionCost NumInstsThisBB = NumInstsProxy - NumInstsBeforeThisBB; - NumBBInsts[BB] = *NumInstsThisBB.getValue(); + InstructionCost NumInstsThisBB = NumInsts - NumInstsBeforeThisBB; + NumBBInsts[BB] = NumInstsThisBB; } diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 7cf69f613c66..a81041845052 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -57,7 +57,6 @@ #include <cerrno> #include <cfenv> #include <cmath> -#include <cstddef> #include <cstdint> using namespace llvm; @@ -92,7 +91,7 @@ static Constant *foldConstVectorToAPInt(APInt &Result, Type *DestTy, return ConstantExpr::getBitCast(C, DestTy); Result <<= BitShift; - Result |= ElementCI->getValue().zextOrSelf(Result.getBitWidth()); + Result |= ElementCI->getValue().zext(Result.getBitWidth()); } return nullptr; @@ -589,14 +588,17 @@ Constant *FoldReinterpretLoadFromConst(Constant *C, Type *LoadTy, if (BytesLoaded > 32 || BytesLoaded == 0) return nullptr; - int64_t InitializerSize = DL.getTypeAllocSize(C->getType()).getFixedSize(); - // If we're not accessing anything in this constant, the result is undefined. if (Offset <= -1 * static_cast<int64_t>(BytesLoaded)) return UndefValue::get(IntType); + // TODO: We should be able to support scalable types. + TypeSize InitializerSize = DL.getTypeAllocSize(C->getType()); + if (InitializerSize.isScalable()) + return nullptr; + // If we're not accessing anything in this constant, the result is undefined. - if (Offset >= InitializerSize) + if (Offset >= (int64_t)InitializerSize.getFixedValue()) return UndefValue::get(IntType); unsigned char RawBytes[32] = {0}; @@ -631,6 +633,39 @@ Constant *FoldReinterpretLoadFromConst(Constant *C, Type *LoadTy, return ConstantInt::get(IntType->getContext(), ResultVal); } +} // anonymous namespace + +// If GV is a constant with an initializer read its representation starting +// at Offset and return it as a constant array of unsigned char. Otherwise +// return null. +Constant *llvm::ReadByteArrayFromGlobal(const GlobalVariable *GV, + uint64_t Offset) { + if (!GV->isConstant() || !GV->hasDefinitiveInitializer()) + return nullptr; + + const DataLayout &DL = GV->getParent()->getDataLayout(); + Constant *Init = const_cast<Constant *>(GV->getInitializer()); + TypeSize InitSize = DL.getTypeAllocSize(Init->getType()); + if (InitSize < Offset) + return nullptr; + + uint64_t NBytes = InitSize - Offset; + if (NBytes > UINT16_MAX) + // Bail for large initializers in excess of 64K to avoid allocating + // too much memory. + // Offset is assumed to be less than or equal than InitSize (this + // is enforced in ReadDataFromGlobal). + return nullptr; + + SmallVector<unsigned char, 256> RawBytes(static_cast<size_t>(NBytes)); + unsigned char *CurPtr = RawBytes.data(); + + if (!ReadDataFromGlobal(Init, Offset, CurPtr, NBytes, DL)) + return nullptr; + + return ConstantDataArray::get(GV->getContext(), RawBytes); +} + /// If this Offset points exactly to the start of an aggregate element, return /// that element, otherwise return nullptr. Constant *getConstantAtOffset(Constant *Base, APInt Offset, @@ -659,8 +694,6 @@ Constant *getConstantAtOffset(Constant *Base, APInt Offset, return C; } -} // end anonymous namespace - Constant *llvm::ConstantFoldLoadFromConst(Constant *C, Type *Ty, const APInt &Offset, const DataLayout &DL) { @@ -864,21 +897,6 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, Type *IntIdxTy = DL.getIndexType(Ptr->getType()); - // If this is "gep i8* Ptr, (sub 0, V)", fold this as: - // "inttoptr (sub (ptrtoint Ptr), V)" - if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) { - auto *CE = dyn_cast<ConstantExpr>(Ops[1]); - assert((!CE || CE->getType() == IntIdxTy) && - "CastGEPIndices didn't canonicalize index types!"); - if (CE && CE->getOpcode() == Instruction::Sub && - CE->getOperand(0)->isNullValue()) { - Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType()); - Res = ConstantExpr::getSub(Res, CE->getOperand(1)); - Res = ConstantExpr::getIntToPtr(Res, ResTy); - return ConstantFoldConstant(Res, DL, TLI); - } - } - for (unsigned i = 1, e = Ops.size(); i != e; ++i) if (!isa<ConstantInt>(Ops[i])) return nullptr; @@ -1012,8 +1030,24 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, if (Instruction::isUnaryOp(Opcode)) return ConstantFoldUnaryOpOperand(Opcode, Ops[0], DL); - if (Instruction::isBinaryOp(Opcode)) + if (Instruction::isBinaryOp(Opcode)) { + switch (Opcode) { + default: + break; + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + // Handle floating point instructions separately to account for denormals + // TODO: If a constant expression is being folded rather than an + // instruction, denormals will not be flushed/treated as zero + if (const auto *I = dyn_cast<Instruction>(InstOrCE)) { + return ConstantFoldFPInstOperands(Opcode, Ops[0], Ops[1], DL, I); + } + } return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); + } if (Instruction::isCast(Opcode)) return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); @@ -1027,13 +1061,21 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, GEP->getInRangeIndex()); } - if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) + if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) { + if (CE->isCompare()) + return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1], + DL, TLI); return CE->getWithOperands(Ops); + } switch (Opcode) { default: return nullptr; case Instruction::ICmp: - case Instruction::FCmp: llvm_unreachable("Invalid for compares"); + case Instruction::FCmp: { + auto *C = cast<CmpInst>(InstOrCE); + return ConstantFoldCompareInstOperands(C->getPredicate(), Ops[0], Ops[1], + DL, TLI, C); + } case Instruction::Freeze: return isGuaranteedNotToBeUndefOrPoison(Ops[0]) ? Ops[0] : nullptr; case Instruction::Call: @@ -1048,13 +1090,22 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, case Instruction::ExtractElement: return ConstantExpr::getExtractElement(Ops[0], Ops[1]); case Instruction::ExtractValue: - return ConstantExpr::getExtractValue( + return ConstantFoldExtractValueInstruction( Ops[0], cast<ExtractValueInst>(InstOrCE)->getIndices()); case Instruction::InsertElement: return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); + case Instruction::InsertValue: + return ConstantFoldInsertValueInstruction( + Ops[0], Ops[1], cast<InsertValueInst>(InstOrCE)->getIndices()); case Instruction::ShuffleVector: return ConstantExpr::getShuffleVector( Ops[0], Ops[1], cast<ShuffleVectorInst>(InstOrCE)->getShuffleMask()); + case Instruction::Load: { + const auto *LI = dyn_cast<LoadInst>(InstOrCE); + if (LI->isVolatile()) + return nullptr; + return ConstantFoldLoadFromConstPtr(Ops[0], LI->getType(), DL); + } } } @@ -1091,13 +1142,8 @@ ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL, Ops.push_back(NewC); } - if (auto *CE = dyn_cast<ConstantExpr>(C)) { - if (CE->isCompare()) - return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1], - DL, TLI); - + if (auto *CE = dyn_cast<ConstantExpr>(C)) return ConstantFoldInstOperandsImpl(CE, CE->getOpcode(), Ops, DL, TLI); - } assert(isa<ConstantVector>(C)); return ConstantVector::get(Ops); @@ -1150,22 +1196,6 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, Ops.push_back(Op); } - if (const auto *CI = dyn_cast<CmpInst>(I)) - return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1], - DL, TLI); - - if (const auto *LI = dyn_cast<LoadInst>(I)) { - if (LI->isVolatile()) - return nullptr; - return ConstantFoldLoadFromConstPtr(Ops[0], LI->getType(), DL); - } - - if (auto *IVI = dyn_cast<InsertValueInst>(I)) - return ConstantExpr::getInsertValue(Ops[0], Ops[1], IVI->getIndices()); - - if (auto *EVI = dyn_cast<ExtractValueInst>(I)) - return ConstantExpr::getExtractValue(Ops[0], EVI->getIndices()); - return ConstantFoldInstOperands(I, Ops, DL, TLI); } @@ -1182,10 +1212,9 @@ Constant *llvm::ConstantFoldInstOperands(Instruction *I, return ConstantFoldInstOperandsImpl(I, I->getOpcode(), Ops, DL, TLI); } -Constant *llvm::ConstantFoldCompareInstOperands(unsigned IntPredicate, - Constant *Ops0, Constant *Ops1, - const DataLayout &DL, - const TargetLibraryInfo *TLI) { +Constant *llvm::ConstantFoldCompareInstOperands( + unsigned IntPredicate, Constant *Ops0, Constant *Ops1, const DataLayout &DL, + const TargetLibraryInfo *TLI, const Instruction *I) { CmpInst::Predicate Predicate = (CmpInst::Predicate)IntPredicate; // fold: icmp (inttoptr x), null -> icmp x, 0 // fold: icmp null, (inttoptr x) -> icmp 0, x @@ -1287,6 +1316,11 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned IntPredicate, return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI); } + // Flush any denormal constant float input according to denormal handling + // mode. + Ops0 = FlushFPConstant(Ops0, I, /* IsOutput */ false); + Ops1 = FlushFPConstant(Ops1, I, /* IsOutput */ false); + return ConstantExpr::getCompare(Predicate, Ops0, Ops1); } @@ -1308,6 +1342,63 @@ Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, return ConstantExpr::get(Opcode, LHS, RHS); } +Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *I, + bool IsOutput) { + if (!I || !I->getParent() || !I->getFunction()) + return Operand; + + ConstantFP *CFP = dyn_cast<ConstantFP>(Operand); + if (!CFP) + return Operand; + + const APFloat &APF = CFP->getValueAPF(); + Type *Ty = CFP->getType(); + DenormalMode DenormMode = + I->getFunction()->getDenormalMode(Ty->getFltSemantics()); + DenormalMode::DenormalModeKind Mode = + IsOutput ? DenormMode.Output : DenormMode.Input; + switch (Mode) { + default: + llvm_unreachable("unknown denormal mode"); + return Operand; + case DenormalMode::IEEE: + return Operand; + case DenormalMode::PreserveSign: + if (APF.isDenormal()) { + return ConstantFP::get( + Ty->getContext(), + APFloat::getZero(Ty->getFltSemantics(), APF.isNegative())); + } + return Operand; + case DenormalMode::PositiveZero: + if (APF.isDenormal()) { + return ConstantFP::get(Ty->getContext(), + APFloat::getZero(Ty->getFltSemantics(), false)); + } + return Operand; + } + return Operand; +} + +Constant *llvm::ConstantFoldFPInstOperands(unsigned Opcode, Constant *LHS, + Constant *RHS, const DataLayout &DL, + const Instruction *I) { + if (Instruction::isBinaryOp(Opcode)) { + // Flush denormal inputs if needed. + Constant *Op0 = FlushFPConstant(LHS, I, /* IsOutput */ false); + Constant *Op1 = FlushFPConstant(RHS, I, /* IsOutput */ false); + + // Calculate constant result. + Constant *C = ConstantFoldBinaryOpOperands(Opcode, Op0, Op1, DL); + + // Flush denormal output if needed. + return FlushFPConstant(C, I, /* IsOutput */ true); + } + // If instruction lacks a parent/function and the denormal mode cannot be + // determined, use the default (IEEE). + return ConstantFoldBinaryOpOperands(Opcode, LHS, RHS, DL); +} + Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, Type *DestTy, const DataLayout &DL) { assert(Instruction::isCast(Opcode)); @@ -1334,6 +1425,19 @@ Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, DL, BaseOffset, /*AllowNonInbounds=*/true)); if (Base->isNullValue()) { FoldedValue = ConstantInt::get(CE->getContext(), BaseOffset); + } else { + // ptrtoint (gep i8, Ptr, (sub 0, V)) -> sub (ptrtoint Ptr), V + if (GEP->getNumIndices() == 1 && + GEP->getSourceElementType()->isIntegerTy(8)) { + auto *Ptr = cast<Constant>(GEP->getPointerOperand()); + auto *Sub = dyn_cast<ConstantExpr>(GEP->getOperand(1)); + Type *IntIdxTy = DL.getIndexType(Ptr->getType()); + if (Sub && Sub->getType() == IntIdxTy && + Sub->getOpcode() == Instruction::Sub && + Sub->getOperand(0)->isNullValue()) + FoldedValue = ConstantExpr::getSub( + ConstantExpr::getPtrToInt(Ptr, IntIdxTy), Sub->getOperand(1)); + } } } if (FoldedValue) { @@ -1386,6 +1490,8 @@ Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { if (Call->isNoBuiltin()) return false; + if (Call->getFunctionType() != F->getFunctionType()) + return false; switch (F->getIntrinsicID()) { // Operations that do not operate floating-point numbers and do not depend on // FP environment can be folded even in strictfp functions. @@ -1527,6 +1633,8 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::experimental_constrained_trunc: case Intrinsic::experimental_constrained_nearbyint: case Intrinsic::experimental_constrained_rint: + case Intrinsic::experimental_constrained_fcmp: + case Intrinsic::experimental_constrained_fcmps: return true; default: return false; @@ -1798,12 +1906,12 @@ static bool mayFoldConstrained(ConstrainedFPIntrinsic *CI, // If evaluation raised FP exception, the result can depend on rounding // mode. If the latter is unknown, folding is not possible. - if (!ORM || *ORM == RoundingMode::Dynamic) + if (ORM && *ORM == RoundingMode::Dynamic) return false; // If FP exceptions are ignored, fold the call, even if such exception is // raised. - if (!EB || *EB != fp::ExceptionBehavior::ebStrict) + if (EB && *EB != fp::ExceptionBehavior::ebStrict) return true; // Leave the calculation for runtime so that exception flags be correctly set @@ -1979,7 +2087,7 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::experimental_constrained_rint: { auto CI = cast<ConstrainedFPIntrinsic>(Call); RM = CI->getRoundingMode(); - if (!RM || RM.getValue() == RoundingMode::Dynamic) + if (!RM || *RM == RoundingMode::Dynamic) return nullptr; break; } @@ -2301,6 +2409,24 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, return nullptr; } +static Constant *evaluateCompare(const APFloat &Op1, const APFloat &Op2, + const ConstrainedFPIntrinsic *Call) { + APFloat::opStatus St = APFloat::opOK; + auto *FCmp = cast<ConstrainedFPCmpIntrinsic>(Call); + FCmpInst::Predicate Cond = FCmp->getPredicate(); + if (FCmp->isSignaling()) { + if (Op1.isNaN() || Op2.isNaN()) + St = APFloat::opInvalidOp; + } else { + if (Op1.isSignaling() || Op2.isSignaling()) + St = APFloat::opInvalidOp; + } + bool Result = FCmpInst::compare(Op1, Op2, Cond); + if (mayFoldConstrained(const_cast<ConstrainedFPCmpIntrinsic *>(FCmp), St)) + return ConstantInt::get(Call->getType()->getScalarType(), Result); + return nullptr; +} + static Constant *ConstantFoldScalarCall2(StringRef Name, Intrinsic::ID IntrinsicID, Type *Ty, @@ -2329,8 +2455,6 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, } if (const auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { - if (!Ty->isFloatingPointTy()) - return nullptr; const APFloat &Op1V = Op1->getValueAPF(); if (const auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { @@ -2360,6 +2484,9 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, case Intrinsic::experimental_constrained_frem: St = Res.mod(Op2V); break; + case Intrinsic::experimental_constrained_fcmp: + case Intrinsic::experimental_constrained_fcmps: + return evaluateCompare(Op1V, Op2V, ConstrIntr); } if (mayFoldConstrained(const_cast<ConstrainedFPIntrinsic *>(ConstrIntr), St)) @@ -2484,6 +2611,11 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, case Intrinsic::smin: case Intrinsic::umax: case Intrinsic::umin: + // This is the same as for binary ops - poison propagates. + // TODO: Poison handling should be consolidated. + if (isa<PoisonValue>(Operands[0]) || isa<PoisonValue>(Operands[1])) + return PoisonValue::get(Ty); + if (!C0 && !C1) return UndefValue::get(Ty); if (!C0 || !C1) @@ -2550,6 +2682,11 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, } case Intrinsic::uadd_sat: case Intrinsic::sadd_sat: + // This is the same as for binary ops - poison propagates. + // TODO: Poison handling should be consolidated. + if (isa<PoisonValue>(Operands[0]) || isa<PoisonValue>(Operands[1])) + return PoisonValue::get(Ty); + if (!C0 && !C1) return UndefValue::get(Ty); if (!C0 || !C1) @@ -2560,6 +2697,11 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, return ConstantInt::get(Ty, C0->sadd_sat(*C1)); case Intrinsic::usub_sat: case Intrinsic::ssub_sat: + // This is the same as for binary ops - poison propagates. + // TODO: Poison handling should be consolidated. + if (isa<PoisonValue>(Operands[0]) || isa<PoisonValue>(Operands[1])) + return PoisonValue::get(Ty); + if (!C0 && !C1) return UndefValue::get(Ty); if (!C0 || !C1) @@ -2840,11 +2982,11 @@ static Constant *ConstantFoldScalarCall3(StringRef Name, unsigned Width = C0->getBitWidth(); assert(Scale < Width && "Illegal scale."); unsigned ExtendedWidth = Width * 2; - APInt Product = (C0->sextOrSelf(ExtendedWidth) * - C1->sextOrSelf(ExtendedWidth)).ashr(Scale); + APInt Product = + (C0->sext(ExtendedWidth) * C1->sext(ExtendedWidth)).ashr(Scale); if (IntrinsicID == Intrinsic::smul_fix_sat) { - APInt Max = APInt::getSignedMaxValue(Width).sextOrSelf(ExtendedWidth); - APInt Min = APInt::getSignedMinValue(Width).sextOrSelf(ExtendedWidth); + APInt Max = APInt::getSignedMaxValue(Width).sext(ExtendedWidth); + APInt Min = APInt::getSignedMinValue(Width).sext(ExtendedWidth); Product = APIntOps::smin(Product, Max); Product = APIntOps::smax(Product, Min); } @@ -2998,7 +3140,7 @@ static Constant *ConstantFoldFixedVectorCall( // Gather a column of constants. for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) { // Some intrinsics use a scalar type for certain arguments. - if (hasVectorInstrinsicScalarOpd(IntrinsicID, J)) { + if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) { Lane[J] = Operands[J]; continue; } diff --git a/llvm/lib/Analysis/ConstraintSystem.cpp b/llvm/lib/Analysis/ConstraintSystem.cpp index 773f71ada0ee..dc774728ab3d 100644 --- a/llvm/lib/Analysis/ConstraintSystem.cpp +++ b/llvm/lib/Analysis/ConstraintSystem.cpp @@ -12,7 +12,6 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" -#include <algorithm> #include <string> using namespace llvm; diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp index 326bacad01fe..52e424ae324b 100644 --- a/llvm/lib/Analysis/CostModel.cpp +++ b/llvm/lib/Analysis/CostModel.cpp @@ -17,7 +17,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/CostModel.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" @@ -25,7 +24,6 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -119,7 +117,7 @@ void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { PreservedAnalyses CostModelPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); - OS << "Cost Model for function '" << F.getName() << "'\n"; + OS << "Printing analysis 'Cost Model Analysis' for function '" << F.getName() << "':\n"; for (BasicBlock &B : F) { for (Instruction &Inst : B) { // TODO: Use a pass parameter instead of cl::opt CostKind to determine diff --git a/llvm/lib/Analysis/CycleAnalysis.cpp b/llvm/lib/Analysis/CycleAnalysis.cpp index 09c7ee67e05c..17998123fce7 100644 --- a/llvm/lib/Analysis/CycleAnalysis.cpp +++ b/llvm/lib/Analysis/CycleAnalysis.cpp @@ -8,11 +8,15 @@ #include "llvm/Analysis/CycleAnalysis.h" #include "llvm/ADT/GenericCycleImpl.h" -#include "llvm/IR/CFG.h" +#include "llvm/IR/CFG.h" // for successors found by ADL in GenericCycleImpl.h #include "llvm/InitializePasses.h" using namespace llvm; +namespace llvm { +class Module; +} + template class llvm::GenericCycleInfo<SSAContext>; template class llvm::GenericCycle<SSAContext>; diff --git a/llvm/lib/Analysis/DDG.cpp b/llvm/lib/Analysis/DDG.cpp index 7e1357959a3f..998c888dd2d9 100644 --- a/llvm/lib/Analysis/DDG.cpp +++ b/llvm/lib/Analysis/DDG.cpp @@ -17,13 +17,12 @@ using namespace llvm; static cl::opt<bool> SimplifyDDG( - "ddg-simplify", cl::init(true), cl::Hidden, cl::ZeroOrMore, + "ddg-simplify", cl::init(true), cl::Hidden, cl::desc( "Simplify DDG by merging nodes that have less interesting edges.")); -static cl::opt<bool> - CreatePiBlocks("ddg-pi-blocks", cl::init(true), cl::Hidden, cl::ZeroOrMore, - cl::desc("Create pi-block nodes.")); +static cl::opt<bool> CreatePiBlocks("ddg-pi-blocks", cl::init(true), cl::Hidden, + cl::desc("Create pi-block nodes.")); #define DEBUG_TYPE "ddg" @@ -34,7 +33,7 @@ template class llvm::DirectedGraph<DDGNode, DDGEdge>; //===--------------------------------------------------------------------===// // DDGNode implementation //===--------------------------------------------------------------------===// -DDGNode::~DDGNode() {} +DDGNode::~DDGNode() = default; bool DDGNode::collectInstructions( llvm::function_ref<bool(Instruction *)> const &Pred, diff --git a/llvm/lib/Analysis/DDGPrinter.cpp b/llvm/lib/Analysis/DDGPrinter.cpp index 0d5a936723ce..6b5acd204ec7 100644 --- a/llvm/lib/Analysis/DDGPrinter.cpp +++ b/llvm/lib/Analysis/DDGPrinter.cpp @@ -18,8 +18,8 @@ using namespace llvm; -static cl::opt<bool> DotOnly("dot-ddg-only", cl::init(false), cl::Hidden, - cl::ZeroOrMore, cl::desc("simple ddg dot graph")); +static cl::opt<bool> DotOnly("dot-ddg-only", cl::Hidden, + cl::desc("simple ddg dot graph")); static cl::opt<std::string> DDGDotFilenamePrefix( "dot-ddg-filename-prefix", cl::init("ddg"), cl::Hidden, cl::desc("The prefix used for the DDG dot file names.")); diff --git a/llvm/lib/Analysis/Delinearization.cpp b/llvm/lib/Analysis/Delinearization.cpp index 670532c6d9a8..c36e1d922915 100644 --- a/llvm/lib/Analysis/Delinearization.cpp +++ b/llvm/lib/Analysis/Delinearization.cpp @@ -24,9 +24,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" @@ -523,6 +521,44 @@ bool llvm::getIndexExpressionsFromGEP(ScalarEvolution &SE, return !Subscripts.empty(); } +bool llvm::tryDelinearizeFixedSizeImpl( + ScalarEvolution *SE, Instruction *Inst, const SCEV *AccessFn, + SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<int> &Sizes) { + Value *SrcPtr = getLoadStorePointerOperand(Inst); + + // Check the simple case where the array dimensions are fixed size. + auto *SrcGEP = dyn_cast<GetElementPtrInst>(SrcPtr); + if (!SrcGEP) + return false; + + getIndexExpressionsFromGEP(*SE, SrcGEP, Subscripts, Sizes); + + // Check that the two size arrays are non-empty and equal in length and + // value. + // TODO: it would be better to let the caller to clear Subscripts, similar + // to how we handle Sizes. + if (Sizes.empty() || Subscripts.size() <= 1) { + Subscripts.clear(); + return false; + } + + // Check that for identical base pointers we do not miss index offsets + // that have been added before this GEP is applied. + Value *SrcBasePtr = SrcGEP->getOperand(0)->stripPointerCasts(); + const SCEVUnknown *SrcBase = + dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn)); + if (!SrcBase || SrcBasePtr != SrcBase->getValue()) { + Subscripts.clear(); + return false; + } + + assert(Subscripts.size() == Sizes.size() + 1 && + "Expected equal number of entries in the list of size and " + "subscript."); + + return true; +} + namespace { class Delinearization : public FunctionPass { diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp index 117b12fc0701..e01ed48be376 100644 --- a/llvm/lib/Analysis/DemandedBits.cpp +++ b/llvm/lib/Analysis/DemandedBits.cpp @@ -21,19 +21,13 @@ #include "llvm/Analysis/DemandedBits.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index f827f74d5367..3d2d84ecadb4 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -50,7 +50,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/DependenceAnalysis.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/Delinearization.h" @@ -58,10 +57,8 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Config/llvm-config.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -109,11 +106,10 @@ STATISTIC(BanerjeeIndependence, "Banerjee independence"); STATISTIC(BanerjeeSuccesses, "Banerjee successes"); static cl::opt<bool> - Delinearize("da-delinearize", cl::init(true), cl::Hidden, cl::ZeroOrMore, + Delinearize("da-delinearize", cl::init(true), cl::Hidden, cl::desc("Try to delinearize array references.")); static cl::opt<bool> DisableDelinearizationChecks( - "da-disable-delinearization-checks", cl::init(false), cl::Hidden, - cl::ZeroOrMore, + "da-disable-delinearization-checks", cl::Hidden, cl::desc( "Disable checks that try to statically verify validity of " "delinearized subscripts. Enabling this option may result in incorrect " @@ -121,7 +117,7 @@ static cl::opt<bool> DisableDelinearizationChecks( "dimension to underflow or overflow into another dimension.")); static cl::opt<unsigned> MIVMaxLevelThreshold( - "da-miv-max-level-threshold", cl::init(7), cl::Hidden, cl::ZeroOrMore, + "da-miv-max-level-threshold", cl::init(7), cl::Hidden, cl::desc("Maximum depth allowed for the recursive algorithm used to " "explore MIV direction vectors.")); @@ -787,6 +783,8 @@ unsigned DependenceInfo::mapSrcLoop(const Loop *SrcLoop) const { unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const { unsigned D = DstLoop->getLoopDepth(); if (D > CommonLevels) + // This tries to make sure that we assign unique numbers to src and dst when + // the memory accesses reside in different loops that have the same depth. return D - CommonLevels + SrcLevels; else return D; @@ -796,10 +794,16 @@ unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const { // Returns true if Expression is loop invariant in LoopNest. bool DependenceInfo::isLoopInvariant(const SCEV *Expression, const Loop *LoopNest) const { + // Unlike ScalarEvolution::isLoopInvariant() we consider an access outside of + // any loop as invariant, because we only consier expression evaluation at a + // specific position (where the array access takes place), and not across the + // entire function. if (!LoopNest) return true; - return SE->isLoopInvariant(Expression, LoopNest) && - isLoopInvariant(Expression, LoopNest->getParentLoop()); + + // If the expression is invariant in the outermost loop of the loop nest, it + // is invariant anywhere in the loop nest. + return SE->isLoopInvariant(Expression, LoopNest->getOutermostLoop()); } @@ -890,13 +894,25 @@ void DependenceInfo::removeMatchingExtensions(Subscript *Pair) { } } -// Examine the scev and return true iff it's linear. +// Examine the scev and return true iff it's affine. // Collect any loops mentioned in the set of "Loops". bool DependenceInfo::checkSubscript(const SCEV *Expr, const Loop *LoopNest, SmallBitVector &Loops, bool IsSrc) { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); if (!AddRec) return isLoopInvariant(Expr, LoopNest); + + // The AddRec must depend on one of the containing loops. Otherwise, + // mapSrcLoop and mapDstLoop return indices outside the intended range. This + // can happen when a subscript in one loop references an IV from a sibling + // loop that could not be replaced with a concrete exit value by + // getSCEVAtScope. + const Loop *L = LoopNest; + while (L && AddRec->getLoop() != L) + L = L->getParentLoop(); + if (!L) + return false; + const SCEV *Start = AddRec->getStart(); const SCEV *Step = AddRec->getStepRecurrence(*SE); const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop()); @@ -3318,59 +3334,45 @@ bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, return true; } +/// Try to delinearize \p SrcAccessFn and \p DstAccessFn if the underlying +/// arrays accessed are fixed-size arrays. Return true if delinearization was +/// successful. bool DependenceInfo::tryDelinearizeFixedSize( Instruction *Src, Instruction *Dst, const SCEV *SrcAccessFn, const SCEV *DstAccessFn, SmallVectorImpl<const SCEV *> &SrcSubscripts, SmallVectorImpl<const SCEV *> &DstSubscripts) { - - Value *SrcPtr = getLoadStorePointerOperand(Src); - Value *DstPtr = getLoadStorePointerOperand(Dst); - const SCEVUnknown *SrcBase = - dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcAccessFn)); - const SCEVUnknown *DstBase = - dyn_cast<SCEVUnknown>(SE->getPointerBase(DstAccessFn)); - assert(SrcBase && DstBase && SrcBase == DstBase && - "expected src and dst scev unknowns to be equal"); - - // Check the simple case where the array dimensions are fixed size. - auto *SrcGEP = dyn_cast<GetElementPtrInst>(SrcPtr); - auto *DstGEP = dyn_cast<GetElementPtrInst>(DstPtr); - if (!SrcGEP || !DstGEP) + LLVM_DEBUG({ + const SCEVUnknown *SrcBase = + dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcAccessFn)); + const SCEVUnknown *DstBase = + dyn_cast<SCEVUnknown>(SE->getPointerBase(DstAccessFn)); + assert(SrcBase && DstBase && SrcBase == DstBase && + "expected src and dst scev unknowns to be equal"); + }); + + SmallVector<int, 4> SrcSizes; + SmallVector<int, 4> DstSizes; + if (!tryDelinearizeFixedSizeImpl(SE, Src, SrcAccessFn, SrcSubscripts, + SrcSizes) || + !tryDelinearizeFixedSizeImpl(SE, Dst, DstAccessFn, DstSubscripts, + DstSizes)) return false; - SmallVector<int, 4> SrcSizes, DstSizes; - getIndexExpressionsFromGEP(*SE, SrcGEP, SrcSubscripts, SrcSizes); - getIndexExpressionsFromGEP(*SE, DstGEP, DstSubscripts, DstSizes); - // Check that the two size arrays are non-empty and equal in length and // value. - if (SrcSizes.empty() || SrcSubscripts.size() <= 1 || - SrcSizes.size() != DstSizes.size() || + if (SrcSizes.size() != DstSizes.size() || !std::equal(SrcSizes.begin(), SrcSizes.end(), DstSizes.begin())) { SrcSubscripts.clear(); DstSubscripts.clear(); return false; } - Value *SrcBasePtr = SrcGEP->getOperand(0); - Value *DstBasePtr = DstGEP->getOperand(0); - while (auto *PCast = dyn_cast<BitCastInst>(SrcBasePtr)) - SrcBasePtr = PCast->getOperand(0); - while (auto *PCast = dyn_cast<BitCastInst>(DstBasePtr)) - DstBasePtr = PCast->getOperand(0); - - // Check that for identical base pointers we do not miss index offsets - // that have been added before this GEP is applied. - if (SrcBasePtr != SrcBase->getValue() || DstBasePtr != DstBase->getValue()) { - SrcSubscripts.clear(); - DstSubscripts.clear(); - return false; - } - assert(SrcSubscripts.size() == DstSubscripts.size() && - SrcSubscripts.size() == SrcSizes.size() + 1 && - "Expected equal number of entries in the list of sizes and " - "subscripts."); + "Expected equal number of entries in the list of SrcSubscripts and " + "DstSubscripts."); + + Value *SrcPtr = getLoadStorePointerOperand(Src); + Value *DstPtr = getLoadStorePointerOperand(Dst); // In general we cannot safely assume that the subscripts recovered from GEPs // are in the range of values defined for their corresponding array @@ -3406,8 +3408,8 @@ bool DependenceInfo::tryDelinearizeFixedSize( } LLVM_DEBUG({ dbgs() << "Delinearized subscripts of fixed-size array\n" - << "SrcGEP:" << *SrcGEP << "\n" - << "DstGEP:" << *DstGEP << "\n"; + << "SrcGEP:" << *SrcPtr << "\n" + << "DstGEP:" << *DstPtr << "\n"; }); return true; } diff --git a/llvm/lib/Analysis/DependenceGraphBuilder.cpp b/llvm/lib/Analysis/DependenceGraphBuilder.cpp index 6b90db4bafe1..7ee2adf49ebb 100644 --- a/llvm/lib/Analysis/DependenceGraphBuilder.cpp +++ b/llvm/lib/Analysis/DependenceGraphBuilder.cpp @@ -12,6 +12,7 @@ #include "llvm/Analysis/DependenceGraphBuilder.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/EnumeratedArray.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DDG.h" diff --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp index 4a792fce51d1..79ea160afc22 100644 --- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp +++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// #include "llvm/Config/config.h" -#include "llvm/Support/Casting.h" #if defined(LLVM_HAVE_TF_API) #include "llvm/ADT/BitVector.h" @@ -273,8 +272,8 @@ static const std::vector<TensorSpec> TrainingOnlyFeatures{ static const std::vector<TensorSpec> getInputFeatures() { std::vector<TensorSpec> InputSpecs; for (size_t I = 0; I < NumberOfFeatures; ++I) - InputSpecs.push_back( - TensorSpec::createSpec<int64_t>(TFFeedPrefix + FeatureNameMap[I], {1})); + InputSpecs.push_back(TensorSpec::createSpec<int64_t>( + TFFeedPrefix + FeatureMap[I].name(), FeatureMap[I].shape())); append_range(InputSpecs, TrainingOnlyFeatures); return InputSpecs; } @@ -290,8 +289,7 @@ TrainingLogger::TrainingLogger(StringRef LogFileName, std::vector<LoggedFeatureSpec> FT; for (size_t I = 0; I < NumberOfFeatures; ++I) - FT.push_back( - {TensorSpec::createSpec<int64_t>(FeatureNameMap.at(I), {1}), None}); + FT.push_back({FeatureMap.at(I), None}); if (MUTR && MUTR->outputLoggedFeatureSpecs().size() > 1) append_range(FT, drop_begin(MUTR->outputLoggedFeatureSpecs())); diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp index 39e80c2ad51c..1a4b09e0cac2 100644 --- a/llvm/lib/Analysis/DivergenceAnalysis.cpp +++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -73,15 +73,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/Passes.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/llvm/lib/Analysis/DomPrinter.cpp b/llvm/lib/Analysis/DomPrinter.cpp index 6088de53028d..e9f5103e1276 100644 --- a/llvm/lib/Analysis/DomPrinter.cpp +++ b/llvm/lib/Analysis/DomPrinter.cpp @@ -24,74 +24,6 @@ using namespace llvm; -namespace llvm { -template<> -struct DOTGraphTraits<DomTreeNode*> : public DefaultDOTGraphTraits { - - DOTGraphTraits (bool isSimple=false) - : DefaultDOTGraphTraits(isSimple) {} - - std::string getNodeLabel(DomTreeNode *Node, DomTreeNode *Graph) { - - BasicBlock *BB = Node->getBlock(); - - if (!BB) - return "Post dominance root node"; - - - if (isSimple()) - return DOTGraphTraits<DOTFuncInfo *> - ::getSimpleNodeLabel(BB, nullptr); - else - return DOTGraphTraits<DOTFuncInfo *> - ::getCompleteNodeLabel(BB, nullptr); - } -}; - -template<> -struct DOTGraphTraits<DominatorTree*> : public DOTGraphTraits<DomTreeNode*> { - - DOTGraphTraits (bool isSimple=false) - : DOTGraphTraits<DomTreeNode*>(isSimple) {} - - static std::string getGraphName(DominatorTree *DT) { - return "Dominator tree"; - } - - std::string getNodeLabel(DomTreeNode *Node, DominatorTree *G) { - return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); - } -}; - -template<> -struct DOTGraphTraits<PostDominatorTree*> - : public DOTGraphTraits<DomTreeNode*> { - - DOTGraphTraits (bool isSimple=false) - : DOTGraphTraits<DomTreeNode*>(isSimple) {} - - static std::string getGraphName(PostDominatorTree *DT) { - return "Post dominator tree"; - } - - std::string getNodeLabel(DomTreeNode *Node, PostDominatorTree *G ) { - return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); - } -}; -} - -PreservedAnalyses DomTreePrinterPass::run(Function &F, - FunctionAnalysisManager &AM) { - WriteDOTGraphToFile(F, &AM.getResult<DominatorTreeAnalysis>(F), "dom", false); - return PreservedAnalyses::all(); -} - -PreservedAnalyses DomTreeOnlyPrinterPass::run(Function &F, - FunctionAnalysisManager &AM) { - WriteDOTGraphToFile(F, &AM.getResult<DominatorTreeAnalysis>(F), "domonly", - true); - return PreservedAnalyses::all(); -} void DominatorTree::viewGraph(const Twine &Name, const Twine &Title) { #ifndef NDEBUG @@ -110,166 +42,167 @@ void DominatorTree::viewGraph() { } namespace { -struct DominatorTreeWrapperPassAnalysisGraphTraits { +struct LegacyDominatorTreeWrapperPassAnalysisGraphTraits { static DominatorTree *getGraph(DominatorTreeWrapperPass *DTWP) { return &DTWP->getDomTree(); } }; -struct DomViewer : public DOTGraphTraitsViewer< - DominatorTreeWrapperPass, false, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits> { +struct DomViewerWrapperPass + : public DOTGraphTraitsViewerWrapperPass< + DominatorTreeWrapperPass, false, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - DomViewer() - : DOTGraphTraitsViewer<DominatorTreeWrapperPass, false, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits>( - "dom", ID) { - initializeDomViewerPass(*PassRegistry::getPassRegistry()); + DomViewerWrapperPass() + : DOTGraphTraitsViewerWrapperPass< + DominatorTreeWrapperPass, false, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits>("dom", ID) { + initializeDomViewerWrapperPassPass(*PassRegistry::getPassRegistry()); } }; -struct DomOnlyViewer : public DOTGraphTraitsViewer< - DominatorTreeWrapperPass, true, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits> { +struct DomOnlyViewerWrapperPass + : public DOTGraphTraitsViewerWrapperPass< + DominatorTreeWrapperPass, true, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - DomOnlyViewer() - : DOTGraphTraitsViewer<DominatorTreeWrapperPass, true, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits>( - "domonly", ID) { - initializeDomOnlyViewerPass(*PassRegistry::getPassRegistry()); + DomOnlyViewerWrapperPass() + : DOTGraphTraitsViewerWrapperPass< + DominatorTreeWrapperPass, true, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits>("domonly", ID) { + initializeDomOnlyViewerWrapperPassPass(*PassRegistry::getPassRegistry()); } }; -struct PostDominatorTreeWrapperPassAnalysisGraphTraits { +struct LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits { static PostDominatorTree *getGraph(PostDominatorTreeWrapperPass *PDTWP) { return &PDTWP->getPostDomTree(); } }; -struct PostDomViewer : public DOTGraphTraitsViewer< - PostDominatorTreeWrapperPass, false, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits> { +struct PostDomViewerWrapperPass + : public DOTGraphTraitsViewerWrapperPass< + PostDominatorTreeWrapperPass, false, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - PostDomViewer() : - DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, false, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits>( - "postdom", ID){ - initializePostDomViewerPass(*PassRegistry::getPassRegistry()); - } + PostDomViewerWrapperPass() + : DOTGraphTraitsViewerWrapperPass< + PostDominatorTreeWrapperPass, false, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits>("postdom", + ID) { + initializePostDomViewerWrapperPassPass(*PassRegistry::getPassRegistry()); + } }; -struct PostDomOnlyViewer : public DOTGraphTraitsViewer< - PostDominatorTreeWrapperPass, true, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits> { +struct PostDomOnlyViewerWrapperPass + : public DOTGraphTraitsViewerWrapperPass< + PostDominatorTreeWrapperPass, true, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - PostDomOnlyViewer() : - DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, true, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits>( - "postdomonly", ID){ - initializePostDomOnlyViewerPass(*PassRegistry::getPassRegistry()); - } + PostDomOnlyViewerWrapperPass() + : DOTGraphTraitsViewerWrapperPass< + PostDominatorTreeWrapperPass, true, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID) { + initializePostDomOnlyViewerWrapperPassPass( + *PassRegistry::getPassRegistry()); + } }; } // end anonymous namespace -char DomViewer::ID = 0; -INITIALIZE_PASS(DomViewer, "view-dom", +char DomViewerWrapperPass::ID = 0; +INITIALIZE_PASS(DomViewerWrapperPass, "view-dom", "View dominance tree of function", false, false) -char DomOnlyViewer::ID = 0; -INITIALIZE_PASS(DomOnlyViewer, "view-dom-only", +char DomOnlyViewerWrapperPass::ID = 0; +INITIALIZE_PASS(DomOnlyViewerWrapperPass, "view-dom-only", "View dominance tree of function (with no function bodies)", false, false) -char PostDomViewer::ID = 0; -INITIALIZE_PASS(PostDomViewer, "view-postdom", +char PostDomViewerWrapperPass::ID = 0; +INITIALIZE_PASS(PostDomViewerWrapperPass, "view-postdom", "View postdominance tree of function", false, false) -char PostDomOnlyViewer::ID = 0; -INITIALIZE_PASS(PostDomOnlyViewer, "view-postdom-only", +char PostDomOnlyViewerWrapperPass::ID = 0; +INITIALIZE_PASS(PostDomOnlyViewerWrapperPass, "view-postdom-only", "View postdominance tree of function " "(with no function bodies)", false, false) namespace { -struct DomPrinter : public DOTGraphTraitsPrinter< - DominatorTreeWrapperPass, false, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits> { +struct DomPrinterWrapperPass + : public DOTGraphTraitsPrinterWrapperPass< + DominatorTreeWrapperPass, false, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - DomPrinter() - : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, false, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits>( - "dom", ID) { - initializeDomPrinterPass(*PassRegistry::getPassRegistry()); + DomPrinterWrapperPass() + : DOTGraphTraitsPrinterWrapperPass< + DominatorTreeWrapperPass, false, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits>("dom", ID) { + initializeDomPrinterWrapperPassPass(*PassRegistry::getPassRegistry()); } }; -struct DomOnlyPrinter : public DOTGraphTraitsPrinter< - DominatorTreeWrapperPass, true, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits> { +struct DomOnlyPrinterWrapperPass + : public DOTGraphTraitsPrinterWrapperPass< + DominatorTreeWrapperPass, true, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - DomOnlyPrinter() - : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, true, DominatorTree *, - DominatorTreeWrapperPassAnalysisGraphTraits>( - "domonly", ID) { - initializeDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); + DomOnlyPrinterWrapperPass() + : DOTGraphTraitsPrinterWrapperPass< + DominatorTreeWrapperPass, true, DominatorTree *, + LegacyDominatorTreeWrapperPassAnalysisGraphTraits>("domonly", ID) { + initializeDomOnlyPrinterWrapperPassPass(*PassRegistry::getPassRegistry()); } }; -struct PostDomPrinter - : public DOTGraphTraitsPrinter< - PostDominatorTreeWrapperPass, false, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits> { +struct PostDomPrinterWrapperPass + : public DOTGraphTraitsPrinterWrapperPass< + PostDominatorTreeWrapperPass, false, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - PostDomPrinter() : - DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, false, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits>( - "postdom", ID) { - initializePostDomPrinterPass(*PassRegistry::getPassRegistry()); - } + PostDomPrinterWrapperPass() + : DOTGraphTraitsPrinterWrapperPass< + PostDominatorTreeWrapperPass, false, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits>("postdom", + ID) { + initializePostDomPrinterWrapperPassPass(*PassRegistry::getPassRegistry()); + } }; -struct PostDomOnlyPrinter - : public DOTGraphTraitsPrinter< - PostDominatorTreeWrapperPass, true, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits> { +struct PostDomOnlyPrinterWrapperPass + : public DOTGraphTraitsPrinterWrapperPass< + PostDominatorTreeWrapperPass, true, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; - PostDomOnlyPrinter() : - DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, true, - PostDominatorTree *, - PostDominatorTreeWrapperPassAnalysisGraphTraits>( - "postdomonly", ID) { - initializePostDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); - } + PostDomOnlyPrinterWrapperPass() + : DOTGraphTraitsPrinterWrapperPass< + PostDominatorTreeWrapperPass, true, PostDominatorTree *, + LegacyPostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID) { + initializePostDomOnlyPrinterWrapperPassPass( + *PassRegistry::getPassRegistry()); + } }; } // end anonymous namespace +char DomPrinterWrapperPass::ID = 0; +INITIALIZE_PASS(DomPrinterWrapperPass, "dot-dom", + "Print dominance tree of function to 'dot' file", false, false) - -char DomPrinter::ID = 0; -INITIALIZE_PASS(DomPrinter, "dot-dom", - "Print dominance tree of function to 'dot' file", - false, false) - -char DomOnlyPrinter::ID = 0; -INITIALIZE_PASS(DomOnlyPrinter, "dot-dom-only", +char DomOnlyPrinterWrapperPass::ID = 0; +INITIALIZE_PASS(DomOnlyPrinterWrapperPass, "dot-dom-only", "Print dominance tree of function to 'dot' file " "(with no function bodies)", false, false) -char PostDomPrinter::ID = 0; -INITIALIZE_PASS(PostDomPrinter, "dot-postdom", - "Print postdominance tree of function to 'dot' file", - false, false) +char PostDomPrinterWrapperPass::ID = 0; +INITIALIZE_PASS(PostDomPrinterWrapperPass, "dot-postdom", + "Print postdominance tree of function to 'dot' file", false, + false) -char PostDomOnlyPrinter::ID = 0; -INITIALIZE_PASS(PostDomOnlyPrinter, "dot-postdom-only", +char PostDomOnlyPrinterWrapperPass::ID = 0; +INITIALIZE_PASS(PostDomOnlyPrinterWrapperPass, "dot-postdom-only", "Print postdominance tree of function to 'dot' file " "(with no function bodies)", false, false) @@ -278,34 +211,34 @@ INITIALIZE_PASS(PostDomOnlyPrinter, "dot-postdom-only", // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by // the link time optimization. -FunctionPass *llvm::createDomPrinterPass() { - return new DomPrinter(); +FunctionPass *llvm::createDomPrinterWrapperPassPass() { + return new DomPrinterWrapperPass(); } -FunctionPass *llvm::createDomOnlyPrinterPass() { - return new DomOnlyPrinter(); +FunctionPass *llvm::createDomOnlyPrinterWrapperPassPass() { + return new DomOnlyPrinterWrapperPass(); } -FunctionPass *llvm::createDomViewerPass() { - return new DomViewer(); +FunctionPass *llvm::createDomViewerWrapperPassPass() { + return new DomViewerWrapperPass(); } -FunctionPass *llvm::createDomOnlyViewerPass() { - return new DomOnlyViewer(); +FunctionPass *llvm::createDomOnlyViewerWrapperPassPass() { + return new DomOnlyViewerWrapperPass(); } -FunctionPass *llvm::createPostDomPrinterPass() { - return new PostDomPrinter(); +FunctionPass *llvm::createPostDomPrinterWrapperPassPass() { + return new PostDomPrinterWrapperPass(); } -FunctionPass *llvm::createPostDomOnlyPrinterPass() { - return new PostDomOnlyPrinter(); +FunctionPass *llvm::createPostDomOnlyPrinterWrapperPassPass() { + return new PostDomOnlyPrinterWrapperPass(); } -FunctionPass *llvm::createPostDomViewerPass() { - return new PostDomViewer(); +FunctionPass *llvm::createPostDomViewerWrapperPassPass() { + return new PostDomViewerWrapperPass(); } -FunctionPass *llvm::createPostDomOnlyViewerPass() { - return new PostDomOnlyViewer(); +FunctionPass *llvm::createPostDomOnlyViewerWrapperPassPass() { + return new PostDomOnlyViewerWrapperPass(); } diff --git a/llvm/lib/Analysis/DomTreeUpdater.cpp b/llvm/lib/Analysis/DomTreeUpdater.cpp index 6e299263e66d..888c16723208 100644 --- a/llvm/lib/Analysis/DomTreeUpdater.cpp +++ b/llvm/lib/Analysis/DomTreeUpdater.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/Support/GenericDomTree.h" #include <algorithm> @@ -314,98 +315,6 @@ PostDominatorTree &DomTreeUpdater::getPostDomTree() { return *PDT; } -void DomTreeUpdater::insertEdge(BasicBlock *From, BasicBlock *To) { - -#ifndef NDEBUG - assert(isUpdateValid({DominatorTree::Insert, From, To}) && - "Inserted edge does not appear in the CFG"); -#endif - - if (!DT && !PDT) - return; - - // Won't affect DomTree and PostDomTree; discard update. - if (From == To) - return; - - if (Strategy == UpdateStrategy::Eager) { - if (DT) - DT->insertEdge(From, To); - if (PDT) - PDT->insertEdge(From, To); - return; - } - - PendUpdates.push_back({DominatorTree::Insert, From, To}); -} - -void DomTreeUpdater::insertEdgeRelaxed(BasicBlock *From, BasicBlock *To) { - if (From == To) - return; - - if (!DT && !PDT) - return; - - if (!isUpdateValid({DominatorTree::Insert, From, To})) - return; - - if (Strategy == UpdateStrategy::Eager) { - if (DT) - DT->insertEdge(From, To); - if (PDT) - PDT->insertEdge(From, To); - return; - } - - PendUpdates.push_back({DominatorTree::Insert, From, To}); -} - -void DomTreeUpdater::deleteEdge(BasicBlock *From, BasicBlock *To) { - -#ifndef NDEBUG - assert(isUpdateValid({DominatorTree::Delete, From, To}) && - "Deleted edge still exists in the CFG!"); -#endif - - if (!DT && !PDT) - return; - - // Won't affect DomTree and PostDomTree; discard update. - if (From == To) - return; - - if (Strategy == UpdateStrategy::Eager) { - if (DT) - DT->deleteEdge(From, To); - if (PDT) - PDT->deleteEdge(From, To); - return; - } - - PendUpdates.push_back({DominatorTree::Delete, From, To}); -} - -void DomTreeUpdater::deleteEdgeRelaxed(BasicBlock *From, BasicBlock *To) { - if (From == To) - return; - - if (!DT && !PDT) - return; - - if (!isUpdateValid({DominatorTree::Delete, From, To})) - return; - - if (Strategy == UpdateStrategy::Eager) { - if (DT) - DT->deleteEdge(From, To); - if (PDT) - PDT->deleteEdge(From, To); - return; - } - - PendUpdates.push_back({DominatorTree::Delete, From, To}); -} - void DomTreeUpdater::dropOutOfDateUpdates() { if (Strategy == DomTreeUpdater::UpdateStrategy::Eager) return; diff --git a/llvm/lib/Analysis/DominanceFrontier.cpp b/llvm/lib/Analysis/DominanceFrontier.cpp index a8806fe5a480..ccba913ccfe5 100644 --- a/llvm/lib/Analysis/DominanceFrontier.cpp +++ b/llvm/lib/Analysis/DominanceFrontier.cpp @@ -15,7 +15,6 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Compiler.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; diff --git a/llvm/lib/Analysis/EHPersonalities.cpp b/llvm/lib/Analysis/EHPersonalities.cpp index df8b7e12e8d7..277ff6ba735f 100644 --- a/llvm/lib/Analysis/EHPersonalities.cpp +++ b/llvm/lib/Analysis/EHPersonalities.cpp @@ -8,6 +8,7 @@ #include "llvm/Analysis/EHPersonalities.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Triple.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -67,7 +68,10 @@ StringRef llvm::getEHPersonalityName(EHPersonality Pers) { } EHPersonality llvm::getDefaultEHPersonality(const Triple &T) { - return EHPersonality::GNU_C; + if (T.isPS5()) + return EHPersonality::GNU_CXX; + else + return EHPersonality::GNU_C; } bool llvm::canSimplifyInvokeNoUnwind(const Function *F) { diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp index 33519038e225..782c11937507 100644 --- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp +++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp @@ -12,48 +12,87 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/FunctionPropertiesAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include <deque> using namespace llvm; -FunctionPropertiesInfo -FunctionPropertiesInfo::getFunctionPropertiesInfo(const Function &F, - const LoopInfo &LI) { - - FunctionPropertiesInfo FPI; +namespace { +int64_t getNrBlocksFromCond(const BasicBlock &BB) { + int64_t Ret = 0; + if (const auto *BI = dyn_cast<BranchInst>(BB.getTerminator())) { + if (BI->isConditional()) + Ret += BI->getNumSuccessors(); + } else if (const auto *SI = dyn_cast<SwitchInst>(BB.getTerminator())) { + Ret += (SI->getNumCases() + (nullptr != SI->getDefaultDest())); + } + return Ret; +} - FPI.Uses = ((!F.hasLocalLinkage()) ? 1 : 0) + F.getNumUses(); +int64_t getUses(const Function &F) { + return ((!F.hasLocalLinkage()) ? 1 : 0) + F.getNumUses(); +} +} // namespace - for (const auto &BB : F) { - ++FPI.BasicBlockCount; +void FunctionPropertiesInfo::reIncludeBB(const BasicBlock &BB) { + updateForBB(BB, +1); +} - if (const auto *BI = dyn_cast<BranchInst>(BB.getTerminator())) { - if (BI->isConditional()) - FPI.BlocksReachedFromConditionalInstruction += BI->getNumSuccessors(); - } else if (const auto *SI = dyn_cast<SwitchInst>(BB.getTerminator())) { - FPI.BlocksReachedFromConditionalInstruction += - (SI->getNumCases() + (nullptr != SI->getDefaultDest())); +void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB, + int64_t Direction) { + assert(Direction == 1 || Direction == -1); + BasicBlockCount += Direction; + BlocksReachedFromConditionalInstruction += + (Direction * getNrBlocksFromCond(BB)); + for (const auto &I : BB) { + if (auto *CS = dyn_cast<CallBase>(&I)) { + const auto *Callee = CS->getCalledFunction(); + if (Callee && !Callee->isIntrinsic() && !Callee->isDeclaration()) + DirectCallsToDefinedFunctions += Direction; } - - for (const auto &I : BB) { - if (auto *CS = dyn_cast<CallBase>(&I)) { - const auto *Callee = CS->getCalledFunction(); - if (Callee && !Callee->isIntrinsic() && !Callee->isDeclaration()) - ++FPI.DirectCallsToDefinedFunctions; - } - if (I.getOpcode() == Instruction::Load) { - ++FPI.LoadInstCount; - } else if (I.getOpcode() == Instruction::Store) { - ++FPI.StoreInstCount; - } + if (I.getOpcode() == Instruction::Load) { + LoadInstCount += Direction; + } else if (I.getOpcode() == Instruction::Store) { + StoreInstCount += Direction; } - // Loop Depth of the Basic Block - int64_t LoopDepth; - LoopDepth = LI.getLoopDepth(&BB); - if (FPI.MaxLoopDepth < LoopDepth) - FPI.MaxLoopDepth = LoopDepth; } - FPI.TopLevelLoopCount += llvm::size(LI); + TotalInstructionCount += Direction * BB.sizeWithoutDebug(); +} + +void FunctionPropertiesInfo::updateAggregateStats(const Function &F, + const LoopInfo &LI) { + + Uses = getUses(F); + TopLevelLoopCount = llvm::size(LI); + MaxLoopDepth = 0; + std::deque<const Loop *> Worklist; + llvm::append_range(Worklist, LI); + while (!Worklist.empty()) { + const auto *L = Worklist.front(); + MaxLoopDepth = + std::max(MaxLoopDepth, static_cast<int64_t>(L->getLoopDepth())); + Worklist.pop_front(); + llvm::append_range(Worklist, L->getSubLoops()); + } +} + +FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( + const Function &F, FunctionAnalysisManager &FAM) { + + FunctionPropertiesInfo FPI; + // The const casts are due to the getResult API - there's no mutation of F. + const auto &LI = FAM.getResult<LoopAnalysis>(const_cast<Function &>(F)); + const auto &DT = + FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F)); + for (const auto &BB : F) + if (DT.isReachableFromEntry(&BB)) + FPI.reIncludeBB(BB); + FPI.updateAggregateStats(F, LI); return FPI; } @@ -67,15 +106,15 @@ void FunctionPropertiesInfo::print(raw_ostream &OS) const { << "LoadInstCount: " << LoadInstCount << "\n" << "StoreInstCount: " << StoreInstCount << "\n" << "MaxLoopDepth: " << MaxLoopDepth << "\n" - << "TopLevelLoopCount: " << TopLevelLoopCount << "\n\n"; + << "TopLevelLoopCount: " << TopLevelLoopCount << "\n" + << "TotalInstructionCount: " << TotalInstructionCount << "\n\n"; } AnalysisKey FunctionPropertiesAnalysis::Key; FunctionPropertiesInfo FunctionPropertiesAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { - return FunctionPropertiesInfo::getFunctionPropertiesInfo( - F, FAM.getResult<LoopAnalysis>(F)); + return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM); } PreservedAnalyses @@ -86,3 +125,127 @@ FunctionPropertiesPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { AM.getResult<FunctionPropertiesAnalysis>(F).print(OS); return PreservedAnalyses::all(); } + +FunctionPropertiesUpdater::FunctionPropertiesUpdater( + FunctionPropertiesInfo &FPI, const CallBase &CB) + : FPI(FPI), CallSiteBB(*CB.getParent()), Caller(*CallSiteBB.getParent()) { + assert(isa<CallInst>(CB) || isa<InvokeInst>(CB)); + // For BBs that are likely to change, we subtract from feature totals their + // contribution. Some features, like max loop counts or depths, are left + // invalid, as they will be updated post-inlining. + SmallPtrSet<const BasicBlock *, 4> LikelyToChangeBBs; + // The CB BB will change - it'll either be split or the callee's body (single + // BB) will be pasted in. + LikelyToChangeBBs.insert(&CallSiteBB); + + // The caller's entry BB may change due to new alloca instructions. + LikelyToChangeBBs.insert(&*Caller.begin()); + + // The successors may become unreachable in the case of `invoke` inlining. + // We track successors separately, too, because they form a boundary, together + // with the CB BB ('Entry') between which the inlined callee will be pasted. + Successors.insert(succ_begin(&CallSiteBB), succ_end(&CallSiteBB)); + + // Inlining only handles invoke and calls. If this is an invoke, and inlining + // it pulls another invoke, the original landing pad may get split, so as to + // share its content with other potential users. So the edge up to which we + // need to invalidate and then re-account BB data is the successors of the + // current landing pad. We can leave the current lp, too - if it doesn't get + // split, then it will be the place traversal stops. Either way, the + // discounted BBs will be checked if reachable and re-added. + if (const auto *II = dyn_cast<InvokeInst>(&CB)) { + const auto *UnwindDest = II->getUnwindDest(); + Successors.insert(succ_begin(UnwindDest), succ_end(UnwindDest)); + } + + // Exclude the CallSiteBB, if it happens to be its own successor (1-BB loop). + // We are only interested in BBs the graph moves past the callsite BB to + // define the frontier past which we don't want to re-process BBs. Including + // the callsite BB in this case would prematurely stop the traversal in + // finish(). + Successors.erase(&CallSiteBB); + + for (const auto *BB : Successors) + LikelyToChangeBBs.insert(BB); + + // Commit the change. While some of the BBs accounted for above may play dual + // role - e.g. caller's entry BB may be the same as the callsite BB - set + // insertion semantics make sure we account them once. This needs to be + // followed in `finish`, too. + for (const auto *BB : LikelyToChangeBBs) + FPI.updateForBB(*BB, -1); +} + +void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const { + // Update feature values from the BBs that were copied from the callee, or + // might have been modified because of inlining. The latter have been + // subtracted in the FunctionPropertiesUpdater ctor. + // There could be successors that were reached before but now are only + // reachable from elsewhere in the CFG. + // One example is the following diamond CFG (lines are arrows pointing down): + // A + // / \ + // B C + // | | + // | D + // | | + // | E + // \ / + // F + // There's a call site in C that is inlined. Upon doing that, it turns out + // it expands to + // call void @llvm.trap() + // unreachable + // F isn't reachable from C anymore, but we did discount it when we set up + // FunctionPropertiesUpdater, so we need to re-include it here. + // At the same time, D and E were reachable before, but now are not anymore, + // so we need to leave D out (we discounted it at setup), and explicitly + // remove E. + SetVector<const BasicBlock *> Reinclude; + SetVector<const BasicBlock *> Unreachable; + const auto &DT = + FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(Caller)); + + if (&CallSiteBB != &*Caller.begin()) + Reinclude.insert(&*Caller.begin()); + + // Distribute the successors to the 2 buckets. + for (const auto *Succ : Successors) + if (DT.isReachableFromEntry(Succ)) + Reinclude.insert(Succ); + else + Unreachable.insert(Succ); + + // For reinclusion, we want to stop at the reachable successors, who are at + // the beginning of the worklist; but, starting from the callsite bb and + // ending at those successors, we also want to perform a traversal. + // IncludeSuccessorsMark is the index after which we include successors. + const auto IncludeSuccessorsMark = Reinclude.size(); + bool CSInsertion = Reinclude.insert(&CallSiteBB); + (void)CSInsertion; + assert(CSInsertion); + for (size_t I = 0; I < Reinclude.size(); ++I) { + const auto *BB = Reinclude[I]; + FPI.reIncludeBB(*BB); + if (I >= IncludeSuccessorsMark) + Reinclude.insert(succ_begin(BB), succ_end(BB)); + } + + // For exclusion, we don't need to exclude the set of BBs that were successors + // before and are now unreachable, because we already did that at setup. For + // the rest, as long as a successor is unreachable, we want to explicitly + // exclude it. + const auto AlreadyExcludedMark = Unreachable.size(); + for (size_t I = 0; I < Unreachable.size(); ++I) { + const auto *U = Unreachable[I]; + if (I >= AlreadyExcludedMark) + FPI.updateForBB(*U, -1); + for (const auto *Succ : successors(U)) + if (!DT.isReachableFromEntry(Succ)) + Unreachable.insert(Succ); + } + + const auto &LI = FAM.getResult<LoopAnalysis>(const_cast<Function &>(Caller)); + FPI.updateAggregateStats(Caller, LI); + assert(FPI == FunctionPropertiesInfo::getFunctionPropertiesInfo(Caller, FAM)); +} diff --git a/llvm/lib/Analysis/GlobalsModRef.cpp b/llvm/lib/Analysis/GlobalsModRef.cpp index 6869530148c5..e82d2fae9356 100644 --- a/llvm/lib/Analysis/GlobalsModRef.cpp +++ b/llvm/lib/Analysis/GlobalsModRef.cpp @@ -21,11 +21,11 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -68,8 +68,8 @@ class GlobalsAAResult::FunctionInfo { /// should provide this much alignment at least, but this makes it clear we /// specifically rely on this amount of alignment. struct alignas(8) AlignedMap { - AlignedMap() {} - AlignedMap(const AlignedMap &Arg) : Map(Arg.Map) {} + AlignedMap() = default; + AlignedMap(const AlignedMap &Arg) = default; GlobalInfoMapType Map; }; @@ -102,7 +102,7 @@ class GlobalsAAResult::FunctionInfo { "Insufficient low bits to store our flag and ModRef info."); public: - FunctionInfo() {} + FunctionInfo() = default; ~FunctionInfo() { delete Info.getPointer(); } @@ -511,6 +511,18 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { Handles.front().I = Handles.begin(); bool KnowNothing = false; + // Intrinsics, like any other synchronizing function, can make effects + // of other threads visible. Without nosync we know nothing really. + // Similarly, if `nocallback` is missing the function, or intrinsic, + // can call into the module arbitrarily. If both are set the function + // has an effect but will not interact with accesses of internal + // globals inside the module. We are conservative here for optnone + // functions, might not be necessary. + auto MaySyncOrCallIntoModule = [](const Function &F) { + return !F.isDeclaration() || !F.hasNoSync() || + !F.hasFnAttribute(Attribute::NoCallback); + }; + // Collect the mod/ref properties due to called functions. We only compute // one mod-ref set. for (unsigned i = 0, e = SCC.size(); i != e && !KnowNothing; ++i) { @@ -525,7 +537,7 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { // Can't do better than that! } else if (F->onlyReadsMemory()) { FI.addModRefInfo(ModRefInfo::Ref); - if (!F->isIntrinsic() && !F->onlyAccessesArgMemory()) + if (!F->onlyAccessesArgMemory() && MaySyncOrCallIntoModule(*F)) // This function might call back into the module and read a global - // consider every global as possibly being read by this function. FI.setMayReadAnyGlobal(); @@ -533,7 +545,7 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { FI.addModRefInfo(ModRefInfo::ModRef); if (!F->onlyAccessesArgMemory()) FI.setMayReadAnyGlobal(); - if (!F->isIntrinsic()) { + if (MaySyncOrCallIntoModule(*F)) { KnowNothing = true; break; } @@ -585,12 +597,7 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { // We handle calls specially because the graph-relevant aspects are // handled above. if (auto *Call = dyn_cast<CallBase>(&I)) { - auto &TLI = GetTLI(*Node->getFunction()); - if (isAllocationFn(Call, &TLI) || isFreeCall(Call, &TLI)) { - // FIXME: It is completely unclear why this is necessary and not - // handled by the above graph code. - FI.addModRefInfo(ModRefInfo::ModRef); - } else if (Function *Callee = Call->getCalledFunction()) { + if (Function *Callee = Call->getCalledFunction()) { // The callgraph doesn't include intrinsic calls. if (Callee->isIntrinsic()) { if (isa<DbgInfoIntrinsic>(Call)) @@ -979,7 +986,7 @@ GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg) } } -GlobalsAAResult::~GlobalsAAResult() {} +GlobalsAAResult::~GlobalsAAResult() = default; /*static*/ GlobalsAAResult GlobalsAAResult::analyzeModule( Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI, @@ -1010,6 +1017,24 @@ GlobalsAAResult GlobalsAA::run(Module &M, ModuleAnalysisManager &AM) { AM.getResult<CallGraphAnalysis>(M)); } +PreservedAnalyses RecomputeGlobalsAAPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (auto *G = AM.getCachedResult<GlobalsAA>(M)) { + auto &CG = AM.getResult<CallGraphAnalysis>(M); + G->NonAddressTakenGlobals.clear(); + G->UnknownFunctionsWithLocalLinkage = false; + G->IndirectGlobals.clear(); + G->AllocsForIndirectGlobals.clear(); + G->FunctionInfos.clear(); + G->FunctionToSCCMap.clear(); + G->Handles.clear(); + G->CollectSCCMembership(CG); + G->AnalyzeGlobals(M); + G->AnalyzeCallGraph(CG, M); + } + return PreservedAnalyses::all(); +} + char GlobalsAAWrapperPass::ID = 0; INITIALIZE_PASS_BEGIN(GlobalsAAWrapperPass, "globals-aa", "Globals Alias Analysis", false, true) diff --git a/llvm/lib/Analysis/IRSimilarityIdentifier.cpp b/llvm/lib/Analysis/IRSimilarityIdentifier.cpp index 01681c47418a..3d51042f4da8 100644 --- a/llvm/lib/Analysis/IRSimilarityIdentifier.cpp +++ b/llvm/lib/Analysis/IRSimilarityIdentifier.cpp @@ -64,7 +64,7 @@ void IRInstructionData::initializeInstruction() { // Here we collect the operands and their types for determining whether // the structure of the operand use matches between two different candidates. for (Use &OI : Inst->operands()) { - if (isa<CmpInst>(Inst) && RevisedPredicate.hasValue()) { + if (isa<CmpInst>(Inst) && RevisedPredicate) { // If we have a CmpInst where the predicate is reversed, it means the // operands must be reversed as well. OperVals.insert(OperVals.begin(), OI.get()); @@ -183,7 +183,7 @@ CmpInst::Predicate IRInstructionData::getPredicate() const { assert(isa<CmpInst>(Inst) && "Can only get a predicate from a compare instruction"); - if (RevisedPredicate.hasValue()) + if (RevisedPredicate) return RevisedPredicate.getValue(); return cast<CmpInst>(Inst)->getPredicate(); @@ -193,7 +193,7 @@ StringRef IRInstructionData::getCalleeName() const { assert(isa<CallInst>(Inst) && "Can only get a name from a call instruction"); - assert(CalleeName.hasValue() && "CalleeName has not been set"); + assert(CalleeName && "CalleeName has not been set"); return *CalleeName; } @@ -289,14 +289,12 @@ void IRInstructionMapper::convertToUnsignedVec( } } - if (HaveLegalRange) { - if (AddedIllegalLastTime) - mapToIllegalUnsigned(It, IntegerMappingForBB, InstrListForBB, true); - for (IRInstructionData *ID : InstrListForBB) - this->IDL->push_back(*ID); - llvm::append_range(InstrList, InstrListForBB); - llvm::append_range(IntegerMapping, IntegerMappingForBB); - } + if (AddedIllegalLastTime) + mapToIllegalUnsigned(It, IntegerMappingForBB, InstrListForBB, true); + for (IRInstructionData *ID : InstrListForBB) + this->IDL->push_back(*ID); + llvm::append_range(InstrList, InstrListForBB); + llvm::append_range(IntegerMapping, IntegerMappingForBB); } // TODO: This is the same as the MachineOutliner, and should be consolidated @@ -461,6 +459,18 @@ IRSimilarityCandidate::IRSimilarityCandidate(unsigned StartIdx, unsigned Len, // that both of these instructions are not nullptrs. FirstInst = FirstInstIt; LastInst = LastInstIt; + + // Add the basic blocks contained in the set into the global value numbering. + DenseSet<BasicBlock *> BBSet; + getBasicBlocks(BBSet); + for (BasicBlock *BB : BBSet) { + if (ValueToNumber.find(BB) != ValueToNumber.end()) + continue; + + ValueToNumber.try_emplace(BB, LocalValNumber); + NumberToValue.try_emplace(LocalValNumber, BB); + LocalValNumber++; + } } bool IRSimilarityCandidate::isSimilar(const IRSimilarityCandidate &A, @@ -516,19 +526,13 @@ static bool checkNumberingAndReplaceCommutative( for (Value *V : SourceOperands) { ArgVal = SourceValueToNumberMapping.find(V)->second; + // Instead of finding a current mapping, we attempt to insert a set. std::tie(ValueMappingIt, WasInserted) = CurrentSrcTgtNumberMapping.insert( std::make_pair(ArgVal, TargetValueNumbers)); - // Instead of finding a current mapping, we inserted a set. This means a - // mapping did not exist for the source Instruction operand, it has no - // current constraints we need to check. - if (WasInserted) - continue; - - // If a mapping already exists for the source operand to the values in the - // other IRSimilarityCandidate we need to iterate over the items in other - // IRSimilarityCandidate's Instruction to determine whether there is a valid - // mapping of Value to Value. + // We need to iterate over the items in other IRSimilarityCandidate's + // Instruction to determine whether there is a valid mapping of + // Value to Value. DenseSet<unsigned> NewSet; for (unsigned &Curr : ValueMappingIt->second) // If we can find the value in the mapping, we add it to the new set. @@ -548,7 +552,6 @@ static bool checkNumberingAndReplaceCommutative( if (ValueMappingIt->second.size() != 1) continue; - unsigned ValToRemove = *ValueMappingIt->second.begin(); // When there is only one item left in the mapping for and operand, remove // the value from the other operands. If it results in there being no @@ -791,7 +794,8 @@ bool IRSimilarityCandidate::compareStructure( // We have different paths for commutative instructions and non-commutative // instructions since commutative instructions could allow multiple mappings // to certain values. - if (IA->isCommutative() && !isa<FPMathOperator>(IA)) { + if (IA->isCommutative() && !isa<FPMathOperator>(IA) && + !isa<IntrinsicInst>(IA)) { if (!compareCommutativeOperandMapping( {A, OperValsA, ValueNumberMappingA}, {B, OperValsB, ValueNumberMappingB})) @@ -1008,6 +1012,40 @@ void IRSimilarityCandidate::createCanonicalRelationFrom( CanonNumToNumber.insert(std::make_pair(CanonNum, SourceGVN)); NumberToCanonNum.insert(std::make_pair(SourceGVN, CanonNum)); } + + DenseSet<BasicBlock *> BBSet; + getBasicBlocks(BBSet); + // Find canonical numbers for the BasicBlocks in the current candidate. + // This is done by finding the corresponding value for the first instruction + // in the block in the current candidate, finding the matching value in the + // source candidate. Then by finding the parent of this value, use the + // canonical number of the block in the source candidate for the canonical + // number in the current candidate. + for (BasicBlock *BB : BBSet) { + unsigned BBGVNForCurrCand = ValueToNumber.find(BB)->second; + + // We can skip the BasicBlock if the canonical numbering has already been + // found in a separate instruction. + if (NumberToCanonNum.find(BBGVNForCurrCand) != NumberToCanonNum.end()) + continue; + + // If the basic block is the starting block, then the shared instruction may + // not be the first instruction in the block, it will be the first + // instruction in the similarity region. + Value *FirstOutlineInst = BB == getStartBB() + ? frontInstruction() + : &*BB->instructionsWithoutDebug().begin(); + + unsigned FirstInstGVN = *getGVN(FirstOutlineInst); + unsigned FirstInstCanonNum = *getCanonicalNum(FirstInstGVN); + unsigned SourceGVN = *SourceCand.fromCanonicalNum(FirstInstCanonNum); + Value *SourceV = *SourceCand.fromGVN(SourceGVN); + BasicBlock *SourceBB = cast<Instruction>(SourceV)->getParent(); + unsigned SourceBBGVN = *SourceCand.getGVN(SourceBB); + unsigned SourceCanonBBGVN = *SourceCand.getCanonicalNum(SourceBBGVN); + CanonNumToNumber.insert(std::make_pair(SourceCanonBBGVN, BBGVNForCurrCand)); + NumberToCanonNum.insert(std::make_pair(BBGVNForCurrCand, SourceCanonBBGVN)); + } } void IRSimilarityCandidate::createCanonicalMappingFor( @@ -1162,11 +1200,12 @@ SimilarityGroupList &IRSimilarityIdentifier::findSimilarity( Mapper.InstClassifier.EnableIndirectCalls = EnableIndirectCalls; Mapper.EnableMatchCallsByName = EnableMatchingCallsByName; Mapper.InstClassifier.EnableIntrinsics = EnableIntrinsics; + Mapper.InstClassifier.EnableMustTailCalls = EnableMustTailCalls; populateMapper(Modules, InstrList, IntegerMapping); findCandidates(InstrList, IntegerMapping); - return SimilarityCandidates.getValue(); + return *SimilarityCandidates; } SimilarityGroupList &IRSimilarityIdentifier::findSimilarity(Module &M) { @@ -1175,6 +1214,7 @@ SimilarityGroupList &IRSimilarityIdentifier::findSimilarity(Module &M) { Mapper.InstClassifier.EnableIndirectCalls = EnableIndirectCalls; Mapper.EnableMatchCallsByName = EnableMatchingCallsByName; Mapper.InstClassifier.EnableIntrinsics = EnableIntrinsics; + Mapper.InstClassifier.EnableMustTailCalls = EnableMustTailCalls; std::vector<IRInstructionData *> InstrList; std::vector<unsigned> IntegerMapping; @@ -1182,7 +1222,7 @@ SimilarityGroupList &IRSimilarityIdentifier::findSimilarity(Module &M) { populateMapper(M, InstrList, IntegerMapping); findCandidates(InstrList, IntegerMapping); - return SimilarityCandidates.getValue(); + return *SimilarityCandidates; } INITIALIZE_PASS(IRSimilarityIdentifierWrapperPass, "ir-similarity-identifier", @@ -1196,7 +1236,8 @@ IRSimilarityIdentifierWrapperPass::IRSimilarityIdentifierWrapperPass() bool IRSimilarityIdentifierWrapperPass::doInitialization(Module &M) { IRSI.reset(new IRSimilarityIdentifier(!DisableBranches, !DisableIndirectCalls, - MatchCallsByName, !DisableIntrinsics)); + MatchCallsByName, !DisableIntrinsics, + false)); return false; } @@ -1214,7 +1255,8 @@ AnalysisKey IRSimilarityAnalysis::Key; IRSimilarityIdentifier IRSimilarityAnalysis::run(Module &M, ModuleAnalysisManager &) { auto IRSI = IRSimilarityIdentifier(!DisableBranches, !DisableIndirectCalls, - MatchCallsByName, !DisableIntrinsics); + MatchCallsByName, !DisableIntrinsics, + false); IRSI.findSimilarity(M); return IRSI; } diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp index 44b1d94ebdc8..e4d706ab045c 100644 --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -11,26 +11,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/IVDescriptors.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DemandedBits.h" -#include "llvm/Analysis/DomTreeUpdater.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" @@ -237,12 +227,10 @@ static bool checkOrderedReduction(RecurKind Kind, Instruction *ExactFPMathInst, return true; } -bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, - Loop *TheLoop, FastMathFlags FuncFMF, - RecurrenceDescriptor &RedDes, - DemandedBits *DB, - AssumptionCache *AC, - DominatorTree *DT) { +bool RecurrenceDescriptor::AddReductionVar( + PHINode *Phi, RecurKind Kind, Loop *TheLoop, FastMathFlags FuncFMF, + RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC, + DominatorTree *DT, ScalarEvolution *SE) { if (Phi->getNumIncomingValues() != 2) return false; @@ -259,6 +247,12 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, // This includes users of the reduction, variables (which form a cycle // which ends in the phi node). Instruction *ExitInstruction = nullptr; + + // Variable to keep last visited store instruction. By the end of the + // algorithm this variable will be either empty or having intermediate + // reduction value stored in invariant address. + StoreInst *IntermediateStore = nullptr; + // Indicates that we found a reduction operation in our scan. bool FoundReduxOp = false; @@ -324,6 +318,10 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, // - By instructions outside of the loop (safe). // * One value may have several outside users, but all outside // uses must be of the same value. + // - By store instructions with a loop invariant address (safe with + // the following restrictions): + // * If there are several stores, all must have the same address. + // * Final value should be stored in that loop invariant address. // - By an instruction that is not part of the reduction (not safe). // This is either: // * An instruction type other than PHI or the reduction operation. @@ -331,6 +329,43 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, while (!Worklist.empty()) { Instruction *Cur = Worklist.pop_back_val(); + // Store instructions are allowed iff it is the store of the reduction + // value to the same loop invariant memory location. + if (auto *SI = dyn_cast<StoreInst>(Cur)) { + if (!SE) { + LLVM_DEBUG(dbgs() << "Store instructions are not processed without " + << "Scalar Evolution Analysis\n"); + return false; + } + + const SCEV *PtrScev = SE->getSCEV(SI->getPointerOperand()); + // Check it is the same address as previous stores + if (IntermediateStore) { + const SCEV *OtherScev = + SE->getSCEV(IntermediateStore->getPointerOperand()); + + if (OtherScev != PtrScev) { + LLVM_DEBUG(dbgs() << "Storing reduction value to different addresses " + << "inside the loop: " << *SI->getPointerOperand() + << " and " + << *IntermediateStore->getPointerOperand() << '\n'); + return false; + } + } + + // Check the pointer is loop invariant + if (!SE->isLoopInvariant(PtrScev, TheLoop)) { + LLVM_DEBUG(dbgs() << "Storing reduction value to non-uniform address " + << "inside the loop: " << *SI->getPointerOperand() + << '\n'); + return false; + } + + // IntermediateStore is always the last store in the loop. + IntermediateStore = SI; + continue; + } + // No Users. // If the instruction has no users then this is a broken chain and can't be // a reduction variable. @@ -453,10 +488,17 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, // reductions which are represented as a cmp followed by a select. InstDesc IgnoredVal(false, nullptr); if (VisitedInsts.insert(UI).second) { - if (isa<PHINode>(UI)) + if (isa<PHINode>(UI)) { PHIs.push_back(UI); - else + } else { + StoreInst *SI = dyn_cast<StoreInst>(UI); + if (SI && SI->getPointerOperand() == Cur) { + // Reduction variable chain can only be stored somewhere but it + // can't be used as an address. + return false; + } NonPHIs.push_back(UI); + } } else if (!isa<PHINode>(UI) && ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) && !isa<SelectInst>(UI)) || @@ -476,7 +518,7 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, // This means we have seen one but not the other instruction of the // pattern or more than just a select and cmp. Zero implies that we saw a - // llvm.min/max instrinsic, which is always OK. + // llvm.min/max intrinsic, which is always OK. if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 && NumCmpSelectPatternInst != 0) return false; @@ -484,6 +526,32 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, if (isSelectCmpRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1) return false; + if (IntermediateStore) { + // Check that stored value goes to the phi node again. This way we make sure + // that the value stored in IntermediateStore is indeed the final reduction + // value. + if (!is_contained(Phi->operands(), IntermediateStore->getValueOperand())) { + LLVM_DEBUG(dbgs() << "Not a final reduction value stored: " + << *IntermediateStore << '\n'); + return false; + } + + // If there is an exit instruction it's value should be stored in + // IntermediateStore + if (ExitInstruction && + IntermediateStore->getValueOperand() != ExitInstruction) { + LLVM_DEBUG(dbgs() << "Last store Instruction of reduction value does not " + "store last calculated value of the reduction: " + << *IntermediateStore << '\n'); + return false; + } + + // If all uses are inside the loop (intermediate stores), then the + // reduction value after the loop will be the one used in the last store. + if (!ExitInstruction) + ExitInstruction = cast<Instruction>(IntermediateStore->getValueOperand()); + } + if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) return false; @@ -545,9 +613,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind, // is saved as part of the RecurrenceDescriptor. // Save the description of this reduction variable. - RecurrenceDescriptor RD(RdxStart, ExitInstruction, Kind, FMF, ExactFPMathInst, - RecurrenceType, IsSigned, IsOrdered, CastInsts, - MinWidthCastToRecurrenceType); + RecurrenceDescriptor RD(RdxStart, ExitInstruction, IntermediateStore, Kind, + FMF, ExactFPMathInst, RecurrenceType, IsSigned, + IsOrdered, CastInsts, MinWidthCastToRecurrenceType); RedDes = RD; return true; @@ -771,7 +839,8 @@ bool RecurrenceDescriptor::hasMultipleUsesOf( bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC, - DominatorTree *DT) { + DominatorTree *DT, + ScalarEvolution *SE) { BasicBlock *Header = TheLoop->getHeader(); Function &F = *Header->getParent(); FastMathFlags FMF; @@ -780,72 +849,85 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, FMF.setNoSignedZeros( F.getFnAttribute("no-signed-zeros-fp-math").getValueAsBool()); - if (AddReductionVar(Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a SMAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a SMIN reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a UMAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a UMIN reduction PHI." << *Phi << "\n"); return true; } if (AddReductionVar(Phi, RecurKind::SelectICmp, TheLoop, FMF, RedDes, DB, AC, - DT)) { + DT, SE)) { LLVM_DEBUG(dbgs() << "Found an integer conditional select reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a float MAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT)) { + if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found a float MIN reduction PHI." << *Phi << "\n"); return true; } if (AddReductionVar(Phi, RecurKind::SelectFCmp, TheLoop, FMF, RedDes, DB, AC, - DT)) { + DT, SE)) { LLVM_DEBUG(dbgs() << "Found a float conditional select reduction PHI." << " PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, - DT)) { + if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n"); return true; } @@ -917,12 +999,37 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence( SinkCandidate->mayReadFromMemory() || SinkCandidate->isTerminator()) return false; - // Do not try to sink an instruction multiple times (if multiple operands - // are first order recurrences). - // TODO: We can support this case, by sinking the instruction after the - // 'deepest' previous instruction. - if (SinkAfter.find(SinkCandidate) != SinkAfter.end()) - return false; + // Avoid sinking an instruction multiple times (if multiple operands are + // first order recurrences) by sinking once - after the latest 'previous' + // instruction. + auto It = SinkAfter.find(SinkCandidate); + if (It != SinkAfter.end()) { + auto *OtherPrev = It->second; + // Find the earliest entry in the 'sink-after' chain. The last entry in + // the chain is the original 'Previous' for a recurrence handled earlier. + auto EarlierIt = SinkAfter.find(OtherPrev); + while (EarlierIt != SinkAfter.end()) { + Instruction *EarlierInst = EarlierIt->second; + EarlierIt = SinkAfter.find(EarlierInst); + // Bail out if order has not been preserved. + if (EarlierIt != SinkAfter.end() && + !DT->dominates(EarlierInst, OtherPrev)) + return false; + OtherPrev = EarlierInst; + } + // Bail out if order has not been preserved. + if (OtherPrev != It->second && !DT->dominates(It->second, OtherPrev)) + return false; + + // SinkCandidate is already being sunk after an instruction after + // Previous. Nothing left to do. + if (DT->dominates(Previous, OtherPrev) || Previous == OtherPrev) + return true; + // Otherwise, Previous comes after OtherPrev and SinkCandidate needs to be + // re-sunk to Previous, instead of sinking to OtherPrev. Remove + // SinkCandidate from SinkAfter to ensure it's insert position is updated. + SinkAfter.erase(SinkCandidate); + } // If we reach a PHI node that is not dominated by Previous, we reached a // header PHI. No need for sinking. @@ -1052,7 +1159,7 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const { // to check for a pair of icmp/select, for which we use getNextInstruction and // isCorrectOpcode functions to step the right number of instruction, and // check the icmp/select pair. - // FIXME: We also do not attempt to look through Phi/Select's yet, which might + // FIXME: We also do not attempt to look through Select's yet, which might // be part of the reduction chain, or attempt to looks through And's to find a // smaller bitwidth. Subs are also currently not allowed (which are usually // treated as part of a add reduction) as they are expected to generally be @@ -1062,16 +1169,21 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const { if (RedOp == Instruction::ICmp || RedOp == Instruction::FCmp) ExpectedUses = 2; - auto getNextInstruction = [&](Instruction *Cur) { - if (RedOp == Instruction::ICmp || RedOp == Instruction::FCmp) { - // We are expecting a icmp/select pair, which we go to the next select - // instruction if we can. We already know that Cur has 2 uses. - if (isa<SelectInst>(*Cur->user_begin())) - return cast<Instruction>(*Cur->user_begin()); - else - return cast<Instruction>(*std::next(Cur->user_begin())); + auto getNextInstruction = [&](Instruction *Cur) -> Instruction * { + for (auto User : Cur->users()) { + Instruction *UI = cast<Instruction>(User); + if (isa<PHINode>(UI)) + continue; + if (RedOp == Instruction::ICmp || RedOp == Instruction::FCmp) { + // We are expecting a icmp/select pair, which we go to the next select + // instruction if we can. We already know that Cur has 2 uses. + if (isa<SelectInst>(UI)) + return UI; + continue; + } + return UI; } - return cast<Instruction>(*Cur->user_begin()); + return nullptr; }; auto isCorrectOpcode = [&](Instruction *Cur) { if (RedOp == Instruction::ICmp || RedOp == Instruction::FCmp) { @@ -1086,22 +1198,46 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const { return Cur->getOpcode() == RedOp; }; + // Attempt to look through Phis which are part of the reduction chain + unsigned ExtraPhiUses = 0; + Instruction *RdxInstr = LoopExitInstr; + if (auto ExitPhi = dyn_cast<PHINode>(LoopExitInstr)) { + if (ExitPhi->getNumIncomingValues() != 2) + return {}; + + Instruction *Inc0 = dyn_cast<Instruction>(ExitPhi->getIncomingValue(0)); + Instruction *Inc1 = dyn_cast<Instruction>(ExitPhi->getIncomingValue(1)); + + Instruction *Chain = nullptr; + if (Inc0 == Phi) + Chain = Inc1; + else if (Inc1 == Phi) + Chain = Inc0; + else + return {}; + + RdxInstr = Chain; + ExtraPhiUses = 1; + } + // The loop exit instruction we check first (as a quick test) but add last. We // check the opcode is correct (and dont allow them to be Subs) and that they // have expected to have the expected number of uses. They will have one use // from the phi and one from a LCSSA value, no matter the type. - if (!isCorrectOpcode(LoopExitInstr) || !LoopExitInstr->hasNUses(2)) + if (!isCorrectOpcode(RdxInstr) || !LoopExitInstr->hasNUses(2)) return {}; - // Check that the Phi has one (or two for min/max) uses. - if (!Phi->hasNUses(ExpectedUses)) + // Check that the Phi has one (or two for min/max) uses, plus an extra use + // for conditional reductions. + if (!Phi->hasNUses(ExpectedUses + ExtraPhiUses)) return {}; + Instruction *Cur = getNextInstruction(Phi); // Each other instruction in the chain should have the expected number of uses // and be the correct opcode. - while (Cur != LoopExitInstr) { - if (!isCorrectOpcode(Cur) || !Cur->hasNUses(ExpectedUses)) + while (Cur != RdxInstr) { + if (!Cur || !isCorrectOpcode(Cur) || !Cur->hasNUses(ExpectedUses)) return {}; ReductionOperations.push_back(Cur); @@ -1428,10 +1564,14 @@ bool InductionDescriptor::isInductionPHI( ConstantInt *CV = ConstStep->getValue(); const DataLayout &DL = Phi->getModule()->getDataLayout(); - int64_t Size = static_cast<int64_t>(DL.getTypeAllocSize(ElementType)); - if (!Size) + TypeSize TySize = DL.getTypeAllocSize(ElementType); + // TODO: We could potentially support this for scalable vectors if we can + // prove at compile time that the constant step is always a multiple of + // the scalable type. + if (TySize.isZero() || TySize.isScalable()) return false; + int64_t Size = static_cast<int64_t>(TySize.getFixedSize()); int64_t CVSize = CV->getSExtValue(); if (CVSize % Size) return false; diff --git a/llvm/lib/Analysis/IVUsers.cpp b/llvm/lib/Analysis/IVUsers.cpp index 0f3929f45506..5bde947bd851 100644 --- a/llvm/lib/Analysis/IVUsers.cpp +++ b/llvm/lib/Analysis/IVUsers.cpp @@ -12,25 +12,21 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/IVUsers.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> using namespace llvm; #define DEBUG_TYPE "iv-users" diff --git a/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp index b112ed2e4439..ebfa1c8fc08e 100644 --- a/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp +++ b/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -13,12 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/IndirectCallVisitor.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Instruction.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -31,7 +26,7 @@ using namespace llvm; // The percent threshold for the direct-call target (this call site vs the // remaining call count) for it to be considered as the promotion target. static cl::opt<unsigned> ICPRemainingPercentThreshold( - "icp-remaining-percent-threshold", cl::init(30), cl::Hidden, cl::ZeroOrMore, + "icp-remaining-percent-threshold", cl::init(30), cl::Hidden, cl::desc("The percentage threshold against remaining unpromoted indirect " "call count for the promotion")); @@ -39,14 +34,14 @@ static cl::opt<unsigned> ICPRemainingPercentThreshold( // total call count) for it to be considered as the promotion target. static cl::opt<unsigned> ICPTotalPercentThreshold("icp-total-percent-threshold", cl::init(5), - cl::Hidden, cl::ZeroOrMore, + cl::Hidden, cl::desc("The percentage threshold against total " "count for the promotion")); // Set the maximum number of targets to promote for a single indirect-call // callsite. static cl::opt<unsigned> - MaxNumPromotions("icp-max-prom", cl::init(3), cl::Hidden, cl::ZeroOrMore, + MaxNumPromotions("icp-max-prom", cl::init(3), cl::Hidden, cl::desc("Max number of promotions for a single indirect " "call callsite")); diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp index f6e3dd354ff8..cf8592c41eda 100644 --- a/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/llvm/lib/Analysis/InlineAdvisor.cpp @@ -13,14 +13,15 @@ #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" @@ -55,6 +56,11 @@ static cl::opt<int> cl::desc("Scale to limit the cost of inline deferral"), cl::init(2), cl::Hidden); +static cl::opt<bool> AnnotateInlinePhase( + "annotate-inline-phase", cl::Hidden, cl::init(false), + cl::desc("If true, annotate inline advisor remarks " + "with LTO and pass information.")); + extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; namespace { @@ -80,7 +86,8 @@ private: void recordUnsuccessfulInliningImpl(const InlineResult &Result) override { if (IsInliningRecommended) ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + return OptimizationRemarkMissed(Advisor->getAnnotatedInlinePassName(), + "NotInlined", DLoc, Block) << "'" << NV("Callee", Callee) << "' is not AlwaysInline into '" << NV("Caller", Caller) << "': " << NV("Reason", Result.getFailureReason()); @@ -99,7 +106,8 @@ void DefaultInlineAdvice::recordUnsuccessfulInliningImpl( llvm::setInlineRemark(*OriginalCB, std::string(Result.getFailureReason()) + "; " + inlineCostStr(*OIC)); ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + return OptimizationRemarkMissed(Advisor->getAnnotatedInlinePassName(), + "NotInlined", DLoc, Block) << "'" << NV("Callee", Callee) << "' is not inlined into '" << NV("Caller", Caller) << "': " << NV("Reason", Result.getFailureReason()); @@ -108,12 +116,16 @@ void DefaultInlineAdvice::recordUnsuccessfulInliningImpl( void DefaultInlineAdvice::recordInliningWithCalleeDeletedImpl() { if (EmitRemarks) - emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC); + emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC, + /* ForProfileContext= */ false, + Advisor->getAnnotatedInlinePassName()); } void DefaultInlineAdvice::recordInliningImpl() { if (EmitRemarks) - emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC); + emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC, + /* ForProfileContext= */ false, + Advisor->getAnnotatedInlinePassName()); } llvm::Optional<llvm::InlineCost> static getDefaultInlineAdvice( @@ -146,7 +158,7 @@ llvm::Optional<llvm::InlineCost> static getDefaultInlineAdvice( }; return llvm::shouldInline( CB, GetInlineCost, ORE, - Params.EnableDeferral.getValueOr(EnableInlineDeferral)); + Params.EnableDeferral.value_or(EnableInlineDeferral)); } std::unique_ptr<InlineAdvice> @@ -185,18 +197,18 @@ AnalysisKey InlineAdvisorAnalysis::Key; bool InlineAdvisorAnalysis::Result::tryCreate( InlineParams Params, InliningAdvisorMode Mode, - const ReplayInlinerSettings &ReplaySettings) { + const ReplayInlinerSettings &ReplaySettings, InlineContext IC) { auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); switch (Mode) { case InliningAdvisorMode::Default: LLVM_DEBUG(dbgs() << "Using default inliner heuristic.\n"); - Advisor.reset(new DefaultInlineAdvisor(M, FAM, Params)); + Advisor.reset(new DefaultInlineAdvisor(M, FAM, Params, IC)); // Restrict replay to default advisor, ML advisors are stateful so // replay will need augmentations to interleave with them correctly. if (!ReplaySettings.ReplayFile.empty()) { Advisor = llvm::getReplayInlineAdvisor(M, FAM, M.getContext(), std::move(Advisor), ReplaySettings, - /* EmitRemarks =*/true); + /* EmitRemarks =*/true, IC); } break; case InliningAdvisorMode::Development: @@ -442,7 +454,7 @@ std::string llvm::formatCallSiteLocation(DebugLoc DLoc, } void llvm::addLocationToRemarks(OptimizationRemark &Remark, DebugLoc DLoc) { - if (!DLoc.get()) { + if (!DLoc) { return; } @@ -499,8 +511,11 @@ void llvm::emitInlinedIntoBasedOnCost( PassName); } -InlineAdvisor::InlineAdvisor(Module &M, FunctionAnalysisManager &FAM) - : M(M), FAM(FAM) { +InlineAdvisor::InlineAdvisor(Module &M, FunctionAnalysisManager &FAM, + Optional<InlineContext> IC) + : M(M), FAM(FAM), IC(IC), + AnnotatedInlinePassName((IC && AnnotateInlinePhase) ? llvm::AnnotateInlinePassName(*IC) + : DEBUG_TYPE) { if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) { ImportedFunctionsStats = std::make_unique<ImportedFunctionsInliningStatistics>(); @@ -522,6 +537,48 @@ std::unique_ptr<InlineAdvice> InlineAdvisor::getMandatoryAdvice(CallBase &CB, Advice); } +static inline const char *getLTOPhase(ThinOrFullLTOPhase LTOPhase) { + switch (LTOPhase) { + case (ThinOrFullLTOPhase::None): + return "main"; + case (ThinOrFullLTOPhase::ThinLTOPreLink): + case (ThinOrFullLTOPhase::FullLTOPreLink): + return "prelink"; + case (ThinOrFullLTOPhase::ThinLTOPostLink): + case (ThinOrFullLTOPhase::FullLTOPostLink): + return "postlink"; + } + llvm_unreachable("unreachable"); +} + +static inline const char *getInlineAdvisorContext(InlinePass IP) { + switch (IP) { + case (InlinePass::AlwaysInliner): + return "always-inline"; + case (InlinePass::CGSCCInliner): + return "cgscc-inline"; + case (InlinePass::EarlyInliner): + return "early-inline"; + case (InlinePass::MLInliner): + return "ml-inline"; + case (InlinePass::ModuleInliner): + return "module-inline"; + case (InlinePass::ReplayCGSCCInliner): + return "replay-cgscc-inline"; + case (InlinePass::ReplaySampleProfileInliner): + return "replay-sample-profile-inline"; + case (InlinePass::SampleProfileInliner): + return "sample-profile-inline"; + } + + llvm_unreachable("unreachable"); +} + +std::string llvm::AnnotateInlinePassName(InlineContext IC) { + return std::string(getLTOPhase(IC.LTOPhase)) + "-" + + std::string(getInlineAdvisorContext(IC.Pass)); +} + InlineAdvisor::MandatoryInliningKind InlineAdvisor::getMandatoryKind(CallBase &CB, FunctionAnalysisManager &FAM, OptimizationRemarkEmitter &ORE) { @@ -536,7 +593,7 @@ InlineAdvisor::getMandatoryKind(CallBase &CB, FunctionAnalysisManager &FAM, auto TrivialDecision = llvm::getAttributeBasedInliningDecision(CB, &Callee, TIR, GetTLI); - if (TrivialDecision.hasValue()) { + if (TrivialDecision) { if (TrivialDecision->isSuccess()) return MandatoryInliningKind::Always; else @@ -568,3 +625,22 @@ InlineAdvisorAnalysisPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { IA->getAdvisor()->print(OS); return PreservedAnalyses::all(); } + +PreservedAnalyses InlineAdvisorAnalysisPrinterPass::run( + LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + const auto &MAMProxy = + AM.getResult<ModuleAnalysisManagerCGSCCProxy>(InitialC, CG); + + if (InitialC.size() == 0) { + OS << "SCC is empty!\n"; + return PreservedAnalyses::all(); + } + Module &M = *InitialC.begin()->getFunction().getParent(); + const auto *IA = MAMProxy.getCachedResult<InlineAdvisorAnalysis>(M); + if (!IA) + OS << "No Inline Advisor\n"; + else + IA->getAdvisor()->print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp index d5411d916c77..e63497260e6e 100644 --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -18,11 +18,11 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -42,6 +42,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Support/raw_ostream.h" +#include <limits> using namespace llvm; @@ -51,24 +52,33 @@ STATISTIC(NumCallsAnalyzed, "Number of call sites analyzed"); static cl::opt<int> DefaultThreshold("inlinedefault-threshold", cl::Hidden, cl::init(225), - cl::ZeroOrMore, cl::desc("Default amount of inlining to perform")); +// We introduce this option since there is a minor compile-time win by avoiding +// addition of TTI attributes (target-features in particular) to inline +// candidates when they are guaranteed to be the same as top level methods in +// some use cases. If we avoid adding the attribute, we need an option to avoid +// checking these attributes. +static cl::opt<bool> IgnoreTTIInlineCompatible( + "ignore-tti-inline-compatible", cl::Hidden, cl::init(false), + cl::desc("Ignore TTI attributes compatibility check between callee/caller " + "during inline cost calculation")); + static cl::opt<bool> PrintInstructionComments( "print-instruction-comments", cl::Hidden, cl::init(false), cl::desc("Prints comments for instruction based on inline cost analysis")); static cl::opt<int> InlineThreshold( - "inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore, + "inline-threshold", cl::Hidden, cl::init(225), cl::desc("Control the amount of inlining to perform (default = 225)")); static cl::opt<int> HintThreshold( - "inlinehint-threshold", cl::Hidden, cl::init(325), cl::ZeroOrMore, + "inlinehint-threshold", cl::Hidden, cl::init(325), cl::desc("Threshold for inlining functions with inline hint")); static cl::opt<int> ColdCallSiteThreshold("inline-cold-callsite-threshold", cl::Hidden, - cl::init(45), cl::ZeroOrMore, + cl::init(45), cl::desc("Threshold for inlining cold callsites")); static cl::opt<bool> InlineEnableCostBenefitAnalysis( @@ -76,12 +86,11 @@ static cl::opt<bool> InlineEnableCostBenefitAnalysis( cl::desc("Enable the cost-benefit analysis for the inliner")); static cl::opt<int> InlineSavingsMultiplier( - "inline-savings-multiplier", cl::Hidden, cl::init(8), cl::ZeroOrMore, + "inline-savings-multiplier", cl::Hidden, cl::init(8), cl::desc("Multiplier to multiply cycle savings by during inlining")); static cl::opt<int> InlineSizeAllowance("inline-size-allowance", cl::Hidden, cl::init(100), - cl::ZeroOrMore, cl::desc("The maximum size of a callee that get's " "inlined without sufficient cycle savings")); @@ -89,26 +98,25 @@ static cl::opt<int> // PGO before we actually hook up inliner with analysis passes such as BPI and // BFI. static cl::opt<int> ColdThreshold( - "inlinecold-threshold", cl::Hidden, cl::init(45), cl::ZeroOrMore, + "inlinecold-threshold", cl::Hidden, cl::init(45), cl::desc("Threshold for inlining functions with cold attribute")); static cl::opt<int> HotCallSiteThreshold("hot-callsite-threshold", cl::Hidden, cl::init(3000), - cl::ZeroOrMore, cl::desc("Threshold for hot callsites ")); static cl::opt<int> LocallyHotCallSiteThreshold( - "locally-hot-callsite-threshold", cl::Hidden, cl::init(525), cl::ZeroOrMore, + "locally-hot-callsite-threshold", cl::Hidden, cl::init(525), cl::desc("Threshold for locally hot callsites ")); static cl::opt<int> ColdCallSiteRelFreq( - "cold-callsite-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore, + "cold-callsite-rel-freq", cl::Hidden, cl::init(2), cl::desc("Maximum block frequency, expressed as a percentage of caller's " "entry frequency, for a callsite to be cold in the absence of " "profile information.")); static cl::opt<int> HotCallSiteRelFreq( - "hot-callsite-rel-freq", cl::Hidden, cl::init(60), cl::ZeroOrMore, + "hot-callsite-rel-freq", cl::Hidden, cl::init(60), cl::desc("Minimum block frequency, expressed as a multiple of caller's " "entry frequency, for a callsite to be hot in the absence of " "profile information.")); @@ -117,14 +125,19 @@ static cl::opt<int> CallPenalty( "inline-call-penalty", cl::Hidden, cl::init(25), cl::desc("Call penalty that is applied per callsite when inlining")); +static cl::opt<size_t> + StackSizeThreshold("inline-max-stacksize", cl::Hidden, + cl::init(std::numeric_limits<size_t>::max()), + cl::desc("Do not inline functions with a stack size " + "that exceeds the specified limit")); + static cl::opt<bool> OptComputeFullInlineCost( - "inline-cost-full", cl::Hidden, cl::init(false), cl::ZeroOrMore, + "inline-cost-full", cl::Hidden, cl::desc("Compute the full inline cost of a call site even when the cost " "exceeds the threshold.")); static cl::opt<bool> InlineCallerSupersetNoBuiltin( "inline-caller-superset-nobuiltin", cl::Hidden, cl::init(true), - cl::ZeroOrMore, cl::desc("Allow inlining when caller has a superset of callee's nobuiltin " "attributes.")); @@ -132,33 +145,18 @@ static cl::opt<bool> DisableGEPConstOperand( "disable-gep-const-evaluation", cl::Hidden, cl::init(false), cl::desc("Disables evaluation of GetElementPtr with constant operands")); -namespace { -class InlineCostCallAnalyzer; - -/// This function behaves more like CallBase::hasFnAttr: when it looks for the -/// requested attribute, it check both the call instruction and the called -/// function (if it's available and operand bundles don't prohibit that). -Attribute getFnAttr(CallBase &CB, StringRef AttrKind) { - Attribute CallAttr = CB.getFnAttr(AttrKind); - if (CallAttr.isValid()) - return CallAttr; - - // Operand bundles override attributes on the called function, but don't - // override attributes directly present on the call instruction. - if (!CB.isFnAttrDisallowedByOpBundle(AttrKind)) - if (const Function *F = CB.getCalledFunction()) - return F->getFnAttribute(AttrKind); - - return {}; -} - +namespace llvm { Optional<int> getStringFnAttrAsInt(CallBase &CB, StringRef AttrKind) { - Attribute Attr = getFnAttr(CB, AttrKind); + Attribute Attr = CB.getFnAttr(AttrKind); int AttrValue; if (Attr.getValueAsString().getAsInteger(10, AttrValue)) return None; return AttrValue; } +} // namespace llvm + +namespace { +class InlineCostCallAnalyzer; // This struct is used to store information about inline cost of a // particular instruction @@ -198,7 +196,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { friend class InstVisitor<CallAnalyzer, bool>; protected: - virtual ~CallAnalyzer() {} + virtual ~CallAnalyzer() = default; /// The TargetTransformInfo available for this compilation. const TargetTransformInfo &TTI; @@ -352,7 +350,7 @@ protected: DenseMap<Value *, std::pair<Value *, APInt>> ConstantOffsetPtrs; /// Keep track of dead blocks due to the constant arguments. - SetVector<BasicBlock *> DeadBlocks; + SmallPtrSet<BasicBlock *, 16> DeadBlocks; /// The mapping of the blocks to their known unique successors due to the /// constant arguments. @@ -385,8 +383,7 @@ protected: bool canFoldInboundsGEP(GetElementPtrInst &I); bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset); bool simplifyCallSite(Function *F, CallBase &Call); - template <typename Callable> - bool simplifyInstruction(Instruction &I, Callable Evaluate); + bool simplifyInstruction(Instruction &I); bool simplifyIntrinsicCallIsConstant(CallBase &CB); ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); @@ -704,7 +701,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { BlockFrequencyInfo *BFI = &(GetBFI(F)); assert(BFI && "BFI must be available"); auto ProfileCount = BFI->getBlockProfileCount(BB); - assert(ProfileCount.hasValue()); + assert(ProfileCount); if (ProfileCount.getValue() == 0) ColdSize += Cost - CostAtBBStart; } @@ -829,14 +826,14 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { } auto ProfileCount = CalleeBFI->getBlockProfileCount(&BB); - assert(ProfileCount.hasValue()); + assert(ProfileCount); CurrentSavings *= ProfileCount.getValue(); CycleSavings += CurrentSavings; } // Compute the cycle savings per call. auto EntryProfileCount = F.getEntryCount(); - assert(EntryProfileCount.hasValue() && EntryProfileCount->getCount()); + assert(EntryProfileCount && EntryProfileCount->getCount()); auto EntryCount = EntryProfileCount->getCount(); CycleSavings += EntryCount / 2; CycleSavings = CycleSavings.udiv(EntryCount); @@ -845,7 +842,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { auto *CallerBB = CandidateCall.getParent(); BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent()))); CycleSavings += getCallsiteCost(this->CandidateCall, DL); - CycleSavings *= CallerBFI->getBlockProfileCount(CallerBB).getValue(); + CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB); // Remove the cost of the cold basic blocks. int Size = Cost - ColdSize; @@ -904,13 +901,18 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { getStringFnAttrAsInt(CandidateCall, "function-inline-cost")) Cost = *AttrCost; + if (Optional<int> AttrCostMult = getStringFnAttrAsInt( + CandidateCall, + InlineConstants::FunctionInlineCostMultiplierAttributeName)) + Cost *= *AttrCostMult; + if (Optional<int> AttrThreshold = getStringFnAttrAsInt(CandidateCall, "function-inline-threshold")) Threshold = *AttrThreshold; if (auto Result = costBenefitAnalysis()) { DecidedByCostBenefit = true; - if (Result.getValue()) + if (*Result) return InlineResult::success(); else return InlineResult::failure("Cost over threshold."); @@ -978,6 +980,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { if (F.getCallingConv() == CallingConv::Cold) Cost += InlineConstants::ColdccPenalty; + LLVM_DEBUG(dbgs() << " Initial cost: " << Cost << "\n"); + // Check if we're done. This can happen due to bonuses and penalties. if (Cost >= Threshold && !ComputeFullInlineCost) return InlineResult::failure("high cost"); @@ -1002,7 +1006,7 @@ public: BoostIndirectCalls(BoostIndirect), IgnoreThreshold(IgnoreThreshold), CostBenefitAnalysisEnabled(isCostBenefitAnalysisEnabled()), Writer(this) { - AllowRecursiveCall = Params.AllowRecursiveCall.getValue(); + AllowRecursiveCall = *Params.AllowRecursiveCall; } /// Annotation Writer for instruction details @@ -1020,7 +1024,7 @@ public: return None; } - virtual ~InlineCostCallAnalyzer() {} + virtual ~InlineCostCallAnalyzer() = default; int getThreshold() const { return Threshold; } int getCost() const { return Cost; } Optional<CostBenefitPair> getCostBenefitPair() { return CostBenefit; } @@ -1203,6 +1207,10 @@ private: set(InlineCostFeatureIndex::ColdCcPenalty, (F.getCallingConv() == CallingConv::Cold)); + set(InlineCostFeatureIndex::LastCallToStaticBonus, + (F.hasLocalLinkage() && F.hasOneLiveUse() && + &F == CandidateCall.getCalledFunction())); + // FIXME: we shouldn't repeat this logic in both the Features and Cost // analyzer - instead, we should abstract it to a common method in the // CallAnalyzer @@ -1262,7 +1270,7 @@ void InlineCostAnnotationWriter::emitInstructionAnnot( auto C = ICCA->getSimplifiedValue(const_cast<Instruction *>(I)); if (C) { OS << ", simplified to "; - C.getValue()->print(OS, true); + (*C)->print(OS, true); } OS << "\n"; } @@ -1501,13 +1509,7 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { }; if (!DisableGEPConstOperand) - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - SmallVector<Constant *, 2> Indices; - for (unsigned int Index = 1; Index < COps.size(); ++Index) - Indices.push_back(COps[Index]); - return ConstantExpr::getGetElementPtr( - I.getSourceElementType(), COps[0], Indices, I.isInBounds()); - })) + if (simplifyInstruction(I)) return true; if ((I.isInBounds() && canFoldInboundsGEP(I)) || IsGEPOffsetConstant(I)) { @@ -1525,11 +1527,8 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { } /// Simplify \p I if its operands are constants and update SimplifiedValues. -/// \p Evaluate is a callable specific to instruction type that evaluates the -/// instruction when all the operands are constants. -template <typename Callable> -bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) { - SmallVector<Constant *, 2> COps; +bool CallAnalyzer::simplifyInstruction(Instruction &I) { + SmallVector<Constant *> COps; for (Value *Op : I.operands()) { Constant *COp = dyn_cast<Constant>(Op); if (!COp) @@ -1538,7 +1537,7 @@ bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) { return false; COps.push_back(COp); } - auto *C = Evaluate(COps); + auto *C = ConstantFoldInstOperands(&I, COps, DL); if (!C) return false; SimplifiedValues[&I] = C; @@ -1568,9 +1567,7 @@ bool CallAnalyzer::simplifyIntrinsicCallIsConstant(CallBase &CB) { bool CallAnalyzer::visitBitCast(BitCastInst &I) { // Propagate constants through bitcasts. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getBitCast(COps[0], I.getType()); - })) + if (simplifyInstruction(I)) return true; // Track base/offsets through casts @@ -1590,9 +1587,7 @@ bool CallAnalyzer::visitBitCast(BitCastInst &I) { bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { // Propagate constants through ptrtoint. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getPtrToInt(COps[0], I.getType()); - })) + if (simplifyInstruction(I)) return true; // Track base/offset pairs when converted to a plain integer provided the @@ -1622,9 +1617,7 @@ bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { // Propagate constants through ptrtoint. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getIntToPtr(COps[0], I.getType()); - })) + if (simplifyInstruction(I)) return true; // Track base/offset pairs when round-tripped through a pointer without @@ -1647,9 +1640,7 @@ bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { bool CallAnalyzer::visitCastInst(CastInst &I) { // Propagate constants through casts. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getCast(I.getOpcode(), COps[0], I.getType()); - })) + if (simplifyInstruction(I)) return true; // Disable SROA in the face of arbitrary casts we don't explicitly list @@ -1855,7 +1846,7 @@ void InlineCostCallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) { // current threshold, but AutoFDO + ThinLTO currently relies on this // behavior to prevent inlining of hot callsites during ThinLTO // compile phase. - Threshold = HotCallSiteThreshold.getValue(); + Threshold = *HotCallSiteThreshold; } else if (isColdCallSite(Call, CallerBFI)) { LLVM_DEBUG(dbgs() << "Cold callsite.\n"); // Do not apply bonuses for a cold callsite including the @@ -1906,9 +1897,7 @@ void InlineCostCallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) { bool CallAnalyzer::visitCmpInst(CmpInst &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); // First try to handle simplified comparisons. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getCompare(I.getPredicate(), COps[0], COps[1]); - })) + if (simplifyInstruction(I)) return true; if (I.getOpcode() == Instruction::FCmp) @@ -1984,11 +1973,11 @@ bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) { Value *SimpleV = nullptr; if (auto FI = dyn_cast<FPMathOperator>(&I)) - SimpleV = SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, + SimpleV = simplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, FI->getFastMathFlags(), DL); else SimpleV = - SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, DL); + simplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, DL); if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) SimplifiedValues[&I] = C; @@ -2018,7 +2007,7 @@ bool CallAnalyzer::visitFNeg(UnaryOperator &I) { if (!COp) COp = SimplifiedValues.lookup(Op); - Value *SimpleV = SimplifyFNegInst( + Value *SimpleV = simplifyFNegInst( COp ? COp : Op, cast<FPMathOperator>(I).getFastMathFlags(), DL); if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) @@ -2067,9 +2056,7 @@ bool CallAnalyzer::visitStore(StoreInst &I) { bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { // Constant folding for extract value is trivial. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getExtractValue(COps[0], I.getIndices()); - })) + if (simplifyInstruction(I)) return true; // SROA can't look through these, but they may be free. @@ -2078,11 +2065,7 @@ bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { bool CallAnalyzer::visitInsertValue(InsertValueInst &I) { // Constant folding for insert value is trivial. - if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { - return ConstantExpr::getInsertValue(/*AggregateOperand*/ COps[0], - /*InsertedValueOperand*/ COps[1], - I.getIndices()); - })) + if (simplifyInstruction(I)) return true; // SROA can't look through these, but they may be free. @@ -2136,14 +2119,14 @@ bool CallAnalyzer::visitCallBase(CallBase &Call) { if (isa<CallInst>(Call) && cast<CallInst>(Call).cannotDuplicate()) ContainsNoDuplicateCall = true; - Value *Callee = Call.getCalledOperand(); - Function *F = dyn_cast_or_null<Function>(Callee); + Function *F = Call.getCalledFunction(); bool IsIndirectCall = !F; if (IsIndirectCall) { // Check if this happens to be an indirect function call to a known function // in this inline context. If not, we've done all we can. + Value *Callee = Call.getCalledOperand(); F = dyn_cast_or_null<Function>(SimplifiedValues.lookup(Callee)); - if (!F) { + if (!F || F->getFunctionType() != Call.getFunctionType()) { onCallArgumentSetup(Call); if (!Call.onlyReadsMemory()) @@ -2552,7 +2535,7 @@ void CallAnalyzer::findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB) { NewDead.push_back(Succ); while (!NewDead.empty()) { BasicBlock *Dead = NewDead.pop_back_val(); - if (DeadBlocks.insert(Dead)) + if (DeadBlocks.insert(Dead).second) // Continue growing the dead block lists. for (BasicBlock *S : successors(Dead)) if (IsNewlyDead(S)) @@ -2707,6 +2690,11 @@ InlineResult CallAnalyzer::analyze() { if (!OnlyOneCallAndLocalLinkage && ContainsNoDuplicateCall) return InlineResult::failure("noduplicate"); + // If the callee's stack size exceeds the user-specified threshold, + // do not let it be inlined. + if (AllocatedSize > StackSizeThreshold) + return InlineResult::failure("stacksize"); + return finalizeAnalysis(); } @@ -2745,7 +2733,8 @@ static bool functionsHaveCompatibleAttributes( // object, and always returns the same object (which is overwritten on each // GetTLI call). Therefore we copy the first result. auto CalleeTLI = GetTLI(*Callee); - return TTI.areInlineCompatible(Caller, Callee) && + return (IgnoreTTIInlineCompatible || + TTI.areInlineCompatible(Caller, Callee)) && GetTLI(*Caller).areInlineCompatible(CalleeTLI, InlineCallerSupersetNoBuiltin) && AttributeFuncs::areInlineCompatible(*Caller, *Callee); @@ -2864,6 +2853,9 @@ Optional<InlineResult> llvm::getAttributeBasedInliningDecision( // Calls to functions with always-inline attributes should be inlined // whenever possible. if (Call.hasFnAttr(Attribute::AlwaysInline)) { + if (Call.getAttributes().hasFnAttr(Attribute::NoInline)) + return InlineResult::failure("noinline call site attribute"); + auto IsViable = isInlineViable(*Callee); if (IsViable.isSuccess()) return InlineResult::success(); @@ -2911,7 +2903,7 @@ InlineCost llvm::getInlineCost( auto UserDecision = llvm::getAttributeBasedInliningDecision(Call, Callee, CalleeTTI, GetTLI); - if (UserDecision.hasValue()) { + if (UserDecision) { if (UserDecision->isSuccess()) return llvm::InlineCost::getAlways("always inline attribute"); return llvm::InlineCost::getNever(UserDecision->getFailureReason()); diff --git a/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp b/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp index a2e231e2d0f4..2371ecbba615 100644 --- a/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp +++ b/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp @@ -15,33 +15,32 @@ #ifdef LLVM_HAVE_TF_API #include "llvm/Analysis/Utils/TFUtils.h" #endif +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +AnalysisKey InlineSizeEstimatorAnalysis::Key; + +#ifdef LLVM_HAVE_TF_API #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/PassManager.h" #include "llvm/MC/MCAsmLayout.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" - #include <algorithm> #include <deque> -using namespace llvm; - -AnalysisKey InlineSizeEstimatorAnalysis::Key; - -#define DEBUG_TYPE "inline-size-estimator" - -#ifdef LLVM_HAVE_TF_API cl::opt<std::string> TFIR2NativeModelPath( "ml-inliner-ir2native-model", cl::Hidden, cl::desc("Path to saved model evaluating native size from IR.")); +#define DEBUG_TYPE "inline-size-estimator" namespace { unsigned getMaxInstructionID() { #define LAST_OTHER_INST(NR) return NR; @@ -261,10 +260,10 @@ InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis( namespace llvm { class TFModelEvaluator {}; } // namespace llvm -InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis() {} +InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis() = default; InlineSizeEstimatorAnalysis ::InlineSizeEstimatorAnalysis( InlineSizeEstimatorAnalysis &&) {} -InlineSizeEstimatorAnalysis::~InlineSizeEstimatorAnalysis() {} +InlineSizeEstimatorAnalysis::~InlineSizeEstimatorAnalysis() = default; InlineSizeEstimatorAnalysis::Result InlineSizeEstimatorAnalysis::run(const Function &F, FunctionAnalysisManager &FAM) { diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 4775340b3438..013e4d6489fa 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" @@ -36,13 +35,10 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" -#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include <algorithm> using namespace llvm; @@ -52,28 +48,30 @@ using namespace llvm::PatternMatch; enum { RecursionLimit = 3 }; -STATISTIC(NumExpand, "Number of expansions"); +STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); -static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyAndInst(Value *, Value *, const SimplifyQuery &, + unsigned); static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned); static Value *simplifyFPUnOp(unsigned, Value *, const FastMathFlags &, const SimplifyQuery &, unsigned); -static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, +static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, +static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, const SimplifyQuery &, unsigned); -static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, +static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse); -static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyCastInst(unsigned, Value *, Type *, - const SimplifyQuery &, unsigned); -static Value *SimplifyGEPInst(Type *, Value *, ArrayRef<Value *>, bool, +static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &, + unsigned); +static Value *simplifyCastInst(unsigned, Value *, Type *, const SimplifyQuery &, + unsigned); +static Value *simplifyGEPInst(Type *, Value *, ArrayRef<Value *>, bool, const SimplifyQuery &, unsigned); -static Value *SimplifySelectInst(Value *, Value *, Value *, +static Value *simplifySelectInst(Value *, Value *, Value *, const SimplifyQuery &, unsigned); static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, @@ -120,15 +118,11 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, /// For a boolean type or a vector of boolean type, return false or a vector /// with every element false. -static Constant *getFalse(Type *Ty) { - return ConstantInt::getFalse(Ty); -} +static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); } /// For a boolean type or a vector of boolean type, return true or a vector /// with every element true. -static Constant *getTrue(Type *Ty) { - return ConstantInt::getTrue(Ty); -} +static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); } /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, @@ -141,7 +135,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, if (CPred == Pred && CLHS == LHS && CRHS == RHS) return true; return CPred == CmpInst::getSwappedPredicate(Pred) && CLHS == RHS && - CRHS == LHS; + CRHS == LHS; } /// Simplify comparison with true or false branch of select: @@ -153,7 +147,7 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS, Value *RHS, Value *Cond, const SimplifyQuery &Q, unsigned MaxRecurse, Constant *TrueOrFalse) { - Value *SimplifiedCmp = SimplifyCmpInst(Pred, LHS, RHS, Q, MaxRecurse); + Value *SimplifiedCmp = simplifyCmpInst(Pred, LHS, RHS, Q, MaxRecurse); if (SimplifiedCmp == Cond) { // %cmp simplified to the select condition (%cond). return TrueOrFalse; @@ -196,17 +190,17 @@ static Value *handleOtherCmpSelSimplifications(Value *TCmp, Value *FCmp, // checks whether folding it does not convert a well-defined value into // poison. if (match(FCmp, m_Zero()) && impliesPoison(TCmp, Cond)) - if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse)) + if (Value *V = simplifyAndInst(Cond, TCmp, Q, MaxRecurse)) return V; // If the true value simplified to true, then the result of the compare // is equal to "Cond || FCmp". if (match(TCmp, m_One()) && impliesPoison(FCmp, Cond)) - if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse)) + if (Value *V = simplifyOrInst(Cond, FCmp, Q, MaxRecurse)) return V; // Finally, if the false value simplified to true and the true value to // false, then the result of the compare is equal to "!Cond". if (match(FCmp, m_One()) && match(TCmp, m_Zero())) - if (Value *V = SimplifyXorInst( + if (Value *V = simplifyXorInst( Cond, Constant::getAllOnesValue(Cond->getType()), Q, MaxRecurse)) return V; return nullptr; @@ -248,12 +242,12 @@ static Value *expandBinOp(Instruction::BinaryOps Opcode, Value *V, if (!B || B->getOpcode() != OpcodeToExpand) return nullptr; Value *B0 = B->getOperand(0), *B1 = B->getOperand(1); - Value *L = SimplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), - MaxRecurse); + Value *L = + simplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), MaxRecurse); if (!L) return nullptr; - Value *R = SimplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), - MaxRecurse); + Value *R = + simplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), MaxRecurse); if (!R) return nullptr; @@ -265,7 +259,7 @@ static Value *expandBinOp(Instruction::BinaryOps Opcode, Value *V, } // Otherwise, return "L op' R" if it simplifies. - Value *S = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse); + Value *S = simplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse); if (!S) return nullptr; @@ -275,8 +269,8 @@ static Value *expandBinOp(Instruction::BinaryOps Opcode, Value *V, /// Try to simplify binops of form "A op (B op' C)" or the commuted variant by /// distributing op over op'. -static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, - Value *L, Value *R, +static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, Value *L, + Value *R, Instruction::BinaryOps OpcodeToExpand, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -293,7 +287,7 @@ static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, +static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -313,12 +307,13 @@ static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *C = RHS; // Does "B op C" simplify? - if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { // It does! Return "A op V" if it simplifies or is already available. // If V equals B then "A op V" is just the LHS. - if (V == B) return LHS; + if (V == B) + return LHS; // Otherwise return "A op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { ++NumReassoc; return W; } @@ -332,12 +327,13 @@ static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { // It does! Return "V op C" if it simplifies or is already available. // If V equals B then "V op C" is just the RHS. - if (V == B) return RHS; + if (V == B) + return RHS; // Otherwise return "V op C" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { ++NumReassoc; return W; } @@ -355,12 +351,13 @@ static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *C = RHS; // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { // It does! Return "V op B" if it simplifies or is already available. // If V equals A then "V op B" is just the LHS. - if (V == A) return LHS; + if (V == A) + return LHS; // Otherwise return "V op B" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { ++NumReassoc; return W; } @@ -374,12 +371,13 @@ static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *C = Op1->getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { // It does! Return "B op V" if it simplifies or is already available. // If V equals C then "B op V" is just the RHS. - if (V == C) return RHS; + if (V == C) + return RHS; // Otherwise return "B op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { ++NumReassoc; return W; } @@ -393,7 +391,7 @@ static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, /// otherwise returns null. -static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, +static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. @@ -412,11 +410,11 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, Value *TV; Value *FV; if (SI == LHS) { - TV = SimplifyBinOp(Opcode, SI->getTrueValue(), RHS, Q, MaxRecurse); - FV = SimplifyBinOp(Opcode, SI->getFalseValue(), RHS, Q, MaxRecurse); + TV = simplifyBinOp(Opcode, SI->getTrueValue(), RHS, Q, MaxRecurse); + FV = simplifyBinOp(Opcode, SI->getFalseValue(), RHS, Q, MaxRecurse); } else { - TV = SimplifyBinOp(Opcode, LHS, SI->getTrueValue(), Q, MaxRecurse); - FV = SimplifyBinOp(Opcode, LHS, SI->getFalseValue(), Q, MaxRecurse); + TV = simplifyBinOp(Opcode, LHS, SI->getTrueValue(), Q, MaxRecurse); + FV = simplifyBinOp(Opcode, LHS, SI->getFalseValue(), Q, MaxRecurse); } // If they simplified to the same value, then return the common value. @@ -471,7 +469,7 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, /// We can simplify %cmp1 to true, because both branches of select are /// less than 3. We compose new comparison by substituting %tmp with both /// branches of select and see if it can be simplified. -static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, +static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. @@ -517,7 +515,7 @@ static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, /// try to simplify the binop by seeing whether evaluating it on the incoming /// phi values yields the same result for every value. If so returns the common /// value, otherwise returns null. -static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, +static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. @@ -542,10 +540,10 @@ static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, Value *CommonValue = nullptr; for (Value *Incoming : PI->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PI) continue; - Value *V = PI == LHS ? - SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) : - SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); + if (Incoming == PI) + continue; + Value *V = PI == LHS ? simplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) + : simplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); // If the operation failed to simplify, or simplified to a different value // to previously, then give up. if (!V || (CommonValue && V != CommonValue)) @@ -560,7 +558,7 @@ static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, /// comparison by seeing whether comparing with all of the incoming phi values /// yields the same result every time. If so returns the common result, /// otherwise returns null. -static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, +static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) @@ -584,11 +582,12 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, Value *Incoming = PI->getIncomingValue(u); Instruction *InTI = PI->getIncomingBlock(u)->getTerminator(); // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PI) continue; + if (Incoming == PI) + continue; // Change the context instruction to the "edge" that flows into the phi. // This is important because that is where incoming is actually "evaluated" // even though it is used later somewhere else. - Value *V = SimplifyCmpInst(Pred, Incoming, RHS, Q.getWithInstruction(InTI), + Value *V = simplifyCmpInst(Pred, Incoming, RHS, Q.getWithInstruction(InTI), MaxRecurse); // If the operation failed to simplify, or simplified to a different value // to previously, then give up. @@ -604,8 +603,20 @@ static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, Value *&Op0, Value *&Op1, const SimplifyQuery &Q) { if (auto *CLHS = dyn_cast<Constant>(Op0)) { - if (auto *CRHS = dyn_cast<Constant>(Op1)) + if (auto *CRHS = dyn_cast<Constant>(Op1)) { + switch (Opcode) { + default: + break; + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + if (Q.CxtI != nullptr) + return ConstantFoldFPInstOperands(Opcode, CLHS, CRHS, Q.DL, Q.CxtI); + } return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + } // Canonicalize the constant to the RHS if this is a commutative operation. if (Instruction::isCommutative(Opcode)) @@ -616,7 +627,7 @@ static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, +static Value *simplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -647,8 +658,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, // X + ~X -> -1 since ~X = -X-1 Type *Ty = Op0->getType(); - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Ty); // add nsw/nuw (xor Y, signmask), signmask --> Y @@ -664,12 +674,12 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, /// i1 add -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = simplifyXorInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + simplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, MaxRecurse)) return V; // Threading Add over selects and phi nodes is pointless, so don't bother. @@ -684,45 +694,37 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, return nullptr; } -Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, +Value *llvm::simplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Query) { - return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); + return ::simplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); } /// Compute the base pointer and cumulative constant offsets for V. /// /// This strips all constant offsets off of V, leaving it the base pointer, and -/// accumulates the total constant offset applied in the returned constant. It -/// returns 0 if V is not a pointer, and returns the constant '0' if there are -/// no constant offsets applied. +/// accumulates the total constant offset applied in the returned constant. +/// It returns zero if there are no constant offsets applied. /// -/// This is very similar to GetPointerBaseWithConstantOffset except it doesn't -/// follow non-inbounds geps. This allows it to remain usable for icmp ult/etc. -/// folding. -static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, - bool AllowNonInbounds = false) { +/// This is very similar to stripAndAccumulateConstantOffsets(), except it +/// normalizes the offset bitwidth to the stripped pointer type, not the +/// original pointer type. +static APInt stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, + bool AllowNonInbounds = false) { assert(V->getType()->isPtrOrPtrVectorTy()); APInt Offset = APInt::getZero(DL.getIndexTypeSizeInBits(V->getType())); - V = V->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds); // As that strip may trace through `addrspacecast`, need to sext or trunc // the offset calculated. - Type *IntIdxTy = DL.getIndexType(V->getType())->getScalarType(); - Offset = Offset.sextOrTrunc(IntIdxTy->getIntegerBitWidth()); - - Constant *OffsetIntPtr = ConstantInt::get(IntIdxTy, Offset); - if (VectorType *VecTy = dyn_cast<VectorType>(V->getType())) - return ConstantVector::getSplat(VecTy->getElementCount(), OffsetIntPtr); - return OffsetIntPtr; + return Offset.sextOrTrunc(DL.getIndexTypeSizeInBits(V->getType())); } /// Compute the constant difference between two pointer values. /// If the difference is not a constant, returns zero. static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, Value *RHS) { - Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); - Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + APInt LHSOffset = stripAndComputeConstantOffsets(DL, LHS); + APInt RHSOffset = stripAndComputeConstantOffsets(DL, RHS); // If LHS and RHS are not related via constant offsets to the same base // value, there is nothing we can do here. @@ -733,12 +735,15 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, // LHS - RHS // = (LHSOffset + Base) - (RHSOffset + Base) // = LHSOffset - RHSOffset - return ConstantExpr::getSub(LHSOffset, RHSOffset); + Constant *Res = ConstantInt::get(LHS->getContext(), LHSOffset - RHSOffset); + if (auto *VecTy = dyn_cast<VectorType>(LHS->getType())) + Res = ConstantVector::getSplat(VecTy->getElementCount(), Res); + return Res; } /// Given operands for a Sub, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +static Value *simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) return C; @@ -784,17 +789,17 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *X = nullptr, *Y = nullptr, *Z = Op1; if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z // See if "V === Y - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse - 1)) // It does! Now see if "X + V" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse-1)) { + if (Value *W = simplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; } // See if "V === X - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse - 1)) // It does! Now see if "Y + V" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse-1)) { + if (Value *W = simplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -806,17 +811,17 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, X = Op0; if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z) // See if "V === X - Y" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse - 1)) // It does! Now see if "V - Z" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse-1)) { + if (Value *W = simplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; } // See if "V === X - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse - 1)) // It does! Now see if "V - Y" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse-1)) { + if (Value *W = simplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -828,9 +833,9 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Z = Op0; if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y) // See if "V === Z - X" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse - 1)) // It does! Now see if "V + Y" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse-1)) { + if (Value *W = simplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -841,22 +846,21 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, match(Op1, m_Trunc(m_Value(Y)))) if (X->getType() == Y->getType()) // See if "V === X - Y" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + if (Value *V = simplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse - 1)) // It does! Now see if "trunc V" simplifies. - if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), + if (Value *W = simplifyCastInst(Instruction::Trunc, V, Op0->getType(), Q, MaxRecurse - 1)) // It does, return the simplified "trunc V". return W; // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...). - if (match(Op0, m_PtrToInt(m_Value(X))) && - match(Op1, m_PtrToInt(m_Value(Y)))) + if (match(Op0, m_PtrToInt(m_Value(X))) && match(Op1, m_PtrToInt(m_Value(Y)))) if (Constant *Result = computePointerDifference(Q.DL, X, Y)) return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); // i1 sub -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = simplifyXorInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Threading Sub over selects and phi nodes is pointless, so don't bother. @@ -871,14 +875,14 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return nullptr; } -Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *llvm::simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); + return ::simplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for a Mul, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) return C; @@ -906,12 +910,12 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // i1 mul -> and. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = simplifyAndInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + simplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; // Mul distributes over Add. Try some generic simplifications based on this. @@ -922,22 +926,22 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + threadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + threadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; return nullptr; } -Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyMulInst(Op0, Op1, Q, RecursionLimit); } /// Check for common or similar folds of integer division or integer remainder. @@ -1026,7 +1030,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, /// when we can prove a relationship between the operands. static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { - Value *V = SimplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); + Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); Constant *C = dyn_cast_or_null<Constant>(V); return (C && C->isAllOnesValue()); } @@ -1122,13 +1126,13 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) return V; if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned)) @@ -1164,13 +1168,13 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) return V; // If X / Y == 0, then X % Y == X. @@ -1182,7 +1186,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// Given operands for an SDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { // If two operands are negated and no signed overflow, return -1. if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) @@ -1191,24 +1195,24 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); } -Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifySDivInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a UDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse); } -Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyUDivInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for an SRem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { // If the divisor is 0, the result is undefined, so assume the divisor is -1. // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 @@ -1223,19 +1227,19 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse); } -Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifySRemInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a URem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { return simplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse); } -Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyURemInst(Op0, Op1, Q, RecursionLimit); } /// Returns true if a shift by \c Amount always yields poison. @@ -1268,7 +1272,7 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, +static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool IsNSW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) @@ -1297,13 +1301,13 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) return V; // If any bits in the shift amount make that value greater than or equal to @@ -1338,11 +1342,11 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, /// Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. -static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, - Value *Op1, bool isExact, const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = - SimplifyShift(Opcode, Op0, Op1, /*IsNSW*/ false, Q, MaxRecurse)) + simplifyShift(Opcode, Op0, Op1, /*IsNSW*/ false, Q, MaxRecurse)) return V; // X >> X -> 0 @@ -1356,7 +1360,8 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, // The low bit cannot be shifted out of an exact shift if it is set. if (isExact) { - KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + KnownBits Op0Known = + computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); if (Op0Known.One[0]) return Op0; } @@ -1366,10 +1371,10 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, /// Given operands for an Shl, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +static Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = - SimplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse)) + simplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse)) return V; // undef << X -> 0 @@ -1392,18 +1397,18 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return nullptr; } -Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *llvm::simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); + return ::simplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for an LShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, +static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, + if (Value *V = simplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, MaxRecurse)) - return V; + return V; // (X << A) >> A -> X Value *X; @@ -1429,16 +1434,16 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, return nullptr; } -Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, +Value *llvm::simplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); + return ::simplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); } /// Given operands for an AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, +static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, + if (Value *V = simplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, MaxRecurse)) return V; @@ -1462,9 +1467,9 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, return nullptr; } -Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, +Value *llvm::simplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); + return ::simplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); } /// Commuted variants are assumed to be handled by calling this function again @@ -1581,7 +1586,7 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, /// with the parameters swapped. static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; + Value *A, *B; if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) return nullptr; @@ -1606,7 +1611,7 @@ static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { /// with the parameters swapped. static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; + Value *A, *B; if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) return nullptr; @@ -1812,6 +1817,27 @@ static Value *simplifyAndOrOfICmpsWithLimitConst(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; } +/// Try to simplify and/or of icmp with ctpop intrinsic. +static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + ICmpInst::Predicate Pred0, Pred1; + Value *X; + const APInt *C; + if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)), + m_APInt(C))) || + !match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt())) || C->isZero()) + return nullptr; + + // (ctpop(X) == C) || (X != 0) --> X != 0 where C > 0 + if (!IsAnd && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_NE) + return Cmp1; + // (ctpop(X) != C) && (X == 0) --> X == 0 where C > 0 + if (IsAnd && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_EQ) + return Cmp1; + + return nullptr; +} + static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1, const SimplifyQuery &Q) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true, Q)) @@ -1833,6 +1859,11 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1, if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) return X; + if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, true)) + return X; + if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, true)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1, Q.IIQ)) return X; if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0, Q.IIQ)) @@ -1909,6 +1940,11 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) return X; + if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, false)) + return X; + if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, false)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1, Q.IIQ)) return X; if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0, Q.IIQ)) @@ -1917,8 +1953,8 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, return nullptr; } -static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, - FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { +static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, FCmpInst *LHS, + FCmpInst *RHS, bool IsAnd) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if (LHS0->getType() != RHS0->getType()) @@ -1955,8 +1991,8 @@ static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, return nullptr; } -static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, - Value *Op0, Value *Op1, bool IsAnd) { +static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, Value *Op0, + Value *Op1, bool IsAnd) { // Look through casts of the 'and' operands to find compares. auto *Cast0 = dyn_cast<CastInst>(Op0); auto *Cast1 = dyn_cast<CastInst>(Op1); @@ -2017,7 +2053,7 @@ static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, /// Given operands for an And, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) return C; @@ -2043,8 +2079,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op0; // A & ~A = ~A & A = 0 - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getNullValue(Op0->getType()); // (A | ?) & A = A @@ -2117,8 +2152,8 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + simplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; // And distributes over Or. Try some generic simplifications based on this. @@ -2142,16 +2177,16 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If the operation is with the result of a select instruction, check // whether operating on either branch of the select always yields the same // value. - if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + threadBinOpOverSelect(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; } // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + threadBinOpOverPHI(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; // Assuming the effective width of Y is not larger than A, i.e. all bits @@ -2174,8 +2209,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); const unsigned EffWidthY = YKnown.countMaxActiveBits(); if (EffWidthY <= ShftCnt) { - const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, - Q.DT); + const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); const unsigned EffWidthX = XKnown.countMaxActiveBits(); const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; @@ -2197,11 +2231,20 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op1, m_c_Xor(m_Specific(Or), m_Specific(Y)))) return Constant::getNullValue(Op0->getType()); + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // Op0&Op1 -> Op0 where Op0 implies Op1 + if (isImpliedCondition(Op0, Op1, Q.DL).value_or(false)) + return Op0; + // Op0&Op1 -> Op1 where Op1 implies Op0 + if (isImpliedCondition(Op1, Op0, Q.DL).value_or(false)) + return Op1; + } + return nullptr; } -Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyAndInst(Op0, Op1, Q, RecursionLimit); } static Value *simplifyOrLogic(Value *X, Value *Y) { @@ -2289,7 +2332,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) { /// Given operands for an Or, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) return C; @@ -2334,6 +2377,31 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, } } + // A funnel shift (rotate) can be decomposed into simpler shifts. See if we + // are mixing in another shift that is redundant with the funnel shift. + + // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y + // (shl X, Y) | (fshl X, ?, Y) --> fshl X, ?, Y + if (match(Op0, + m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), m_Value(Y))) && + match(Op1, m_Shl(m_Specific(X), m_Specific(Y)))) + return Op0; + if (match(Op1, + m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), m_Value(Y))) && + match(Op0, m_Shl(m_Specific(X), m_Specific(Y)))) + return Op1; + + // (fshr ?, X, Y) | (lshr X, Y) --> fshr ?, X, Y + // (lshr X, Y) | (fshr ?, X, Y) --> fshr ?, X, Y + if (match(Op0, + m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), m_Value(Y))) && + match(Op1, m_LShr(m_Specific(X), m_Specific(Y)))) + return Op0; + if (match(Op1, + m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), m_Value(Y))) && + match(Op0, m_LShr(m_Specific(X), m_Specific(Y)))) + return Op1; + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) return V; @@ -2346,8 +2414,8 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op0; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + simplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; // Or distributes over And. Try some generic simplifications based on this. @@ -2366,8 +2434,8 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If the operation is with the result of a select instruction, check // whether operating on either branch of the select always yields the same // value. - if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + threadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; } @@ -2389,8 +2457,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return A; } // Or commutes, try both ways. - if (C1->isMask() && - match(B, m_c_Add(m_Specific(A), m_Value(N)))) { + if (C1->isMask() && match(B, m_c_Add(m_Specific(A), m_Value(N)))) { // Add commutes, try both ways. if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; @@ -2401,19 +2468,28 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = ThreadBinOpOverPHI(Instruction::Or, Op0, Op1, Q, MaxRecurse)) + if (Value *V = threadBinOpOverPHI(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // Op0|Op1 -> Op1 where Op0 implies Op1 + if (isImpliedCondition(Op0, Op1, Q.DL).value_or(false)) + return Op1; + // Op0|Op1 -> Op0 where Op1 implies Op0 + if (isImpliedCondition(Op1, Op0, Q.DL).value_or(false)) + return Op0; + } + return nullptr; } -Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyOrInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a Xor, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +static Value *simplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) return C; @@ -2435,8 +2511,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Constant::getNullValue(Op0->getType()); // A ^ ~A = ~A ^ A = -1 - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); auto foldAndOrNot = [](Value *X, Value *Y) -> Value * { @@ -2467,8 +2542,8 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + simplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, MaxRecurse)) return V; // Threading Xor over selects and phi nodes is pointless, so don't bother. @@ -2483,19 +2558,18 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return nullptr; } -Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::simplifyXorInst(Op0, Op1, Q, RecursionLimit); } - -static Type *GetCompareTy(Value *Op) { +static Type *getCompareTy(Value *Op) { return CmpInst::makeCmpResultType(Op->getType()); } /// Rummage around inside V looking for something equivalent to the comparison /// "LHS Pred RHS". Return such a value if found, otherwise return null. /// Helper function for analyzing max/min idioms. -static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, +static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred, Value *LHS, Value *RHS) { SelectInst *SI = dyn_cast<SelectInst>(V); if (!SI) @@ -2512,6 +2586,70 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, return nullptr; } +/// Return true if the underlying object (storage) must be disjoint from +/// storage returned by any noalias return call. +static bool isAllocDisjoint(const Value *V) { + // For allocas, we consider only static ones (dynamic + // allocas might be transformed into calls to malloc not simultaneously + // live with the compared-to allocation). For globals, we exclude symbols + // that might be resolve lazily to symbols in another dynamically-loaded + // library (and, thus, could be malloc'ed by the implementation). + if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) + return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) + return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || + GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && + !GV->isThreadLocal(); + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasByValAttr(); + return false; +} + +/// Return true if V1 and V2 are each the base of some distict storage region +/// [V, object_size(V)] which do not overlap. Note that zero sized regions +/// *are* possible, and that zero sized regions do not overlap with any other. +static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) { + // Global variables always exist, so they always exist during the lifetime + // of each other and all allocas. Global variables themselves usually have + // non-overlapping storage, but since their addresses are constants, the + // case involving two globals does not reach here and is instead handled in + // constant folding. + // + // Two different allocas usually have different addresses... + // + // However, if there's an @llvm.stackrestore dynamically in between two + // allocas, they may have the same address. It's tempting to reduce the + // scope of the problem by only looking at *static* allocas here. That would + // cover the majority of allocas while significantly reducing the likelihood + // of having an @llvm.stackrestore pop up in the middle. However, it's not + // actually impossible for an @llvm.stackrestore to pop up in the middle of + // an entry block. Also, if we have a block that's not attached to a + // function, we can't tell if it's "static" under the current definition. + // Theoretically, this problem could be fixed by creating a new kind of + // instruction kind specifically for static allocas. Such a new instruction + // could be required to be at the top of the entry block, thus preventing it + // from being subject to a @llvm.stackrestore. Instcombine could even + // convert regular allocas into these special allocas. It'd be nifty. + // However, until then, this problem remains open. + // + // So, we'll assume that two non-empty allocas have different addresses + // for now. + auto isByValArg = [](const Value *V) { + const Argument *A = dyn_cast<Argument>(V); + return A && A->hasByValAttr(); + }; + + // Byval args are backed by store which does not overlap with each other, + // allocas, or globals. + if (isByValArg(V1)) + return isa<AllocaInst>(V2) || isa<GlobalVariable>(V2) || isByValArg(V2); + if (isByValArg(V2)) + return isa<AllocaInst>(V1) || isa<GlobalVariable>(V1) || isByValArg(V1); + + return isa<AllocaInst>(V1) && + (isa<AllocaInst>(V2) || isa<GlobalVariable>(V2)); +} + // A significant optimization not implemented here is assuming that alloca // addresses are not equal to incoming argument values. They don't *alias*, // as we say, but that doesn't mean they aren't equal, so we take a @@ -2540,9 +2678,8 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, // If the C and C++ standards are ever made sufficiently restrictive in this // area, it may be possible to update LLVM's semantics accordingly and reinstate // this optimization. -static Constant * -computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, - const SimplifyQuery &Q) { +static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q) { const DataLayout &DL = Q.DL; const TargetLibraryInfo *TLI = Q.TLI; const DominatorTree *DT = Q.DT; @@ -2557,8 +2694,7 @@ computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, if (isa<ConstantPointerNull>(RHS) && ICmpInst::isEquality(Pred) && llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, IIQ.UseInstrInfo)) - return ConstantInt::get(GetCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); + return ConstantInt::get(getCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); // We can only fold certain predicates on pointer comparisons. switch (Pred) { @@ -2588,88 +2724,47 @@ computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, // numerous hazards. AliasAnalysis and its utilities rely on special rules // governing loads and stores which don't apply to icmps. Also, AliasAnalysis // doesn't need to guarantee pointer inequality when it says NoAlias. - Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); - Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + + // Even if an non-inbounds GEP occurs along the path we can still optimize + // equality comparisons concerning the result. + bool AllowNonInbounds = ICmpInst::isEquality(Pred); + APInt LHSOffset = stripAndComputeConstantOffsets(DL, LHS, AllowNonInbounds); + APInt RHSOffset = stripAndComputeConstantOffsets(DL, RHS, AllowNonInbounds); // If LHS and RHS are related via constant offsets to the same base // value, we can replace it with an icmp which just compares the offsets. if (LHS == RHS) - return ConstantExpr::getICmp(Pred, LHSOffset, RHSOffset); + return ConstantInt::get(getCompareTy(LHS), + ICmpInst::compare(LHSOffset, RHSOffset, Pred)); // Various optimizations for (in)equality comparisons. if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) { // Different non-empty allocations that exist at the same time have - // different addresses (if the program can tell). Global variables always - // exist, so they always exist during the lifetime of each other and all - // allocas. Two different allocas usually have different addresses... - // - // However, if there's an @llvm.stackrestore dynamically in between two - // allocas, they may have the same address. It's tempting to reduce the - // scope of the problem by only looking at *static* allocas here. That would - // cover the majority of allocas while significantly reducing the likelihood - // of having an @llvm.stackrestore pop up in the middle. However, it's not - // actually impossible for an @llvm.stackrestore to pop up in the middle of - // an entry block. Also, if we have a block that's not attached to a - // function, we can't tell if it's "static" under the current definition. - // Theoretically, this problem could be fixed by creating a new kind of - // instruction kind specifically for static allocas. Such a new instruction - // could be required to be at the top of the entry block, thus preventing it - // from being subject to a @llvm.stackrestore. Instcombine could even - // convert regular allocas into these special allocas. It'd be nifty. - // However, until then, this problem remains open. - // - // So, we'll assume that two non-empty allocas have different addresses - // for now. - // - // With all that, if the offsets are within the bounds of their allocations - // (and not one-past-the-end! so we can't use inbounds!), and their - // allocations aren't the same, the pointers are not equal. - // - // Note that it's not necessary to check for LHS being a global variable - // address, due to canonicalization and constant folding. - if (isa<AllocaInst>(LHS) && - (isa<AllocaInst>(RHS) || isa<GlobalVariable>(RHS))) { - ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset); - ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset); + // different addresses (if the program can tell). If the offsets are + // within the bounds of their allocations (and not one-past-the-end! + // so we can't use inbounds!), and their allocations aren't the same, + // the pointers are not equal. + if (haveNonOverlappingStorage(LHS, RHS)) { uint64_t LHSSize, RHSSize; ObjectSizeOpts Opts; - Opts.NullIsUnknownSize = - NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction()); - if (LHSOffsetCI && RHSOffsetCI && - getObjectSize(LHS, LHSSize, DL, TLI, Opts) && - getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { - const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); - const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); - if (!LHSOffsetValue.isNegative() && - !RHSOffsetValue.isNegative() && - LHSOffsetValue.ult(LHSSize) && - RHSOffsetValue.ult(RHSSize)) { - return ConstantInt::get(GetCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); - } - } - - // Repeat the above check but this time without depending on DataLayout - // or being able to compute a precise size. - if (!cast<PointerType>(LHS->getType())->isEmptyTy() && - !cast<PointerType>(RHS->getType())->isEmptyTy() && - LHSOffset->isNullValue() && - RHSOffset->isNullValue()) - return ConstantInt::get(GetCompareTy(LHS), + Opts.EvalMode = ObjectSizeOpts::Mode::Min; + auto *F = [](Value *V) -> Function * { + if (auto *I = dyn_cast<Instruction>(V)) + return I->getFunction(); + if (auto *A = dyn_cast<Argument>(V)) + return A->getParent(); + return nullptr; + }(LHS); + Opts.NullIsUnknownSize = F ? NullPointerIsDefined(F) : true; + if (getObjectSize(LHS, LHSSize, DL, TLI, Opts) && + getObjectSize(RHS, RHSSize, DL, TLI, Opts) && + !LHSOffset.isNegative() && !RHSOffset.isNegative() && + LHSOffset.ult(LHSSize) && RHSOffset.ult(RHSSize)) { + return ConstantInt::get(getCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); + } } - // Even if an non-inbounds GEP occurs along the path we can still optimize - // equality comparisons concerning the result. We avoid walking the whole - // chain again by starting where the last calls to - // stripAndComputeConstantOffsets left off and accumulate the offsets. - Constant *LHSNoBound = stripAndComputeConstantOffsets(DL, LHS, true); - Constant *RHSNoBound = stripAndComputeConstantOffsets(DL, RHS, true); - if (LHS == RHS) - return ConstantExpr::getICmp(Pred, - ConstantExpr::getAdd(LHSOffset, LHSNoBound), - ConstantExpr::getAdd(RHSOffset, RHSNoBound)); - // If one side of the equality comparison must come from a noalias call // (meaning a system memory allocation function), and the other side must // come from a pointer that cannot overlap with dynamically-allocated @@ -2685,29 +2780,16 @@ computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, }; // Is the set of underlying objects all things which must be disjoint from - // noalias calls. For allocas, we consider only static ones (dynamic - // allocas might be transformed into calls to malloc not simultaneously - // live with the compared-to allocation). For globals, we exclude symbols - // that might be resolve lazily to symbols in another dynamically-loaded - // library (and, thus, could be malloc'ed by the implementation). + // noalias calls. We assume that indexing from such disjoint storage + // into the heap is undefined, and thus offsets can be safely ignored. auto IsAllocDisjoint = [](ArrayRef<const Value *> Objects) { - return all_of(Objects, [](const Value *V) { - if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) - return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); - if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) - return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || - GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && - !GV->isThreadLocal(); - if (const Argument *A = dyn_cast<Argument>(V)) - return A->hasByValAttr(); - return false; - }); + return all_of(Objects, ::isAllocDisjoint); }; if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) || (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) - return ConstantInt::get(GetCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); + return ConstantInt::get(getCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); // Fold comparisons for non-escaping pointer even if the allocation call // cannot be elided. We cannot fold malloc comparison to null. Also, the @@ -2724,7 +2806,7 @@ computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, // FIXME: We should also fold the compare when the pointer escapes, but the // compare dominates the pointer escape if (MI && !PointerMayBeCaptured(MI, true, true)) - return ConstantInt::get(GetCompareTy(LHS), + return ConstantInt::get(getCompareTy(LHS), CmpInst::isFalseWhenEqual(Pred)); } @@ -2735,7 +2817,7 @@ computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, /// Fold an icmp when its operands have i1 scalar type. static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q) { - Type *ITy = GetCompareTy(LHS); // The return type. + Type *ITy = getCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. if (!OpTy->isIntOrIntVectorTy(1)) return nullptr; @@ -2773,7 +2855,8 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, case CmpInst::ICMP_SLE: // X <=s 0 -> true return getTrue(ITy); - default: break; + default: + break; } } else if (match(RHS, m_One())) { switch (Pred) { @@ -2797,7 +2880,8 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, case CmpInst::ICMP_SGE: // X >=s -1 -> true return getTrue(ITy); - default: break; + default: + break; } } @@ -2805,7 +2889,7 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, default: break; case ICmpInst::ICMP_UGE: - if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) + if (isImpliedCondition(RHS, LHS, Q.DL).value_or(false)) return getTrue(ITy); break; case ICmpInst::ICMP_SGE: @@ -2816,11 +2900,11 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, /// 0 | 1 | 1 (0 >= -1) | 1 /// 1 | 0 | 0 (-1 >= 0) | 0 /// 1 | 1 | 1 (-1 >= -1) | 1 - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + if (isImpliedCondition(LHS, RHS, Q.DL).value_or(false)) return getTrue(ITy); break; case ICmpInst::ICMP_ULE: - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + if (isImpliedCondition(LHS, RHS, Q.DL).value_or(false)) return getTrue(ITy); break; } @@ -2834,7 +2918,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, if (!match(RHS, m_Zero())) return nullptr; - Type *ITy = GetCompareTy(LHS); // The return type. + Type *ITy = getCompareTy(LHS); // The return type. switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); @@ -2893,7 +2977,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const InstrInfoQuery &IIQ) { - Type *ITy = GetCompareTy(RHS); // The return type. + Type *ITy = getCompareTy(RHS); // The return type. Value *X; // Sign-bit checks can be optimized to true/false after unsigned @@ -2940,10 +3024,11 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -static Value *simplifyICmpWithBinOpOnLHS( - CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { - Type *ITy = GetCompareTy(RHS); // The return type. +static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, + BinaryOperator *LBO, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + Type *ITy = getCompareTy(RHS); // The return type. Value *Y = nullptr; // icmp pred (or X, Y), X @@ -3078,7 +3163,6 @@ static Value *simplifyICmpWithBinOpOnLHS( return nullptr; } - // If only one of the icmp's operands has NSW flags, try to prove that: // // icmp slt (x + C1), (x +nsw C2) @@ -3113,7 +3197,6 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, (C2->slt(*C1) && C1->isNonPositive()); } - /// TODO: A large part of this logic is duplicated in InstCombine's /// foldICmpBinOp(). We should be able to share that and avoid the code /// duplication. @@ -3150,7 +3233,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == RHS || B == RHS) && NoLHSWrapProblem) - if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, + if (Value *V = simplifyICmpInst(Pred, A == RHS ? B : A, Constant::getNullValue(RHS->getType()), Q, MaxRecurse - 1)) return V; @@ -3158,7 +3241,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. if ((C == LHS || D == LHS) && NoRHSWrapProblem) if (Value *V = - SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), + simplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), C == LHS ? D : C, Q, MaxRecurse - 1)) return V; @@ -3186,7 +3269,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, Y = A; Z = C; } - if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) + if (Value *V = simplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) return V; } } @@ -3206,15 +3289,15 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, if (match(RHS, m_APInt(C))) { if (C->isStrictlyPositive()) { if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_NE) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(getCompareTy(RHS)); if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_EQ) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(getCompareTy(RHS)); } if (C->isNonNegative()) { if (Pred == ICmpInst::ICMP_SLE) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(getCompareTy(RHS)); if (Pred == ICmpInst::ICMP_SGT) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(getCompareTy(RHS)); } } } @@ -3237,9 +3320,9 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || match(LHS, m_Shl(m_One(), m_Value())) || !C->isZero()) { if (Pred == ICmpInst::ICMP_EQ) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(getCompareTy(RHS)); if (Pred == ICmpInst::ICMP_NE) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(getCompareTy(RHS)); } } @@ -3248,9 +3331,9 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // (1 << X) <=u 0x8000 --> true if (match(LHS, m_Shl(m_One(), m_Value())) && match(RHS, m_SignMask())) { if (Pred == ICmpInst::ICMP_UGT) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(getCompareTy(RHS)); if (Pred == ICmpInst::ICMP_ULE) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(getCompareTy(RHS)); } if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && @@ -3263,22 +3346,22 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, if (ICmpInst::isSigned(Pred) || !Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO)) break; - if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) - return V; + return V; break; case Instruction::SDiv: if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO)) break; - if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::AShr: if (!Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO)) break; - if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; @@ -3289,7 +3372,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, break; if (!NSW && ICmpInst::isSigned(Pred)) break; - if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; @@ -3299,12 +3382,12 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return nullptr; } -/// Simplify integer comparisons where at least one operand of the compare +/// simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { - Type *ITy = GetCompareTy(LHS); // The return type. + Type *ITy = getCompareTy(LHS); // The return type. Value *A, *B; CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". @@ -3349,13 +3432,13 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, case CmpInst::ICMP_SLE: // Equivalent to "A EqP B". This may be the same as the condition tested // in the max/min; if so, we can just return that. - if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + if (Value *V = extractEquivalentCondition(LHS, EqP, A, B)) return V; - if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + if (Value *V = extractEquivalentCondition(RHS, EqP, A, B)) return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) + if (Value *V = simplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -3363,13 +3446,13 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); // Equivalent to "A InvEqP B". This may be the same as the condition // tested in the max/min; if so, we can just return that. - if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + if (Value *V = extractEquivalentCondition(LHS, InvEqP, A, B)) return V; - if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + if (Value *V = extractEquivalentCondition(RHS, InvEqP, A, B)) return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) + if (Value *V = simplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -3423,13 +3506,13 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, case CmpInst::ICMP_ULE: // Equivalent to "A EqP B". This may be the same as the condition tested // in the max/min; if so, we can just return that. - if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + if (Value *V = extractEquivalentCondition(LHS, EqP, A, B)) return V; - if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + if (Value *V = extractEquivalentCondition(RHS, EqP, A, B)) return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) + if (Value *V = simplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -3437,13 +3520,13 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); // Equivalent to "A InvEqP B". This may be the same as the condition // tested in the max/min; if so, we can just return that. - if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + if (Value *V = extractEquivalentCondition(LHS, InvEqP, A, B)) return V; - if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + if (Value *V = extractEquivalentCondition(RHS, InvEqP, A, B)) return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) + if (Value *V = simplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -3499,11 +3582,10 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, continue; CallInst *Assume = cast<CallInst>(AssumeVH); - if (Optional<bool> Imp = - isImpliedCondition(Assume->getArgOperand(0), Predicate, LHS, RHS, - Q.DL)) + if (Optional<bool> Imp = isImpliedCondition(Assume->getArgOperand(0), + Predicate, LHS, RHS, Q.DL)) if (isValidAssumeForContext(Assume, Q.CxtI, Q.DT)) - return ConstantInt::get(GetCompareTy(LHS), *Imp); + return ConstantInt::get(getCompareTy(LHS), *Imp); } } @@ -3512,7 +3594,7 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); @@ -3527,7 +3609,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } assert(!isa<UndefValue>(LHS) && "Unexpected icmp undef,%X"); - Type *ITy = GetCompareTy(LHS); // The return type. + Type *ITy = getCompareTy(LHS); // The return type. // icmp poison, X -> poison if (isa<PoisonValue>(RHS)) @@ -3589,15 +3671,15 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { if (Constant *RHSC = dyn_cast<Constant>(RHS)) { // Transfer the cast to the constant. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, + if (Value *V = simplifyICmpInst(Pred, SrcOp, ConstantExpr::getIntToPtr(RHSC, SrcTy), - Q, MaxRecurse-1)) + Q, MaxRecurse - 1)) return V; } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { if (RI->getOperand(0)->getType() == SrcTy) // Compare without the cast. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) + if (Value *V = simplifyICmpInst(Pred, SrcOp, RI->getOperand(0), Q, + MaxRecurse - 1)) return V; } } @@ -3608,9 +3690,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) // Compare X and Y. Note that signed predicates become unsigned. - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, RI->getOperand(0), Q, - MaxRecurse-1)) + if (Value *V = + simplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), SrcOp, + RI->getOperand(0), Q, MaxRecurse - 1)) return V; } // Fold (zext X) ule (sext X), (zext X) sge (sext X) to true. @@ -3633,15 +3715,16 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the re-extended constant didn't change then this is effectively // also a case of comparing two zero-extended values. if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, Trunc, Q, MaxRecurse-1)) + if (Value *V = simplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit // there. Use this to work out the result of the comparison. if (RExt != CI) { switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); + default: + llvm_unreachable("Unknown ICmp predicate!"); // LHS <u RHS. case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: @@ -3657,15 +3740,15 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // is non-negative then LHS <s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getTrue(CI->getContext()) + : ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); } } } @@ -3677,8 +3760,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) // Compare X and Y. Note that the predicate does not change. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) + if (Value *V = simplifyICmpInst(Pred, SrcOp, RI->getOperand(0), Q, + MaxRecurse - 1)) return V; } // Fold (sext X) uge (zext X), (sext X) sle (zext X) to true. @@ -3701,14 +3784,16 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the re-extended constant didn't change then this is effectively // also a case of comparing two sign-extended values. if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) + if (Value *V = + simplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are all equal, while RHS has varying // bits there. Use this to work out the result of the comparison. if (RExt != CI) { switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); + default: + llvm_unreachable("Unknown ICmp predicate!"); case ICmpInst::ICMP_EQ: return ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_NE: @@ -3718,14 +3803,14 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // LHS >s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getTrue(CI->getContext()) + : ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); // If LHS is non-negative then LHS <u RHS. If LHS is negative then // LHS >u RHS. @@ -3733,18 +3818,18 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, case ICmpInst::ICMP_UGE: // Comparison is true iff the LHS <s 0. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) + if (Value *V = simplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, + Constant::getNullValue(SrcTy), Q, + MaxRecurse - 1)) return V; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: // Comparison is true iff the LHS >=s 0. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) + if (Value *V = simplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, + Constant::getNullValue(SrcTy), Q, + MaxRecurse - 1)) return V; break; } @@ -3788,26 +3873,26 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the comparison is with the result of a select instruction, check whether // comparing with either branch of the select always yields the same value. if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) - if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) + if (Value *V = threadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) return V; // If the comparison is with the result of a phi instruction, check whether // doing the compare with each incoming phi value yields a common result. if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) - if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) + if (Value *V = threadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) return V; return nullptr; } -Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); + return ::simplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } /// Given operands for an FCmpInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; @@ -3815,7 +3900,8 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Constant *CLHS = dyn_cast<Constant>(LHS)) { if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI, + Q.CxtI); // If we have a constant, make sure it is on the RHS. std::swap(LHS, RHS); @@ -3823,7 +3909,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Fold trivial predicates. - Type *RetTy = GetCompareTy(LHS); + Type *RetTy = getCompareTy(LHS); if (Pred == FCmpInst::FCMP_FALSE) return getFalse(RetTy); if (Pred == FCmpInst::FCMP_TRUE) @@ -3943,23 +4029,29 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // The ordered relationship and minnum/maxnum guarantee that we do not // have NaN constants, so ordered/unordered preds are handled the same. switch (Pred) { - case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: // minnum(X, LesserC) == C --> false // maxnum(X, GreaterC) == C --> false return getFalse(RetTy); - case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: // minnum(X, LesserC) != C --> true // maxnum(X, GreaterC) != C --> true return getTrue(RetTy); - case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_UGE: - case FCmpInst::FCMP_OGT: case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_UGT: // minnum(X, LesserC) >= C --> false // minnum(X, LesserC) > C --> false // maxnum(X, GreaterC) >= C --> true // maxnum(X, GreaterC) > C --> true return ConstantInt::get(RetTy, IsMaxNum); - case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_ULE: - case FCmpInst::FCMP_OLT: case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_ULT: // minnum(X, LesserC) <= C --> true // minnum(X, LesserC) < C --> true // maxnum(X, GreaterC) <= C --> false @@ -3997,21 +4089,21 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the comparison is with the result of a select instruction, check whether // comparing with either branch of the select always yields the same value. if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) - if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) + if (Value *V = threadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) return V; // If the comparison is with the result of a phi instruction, check whether // doing the compare with each incoming phi value yields a common result. if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) - if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) + if (Value *V = threadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) return V; return nullptr; } -Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); + return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, @@ -4078,22 +4170,21 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, }; if (auto *B = dyn_cast<BinaryOperator>(I)) - return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), NewOps[0], + return PreventSelfSimplify(simplifyBinOp(B->getOpcode(), NewOps[0], NewOps[1], Q, MaxRecurse - 1)); if (CmpInst *C = dyn_cast<CmpInst>(I)) - return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), NewOps[0], + return PreventSelfSimplify(simplifyCmpInst(C->getPredicate(), NewOps[0], NewOps[1], Q, MaxRecurse - 1)); if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) - return PreventSelfSimplify(SimplifyGEPInst( + return PreventSelfSimplify(simplifyGEPInst( GEP->getSourceElementType(), NewOps[0], makeArrayRef(NewOps).slice(1), GEP->isInBounds(), Q, MaxRecurse - 1)); if (isa<SelectInst>(I)) - return PreventSelfSimplify( - SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q, - MaxRecurse - 1)); + return PreventSelfSimplify(simplifySelectInst( + NewOps[0], NewOps[1], NewOps[2], Q, MaxRecurse - 1)); // TODO: We could hand off more cases to instsimplify here. } @@ -4119,14 +4210,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (!AllowRefinement && canCreatePoison(cast<Operator>(I))) return nullptr; - if (CmpInst *C = dyn_cast<CmpInst>(I)) - return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], - ConstOps[1], Q.DL, Q.TLI); - - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); - return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); } @@ -4189,7 +4272,8 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, /// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, - Value *FalseVal, const SimplifyQuery &Q, + Value *FalseVal, + const SimplifyQuery &Q, unsigned MaxRecurse) { ICmpInst::Predicate Pred; Value *CmpLHS, *CmpRHS; @@ -4209,7 +4293,8 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, Value *X, *Y; SelectPatternFlavor SPF = matchDecomposedSelectPattern(cast<ICmpInst>(CondVal), TrueVal, FalseVal, - X, Y).Flavor; + X, Y) + .Flavor; if (SelectPatternResult::isMinOrMax(SPF) && Pred == getMinMaxPred(SPF)) { APInt LimitC = getMinMaxLimit(getInverseMinMaxFlavor(SPF), X->getType()->getScalarSizeInBits()); @@ -4261,8 +4346,8 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, } // Check for other compares that behave like bit test. - if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, - TrueVal, FalseVal)) + if (Value *V = + simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal)) return V; // If we have a scalar equality comparison, then we know the value in one of @@ -4272,18 +4357,18 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // because each element of a vector select is chosen independently. if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) { if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal || + /* AllowRefinement */ false, + MaxRecurse) == TrueVal || simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal) + /* AllowRefinement */ false, + MaxRecurse) == TrueVal) return FalseVal; if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal || + /* AllowRefinement */ true, + MaxRecurse) == FalseVal || simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal) + /* AllowRefinement */ true, + MaxRecurse) == FalseVal) return FalseVal; } @@ -4302,11 +4387,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F, // This transform is safe if we do not have (do not care about) -0.0 or if // at least one operand is known to not be -0.0. Otherwise, the select can // change the sign of a zero operand. - bool HasNoSignedZeros = Q.CxtI && isa<FPMathOperator>(Q.CxtI) && - Q.CxtI->hasNoSignedZeros(); + bool HasNoSignedZeros = + Q.CxtI && isa<FPMathOperator>(Q.CxtI) && Q.CxtI->hasNoSignedZeros(); const APFloat *C; if (HasNoSignedZeros || (match(T, m_APFloat(C)) && C->isNonZero()) || - (match(F, m_APFloat(C)) && C->isNonZero())) { + (match(F, m_APFloat(C)) && C->isNonZero())) { // (T == F) ? T : F --> F // (F == T) ? T : F --> F if (Pred == FCmpInst::FCMP_OEQ) @@ -4323,7 +4408,7 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F, /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, +static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { if (auto *CondC = dyn_cast<Constant>(Cond)) { if (auto *TrueC = dyn_cast<Constant>(TrueVal)) @@ -4439,14 +4524,14 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, return nullptr; } -Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, +Value *llvm::simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, const SimplifyQuery &Q) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); + return ::simplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); } /// Given operands for an GetElementPtrInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyGEPInst(Type *SrcTy, Value *Ptr, +static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, ArrayRef<Value *> Indices, bool InBounds, const SimplifyQuery &Q, unsigned) { // The type of the GEP pointer operand. @@ -4473,6 +4558,13 @@ static Value *SimplifyGEPInst(Type *SrcTy, Value *Ptr, } } + // For opaque pointers an all-zero GEP is a no-op. For typed pointers, + // it may be equivalent to a bitcast. + if (Ptr->getType()->getScalarType()->isOpaquePointerTy() && + Ptr->getType() == GEPTy && + all_of(Indices, [](const auto *V) { return match(V, m_Zero()); })) + return Ptr; + // getelementptr poison, idx -> poison // getelementptr baseptr, poison -> poison if (isa<PoisonValue>(Ptr) || @@ -4577,16 +4669,16 @@ static Value *SimplifyGEPInst(Type *SrcTy, Value *Ptr, return ConstantFoldConstant(CE, Q.DL); } -Value *llvm::SimplifyGEPInst(Type *SrcTy, Value *Ptr, ArrayRef<Value *> Indices, +Value *llvm::simplifyGEPInst(Type *SrcTy, Value *Ptr, ArrayRef<Value *> Indices, bool InBounds, const SimplifyQuery &Q) { - return ::SimplifyGEPInst(SrcTy, Ptr, Indices, InBounds, Q, RecursionLimit); + return ::simplifyGEPInst(SrcTy, Ptr, Indices, InBounds, Q, RecursionLimit); } /// Given operands for an InsertValueInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, - ArrayRef<unsigned> Idxs, const SimplifyQuery &Q, - unsigned) { +static Value *simplifyInsertValueInst(Value *Agg, Value *Val, + ArrayRef<unsigned> Idxs, + const SimplifyQuery &Q, unsigned) { if (Constant *CAgg = dyn_cast<Constant>(Agg)) if (Constant *CVal = dyn_cast<Constant>(Val)) return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); @@ -4611,13 +4703,13 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return nullptr; } -Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, +Value *llvm::simplifyInsertValueInst(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const SimplifyQuery &Q) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); + return ::simplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); } -Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, +Value *llvm::simplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, const SimplifyQuery &Q) { // Try to constant fold. auto *VecC = dyn_cast<Constant>(Vec); @@ -4654,7 +4746,7 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, /// Given operands for an ExtractValueInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, +static Value *simplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, const SimplifyQuery &, unsigned) { if (auto *CAgg = dyn_cast<Constant>(Agg)) return ConstantFoldExtractValueInstruction(CAgg, Idxs); @@ -4677,14 +4769,14 @@ static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, return nullptr; } -Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, +Value *llvm::simplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, const SimplifyQuery &Q) { - return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); + return ::simplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); } /// Given operands for an ExtractElementInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, +static Value *simplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &Q, unsigned) { auto *VecVTy = cast<VectorType>(Vec->getType()); if (auto *CVec = dyn_cast<Constant>(Vec)) { @@ -4721,13 +4813,13 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, return nullptr; } -Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, +Value *llvm::simplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &Q) { - return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); + return ::simplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues, +static Value *simplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues, const SimplifyQuery &Q) { // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE // here, because the PHI we may succeed simplifying to was not @@ -4739,14 +4831,15 @@ static Value *SimplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues, bool HasUndefInput = false; for (Value *Incoming : IncomingValues) { // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PN) continue; + if (Incoming == PN) + continue; if (Q.isUndefValue(Incoming)) { // Remember that we saw an undef value, but otherwise ignore them. HasUndefInput = true; continue; } if (CommonValue && Incoming != CommonValue) - return nullptr; // Not the same, bail out. + return nullptr; // Not the same, bail out. CommonValue = Incoming; } @@ -4755,17 +4848,24 @@ static Value *SimplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues, if (!CommonValue) return UndefValue::get(PN->getType()); - // If we have a PHI node like phi(X, undef, X), where X is defined by some - // instruction, we cannot return X as the result of the PHI node unless it - // dominates the PHI block. - if (HasUndefInput) + if (HasUndefInput) { + // We cannot start executing a trapping constant expression on more control + // flow paths. + auto *C = dyn_cast<Constant>(CommonValue); + if (C && C->canTrap()) + return nullptr; + + // If we have a PHI node like phi(X, undef, X), where X is defined by some + // instruction, we cannot return X as the result of the PHI node unless it + // dominates the PHI block. return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; + } return CommonValue; } -static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, - Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { +static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, + const SimplifyQuery &Q, unsigned MaxRecurse) { if (auto *C = dyn_cast<Constant>(Op)) return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); @@ -4798,9 +4898,9 @@ static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, return nullptr; } -Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, +Value *llvm::simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, const SimplifyQuery &Q) { - return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); + return ::simplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); } /// For the given destination element of a shuffle, peek through shuffles to @@ -4854,7 +4954,7 @@ static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, return RootVec; } -static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, +static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask, Type *RetTy, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -4970,14 +5070,14 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, } /// Given operands for a ShuffleVectorInst, fold the result or return null. -Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, +Value *llvm::simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask, Type *RetTy, const SimplifyQuery &Q) { - return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); + return ::simplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } -static Constant *foldConstant(Instruction::UnaryOps Opcode, - Value *&Op, const SimplifyQuery &Q) { +static Constant *foldConstant(Instruction::UnaryOps Opcode, Value *&Op, + const SimplifyQuery &Q) { if (auto *C = dyn_cast<Constant>(Op)) return ConstantFoldUnaryOpOperand(Opcode, C, Q.DL); return nullptr; @@ -4998,7 +5098,7 @@ static Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, return nullptr; } -Value *llvm::SimplifyFNegInst(Value *Op, FastMathFlags FMF, +Value *llvm::simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q) { return ::simplifyFNegInst(Op, FMF, Q, RecursionLimit); } @@ -5049,15 +5149,10 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF, return nullptr; } -// TODO: Move this out to a header file: -static inline bool canIgnoreSNaN(fp::ExceptionBehavior EB, FastMathFlags FMF) { - return (EB == fp::ebIgnore || FMF.noNaNs()); -} - /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. static Value * -SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, +simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { @@ -5119,7 +5214,7 @@ SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. static Value * -SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, +simplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { @@ -5130,24 +5225,28 @@ SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; - if (!isDefaultFPEnvironment(ExBehavior, Rounding)) - return nullptr; - // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) - return Op0; + if (canIgnoreSNaN(ExBehavior, FMF) && + (!canRoundingModeBe(Rounding, RoundingMode::TowardNegative) || + FMF.noSignedZeros())) + if (match(Op1, m_PosZeroFP())) + return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && - (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) - return Op0; + if (canIgnoreSNaN(ExBehavior, FMF)) + if (match(Op1, m_NegZeroFP()) && + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + return Op0; // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) - return X; + if (canIgnoreSNaN(ExBehavior, FMF)) + if (match(Op0, m_NegZeroFP()) && match(Op1, m_FNeg(m_Value(X)))) + return X; + + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. @@ -5170,7 +5269,7 @@ SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, return nullptr; } -static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, +static Value *simplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { @@ -5201,8 +5300,8 @@ static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, // 2. Ignore non-zero negative numbers because sqrt would produce NAN. // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. Value *X; - if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && - FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) + if (Op0 == Op1 && match(Op0, m_Sqrt(m_Value(X))) && FMF.allowReassoc() && + FMF.noNaNs() && FMF.noSignedZeros()) return X; return nullptr; @@ -5210,7 +5309,7 @@ static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given the operands for an FMul, see if we can fold the result static Value * -SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, +simplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { @@ -5219,43 +5318,43 @@ SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, return C; // Now apply simplifications that do not require rounding. - return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse, ExBehavior, Rounding); + return simplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse, ExBehavior, Rounding); } -Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } -Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } -Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } -Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } static Value * -SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, +simplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { @@ -5301,16 +5400,16 @@ SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, return nullptr; } -Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } static Value * -SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, +simplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { @@ -5339,11 +5438,11 @@ SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, return nullptr; } -Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, +Value *llvm::simplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + return ::simplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, Rounding); } @@ -5365,8 +5464,8 @@ static Value *simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q, /// If not, this returns null. /// Try to use FastMathFlags when folding the result. static Value *simplifyFPUnOp(unsigned Opcode, Value *Op, - const FastMathFlags &FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const FastMathFlags &FMF, const SimplifyQuery &Q, + unsigned MaxRecurse) { switch (Opcode) { case Instruction::FNeg: return simplifyFNegInst(Op, FMF, Q, MaxRecurse); @@ -5375,56 +5474,56 @@ static Value *simplifyFPUnOp(unsigned Opcode, Value *Op, } } -Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q) { +Value *llvm::simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q) { return ::simplifyUnOp(Opcode, Op, Q, RecursionLimit); } -Value *llvm::SimplifyUnOp(unsigned Opcode, Value *Op, FastMathFlags FMF, +Value *llvm::simplifyUnOp(unsigned Opcode, Value *Op, FastMathFlags FMF, const SimplifyQuery &Q) { return ::simplifyFPUnOp(Opcode, Op, FMF, Q, RecursionLimit); } /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); + return simplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::Sub: - return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); + return simplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::Mul: - return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); + return simplifyMulInst(LHS, RHS, Q, MaxRecurse); case Instruction::SDiv: - return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); + return simplifySDivInst(LHS, RHS, Q, MaxRecurse); case Instruction::UDiv: - return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); + return simplifyUDivInst(LHS, RHS, Q, MaxRecurse); case Instruction::SRem: - return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); + return simplifySRemInst(LHS, RHS, Q, MaxRecurse); case Instruction::URem: - return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); + return simplifyURemInst(LHS, RHS, Q, MaxRecurse); case Instruction::Shl: - return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); + return simplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::LShr: - return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); + return simplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); case Instruction::AShr: - return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); + return simplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); case Instruction::And: - return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); + return simplifyAndInst(LHS, RHS, Q, MaxRecurse); case Instruction::Or: - return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); + return simplifyOrInst(LHS, RHS, Q, MaxRecurse); case Instruction::Xor: - return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); + return simplifyXorInst(LHS, RHS, Q, MaxRecurse); case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FSub: - return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FMul: - return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FDiv: - return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FRem: - return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); default: llvm_unreachable("Unexpected opcode"); } @@ -5433,49 +5532,50 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. /// Try to use FastMathFlags when folding the result. -static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const FastMathFlags &FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); + return simplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FSub: - return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); + return simplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: - return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); + return simplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FDiv: - return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); + return simplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: - return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); + return simplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); } } -Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); + return ::simplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); } -Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); + return ::simplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); } /// Given operands for a CmpInst, see if we can fold the result. -static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +static Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) - return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); - return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); + return simplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); } -Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +Value *llvm::simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); + return ::simplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } -static bool IsIdempotent(Intrinsic::ID ID) { +static bool isIdempotent(Intrinsic::ID ID) { switch (ID) { - default: return false; + default: + return false; // Unary idempotent: f(f(x)) = f(x) case Intrinsic::fabs: @@ -5491,7 +5591,7 @@ static bool IsIdempotent(Intrinsic::ID ID) { } } -static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, +static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset, const DataLayout &DL) { GlobalValue *PtrSym; APInt PtrOffset; @@ -5551,7 +5651,7 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, const SimplifyQuery &Q) { // Idempotent functions return the same result when called repeatedly. Intrinsic::ID IID = F->getIntrinsicID(); - if (IsIdempotent(IID)) + if (isIdempotent(IID)) if (auto *II = dyn_cast<IntrinsicInst>(Op0)) if (II->getIntrinsicID() == IID) return II; @@ -5559,15 +5659,18 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, Value *X; switch (IID) { case Intrinsic::fabs: - if (SignBitMustBeZero(Op0, Q.TLI)) return Op0; + if (SignBitMustBeZero(Op0, Q.TLI)) + return Op0; break; case Intrinsic::bswap: // bswap(bswap(x)) -> x - if (match(Op0, m_BSwap(m_Value(X)))) return X; + if (match(Op0, m_BSwap(m_Value(X)))) + return X; break; case Intrinsic::bitreverse: // bitreverse(bitreverse(x)) -> x - if (match(Op0, m_BitReverse(m_Value(X)))) return X; + if (match(Op0, m_BitReverse(m_Value(X)))) + return X; break; case Intrinsic::ctpop: { // If everything but the lowest bit is zero, that bit is the pop-count. Ex: @@ -5581,30 +5684,34 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, case Intrinsic::exp: // exp(log(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) return X; + match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) + return X; break; case Intrinsic::exp2: // exp2(log2(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) return X; + match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) + return X; break; case Intrinsic::log: // log(exp(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) return X; + match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) + return X; break; case Intrinsic::log2: // log2(exp2(x)) -> x if (Q.CxtI->hasAllowReassoc() && (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) || - match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), - m_Value(X))))) return X; + match(Op0, + m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), m_Value(X))))) + return X; break; case Intrinsic::log10: // log10(pow(10.0, x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), - m_Value(X)))) return X; + match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X)))) + return X; break; case Intrinsic::floor: case Intrinsic::trunc: @@ -5826,7 +5933,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, case Intrinsic::load_relative: if (auto *C0 = dyn_cast<Constant>(Op0)) if (auto *C1 = dyn_cast<Constant>(Op1)) - return SimplifyRelativeLoad(C0, C1, Q.DL); + return simplifyRelativeLoad(C0, C1, Q.DL); break; case Intrinsic::powi: if (auto *Power = dyn_cast<ConstantInt>(Op1)) { @@ -5853,7 +5960,8 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, case Intrinsic::maximum: case Intrinsic::minimum: { // If the arguments are the same, this is a no-op. - if (Op0 == Op1) return Op0; + if (Op0 == Op1) + return Op0; // Canonicalize constant operand as Op1. if (isa<Constant>(Op0)) @@ -5906,14 +6014,14 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, break; } - case Intrinsic::experimental_vector_extract: { + case Intrinsic::vector_extract: { Type *ReturnType = F->getReturnType(); // (extract_vector (insert_vector _, X, 0), 0) -> X unsigned IdxN = cast<ConstantInt>(Op1)->getZExtValue(); Value *X = nullptr; - if (match(Op0, m_Intrinsic<Intrinsic::experimental_vector_insert>( - m_Value(), m_Value(X), m_Zero())) && + if (match(Op0, m_Intrinsic<Intrinsic::vector_insert>(m_Value(), m_Value(X), + m_Zero())) && IdxN == 0 && X->getType() == ReturnType) return X; @@ -6054,7 +6162,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } - case Intrinsic::experimental_vector_insert: { + case Intrinsic::vector_insert: { Value *Vec = Call->getArgOperand(0); Value *SubVec = Call->getArgOperand(1); Value *Idx = Call->getArgOperand(2); @@ -6064,8 +6172,8 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { // where: Y is X, or Y is undef unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); Value *X = nullptr; - if (match(SubVec, m_Intrinsic<Intrinsic::experimental_vector_extract>( - m_Value(X), m_Zero())) && + if (match(SubVec, + m_Intrinsic<Intrinsic::vector_extract>(m_Value(X), m_Zero())) && (Q.isUndefValue(Vec) || Vec == X) && IdxN == 0 && X->getType() == ReturnType) return X; @@ -6074,43 +6182,38 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } case Intrinsic::experimental_constrained_fadd: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return SimplifyFAddInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + return simplifyFAddInst(FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), Q, FPI->getExceptionBehavior().getValue(), FPI->getRoundingMode().getValue()); - break; } case Intrinsic::experimental_constrained_fsub: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return SimplifyFSubInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + return simplifyFSubInst(FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), Q, FPI->getExceptionBehavior().getValue(), FPI->getRoundingMode().getValue()); - break; } case Intrinsic::experimental_constrained_fmul: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return SimplifyFMulInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + return simplifyFMulInst(FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), Q, FPI->getExceptionBehavior().getValue(), FPI->getRoundingMode().getValue()); - break; } case Intrinsic::experimental_constrained_fdiv: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return SimplifyFDivInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + return simplifyFDivInst(FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), Q, FPI->getExceptionBehavior().getValue(), FPI->getRoundingMode().getValue()); - break; } case Intrinsic::experimental_constrained_frem: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return SimplifyFRemInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + return simplifyFRemInst(FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), Q, FPI->getExceptionBehavior().getValue(), FPI->getRoundingMode().getValue()); - break; } default: return nullptr; @@ -6138,7 +6241,7 @@ static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } -Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { +Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { // musttail calls can only be simplified if they are also DCEd. // As we can't guarantee this here, don't simplify them. if (Call->isMustTailCall()) @@ -6161,8 +6264,17 @@ Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } +Value *llvm::simplifyConstrainedFPCall(CallBase *Call, const SimplifyQuery &Q) { + assert(isa<ConstrainedFPIntrinsic>(Call)); + if (Value *V = tryConstantFoldCall(Call, Q)) + return V; + if (Value *Ret = simplifyIntrinsic(Call, Q)) + return Ret; + return nullptr; +} + /// Given operands for a Freeze, see if we can fold the result. -static Value *SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { +static Value *simplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { // Use a utility function defined in ValueTracking. if (llvm::isGuaranteedNotToBeUndefOrPoison(Op0, Q.AC, Q.CxtI, Q.DT)) return Op0; @@ -6170,11 +6282,11 @@ static Value *SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { return nullptr; } -Value *llvm::SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { - return ::SimplifyFreezeInst(Op0, Q); +Value *llvm::simplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { + return ::simplifyFreezeInst(Op0, Q); } -static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp, +static Value *simplifyLoadInst(LoadInst *LI, Value *PtrOp, const SimplifyQuery &Q) { if (LI->isVolatile()) return nullptr; @@ -6218,134 +6330,134 @@ static Value *simplifyInstructionWithOperands(Instruction *I, } break; case Instruction::FNeg: - Result = SimplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); + Result = simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Add: - Result = SimplifyAddInst( + Result = simplifyAddInst( NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FSub: - Result = SimplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = SimplifySubInst( + Result = simplifySubInst( NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FMul: - Result = SimplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = SimplifyMulInst(NewOps[0], NewOps[1], Q); + Result = simplifyMulInst(NewOps[0], NewOps[1], Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(NewOps[0], NewOps[1], Q); + Result = simplifySDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(NewOps[0], NewOps[1], Q); + Result = simplifyUDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::FDiv: - Result = SimplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(NewOps[0], NewOps[1], Q); + Result = simplifySRemInst(NewOps[0], NewOps[1], Q); break; case Instruction::URem: - Result = SimplifyURemInst(NewOps[0], NewOps[1], Q); + Result = simplifyURemInst(NewOps[0], NewOps[1], Q); break; case Instruction::FRem: - Result = SimplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = SimplifyShlInst( + Result = simplifyShlInst( NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::LShr: - Result = SimplifyLShrInst(NewOps[0], NewOps[1], + Result = simplifyLShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::AShr: - Result = SimplifyAShrInst(NewOps[0], NewOps[1], + Result = simplifyAShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::And: - Result = SimplifyAndInst(NewOps[0], NewOps[1], Q); + Result = simplifyAndInst(NewOps[0], NewOps[1], Q); break; case Instruction::Or: - Result = SimplifyOrInst(NewOps[0], NewOps[1], Q); + Result = simplifyOrInst(NewOps[0], NewOps[1], Q); break; case Instruction::Xor: - Result = SimplifyXorInst(NewOps[0], NewOps[1], Q); + Result = simplifyXorInst(NewOps[0], NewOps[1], Q); break; case Instruction::ICmp: - Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0], + Result = simplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0], NewOps[1], Q); break; case Instruction::FCmp: - Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0], + Result = simplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Select: - Result = SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); + Result = simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); break; case Instruction::GetElementPtr: { auto *GEPI = cast<GetElementPtrInst>(I); Result = - SimplifyGEPInst(GEPI->getSourceElementType(), NewOps[0], + simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0], makeArrayRef(NewOps).slice(1), GEPI->isInBounds(), Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); - Result = SimplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); + Result = simplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); break; } case Instruction::InsertElement: { - Result = SimplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); + Result = simplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); break; } case Instruction::ExtractValue: { auto *EVI = cast<ExtractValueInst>(I); - Result = SimplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); + Result = simplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { - Result = SimplifyExtractElementInst(NewOps[0], NewOps[1], Q); + Result = simplifyExtractElementInst(NewOps[0], NewOps[1], Q); break; } case Instruction::ShuffleVector: { auto *SVI = cast<ShuffleVectorInst>(I); - Result = SimplifyShuffleVectorInst( + Result = simplifyShuffleVectorInst( NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), NewOps, Q); + Result = simplifyPHINode(cast<PHINode>(I), NewOps, Q); break; case Instruction::Call: { // TODO: Use NewOps - Result = SimplifyCall(cast<CallInst>(I), Q); + Result = simplifyCall(cast<CallInst>(I), Q); break; } case Instruction::Freeze: - Result = llvm::SimplifyFreezeInst(NewOps[0], Q); + Result = llvm::simplifyFreezeInst(NewOps[0], Q); break; #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = SimplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); + Result = simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. Result = nullptr; break; case Instruction::Load: - Result = SimplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q); + Result = simplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q); break; } @@ -6355,7 +6467,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I, return Result == I ? UndefValue::get(I->getType()) : Result; } -Value *llvm::SimplifyInstructionWithOperands(Instruction *I, +Value *llvm::simplifyInstructionWithOperands(Instruction *I, ArrayRef<Value *> NewOps, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { @@ -6364,7 +6476,7 @@ Value *llvm::SimplifyInstructionWithOperands(Instruction *I, return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); } -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, +Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { SmallVector<Value *, 8> Ops(I->operands()); return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); @@ -6415,7 +6527,7 @@ static bool replaceAndRecursivelySimplifyImpl( I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC}); + SimpleV = simplifyInstruction(I, {DL, TLI, DT, AC}); if (!SimpleV) { if (UnsimplifiedUsers) UnsimplifiedUsers->insert(I); @@ -6478,6 +6590,6 @@ const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM, } template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &, Function &); -} +} // namespace llvm void InstSimplifyFolder::anchor() {} diff --git a/llvm/lib/Analysis/Interval.cpp b/llvm/lib/Analysis/Interval.cpp index e228ec4f2126..f7fffcb3d5e6 100644 --- a/llvm/lib/Analysis/Interval.cpp +++ b/llvm/lib/Analysis/Interval.cpp @@ -13,7 +13,6 @@ #include "llvm/Analysis/Interval.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; diff --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp index e8e9593d7030..20a905e04a9d 100644 --- a/llvm/lib/Analysis/LazyCallGraph.cpp +++ b/llvm/lib/Analysis/LazyCallGraph.cpp @@ -9,14 +9,13 @@ #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" @@ -30,12 +29,15 @@ #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> -#include <cstddef> #include <iterator> #include <string> #include <tuple> #include <utility> +#ifdef EXPENSIVE_CHECKS +#include "llvm/ADT/ScopeExit.h" +#endif + using namespace llvm; #define DEBUG_TYPE "lcg" diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index e311b40ab25c..8a8e9e923b7c 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -38,7 +38,6 @@ #include "llvm/Support/FormattedStream.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" -#include <map> using namespace llvm; using namespace PatternMatch; @@ -919,7 +918,7 @@ Optional<ValueLatticeElement> LazyValueInfoImpl::solveBlockValueCast( // transfer rule on the full set since we may be able to locally infer // interesting facts. Optional<ConstantRange> LHSRes = getRangeFor(CI->getOperand(0), CI, BB); - if (!LHSRes.hasValue()) + if (!LHSRes) // More work to do before applying this transfer rule. return None; const ConstantRange &LHSRange = LHSRes.getValue(); @@ -943,7 +942,7 @@ Optional<ValueLatticeElement> LazyValueInfoImpl::solveBlockValueBinaryOpImpl( // @foo()), 32" Optional<ConstantRange> LHSRes = getRangeFor(I->getOperand(0), I, BB); Optional<ConstantRange> RHSRes = getRangeFor(I->getOperand(1), I, BB); - if (!LHSRes.hasValue() || !RHSRes.hasValue()) + if (!LHSRes || !RHSRes) // More work to do before applying this transfer rule. return None; @@ -956,13 +955,6 @@ Optional<ValueLatticeElement> LazyValueInfoImpl::solveBlockValueBinaryOp( BinaryOperator *BO, BasicBlock *BB) { assert(BO->getOperand(0)->getType()->isSized() && "all operands to binary operators are sized"); - if (BO->getOpcode() == Instruction::Xor) { - // Xor is the only operation not supported by ConstantRange::binaryOp(). - LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined (unknown binary operator).\n"); - return ValueLatticeElement::getOverdefined(); - } - if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(BO)) { unsigned NoWrapKind = 0; if (OBO->hasNoUnsignedWrap()) @@ -1020,7 +1012,7 @@ Optional<ValueLatticeElement> LazyValueInfoImpl::solveBlockValueExtractValue( // Handle extractvalue of insertvalue to allow further simplification // based on replaced with.overflow intrinsics. - if (Value *V = SimplifyExtractValueInst( + if (Value *V = simplifyExtractValueInst( EVI->getAggregateOperand(), EVI->getIndices(), EVI->getModule()->getDataLayout())) return getBlockValue(V, BB, EVI); @@ -1141,7 +1133,7 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, ConstantRange CR = ConstantRange::makeExactICmpRegion(EdgePred, *C); if (!CR.isEmptySet()) return ValueLatticeElement::getRange(ConstantRange::getNonEmpty( - CR.getUnsignedMin().zextOrSelf(BitWidth), APInt(BitWidth, 0))); + CR.getUnsignedMin().zext(BitWidth), APInt(BitWidth, 0))); } return ValueLatticeElement::getOverdefined(); @@ -1278,7 +1270,7 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op, if (auto *CI = dyn_cast<CastInst>(Usr)) { assert(CI->getOperand(0) == Op && "Operand 0 isn't Op"); if (auto *C = dyn_cast_or_null<ConstantInt>( - SimplifyCastInst(CI->getOpcode(), OpConst, + simplifyCastInst(CI->getOpcode(), OpConst, CI->getDestTy(), DL))) { return ValueLatticeElement::getRange(ConstantRange(C->getValue())); } @@ -1290,7 +1282,7 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op, Value *LHS = Op0Match ? OpConst : BO->getOperand(0); Value *RHS = Op1Match ? OpConst : BO->getOperand(1); if (auto *C = dyn_cast_or_null<ConstantInt>( - SimplifyBinOp(BO->getOpcode(), LHS, RHS, DL))) { + simplifyBinOp(BO->getOpcode(), LHS, RHS, DL))) { return ValueLatticeElement::getRange(ConstantRange(C->getValue())); } } else if (isa<FreezeInst>(Usr)) { @@ -1361,7 +1353,7 @@ static Optional<ValueLatticeElement> getEdgeValueLocal(Value *Val, ValueLatticeElement OpLatticeVal = getValueFromCondition(Op, Condition, isTrueDest); if (Optional<APInt> OpConst = OpLatticeVal.asConstantInteger()) { - Result = constantFoldUser(Usr, Op, OpConst.getValue(), DL); + Result = constantFoldUser(Usr, Op, *OpConst, DL); break; } } @@ -1432,8 +1424,9 @@ Optional<ValueLatticeElement> LazyValueInfoImpl::getEdgeValue( if (Constant *VC = dyn_cast<Constant>(Val)) return ValueLatticeElement::get(VC); - ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo) - .getValueOr(ValueLatticeElement::getOverdefined()); + ValueLatticeElement LocalResult = + getEdgeValueLocal(Val, BBFrom, BBTo) + .value_or(ValueLatticeElement::getOverdefined()); if (hasSingleValue(LocalResult)) // Can't get any more precise here return LocalResult; @@ -1886,6 +1879,11 @@ void LazyValueInfo::eraseBlock(BasicBlock *BB) { } } +void LazyValueInfo::clear(const Module *M) { + if (PImpl) { + getImpl(PImpl, AC, M).clear(); + } +} void LazyValueInfo::printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) { if (PImpl) { diff --git a/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp index 031bf3bae51d..491d44335f22 100644 --- a/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp +++ b/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp @@ -68,6 +68,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" diff --git a/llvm/lib/Analysis/Lint.cpp b/llvm/lib/Analysis/Lint.cpp index f9a7a5bdf434..9cfb91a22b7d 100644 --- a/llvm/lib/Analysis/Lint.cpp +++ b/llvm/lib/Analysis/Lint.cpp @@ -44,7 +44,6 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/Passes.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -69,9 +68,7 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <cstdint> @@ -169,8 +166,8 @@ public: }; } // end anonymous namespace -// Assert - We know that cond should be true, if not print an error message. -#define Assert(C, ...) \ +// Check - We know that cond should be true, if not print an error message. +#define Check(C, ...) \ do { \ if (!(C)) { \ CheckFailed(__VA_ARGS__); \ @@ -181,8 +178,8 @@ public: void Lint::visitFunction(Function &F) { // This isn't undefined behavior, it's just a little unusual, and it's a // fairly common mistake to neglect to name a function. - Assert(F.hasName() || F.hasLocalLinkage(), - "Unusual: Unnamed function with non-local linkage", &F); + Check(F.hasName() || F.hasLocalLinkage(), + "Unusual: Unnamed function with non-local linkage", &F); // TODO: Check for irreducible control flow. } @@ -195,23 +192,23 @@ void Lint::visitCallBase(CallBase &I) { if (Function *F = dyn_cast<Function>(findValue(Callee, /*OffsetOk=*/false))) { - Assert(I.getCallingConv() == F->getCallingConv(), - "Undefined behavior: Caller and callee calling convention differ", - &I); + Check(I.getCallingConv() == F->getCallingConv(), + "Undefined behavior: Caller and callee calling convention differ", + &I); FunctionType *FT = F->getFunctionType(); unsigned NumActualArgs = I.arg_size(); - Assert(FT->isVarArg() ? FT->getNumParams() <= NumActualArgs - : FT->getNumParams() == NumActualArgs, - "Undefined behavior: Call argument count mismatches callee " - "argument count", - &I); + Check(FT->isVarArg() ? FT->getNumParams() <= NumActualArgs + : FT->getNumParams() == NumActualArgs, + "Undefined behavior: Call argument count mismatches callee " + "argument count", + &I); - Assert(FT->getReturnType() == I.getType(), - "Undefined behavior: Call return type mismatches " - "callee return type", - &I); + Check(FT->getReturnType() == I.getType(), + "Undefined behavior: Call return type mismatches " + "callee return type", + &I); // Check argument types (in case the callee was casted) and attributes. // TODO: Verify that caller and callee attributes are compatible. @@ -221,10 +218,10 @@ void Lint::visitCallBase(CallBase &I) { Value *Actual = *AI; if (PI != PE) { Argument *Formal = &*PI++; - Assert(Formal->getType() == Actual->getType(), - "Undefined behavior: Call argument type mismatches " - "callee parameter type", - &I); + Check(Formal->getType() == Actual->getType(), + "Undefined behavior: Call argument type mismatches " + "callee parameter type", + &I); // Check that noalias arguments don't alias other arguments. This is // not fully precise because we don't know the sizes of the dereferenced @@ -242,9 +239,9 @@ void Lint::visitCallBase(CallBase &I) { continue; if (AI != BI && (*BI)->getType()->isPointerTy()) { AliasResult Result = AA->alias(*AI, *BI); - Assert(Result != AliasResult::MustAlias && - Result != AliasResult::PartialAlias, - "Unusual: noalias argument aliases another argument", &I); + Check(Result != AliasResult::MustAlias && + Result != AliasResult::PartialAlias, + "Unusual: noalias argument aliases another argument", &I); } } } @@ -271,10 +268,10 @@ void Lint::visitCallBase(CallBase &I) { if (PAL.hasParamAttr(ArgNo++, Attribute::ByVal)) continue; Value *Obj = findValue(Arg, /*OffsetOk=*/true); - Assert(!isa<AllocaInst>(Obj), - "Undefined behavior: Call with \"tail\" keyword references " - "alloca", - &I); + Check(!isa<AllocaInst>(Obj), + "Undefined behavior: Call with \"tail\" keyword references " + "alloca", + &I); } } } @@ -302,9 +299,9 @@ void Lint::visitCallBase(CallBase &I) { /*OffsetOk=*/false))) if (Len->getValue().isIntN(32)) Size = LocationSize::precise(Len->getValue().getZExtValue()); - Assert(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) != - AliasResult::MustAlias, - "Undefined behavior: memcpy source and destination overlap", &I); + Check(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) != + AliasResult::MustAlias, + "Undefined behavior: memcpy source and destination overlap", &I); break; } case Intrinsic::memcpy_inline: { @@ -319,9 +316,9 @@ void Lint::visitCallBase(CallBase &I) { // isn't expressive enough for what we really want to do. Known partial // overlap is not distinguished from the case where nothing is known. const LocationSize LS = LocationSize::precise(Size); - Assert(AA->alias(MCII->getSource(), LS, MCII->getDest(), LS) != - AliasResult::MustAlias, - "Undefined behavior: memcpy source and destination overlap", &I); + Check(AA->alias(MCII->getSource(), LS, MCII->getDest(), LS) != + AliasResult::MustAlias, + "Undefined behavior: memcpy source and destination overlap", &I); break; } case Intrinsic::memmove: { @@ -338,11 +335,17 @@ void Lint::visitCallBase(CallBase &I) { MSI->getDestAlign(), nullptr, MemRef::Write); break; } + case Intrinsic::memset_inline: { + MemSetInlineInst *MSII = cast<MemSetInlineInst>(&I); + visitMemoryReference(I, MemoryLocation::getForDest(MSII), + MSII->getDestAlign(), nullptr, MemRef::Write); + break; + } case Intrinsic::vastart: - Assert(I.getParent()->getParent()->isVarArg(), - "Undefined behavior: va_start called in a non-varargs function", - &I); + Check(I.getParent()->getParent()->isVarArg(), + "Undefined behavior: va_start called in a non-varargs function", + &I); visitMemoryReference(I, MemoryLocation::getForArgument(&I, 0, TLI), None, nullptr, MemRef::Read | MemRef::Write); @@ -367,20 +370,22 @@ void Lint::visitCallBase(CallBase &I) { break; case Intrinsic::get_active_lane_mask: if (auto *TripCount = dyn_cast<ConstantInt>(I.getArgOperand(1))) - Assert(!TripCount->isZero(), "get_active_lane_mask: operand #2 " - "must be greater than 0", &I); + Check(!TripCount->isZero(), + "get_active_lane_mask: operand #2 " + "must be greater than 0", + &I); break; } } void Lint::visitReturnInst(ReturnInst &I) { Function *F = I.getParent()->getParent(); - Assert(!F->doesNotReturn(), - "Unusual: Return statement in function with noreturn attribute", &I); + Check(!F->doesNotReturn(), + "Unusual: Return statement in function with noreturn attribute", &I); if (Value *V = I.getReturnValue()) { Value *Obj = findValue(V, /*OffsetOk=*/true); - Assert(!isa<AllocaInst>(Obj), "Unusual: Returning alloca value", &I); + Check(!isa<AllocaInst>(Obj), "Unusual: Returning alloca value", &I); } } @@ -395,39 +400,39 @@ void Lint::visitMemoryReference(Instruction &I, const MemoryLocation &Loc, Value *Ptr = const_cast<Value *>(Loc.Ptr); Value *UnderlyingObject = findValue(Ptr, /*OffsetOk=*/true); - Assert(!isa<ConstantPointerNull>(UnderlyingObject), - "Undefined behavior: Null pointer dereference", &I); - Assert(!isa<UndefValue>(UnderlyingObject), - "Undefined behavior: Undef pointer dereference", &I); - Assert(!isa<ConstantInt>(UnderlyingObject) || - !cast<ConstantInt>(UnderlyingObject)->isMinusOne(), - "Unusual: All-ones pointer dereference", &I); - Assert(!isa<ConstantInt>(UnderlyingObject) || - !cast<ConstantInt>(UnderlyingObject)->isOne(), - "Unusual: Address one pointer dereference", &I); + Check(!isa<ConstantPointerNull>(UnderlyingObject), + "Undefined behavior: Null pointer dereference", &I); + Check(!isa<UndefValue>(UnderlyingObject), + "Undefined behavior: Undef pointer dereference", &I); + Check(!isa<ConstantInt>(UnderlyingObject) || + !cast<ConstantInt>(UnderlyingObject)->isMinusOne(), + "Unusual: All-ones pointer dereference", &I); + Check(!isa<ConstantInt>(UnderlyingObject) || + !cast<ConstantInt>(UnderlyingObject)->isOne(), + "Unusual: Address one pointer dereference", &I); if (Flags & MemRef::Write) { if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(UnderlyingObject)) - Assert(!GV->isConstant(), "Undefined behavior: Write to read-only memory", - &I); - Assert(!isa<Function>(UnderlyingObject) && - !isa<BlockAddress>(UnderlyingObject), - "Undefined behavior: Write to text section", &I); + Check(!GV->isConstant(), "Undefined behavior: Write to read-only memory", + &I); + Check(!isa<Function>(UnderlyingObject) && + !isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Write to text section", &I); } if (Flags & MemRef::Read) { - Assert(!isa<Function>(UnderlyingObject), "Unusual: Load from function body", - &I); - Assert(!isa<BlockAddress>(UnderlyingObject), - "Undefined behavior: Load from block address", &I); + Check(!isa<Function>(UnderlyingObject), "Unusual: Load from function body", + &I); + Check(!isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Load from block address", &I); } if (Flags & MemRef::Callee) { - Assert(!isa<BlockAddress>(UnderlyingObject), - "Undefined behavior: Call to block address", &I); + Check(!isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Call to block address", &I); } if (Flags & MemRef::Branchee) { - Assert(!isa<Constant>(UnderlyingObject) || - isa<BlockAddress>(UnderlyingObject), - "Undefined behavior: Branch to non-blockaddress", &I); + Check(!isa<Constant>(UnderlyingObject) || + isa<BlockAddress>(UnderlyingObject), + "Undefined behavior: Branch to non-blockaddress", &I); } // Check for buffer overflows and misalignment. @@ -461,17 +466,17 @@ void Lint::visitMemoryReference(Instruction &I, const MemoryLocation &Loc, // Accesses from before the start or after the end of the object are not // defined. - Assert(!Loc.Size.hasValue() || BaseSize == MemoryLocation::UnknownSize || - (Offset >= 0 && Offset + Loc.Size.getValue() <= BaseSize), - "Undefined behavior: Buffer overflow", &I); + Check(!Loc.Size.hasValue() || BaseSize == MemoryLocation::UnknownSize || + (Offset >= 0 && Offset + Loc.Size.getValue() <= BaseSize), + "Undefined behavior: Buffer overflow", &I); // Accesses that say that the memory is more aligned than it is are not // defined. if (!Align && Ty && Ty->isSized()) Align = DL->getABITypeAlign(Ty); if (BaseAlign && Align) - Assert(*Align <= commonAlignment(*BaseAlign, Offset), - "Undefined behavior: Memory reference address is misaligned", &I); + Check(*Align <= commonAlignment(*BaseAlign, Offset), + "Undefined behavior: Memory reference address is misaligned", &I); } } @@ -486,34 +491,34 @@ void Lint::visitStoreInst(StoreInst &I) { } void Lint::visitXor(BinaryOperator &I) { - Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), - "Undefined result: xor(undef, undef)", &I); + Check(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), + "Undefined result: xor(undef, undef)", &I); } void Lint::visitSub(BinaryOperator &I) { - Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), - "Undefined result: sub(undef, undef)", &I); + Check(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), + "Undefined result: sub(undef, undef)", &I); } void Lint::visitLShr(BinaryOperator &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) - Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), - "Undefined result: Shift count out of range", &I); + Check(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); } void Lint::visitAShr(BinaryOperator &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) - Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), - "Undefined result: Shift count out of range", &I); + Check(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); } void Lint::visitShl(BinaryOperator &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) - Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), - "Undefined result: Shift count out of range", &I); + Check(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), + "Undefined result: Shift count out of range", &I); } static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, @@ -554,30 +559,30 @@ static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, } void Lint::visitSDiv(BinaryOperator &I) { - Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), - "Undefined behavior: Division by zero", &I); + Check(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); } void Lint::visitUDiv(BinaryOperator &I) { - Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), - "Undefined behavior: Division by zero", &I); + Check(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); } void Lint::visitSRem(BinaryOperator &I) { - Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), - "Undefined behavior: Division by zero", &I); + Check(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); } void Lint::visitURem(BinaryOperator &I) { - Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), - "Undefined behavior: Division by zero", &I); + Check(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), + "Undefined behavior: Division by zero", &I); } void Lint::visitAllocaInst(AllocaInst &I) { if (isa<ConstantInt>(I.getArraySize())) // This isn't undefined behavior, it's just an obvious pessimization. - Assert(&I.getParent()->getParent()->getEntryBlock() == I.getParent(), - "Pessimization: Static alloca outside of entry block", &I); + Check(&I.getParent()->getParent()->getEntryBlock() == I.getParent(), + "Pessimization: Static alloca outside of entry block", &I); // TODO: Check for an unusual size (MSB set?) } @@ -591,14 +596,14 @@ void Lint::visitIndirectBrInst(IndirectBrInst &I) { visitMemoryReference(I, MemoryLocation::getAfter(I.getAddress()), None, nullptr, MemRef::Branchee); - Assert(I.getNumDestinations() != 0, - "Undefined behavior: indirectbr with no destinations", &I); + Check(I.getNumDestinations() != 0, + "Undefined behavior: indirectbr with no destinations", &I); } void Lint::visitExtractElementInst(ExtractElementInst &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getIndexOperand(), /*OffsetOk=*/false))) - Assert( + Check( CI->getValue().ult( cast<FixedVectorType>(I.getVectorOperandType())->getNumElements()), "Undefined result: extractelement index out of range", &I); @@ -607,18 +612,18 @@ void Lint::visitExtractElementInst(ExtractElementInst &I) { void Lint::visitInsertElementInst(InsertElementInst &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(2), /*OffsetOk=*/false))) - Assert(CI->getValue().ult( - cast<FixedVectorType>(I.getType())->getNumElements()), - "Undefined result: insertelement index out of range", &I); + Check(CI->getValue().ult( + cast<FixedVectorType>(I.getType())->getNumElements()), + "Undefined result: insertelement index out of range", &I); } void Lint::visitUnreachableInst(UnreachableInst &I) { // This isn't undefined behavior, it's merely suspicious. - Assert(&I == &I.getParent()->front() || - std::prev(I.getIterator())->mayHaveSideEffects(), - "Unusual: unreachable immediately preceded by instruction without " - "side effects", - &I); + Check(&I == &I.getParent()->front() || + std::prev(I.getIterator())->mayHaveSideEffects(), + "Unusual: unreachable immediately preceded by instruction without " + "side effects", + &I); } /// findValue - Look through bitcasts and simple memory reference patterns @@ -681,17 +686,12 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, CE->getOperand(0)->getType(), CE->getType(), *DL)) return findValueImpl(CE->getOperand(0), OffsetOk, Visited); - } else if (CE->getOpcode() == Instruction::ExtractValue) { - ArrayRef<unsigned> Indices = CE->getIndices(); - if (Value *W = FindInsertedValue(CE->getOperand(0), Indices)) - if (W != V) - return findValueImpl(W, OffsetOk, Visited); } } // As a last resort, try SimplifyInstruction or constant folding. if (Instruction *Inst = dyn_cast<Instruction>(V)) { - if (Value *W = SimplifyInstruction(Inst, {*DL, TLI, DT, AC})) + if (Value *W = simplifyInstruction(Inst, {*DL, TLI, DT, AC})) return findValueImpl(W, OffsetOk, Visited); } else if (auto *C = dyn_cast<Constant>(V)) { Value *W = ConstantFoldConstant(C, *DL, TLI); diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index cd0d4d6b9ca8..bc1d82cf1480 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -13,19 +13,14 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" -#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/GlobalAlias.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -509,8 +504,8 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr, if (CastInst::isBitOrNoopPointerCastable(Val->getType(), AccessTy, DL)) return Val; - TypeSize StoreSize = DL.getTypeStoreSize(Val->getType()); - TypeSize LoadSize = DL.getTypeStoreSize(AccessTy); + TypeSize StoreSize = DL.getTypeSizeInBits(Val->getType()); + TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy); if (TypeSize::isKnownLE(LoadSize, StoreSize)) if (auto *C = dyn_cast<Constant>(Val)) return ConstantFoldLoadFromConst(C, AccessTy, DL); diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 2ab78d2b7ee2..79161db9b5e4 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -47,6 +47,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -60,12 +61,12 @@ #include <algorithm> #include <cassert> #include <cstdint> -#include <cstdlib> #include <iterator> #include <utility> #include <vector> using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "loop-accesses" @@ -172,7 +173,8 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( : High(RtCheck.Pointers[Index].End), Low(RtCheck.Pointers[Index].Start), AddressSpace(RtCheck.Pointers[Index] .PointerValue->getType() - ->getPointerAddressSpace()) { + ->getPointerAddressSpace()), + NeedsFreeze(RtCheck.Pointers[Index].NeedsFreeze) { Members.push_back(Index); } @@ -189,21 +191,20 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( /// /// There is no conflict when the intervals are disjoint: /// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End) -void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, +void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, + Type *AccessTy, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides, - PredicatedScalarEvolution &PSE) { - // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PredicatedScalarEvolution &PSE, + bool NeedsFreeze) { ScalarEvolution *SE = PSE.getSE(); const SCEV *ScStart; const SCEV *ScEnd; - if (SE->isLoopInvariant(Sc, Lp)) { - ScStart = ScEnd = Sc; + if (SE->isLoopInvariant(PtrExpr, Lp)) { + ScStart = ScEnd = PtrExpr; } else { - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr); assert(AR && "Invalid addrec expression"); const SCEV *Ex = PSE.getBackedgeTakenCount(); @@ -227,15 +228,100 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, // Add the size of the pointed element to ScEnd. auto &DL = Lp->getHeader()->getModule()->getDataLayout(); Type *IdxTy = DL.getIndexType(Ptr->getType()); - const SCEV *EltSizeSCEV = - SE->getStoreSizeOfExpr(IdxTy, Ptr->getType()->getPointerElementType()); + const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy); ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); - Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); + Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr, + NeedsFreeze); } -SmallVector<RuntimePointerCheck, 4> -RuntimePointerChecking::generateChecks() const { +void RuntimePointerChecking::tryToCreateDiffCheck( + const RuntimeCheckingPtrGroup &CGI, const RuntimeCheckingPtrGroup &CGJ) { + if (!CanUseDiffCheck) + return; + + // If either group contains multiple different pointers, bail out. + // TODO: Support multiple pointers by using the minimum or maximum pointer, + // depending on src & sink. + if (CGI.Members.size() != 1 || CGJ.Members.size() != 1) { + CanUseDiffCheck = false; + return; + } + + PointerInfo *Src = &Pointers[CGI.Members[0]]; + PointerInfo *Sink = &Pointers[CGJ.Members[0]]; + + // If either pointer is read and written, multiple checks may be needed. Bail + // out. + if (!DC.getOrderForAccess(Src->PointerValue, !Src->IsWritePtr).empty() || + !DC.getOrderForAccess(Sink->PointerValue, !Sink->IsWritePtr).empty()) { + CanUseDiffCheck = false; + return; + } + + ArrayRef<unsigned> AccSrc = + DC.getOrderForAccess(Src->PointerValue, Src->IsWritePtr); + ArrayRef<unsigned> AccSink = + DC.getOrderForAccess(Sink->PointerValue, Sink->IsWritePtr); + // If either pointer is accessed multiple times, there may not be a clear + // src/sink relation. Bail out for now. + if (AccSrc.size() != 1 || AccSink.size() != 1) { + CanUseDiffCheck = false; + return; + } + // If the sink is accessed before src, swap src/sink. + if (AccSink[0] < AccSrc[0]) + std::swap(Src, Sink); + + auto *SrcAR = dyn_cast<SCEVAddRecExpr>(Src->Expr); + auto *SinkAR = dyn_cast<SCEVAddRecExpr>(Sink->Expr); + if (!SrcAR || !SinkAR) { + CanUseDiffCheck = false; + return; + } + + const DataLayout &DL = + SinkAR->getLoop()->getHeader()->getModule()->getDataLayout(); + SmallVector<Instruction *, 4> SrcInsts = + DC.getInstructionsForAccess(Src->PointerValue, Src->IsWritePtr); + SmallVector<Instruction *, 4> SinkInsts = + DC.getInstructionsForAccess(Sink->PointerValue, Sink->IsWritePtr); + Type *SrcTy = getLoadStoreType(SrcInsts[0]); + Type *DstTy = getLoadStoreType(SinkInsts[0]); + if (isa<ScalableVectorType>(SrcTy) || isa<ScalableVectorType>(DstTy)) + return; + unsigned AllocSize = + std::max(DL.getTypeAllocSize(SrcTy), DL.getTypeAllocSize(DstTy)); + IntegerType *IntTy = + IntegerType::get(Src->PointerValue->getContext(), + DL.getPointerSizeInBits(CGI.AddressSpace)); + + // Only matching constant steps matching the AllocSize are supported at the + // moment. This simplifies the difference computation. Can be extended in the + // future. + auto *Step = dyn_cast<SCEVConstant>(SinkAR->getStepRecurrence(*SE)); + if (!Step || Step != SrcAR->getStepRecurrence(*SE) || + Step->getAPInt().abs() != AllocSize) { + CanUseDiffCheck = false; + return; + } + + // When counting down, the dependence distance needs to be swapped. + if (Step->getValue()->isNegative()) + std::swap(SinkAR, SrcAR); + + const SCEV *SinkStartInt = SE->getPtrToIntExpr(SinkAR->getStart(), IntTy); + const SCEV *SrcStartInt = SE->getPtrToIntExpr(SrcAR->getStart(), IntTy); + if (isa<SCEVCouldNotCompute>(SinkStartInt) || + isa<SCEVCouldNotCompute>(SrcStartInt)) { + CanUseDiffCheck = false; + return; + } + DiffChecks.emplace_back(SrcStartInt, SinkStartInt, AllocSize, + Src->NeedsFreeze || Sink->NeedsFreeze); +} + +SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() { SmallVector<RuntimePointerCheck, 4> Checks; for (unsigned I = 0; I < CheckingGroups.size(); ++I) { @@ -243,8 +329,10 @@ RuntimePointerChecking::generateChecks() const { const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I]; const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J]; - if (needsChecking(CGI, CGJ)) + if (needsChecking(CGI, CGJ)) { + tryToCreateDiffCheck(CGI, CGJ); Checks.push_back(std::make_pair(&CGI, &CGJ)); + } } } return Checks; @@ -285,11 +373,12 @@ bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, return addPointer( Index, RtCheck.Pointers[Index].Start, RtCheck.Pointers[Index].End, RtCheck.Pointers[Index].PointerValue->getType()->getPointerAddressSpace(), - *RtCheck.SE); + RtCheck.Pointers[Index].NeedsFreeze, *RtCheck.SE); } bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, const SCEV *Start, const SCEV *End, unsigned AS, + bool NeedsFreeze, ScalarEvolution &SE) { assert(AddressSpace == AS && "all pointers in a checking group must be in the same address space"); @@ -314,6 +403,7 @@ bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, const SCEV *Start, High = End; Members.push_back(Index); + this->NeedsFreeze |= NeedsFreeze; return true; } @@ -371,9 +461,11 @@ void RuntimePointerChecking::groupChecks( unsigned TotalComparisons = 0; - DenseMap<Value *, unsigned> PositionMap; - for (unsigned Index = 0; Index < Pointers.size(); ++Index) - PositionMap[Pointers[Index].PointerValue] = Index; + DenseMap<Value *, SmallVector<unsigned>> PositionMap; + for (unsigned Index = 0; Index < Pointers.size(); ++Index) { + auto Iter = PositionMap.insert({Pointers[Index].PointerValue, {}}); + Iter.first->second.push_back(Index); + } // We need to keep track of what pointers we've already seen so we // don't process them twice. @@ -404,34 +496,35 @@ void RuntimePointerChecking::groupChecks( auto PointerI = PositionMap.find(MI->getPointer()); assert(PointerI != PositionMap.end() && "pointer in equivalence class not found in PositionMap"); - unsigned Pointer = PointerI->second; - bool Merged = false; - // Mark this pointer as seen. - Seen.insert(Pointer); - - // Go through all the existing sets and see if we can find one - // which can include this pointer. - for (RuntimeCheckingPtrGroup &Group : Groups) { - // Don't perform more than a certain amount of comparisons. - // This should limit the cost of grouping the pointers to something - // reasonable. If we do end up hitting this threshold, the algorithm - // will create separate groups for all remaining pointers. - if (TotalComparisons > MemoryCheckMergeThreshold) - break; - - TotalComparisons++; - - if (Group.addPointer(Pointer, *this)) { - Merged = true; - break; + for (unsigned Pointer : PointerI->second) { + bool Merged = false; + // Mark this pointer as seen. + Seen.insert(Pointer); + + // Go through all the existing sets and see if we can find one + // which can include this pointer. + for (RuntimeCheckingPtrGroup &Group : Groups) { + // Don't perform more than a certain amount of comparisons. + // This should limit the cost of grouping the pointers to something + // reasonable. If we do end up hitting this threshold, the algorithm + // will create separate groups for all remaining pointers. + if (TotalComparisons > MemoryCheckMergeThreshold) + break; + + TotalComparisons++; + + if (Group.addPointer(Pointer, *this)) { + Merged = true; + break; + } } - } - if (!Merged) - // We couldn't add this pointer to any existing set or the threshold - // for the number of comparisons has been reached. Create a new group - // to hold the current pointer. - Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); + if (!Merged) + // We couldn't add this pointer to any existing set or the threshold + // for the number of comparisons has been reached. Create a new group + // to hold the current pointer. + Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); + } } // We've computed the grouped checks for this partition. @@ -522,19 +615,19 @@ public: : TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), PSE(PSE) {} /// Register a load and whether it is only read from. - void addLoad(MemoryLocation &Loc, bool IsReadOnly) { + void addLoad(MemoryLocation &Loc, Type *AccessTy, bool IsReadOnly) { Value *Ptr = const_cast<Value*>(Loc.Ptr); AST.add(Ptr, LocationSize::beforeOrAfterPointer(), Loc.AATags); - Accesses.insert(MemAccessInfo(Ptr, false)); + Accesses[MemAccessInfo(Ptr, false)].insert(AccessTy); if (IsReadOnly) ReadOnlyPtr.insert(Ptr); } /// Register a store. - void addStore(MemoryLocation &Loc) { + void addStore(MemoryLocation &Loc, Type *AccessTy) { Value *Ptr = const_cast<Value*>(Loc.Ptr); AST.add(Ptr, LocationSize::beforeOrAfterPointer(), Loc.AATags); - Accesses.insert(MemAccessInfo(Ptr, true)); + Accesses[MemAccessInfo(Ptr, true)].insert(AccessTy); } /// Check if we can emit a run-time no-alias check for \p Access. @@ -545,12 +638,11 @@ public: /// we will attempt to use additional run-time checks in order to get /// the bounds of the pointer. bool createCheckForAccess(RuntimePointerChecking &RtCheck, - MemAccessInfo Access, + MemAccessInfo Access, Type *AccessTy, const ValueToValueMap &Strides, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, - unsigned ASId, bool ShouldCheckStride, - bool Assume); + unsigned ASId, bool ShouldCheckStride, bool Assume); /// Check whether we can check the pointers at runtime for /// non-intersection. @@ -559,7 +651,7 @@ public: /// (i.e. the pointers have computable bounds). bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &Strides, - bool ShouldCheckWrap = false); + Value *&UncomputablePtr, bool ShouldCheckWrap = false); /// Goes over all memory accesses, checks whether a RT check is needed /// and builds sets of dependent accesses. @@ -583,14 +675,15 @@ public: MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; } private: - typedef SetVector<MemAccessInfo> PtrAccessSet; + typedef MapVector<MemAccessInfo, SmallSetVector<Type *, 1>> PtrAccessMap; /// Go over all memory access and check whether runtime pointer checks /// are needed and build sets of dependency check candidates. void processMemAccesses(); - /// Set of all accesses. - PtrAccessSet Accesses; + /// Map of all accesses. Values are the types used to access memory pointed to + /// by the pointer. + PtrAccessMap Accesses; /// The loop being checked. const Loop *TheLoop; @@ -630,11 +723,8 @@ private: /// Check whether a pointer can participate in a runtime bounds check. /// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr /// by adding run-time checks (overflow checks) if necessary. -static bool hasComputableBounds(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, - Loop *L, bool Assume) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - +static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr, + const SCEV *PtrScev, Loop *L, bool Assume) { // The bounds for loop-invariant pointer is trivial. if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; @@ -652,12 +742,12 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, /// Check whether a pointer address cannot wrap. static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Loop *L) { + const ValueToValueMap &Strides, Value *Ptr, Type *AccessTy, + Loop *L) { const SCEV *PtrScev = PSE.getSCEV(Ptr); if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; - Type *AccessTy = Ptr->getType()->getPointerElementType(); int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, L, Strides); if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) return true; @@ -689,7 +779,7 @@ static void visitPointers(Value *StartPtr, const Loop &InnermostLoop, } bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, - MemAccessInfo Access, + MemAccessInfo Access, Type *AccessTy, const ValueToValueMap &StridesMap, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, @@ -697,42 +787,75 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, bool Assume) { Value *Ptr = Access.getPointer(); - if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) - return false; + ScalarEvolution &SE = *PSE.getSE(); + SmallVector<std::pair<const SCEV *, bool>> TranslatedPtrs; + auto *SI = dyn_cast<SelectInst>(Ptr); + // Look through selects in the current loop. + if (SI && !TheLoop->isLoopInvariant(SI)) { + TranslatedPtrs = { + std::make_pair(SE.getSCEV(SI->getOperand(1)), + !isGuaranteedNotToBeUndefOrPoison(SI->getOperand(1))), + std::make_pair(SE.getSCEV(SI->getOperand(2)), + !isGuaranteedNotToBeUndefOrPoison(SI->getOperand(2)))}; + } else + TranslatedPtrs = { + std::make_pair(replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr), false)}; - // When we run after a failing dependency check we have to make sure - // we don't have wrapping pointers. - if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { - auto *Expr = PSE.getSCEV(Ptr); - if (!Assume || !isa<SCEVAddRecExpr>(Expr)) + for (auto &P : TranslatedPtrs) { + const SCEV *PtrExpr = P.first; + if (!hasComputableBounds(PSE, Ptr, PtrExpr, TheLoop, Assume)) return false; - PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + + // When we run after a failing dependency check we have to make sure + // we don't have wrapping pointers. + if (ShouldCheckWrap) { + // Skip wrap checking when translating pointers. + if (TranslatedPtrs.size() > 1) + return false; + + if (!isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) { + auto *Expr = PSE.getSCEV(Ptr); + if (!Assume || !isa<SCEVAddRecExpr>(Expr)) + return false; + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } + } + // If there's only one option for Ptr, look it up after bounds and wrap + // checking, because assumptions might have been added to PSE. + if (TranslatedPtrs.size() == 1) + TranslatedPtrs[0] = std::make_pair( + replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr), false); } - // The id of the dependence set. - unsigned DepId; + for (auto &P : TranslatedPtrs) { + const SCEV *PtrExpr = P.first; - if (isDependencyCheckNeeded()) { - Value *Leader = DepCands.getLeaderValue(Access).getPointer(); - unsigned &LeaderId = DepSetId[Leader]; - if (!LeaderId) - LeaderId = RunningDepId++; - DepId = LeaderId; - } else - // Each access has its own dependence set. - DepId = RunningDepId++; + // The id of the dependence set. + unsigned DepId; - bool IsWrite = Access.getInt(); - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); - LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + if (isDependencyCheckNeeded()) { + Value *Leader = DepCands.getLeaderValue(Access).getPointer(); + unsigned &LeaderId = DepSetId[Leader]; + if (!LeaderId) + LeaderId = RunningDepId++; + DepId = LeaderId; + } else + // Each access has its own dependence set. + DepId = RunningDepId++; + + bool IsWrite = Access.getInt(); + RtCheck.insert(TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE, + P.second); + LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + } return true; - } +} bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &StridesMap, - bool ShouldCheckWrap) { + Value *&UncomputablePtr, bool ShouldCheckWrap) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. bool CanDoRT = true; @@ -788,12 +911,15 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, } for (auto &Access : AccessInfos) { - if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, - RunningDepId, ASId, ShouldCheckWrap, false)) { - LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" - << *Access.getPointer() << '\n'); - Retries.push_back(Access); - CanDoAliasSetRT = false; + for (auto &AccessTy : Accesses[Access]) { + if (!createCheckForAccess(RtCheck, Access, AccessTy, StridesMap, + DepSetId, TheLoop, RunningDepId, ASId, + ShouldCheckWrap, false)) { + LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" + << *Access.getPointer() << '\n'); + Retries.push_back(Access); + CanDoAliasSetRT = false; + } } } @@ -815,13 +941,17 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, // We know that we need these checks, so we can now be more aggressive // and add further checks if required (overflow checks). CanDoAliasSetRT = true; - for (auto Access : Retries) - if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, - TheLoop, RunningDepId, ASId, - ShouldCheckWrap, /*Assume=*/true)) { - CanDoAliasSetRT = false; - break; + for (auto Access : Retries) { + for (auto &AccessTy : Accesses[Access]) { + if (!createCheckForAccess(RtCheck, Access, AccessTy, StridesMap, + DepSetId, TheLoop, RunningDepId, ASId, + ShouldCheckWrap, /*Assume=*/true)) { + CanDoAliasSetRT = false; + UncomputablePtr = Access.getPointer(); + break; + } } + } } CanDoRT &= CanDoAliasSetRT; @@ -886,9 +1016,12 @@ void AccessAnalysis::processMemAccesses() { LLVM_DEBUG(dbgs() << "LAA: Accesses(" << Accesses.size() << "):\n"); LLVM_DEBUG({ for (auto A : Accesses) - dbgs() << "\t" << *A.getPointer() << " (" << - (A.getInt() ? "write" : (ReadOnlyPtr.count(A.getPointer()) ? - "read-only" : "read")) << ")\n"; + dbgs() << "\t" << *A.first.getPointer() << " (" + << (A.first.getInt() + ? "write" + : (ReadOnlyPtr.count(A.first.getPointer()) ? "read-only" + : "read")) + << ")\n"; }); // The AliasSetTracker has nicely partitioned our pointers by metadata @@ -907,13 +1040,13 @@ void AccessAnalysis::processMemAccesses() { UnderlyingObjToAccessMap ObjToLastAccess; // Set of access to check after all writes have been processed. - PtrAccessSet DeferredAccesses; + PtrAccessMap DeferredAccesses; // Iterate over each alias set twice, once to process read/write pointers, // and then to process read-only pointers. for (int SetIteration = 0; SetIteration < 2; ++SetIteration) { bool UseDeferred = SetIteration > 0; - PtrAccessSet &S = UseDeferred ? DeferredAccesses : Accesses; + PtrAccessMap &S = UseDeferred ? DeferredAccesses : Accesses; for (const auto &AV : AS) { Value *Ptr = AV.getValue(); @@ -921,10 +1054,10 @@ void AccessAnalysis::processMemAccesses() { // For a single memory access in AliasSetTracker, Accesses may contain // both read and write, and they both need to be handled for CheckDeps. for (const auto &AC : S) { - if (AC.getPointer() != Ptr) + if (AC.first.getPointer() != Ptr) continue; - bool IsWrite = AC.getInt(); + bool IsWrite = AC.first.getInt(); // If we're using the deferred access set, then it contains only // reads. @@ -946,7 +1079,9 @@ void AccessAnalysis::processMemAccesses() { // consecutive as "read-only" pointers (so that we check // "a[b[i]] +="). Hence, we need the second check for "!IsWrite". if (!UseDeferred && IsReadOnlyPtr) { - DeferredAccesses.insert(Access); + // We only use the pointer keys, the types vector values don't + // matter. + DeferredAccesses.insert({Access, {}}); continue; } @@ -1445,13 +1580,13 @@ static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE, const SCEV *CastedDist = &Dist; const SCEV *CastedProduct = Product; - uint64_t DistTypeSize = DL.getTypeAllocSize(Dist.getType()); - uint64_t ProductTypeSize = DL.getTypeAllocSize(Product->getType()); + uint64_t DistTypeSizeBits = DL.getTypeSizeInBits(Dist.getType()); + uint64_t ProductTypeSizeBits = DL.getTypeSizeInBits(Product->getType()); // The dependence distance can be positive/negative, so we sign extend Dist; // The multiplication of the absolute stride in bytes and the // backedgeTakenCount is non-negative, so we zero extend Product. - if (DistTypeSize > ProductTypeSize) + if (DistTypeSizeBits > ProductTypeSizeBits) CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType()); else CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType()); @@ -1518,8 +1653,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, Value *BPtr = B.getPointer(); bool AIsWrite = A.getInt(); bool BIsWrite = B.getInt(); - Type *ATy = APtr->getType()->getPointerElementType(); - Type *BTy = BPtr->getType()->getPointerElementType(); + Type *ATy = getLoadStoreType(InstMap[AIdx]); + Type *BTy = getLoadStoreType(InstMap[BIdx]); // Two reads are independent. if (!AIsWrite && !BIsWrite) @@ -1842,8 +1977,6 @@ bool LoopAccessInfo::canAnalyzeLoop() { void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, const TargetLibraryInfo *TLI, DominatorTree *DT) { - typedef SmallPtrSet<Value*, 16> ValueSet; - // Holds the Load and Store instructions. SmallVector<LoadInst *, 16> Loads; SmallVector<StoreInst *, 16> Stores; @@ -1975,22 +2108,26 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, // for read and once for write, it will only appear once (on the write // list). This is okay, since we are going to check for conflicts between // writes and between reads and writes, but not between reads and reads. - ValueSet Seen; + SmallSet<std::pair<Value *, Type *>, 16> Seen; // Record uniform store addresses to identify if we have multiple stores // to the same address. - ValueSet UniformStores; + SmallPtrSet<Value *, 16> UniformStores; for (StoreInst *ST : Stores) { Value *Ptr = ST->getPointerOperand(); - if (isUniform(Ptr)) + if (isUniform(Ptr)) { + // Record store instructions to loop invariant addresses + StoresToInvariantAddresses.push_back(ST); HasDependenceInvolvingLoopInvariantAddress |= !UniformStores.insert(Ptr).second; + } // If we did *not* see this pointer before, insert it to the read-write // list. At this phase it is only a 'write' list. - if (Seen.insert(Ptr).second) { + Type *AccessTy = getLoadStoreType(ST); + if (Seen.insert({Ptr, AccessTy}).second) { ++NumReadWrites; MemoryLocation Loc = MemoryLocation::get(ST); @@ -2001,9 +2138,9 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, Loc.AATags.TBAA = nullptr; visitPointers(const_cast<Value *>(Loc.Ptr), *TheLoop, - [&Accesses, Loc](Value *Ptr) { + [&Accesses, AccessTy, Loc](Value *Ptr) { MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr); - Accesses.addStore(NewLoc); + Accesses.addStore(NewLoc, AccessTy); }); } } @@ -2027,7 +2164,8 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || + Type *AccessTy = getLoadStoreType(LD); + if (Seen.insert({Ptr, AccessTy}).second || !getPtrStride(*PSE, LD->getType(), Ptr, TheLoop, SymbolicStrides)) { ++NumReads; IsReadOnlyPtr = true; @@ -2049,9 +2187,9 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, Loc.AATags.TBAA = nullptr; visitPointers(const_cast<Value *>(Loc.Ptr), *TheLoop, - [&Accesses, Loc, IsReadOnlyPtr](Value *Ptr) { + [&Accesses, AccessTy, Loc, IsReadOnlyPtr](Value *Ptr) { MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr); - Accesses.addLoad(NewLoc, IsReadOnlyPtr); + Accesses.addLoad(NewLoc, AccessTy, IsReadOnlyPtr); }); } @@ -2069,10 +2207,14 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. - bool CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), - TheLoop, SymbolicStrides); + Value *UncomputablePtr = nullptr; + bool CanDoRTIfNeeded = + Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), TheLoop, + SymbolicStrides, UncomputablePtr, false); if (!CanDoRTIfNeeded) { - recordAnalysis("CantIdentifyArrayBounds") << "cannot identify array bounds"; + auto *I = dyn_cast_or_null<Instruction>(UncomputablePtr); + recordAnalysis("CantIdentifyArrayBounds", I) + << "cannot identify array bounds"; LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " << "the array bounds.\n"); CanVecMem = false; @@ -2099,12 +2241,14 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, PtrRtChecking->Need = true; auto *SE = PSE->getSE(); - CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, SE, TheLoop, - SymbolicStrides, true); + UncomputablePtr = nullptr; + CanDoRTIfNeeded = Accesses.canCheckPtrAtRT( + *PtrRtChecking, SE, TheLoop, SymbolicStrides, UncomputablePtr, true); // Check that we found the bounds for the pointer. if (!CanDoRTIfNeeded) { - recordAnalysis("CantCheckMemDepsAtRunTime") + auto *I = dyn_cast_or_null<Instruction>(UncomputablePtr); + recordAnalysis("CantCheckMemDepsAtRunTime", I) << "cannot check memory dependencies at runtime"; LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); CanVecMem = false; @@ -2129,13 +2273,61 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, dbgs() << "LAA: No unsafe dependent memory operations in loop. We" << (PtrRtChecking->Need ? "" : " don't") << " need runtime memory checks.\n"); - else { - recordAnalysis("UnsafeMemDep") - << "unsafe dependent memory operations in loop. Use " - "#pragma loop distribute(enable) to allow loop distribution " - "to attempt to isolate the offending operations into a separate " - "loop"; - LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); + else + emitUnsafeDependenceRemark(); +} + +void LoopAccessInfo::emitUnsafeDependenceRemark() { + auto Deps = getDepChecker().getDependences(); + if (!Deps) + return; + auto Found = std::find_if( + Deps->begin(), Deps->end(), [](const MemoryDepChecker::Dependence &D) { + return MemoryDepChecker::Dependence::isSafeForVectorization(D.Type) != + MemoryDepChecker::VectorizationSafetyStatus::Safe; + }); + if (Found == Deps->end()) + return; + MemoryDepChecker::Dependence Dep = *Found; + + LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); + + // Emit remark for first unsafe dependence + OptimizationRemarkAnalysis &R = + recordAnalysis("UnsafeDep", Dep.getDestination(*this)) + << "unsafe dependent memory operations in loop. Use " + "#pragma loop distribute(enable) to allow loop distribution " + "to attempt to isolate the offending operations into a separate " + "loop"; + + switch (Dep.Type) { + case MemoryDepChecker::Dependence::NoDep: + case MemoryDepChecker::Dependence::Forward: + case MemoryDepChecker::Dependence::BackwardVectorizable: + llvm_unreachable("Unexpected dependence"); + case MemoryDepChecker::Dependence::Backward: + R << "\nBackward loop carried data dependence."; + break; + case MemoryDepChecker::Dependence::ForwardButPreventsForwarding: + R << "\nForward loop carried data dependence that prevents " + "store-to-load forwarding."; + break; + case MemoryDepChecker::Dependence::BackwardVectorizableButPreventsForwarding: + R << "\nBackward loop carried data dependence that prevents " + "store-to-load forwarding."; + break; + case MemoryDepChecker::Dependence::Unknown: + R << "\nUnknown data dependence."; + break; + } + + if (Instruction *I = Dep.getSource(*this)) { + DebugLoc SourceLoc = I->getDebugLoc(); + if (auto *DD = dyn_cast_or_null<Instruction>(getPointerOperand(I))) + SourceLoc = DD->getDebugLoc(); + if (SourceLoc) + R << " Memory location is the same as accessed at " + << ore::NV("Location", SourceLoc); } } @@ -2212,12 +2404,12 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { // The Stride can be positive/negative, so we sign extend Stride; // The backedgeTakenCount is non-negative, so we zero extend BETakenCount. const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); - uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType()); - uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType()); + uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType()); + uint64_t BETypeSizeBits = DL.getTypeSizeInBits(BETakenCount->getType()); const SCEV *CastedStride = StrideExpr; const SCEV *CastedBECount = BETakenCount; ScalarEvolution *SE = PSE->getSE(); - if (BETypeSize >= StrideTypeSize) + if (BETypeSizeBits >= StrideTypeSizeBits) CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType()); else CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType()); @@ -2232,7 +2424,7 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { "at most once.\n"); return; } - LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version."); + LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n"); SymbolicStrides[Ptr] = Stride; StrideSet.insert(Stride); @@ -2242,10 +2434,12 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, AAResults *AA, DominatorTree *DT, LoopInfo *LI) : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)), - PtrRtChecking(std::make_unique<RuntimePointerChecking>(SE)), + PtrRtChecking(nullptr), DepChecker(std::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L) { - if (canAnalyzeLoop()) + PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE); + if (canAnalyzeLoop()) { analyzeLoop(AA, LI, TLI, DT); + } } void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { @@ -2283,7 +2477,7 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { << "found in loop.\n"; OS.indent(Depth) << "SCEV assumptions:\n"; - PSE->getUnionPredicate().print(OS, Depth); + PSE->getPredicate().print(OS, Depth); OS << "\n"; @@ -2301,7 +2495,7 @@ const LoopAccessInfo &LoopAccessLegacyAnalysis::getInfo(Loop *L) { if (!LAI) LAI = std::make_unique<LoopAccessInfo>(L, SE, TLI, AA, DT, LI); - return *LAI.get(); + return *LAI; } void LoopAccessLegacyAnalysis::print(raw_ostream &OS, const Module *M) const { diff --git a/llvm/lib/Analysis/LoopAnalysisManager.cpp b/llvm/lib/Analysis/LoopAnalysisManager.cpp index 4d6f8a64329a..8d71b31ca393 100644 --- a/llvm/lib/Analysis/LoopAnalysisManager.cpp +++ b/llvm/lib/Analysis/LoopAnalysisManager.cpp @@ -8,12 +8,9 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/PassManagerImpl.h" diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index ba014bd08c98..2cbf1f7f2d28 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -103,14 +103,24 @@ static bool isOneDimensionalArray(const SCEV &AccessFn, const SCEV &ElemSize, return StepRec == &ElemSize; } -/// Compute the trip count for the given loop \p L. Return the SCEV expression -/// for the trip count or nullptr if it cannot be computed. -static const SCEV *computeTripCount(const Loop &L, ScalarEvolution &SE) { +/// Compute the trip count for the given loop \p L or assume a default value if +/// it is not a compile time constant. Return the SCEV expression for the trip +/// count. +static const SCEV *computeTripCount(const Loop &L, const SCEV &ElemSize, + ScalarEvolution &SE) { const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(&L); - if (isa<SCEVCouldNotCompute>(BackedgeTakenCount) || - !isa<SCEVConstant>(BackedgeTakenCount)) - return nullptr; - return SE.getTripCountFromExitCount(BackedgeTakenCount); + const SCEV *TripCount = (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && + isa<SCEVConstant>(BackedgeTakenCount)) + ? SE.getTripCountFromExitCount(BackedgeTakenCount) + : nullptr; + + if (!TripCount) { + LLVM_DEBUG(dbgs() << "Trip count of loop " << L.getName() + << " could not be computed, using DefaultTripCount\n"); + TripCount = SE.getConstant(ElemSize.getType(), DefaultTripCount); + } + + return TripCount; } //===----------------------------------------------------------------------===// @@ -274,22 +284,18 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, return 1; } - const SCEV *TripCount = computeTripCount(L, SE); - if (!TripCount) { - LLVM_DEBUG(dbgs() << "Trip count of loop " << L.getName() - << " could not be computed, using DefaultTripCount\n"); - const SCEV *ElemSize = Sizes.back(); - TripCount = SE.getConstant(ElemSize->getType(), DefaultTripCount); - } + const SCEV *TripCount = computeTripCount(L, *Sizes.back(), SE); + assert(TripCount && "Expecting valid TripCount"); LLVM_DEBUG(dbgs() << "TripCount=" << *TripCount << "\n"); - // If the indexed reference is 'consecutive' the cost is - // (TripCount*Stride)/CLS, otherwise the cost is TripCount. - const SCEV *RefCost = TripCount; - + const SCEV *RefCost = nullptr; if (isConsecutive(L, CLS)) { + // If the indexed reference is 'consecutive' the cost is + // (TripCount*Stride)/CLS. const SCEV *Coeff = getLastCoefficient(); const SCEV *ElemSize = Sizes.back(); + assert(Coeff->getType() == ElemSize->getType() && + "Expecting the same type"); const SCEV *Stride = SE.getMulExpr(Coeff, ElemSize); Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType()); const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS); @@ -303,10 +309,33 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, LLVM_DEBUG(dbgs().indent(4) << "Access is consecutive: RefCost=(TripCount*Stride)/CLS=" << *RefCost << "\n"); - } else + } else { + // If the indexed reference is not 'consecutive' the cost is proportional to + // the trip count and the depth of the dimension which the subject loop + // subscript is accessing. We try to estimate this by multiplying the cost + // by the trip counts of loops corresponding to the inner dimensions. For + // example, given the indexed reference 'A[i][j][k]', and assuming the + // i-loop is in the innermost position, the cost would be equal to the + // iterations of the i-loop multiplied by iterations of the j-loop. + RefCost = TripCount; + + int Index = getSubscriptIndex(L); + assert(Index >= 0 && "Cound not locate a valid Index"); + + for (unsigned I = Index + 1; I < getNumSubscripts() - 1; ++I) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(getSubscript(I)); + assert(AR && AR->getLoop() && "Expecting valid loop"); + const SCEV *TripCount = + computeTripCount(*AR->getLoop(), *Sizes.back(), SE); + Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType()); + RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType), + SE.getNoopOrAnyExtend(TripCount, WiderType)); + } + LLVM_DEBUG(dbgs().indent(4) - << "Access is not consecutive: RefCost=TripCount=" << *RefCost - << "\n"); + << "Access is not consecutive: RefCost=" << *RefCost << "\n"); + } + assert(RefCost && "Expecting a valid RefCost"); // Attempt to fold RefCost into a constant. if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost)) @@ -319,6 +348,26 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, return CacheCost::InvalidCost; } +bool IndexedReference::tryDelinearizeFixedSize( + const SCEV *AccessFn, SmallVectorImpl<const SCEV *> &Subscripts) { + SmallVector<int, 4> ArraySizes; + if (!tryDelinearizeFixedSizeImpl(&SE, &StoreOrLoadInst, AccessFn, Subscripts, + ArraySizes)) + return false; + + // Populate Sizes with scev expressions to be used in calculations later. + for (auto Idx : seq<unsigned>(1, Subscripts.size())) + Sizes.push_back( + SE.getConstant(Subscripts[Idx]->getType(), ArraySizes[Idx - 1])); + + LLVM_DEBUG({ + dbgs() << "Delinearized subscripts of fixed-size array\n" + << "GEP:" << *getLoadStorePointerOperand(&StoreOrLoadInst) + << "\n"; + }); + return true; +} + bool IndexedReference::delinearize(const LoopInfo &LI) { assert(Subscripts.empty() && "Subscripts should be empty"); assert(Sizes.empty() && "Sizes should be empty"); @@ -340,13 +389,25 @@ bool IndexedReference::delinearize(const LoopInfo &LI) { return false; } - AccessFn = SE.getMinusSCEV(AccessFn, BasePointer); + bool IsFixedSize = false; + // Try to delinearize fixed-size arrays. + if (tryDelinearizeFixedSize(AccessFn, Subscripts)) { + IsFixedSize = true; + // The last element of Sizes is the element size. + Sizes.push_back(ElemSize); + LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName() + << "', AccessFn: " << *AccessFn << "\n"); + } - LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName() - << "', AccessFn: " << *AccessFn << "\n"); + AccessFn = SE.getMinusSCEV(AccessFn, BasePointer); - llvm::delinearize(SE, AccessFn, Subscripts, Sizes, - SE.getElementSize(&StoreOrLoadInst)); + // Try to delinearize parametric-size arrays. + if (!IsFixedSize) { + LLVM_DEBUG(dbgs().indent(2) << "In Loop '" << L->getName() + << "', AccessFn: " << *AccessFn << "\n"); + llvm::delinearize(SE, AccessFn, Subscripts, Sizes, + SE.getElementSize(&StoreOrLoadInst)); + } if (Subscripts.empty() || Sizes.empty() || Subscripts.size() != Sizes.size()) { @@ -424,6 +485,16 @@ bool IndexedReference::isConsecutive(const Loop &L, unsigned CLS) const { return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize); } +int IndexedReference::getSubscriptIndex(const Loop &L) const { + for (auto Idx : seq<int>(0, getNumSubscripts())) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(getSubscript(Idx)); + if (AR && AR->getLoop() == &L) { + return Idx; + } + } + return -1; +} + const SCEV *IndexedReference::getLastCoefficient() const { const SCEV *LastSubscript = getLastSubscript(); auto *AR = cast<SCEVAddRecExpr>(LastSubscript); @@ -550,7 +621,7 @@ bool CacheCost::populateReferenceGroups(ReferenceGroupsTy &RefGroups) const { bool Added = false; for (ReferenceGroupTy &RefGroup : RefGroups) { - const IndexedReference &Representative = *RefGroup.front().get(); + const IndexedReference &Representative = *RefGroup.front(); LLVM_DEBUG({ dbgs() << "References:\n"; dbgs().indent(2) << *R << "\n"; @@ -574,8 +645,8 @@ bool CacheCost::populateReferenceGroups(ReferenceGroupsTy &RefGroups) const { Optional<bool> HasSpacialReuse = R->hasSpacialReuse(Representative, CLS, AA); - if ((HasTemporalReuse.hasValue() && *HasTemporalReuse) || - (HasSpacialReuse.hasValue() && *HasSpacialReuse)) { + if ((HasTemporalReuse && *HasTemporalReuse) || + (HasSpacialReuse && *HasSpacialReuse)) { RefGroup.push_back(std::move(R)); Added = true; break; diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index b161c490a6bc..29c2437ff5ea 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -14,7 +14,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/LoopInfo.h" -#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/IVDescriptors.h" @@ -30,7 +29,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" @@ -38,9 +36,7 @@ #include "llvm/IR/PrintPasses.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> using namespace llvm; // Explicitly instantiate methods in LoopInfoImpl.h for IR-level Loops. @@ -740,6 +736,7 @@ void UnloopUpdater::updateBlockParents() { bool Changed = FoundIB; for (unsigned NIters = 0; Changed; ++NIters) { assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm"); + (void) NIters; // Iterate over the postorder list of blocks, propagating the nearest loop // from successors to predecessors as before. @@ -1085,13 +1082,13 @@ Optional<bool> llvm::getOptionalBoolLoopAttribute(const Loop *TheLoop, } bool llvm::getBooleanLoopAttribute(const Loop *TheLoop, StringRef Name) { - return getOptionalBoolLoopAttribute(TheLoop, Name).getValueOr(false); + return getOptionalBoolLoopAttribute(TheLoop, Name).value_or(false); } llvm::Optional<int> llvm::getOptionalIntLoopAttribute(const Loop *TheLoop, StringRef Name) { const MDOperand *AttrMD = - findStringMetadataForLoop(TheLoop, Name).getValueOr(nullptr); + findStringMetadataForLoop(TheLoop, Name).value_or(nullptr); if (!AttrMD) return None; @@ -1104,7 +1101,7 @@ llvm::Optional<int> llvm::getOptionalIntLoopAttribute(const Loop *TheLoop, int llvm::getIntLoopAttribute(const Loop *TheLoop, StringRef Name, int Default) { - return getOptionalIntLoopAttribute(TheLoop, Name).getValueOr(Default); + return getOptionalIntLoopAttribute(TheLoop, Name).value_or(Default); } bool llvm::isFinite(const Loop *L) { diff --git a/llvm/lib/Analysis/LoopNestAnalysis.cpp b/llvm/lib/Analysis/LoopNestAnalysis.cpp index 675bb7a7749c..bff796f339ab 100644 --- a/llvm/lib/Analysis/LoopNestAnalysis.cpp +++ b/llvm/lib/Analysis/LoopNestAnalysis.cpp @@ -13,8 +13,7 @@ #include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/ADT/BreadthFirstIterator.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/PostDominators.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/Analysis/ValueTracking.h" using namespace llvm; diff --git a/llvm/lib/Analysis/LoopPass.cpp b/llvm/lib/Analysis/LoopPass.cpp index b720bab454e9..5d824aece488 100644 --- a/llvm/lib/Analysis/LoopPass.cpp +++ b/llvm/lib/Analysis/LoopPass.cpp @@ -13,14 +13,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/OptBisect.h" -#include "llvm/IR/PassManager.h" #include "llvm/IR/PassTimingInfo.h" #include "llvm/IR/PrintPasses.h" -#include "llvm/IR/StructuralHash.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/TimeProfiler.h" @@ -192,12 +190,12 @@ bool LPPassManager::runOnFunction(Function &F) { PassManagerPrettyStackEntry X(P, *CurrentLoop->getHeader()); TimeRegion PassTimer(getPassTimer(P)); #ifdef EXPENSIVE_CHECKS - uint64_t RefHash = StructuralHash(F); + uint64_t RefHash = P->structuralHash(F); #endif LocalChanged = P->runOnLoop(CurrentLoop, *this); #ifdef EXPENSIVE_CHECKS - if (!LocalChanged && (RefHash != StructuralHash(F))) { + if (!LocalChanged && (RefHash != P->structuralHash(F))) { llvm::errs() << "Pass modifies its input and doesn't report it: " << P->getPassName() << "\n"; llvm_unreachable("Pass modifies its input and doesn't report it"); diff --git a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp index 15095d67d385..84f1eff9a732 100644 --- a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp +++ b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp @@ -13,7 +13,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/LoopUnrollAnalyzer.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Operator.h" using namespace llvm; @@ -84,9 +87,9 @@ bool UnrolledInstAnalyzer::visitBinaryOperator(BinaryOperator &I) { const DataLayout &DL = I.getModule()->getDataLayout(); if (auto FI = dyn_cast<FPMathOperator>(&I)) SimpleV = - SimplifyBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); + simplifyBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); else - SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); + SimpleV = simplifyBinOp(I.getOpcode(), LHS, RHS, DL); if (SimpleV) { SimplifiedValues[&I] = SimpleV; @@ -155,7 +158,7 @@ bool UnrolledInstAnalyzer::visitCastInst(CastInst &I) { // i32 0). if (CastInst::castIsValid(I.getOpcode(), Op, I.getType())) { const DataLayout &DL = I.getModule()->getDataLayout(); - if (Value *V = SimplifyCastInst(I.getOpcode(), Op, I.getType(), DL)) { + if (Value *V = simplifyCastInst(I.getOpcode(), Op, I.getType(), DL)) { SimplifiedValues[&I] = V; return true; } @@ -192,7 +195,7 @@ bool UnrolledInstAnalyzer::visitCmpInst(CmpInst &I) { } const DataLayout &DL = I.getModule()->getDataLayout(); - if (Value *V = SimplifyCmpInst(I.getPredicate(), LHS, RHS, DL)) { + if (Value *V = simplifyCmpInst(I.getPredicate(), LHS, RHS, DL)) { SimplifiedValues[&I] = V; return true; } diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index 0480c1cd2842..f55de71ea98a 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -13,30 +13,25 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/MLInlineAdvisor.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/FunctionPropertiesAnalysis.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InlineModelFeatureMaps.h" #include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MLModelRunner.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/ReleaseModeModelRunner.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Config/config.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Path.h" - -#include <limits> -#include <unordered_map> -#include <unordered_set> using namespace llvm; #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL) +#include "llvm/Analysis/ReleaseModeModelRunner.h" // codegen-ed file #include "InlinerSizeModel.h" // NOLINT @@ -44,7 +39,7 @@ std::unique_ptr<InlineAdvisor> llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM) { auto AOTRunner = std::make_unique<ReleaseModeModelRunner<llvm::InlinerSizeModel>>( - M.getContext(), FeatureNameMap, DecisionName); + M.getContext(), FeatureMap, DecisionName); return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner)); } #endif @@ -57,15 +52,21 @@ static cl::opt<float> SizeIncreaseThreshold( "blocking any further inlining."), cl::init(2.0)); +static cl::opt<bool> KeepFPICache( + "ml-advisor-keep-fpi-cache", cl::Hidden, + cl::desc( + "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"), + cl::init(false)); + // clang-format off -const std::array<std::string, NumberOfFeatures> llvm::FeatureNameMap{ +const std::array<TensorSpec, NumberOfFeatures> llvm::FeatureMap{ +#define POPULATE_NAMES(_, NAME) TensorSpec::createSpec<int64_t>(NAME, {1} ), // InlineCost features - these must come first -#define POPULATE_NAMES(INDEX_NAME, NAME) NAME, INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES) #undef POPULATE_NAMES // Non-cost features -#define POPULATE_NAMES(INDEX_NAME, NAME, COMMENT) NAME, +#define POPULATE_NAMES(_, NAME, __) TensorSpec::createSpec<int64_t>(NAME, {1} ), INLINE_FEATURE_ITERATOR(POPULATE_NAMES) #undef POPULATE_NAMES }; @@ -138,7 +139,10 @@ unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0; } -void MLInlineAdvisor::onPassEntry() { +void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *LastSCC) { + if (!LastSCC || ForceStop) + return; + FPICache.clear(); // Function passes executed between InlinerPass runs may have changed the // module-wide features. // The cgscc pass manager rules are such that: @@ -154,8 +158,8 @@ void MLInlineAdvisor::onPassEntry() { // care about the nature of the Edge (call or ref). NodeCount -= static_cast<int64_t>(NodesInLastSCC.size()); while (!NodesInLastSCC.empty()) { - const auto *N = NodesInLastSCC.front(); - NodesInLastSCC.pop_front(); + const auto *N = *NodesInLastSCC.begin(); + NodesInLastSCC.erase(N); // The Function wrapped by N could have been deleted since we last saw it. if (N->isDead()) { assert(!N->getFunction().isDeclaration()); @@ -168,34 +172,52 @@ void MLInlineAdvisor::onPassEntry() { assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration()); auto I = AllNodes.insert(AdjNode); if (I.second) - NodesInLastSCC.push_back(AdjNode); + NodesInLastSCC.insert(AdjNode); } } EdgeCount -= EdgesOfLastSeenNodes; EdgesOfLastSeenNodes = 0; + + // (Re)use NodesInLastSCC to remember the nodes in the SCC right now, + // in case the SCC is split before onPassExit and some nodes are split out + assert(NodesInLastSCC.empty()); + for (const auto &N : *LastSCC) + NodesInLastSCC.insert(&N); } void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *LastSCC) { - if (!LastSCC) + // No need to keep this around - function passes will invalidate it. + if (!KeepFPICache) + FPICache.clear(); + if (!LastSCC || ForceStop) return; // Keep track of the nodes and edges we last saw. Then, in onPassEntry, // we update the node count and edge count from the subset of these nodes that // survived. - assert(NodesInLastSCC.empty()); - assert(NodeCount >= LastSCC->size()); EdgesOfLastSeenNodes = 0; + + // Check on nodes that were in SCC onPassEntry + for (auto I = NodesInLastSCC.begin(); I != NodesInLastSCC.end();) { + if ((*I)->isDead()) + NodesInLastSCC.erase(*I++); + else + EdgesOfLastSeenNodes += getLocalCalls((*I++)->getFunction()); + } + + // Check on nodes that may have got added to SCC for (const auto &N : *LastSCC) { assert(!N.isDead()); - EdgesOfLastSeenNodes += getLocalCalls(N.getFunction()); - NodesInLastSCC.push_back(&N); + auto I = NodesInLastSCC.insert(&N); + if (I.second) + EdgesOfLastSeenNodes += getLocalCalls(N.getFunction()); } + assert(NodeCount >= NodesInLastSCC.size()); assert(EdgeCount >= EdgesOfLastSeenNodes); } int64_t MLInlineAdvisor::getLocalCalls(Function &F) { - return FAM.getResult<FunctionPropertiesAnalysis>(F) - .DirectCallsToDefinedFunctions; + return getCachedFPI(F).DirectCallsToDefinedFunctions; } // Update the internal state of the advisor, and force invalidate feature @@ -208,13 +230,15 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, assert(!ForceStop); Function *Caller = Advice.getCaller(); Function *Callee = Advice.getCallee(); - // The caller features aren't valid anymore. { PreservedAnalyses PA = PreservedAnalyses::all(); PA.abandon<FunctionPropertiesAnalysis>(); + PA.abandon<DominatorTreeAnalysis>(); + PA.abandon<LoopAnalysis>(); FAM.invalidate(*Caller, PA); } + Advice.updateCachedCallerFPI(FAM); int64_t IRSizeAfter = getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); @@ -227,15 +251,13 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, // For edges, we 'forget' the edges that the caller and callee used to have // before inlining, and add back what they currently have together. int64_t NewCallerAndCalleeEdges = - FAM.getResult<FunctionPropertiesAnalysis>(*Caller) - .DirectCallsToDefinedFunctions; + getCachedFPI(*Caller).DirectCallsToDefinedFunctions; if (CalleeWasDeleted) --NodeCount; else NewCallerAndCalleeEdges += - FAM.getResult<FunctionPropertiesAnalysis>(*Callee) - .DirectCallsToDefinedFunctions; + getCachedFPI(*Callee).DirectCallsToDefinedFunctions; EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); } @@ -248,7 +270,19 @@ int64_t MLInlineAdvisor::getModuleIRSize() const { return Ret; } +FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const { + auto InsertPair = + FPICache.insert(std::make_pair(&F, FunctionPropertiesInfo())); + if (!InsertPair.second) + return InsertPair.first->second; + InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F); + return InsertPair.first->second; +} + std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { + if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) + return Skip; + auto &Caller = *CB.getCaller(); auto &Callee = *CB.getCalledFunction(); @@ -307,8 +341,8 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { NrCtantParams += (isa<Constant>(*I)); } - auto &CallerBefore = FAM.getResult<FunctionPropertiesAnalysis>(Caller); - auto &CalleeBefore = FAM.getResult<FunctionPropertiesAnalysis>(Callee); + auto &CallerBefore = getCachedFPI(Caller); + auto &CalleeBefore = getCachedFPI(Callee); *ModelRunner->getTensor<int64_t>(FeatureIndex::CalleeBasicBlockCount) = CalleeBefore.BasicBlockCount; @@ -348,9 +382,19 @@ MLInlineAdvisor::getAdviceFromModel(CallBase &CB, this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>())); } +std::unique_ptr<InlineAdvice> +MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) { + if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller()) + .isReachableFromEntry(CB.getParent())) + return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); + return nullptr; +} + std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB, bool Advice) { // Make sure we track inlinings in all cases - mandatory or not. + if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) + return Skip; if (Advice && !ForceStop) return getMandatoryAdviceImpl(CB); @@ -366,16 +410,47 @@ MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) { return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true); } +void MLInlineAdvisor::print(raw_ostream &OS) const { + OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount + << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n"; + OS << "[MLInlineAdvisor] FPI:\n"; + for (auto I : FPICache) { + OS << I.getFirst()->getName() << ":\n"; + I.getSecond().print(OS); + OS << "\n"; + } + OS << "\n"; +} + +MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, + OptimizationRemarkEmitter &ORE, + bool Recommendation) + : InlineAdvice(Advisor, CB, ORE, Recommendation), + CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)), + CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)), + CallerAndCalleeEdges(Advisor->isForcedToStop() + ? 0 + : (Advisor->getLocalCalls(*Caller) + + Advisor->getLocalCalls(*Callee))), + PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) { + if (Recommendation) + FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB); +} + void MLInlineAdvice::reportContextForRemark( DiagnosticInfoOptimizationBase &OR) { using namespace ore; OR << NV("Callee", Callee->getName()); for (size_t I = 0; I < NumberOfFeatures; ++I) - OR << NV(FeatureNameMap[I], + OR << NV(FeatureMap[I].name(), *getAdvisor()->getModelRunner().getTensor<int64_t>(I)); OR << NV("ShouldInline", isInliningRecommended()); } +void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const { + FPU->finish(FAM); +} + void MLInlineAdvice::recordInliningImpl() { ORE.emit([&]() { OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block); @@ -397,6 +472,7 @@ void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() { void MLInlineAdvice::recordUnsuccessfulInliningImpl( const InlineResult &Result) { + getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI; ORE.emit([&]() { OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful", DLoc, Block); @@ -405,6 +481,7 @@ void MLInlineAdvice::recordUnsuccessfulInliningImpl( }); } void MLInlineAdvice::recordUnattemptedInliningImpl() { + assert(!FPU); ORE.emit([&]() { OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block); reportContextForRemark(R); diff --git a/llvm/lib/Analysis/MemDepPrinter.cpp b/llvm/lib/Analysis/MemDepPrinter.cpp index 00642347102a..305ae3e2a992 100644 --- a/llvm/lib/Analysis/MemDepPrinter.cpp +++ b/llvm/lib/Analysis/MemDepPrinter.cpp @@ -15,7 +15,6 @@ #include "llvm/Analysis/Passes.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/InitializePasses.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" diff --git a/llvm/lib/Analysis/MemDerefPrinter.cpp b/llvm/lib/Analysis/MemDerefPrinter.cpp index 82617c7256a5..4dd5c76cc604 100644 --- a/llvm/lib/Analysis/MemDerefPrinter.cpp +++ b/llvm/lib/Analysis/MemDerefPrinter.cpp @@ -9,14 +9,11 @@ #include "llvm/Analysis/MemDerefPrinter.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/Passes.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; diff --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp index 208f93aa1ac6..91501b04448e 100644 --- a/llvm/lib/Analysis/MemoryBuiltins.cpp +++ b/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -17,7 +17,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/Utils/Local.h" @@ -43,6 +43,8 @@ #include <cassert> #include <cstdint> #include <iterator> +#include <numeric> +#include <type_traits> #include <utility> using namespace llvm; @@ -62,6 +64,42 @@ enum AllocType : uint8_t { AnyAlloc = AllocLike | ReallocLike }; +enum class MallocFamily { + Malloc, + CPPNew, // new(unsigned int) + CPPNewAligned, // new(unsigned int, align_val_t) + CPPNewArray, // new[](unsigned int) + CPPNewArrayAligned, // new[](unsigned long, align_val_t) + MSVCNew, // new(unsigned int) + MSVCArrayNew, // new[](unsigned int) + VecMalloc, + KmpcAllocShared, +}; + +StringRef mangledNameForMallocFamily(const MallocFamily &Family) { + switch (Family) { + case MallocFamily::Malloc: + return "malloc"; + case MallocFamily::CPPNew: + return "_Znwm"; + case MallocFamily::CPPNewAligned: + return "_ZnwmSt11align_val_t"; + case MallocFamily::CPPNewArray: + return "_Znam"; + case MallocFamily::CPPNewArrayAligned: + return "_ZnamSt11align_val_t"; + case MallocFamily::MSVCNew: + return "??2@YAPAXI@Z"; + case MallocFamily::MSVCArrayNew: + return "??_U@YAPAXI@Z"; + case MallocFamily::VecMalloc: + return "vec_malloc"; + case MallocFamily::KmpcAllocShared: + return "__kmpc_alloc_shared"; + } + llvm_unreachable("missing an alloc family"); +} + struct AllocFnsTy { AllocType AllocTy; unsigned NumParams; @@ -69,50 +107,55 @@ struct AllocFnsTy { int FstParam, SndParam; // Alignment parameter for aligned_alloc and aligned new int AlignParam; + // Name of default allocator function to group malloc/free calls by family + MallocFamily Family; }; +// clang-format off // FIXME: certain users need more information. E.g., SimplifyLibCalls needs to // know which functions are nounwind, noalias, nocapture parameters, etc. static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = { - {LibFunc_malloc, {MallocLike, 1, 0, -1, -1}}, - {LibFunc_vec_malloc, {MallocLike, 1, 0, -1, -1}}, - {LibFunc_valloc, {MallocLike, 1, 0, -1, -1}}, - {LibFunc_Znwj, {OpNewLike, 1, 0, -1, -1}}, // new(unsigned int) - {LibFunc_ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1}}, // new(unsigned int, nothrow) - {LibFunc_ZnwjSt11align_val_t, {OpNewLike, 2, 0, -1, 1}}, // new(unsigned int, align_val_t) - {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1}}, // new(unsigned int, align_val_t, nothrow) - {LibFunc_Znwm, {OpNewLike, 1, 0, -1, -1}}, // new(unsigned long) - {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1}}, // new(unsigned long, nothrow) - {LibFunc_ZnwmSt11align_val_t, {OpNewLike, 2, 0, -1, 1}}, // new(unsigned long, align_val_t) - {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1}}, // new(unsigned long, align_val_t, nothrow) - {LibFunc_Znaj, {OpNewLike, 1, 0, -1, -1}}, // new[](unsigned int) - {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1}}, // new[](unsigned int, nothrow) - {LibFunc_ZnajSt11align_val_t, {OpNewLike, 2, 0, -1, 1}}, // new[](unsigned int, align_val_t) - {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1}}, // new[](unsigned int, align_val_t, nothrow) - {LibFunc_Znam, {OpNewLike, 1, 0, -1, -1}}, // new[](unsigned long) - {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1}}, // new[](unsigned long, nothrow) - {LibFunc_ZnamSt11align_val_t, {OpNewLike, 2, 0, -1, 1}}, // new[](unsigned long, align_val_t) - {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1}}, // new[](unsigned long, align_val_t, nothrow) - {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1, -1}}, // new(unsigned int) - {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1, -1}}, // new(unsigned int, nothrow) - {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1, -1}}, // new(unsigned long long) - {LibFunc_msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1, -1}}, // new(unsigned long long, nothrow) - {LibFunc_msvc_new_array_int, {OpNewLike, 1, 0, -1, -1}}, // new[](unsigned int) - {LibFunc_msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1, -1}}, // new[](unsigned int, nothrow) - {LibFunc_msvc_new_array_longlong, {OpNewLike, 1, 0, -1, -1}}, // new[](unsigned long long) - {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1, -1}}, // new[](unsigned long long, nothrow) - {LibFunc_aligned_alloc, {AlignedAllocLike, 2, 1, -1, 0}}, - {LibFunc_memalign, {AlignedAllocLike, 2, 1, -1, 0}}, - {LibFunc_calloc, {CallocLike, 2, 0, 1, -1}}, - {LibFunc_vec_calloc, {CallocLike, 2, 0, 1, -1}}, - {LibFunc_realloc, {ReallocLike, 2, 1, -1, -1}}, - {LibFunc_vec_realloc, {ReallocLike, 2, 1, -1, -1}}, - {LibFunc_reallocf, {ReallocLike, 2, 1, -1, -1}}, - {LibFunc_strdup, {StrDupLike, 1, -1, -1, -1}}, - {LibFunc_strndup, {StrDupLike, 2, 1, -1, -1}}, - {LibFunc___kmpc_alloc_shared, {MallocLike, 1, 0, -1, -1}}, - // TODO: Handle "int posix_memalign(void **, size_t, size_t)" + {LibFunc_malloc, {MallocLike, 1, 0, -1, -1, MallocFamily::Malloc}}, + {LibFunc_vec_malloc, {MallocLike, 1, 0, -1, -1, MallocFamily::VecMalloc}}, + {LibFunc_valloc, {MallocLike, 1, 0, -1, -1, MallocFamily::Malloc}}, + {LibFunc_Znwj, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned int) + {LibFunc_ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned int, nothrow) + {LibFunc_ZnwjSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned int, align_val_t) + {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned int, align_val_t, nothrow) + {LibFunc_Znwm, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long) + {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long, nothrow) + {LibFunc_ZnwmSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t) + {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t, nothrow) + {LibFunc_Znaj, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned int) + {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned int, nothrow) + {LibFunc_ZnajSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned int, align_val_t) + {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned int, align_val_t, nothrow) + {LibFunc_Znam, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned long) + {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned long, nothrow) + {LibFunc_ZnamSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned long, align_val_t) + {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned long, align_val_t, nothrow) + {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned int) + {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned int, nothrow) + {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned long long) + {LibFunc_msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned long long, nothrow) + {LibFunc_msvc_new_array_int, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCArrayNew}}, // new[](unsigned int) + {LibFunc_msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1, -1, MallocFamily::MSVCArrayNew}}, // new[](unsigned int, nothrow) + {LibFunc_msvc_new_array_longlong, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCArrayNew}}, // new[](unsigned long long) + {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1, -1, MallocFamily::MSVCArrayNew}}, // new[](unsigned long long, nothrow) + {LibFunc_aligned_alloc, {AlignedAllocLike, 2, 1, -1, 0, MallocFamily::Malloc}}, + {LibFunc_memalign, {AlignedAllocLike, 2, 1, -1, 0, MallocFamily::Malloc}}, + {LibFunc_calloc, {CallocLike, 2, 0, 1, -1, MallocFamily::Malloc}}, + {LibFunc_vec_calloc, {CallocLike, 2, 0, 1, -1, MallocFamily::VecMalloc}}, + {LibFunc_realloc, {ReallocLike, 2, 1, -1, -1, MallocFamily::Malloc}}, + {LibFunc_vec_realloc, {ReallocLike, 2, 1, -1, -1, MallocFamily::VecMalloc}}, + {LibFunc_reallocf, {ReallocLike, 2, 1, -1, -1, MallocFamily::Malloc}}, + {LibFunc_strdup, {StrDupLike, 1, -1, -1, -1, MallocFamily::Malloc}}, + {LibFunc_dunder_strdup, {StrDupLike, 1, -1, -1, -1, MallocFamily::Malloc}}, + {LibFunc_strndup, {StrDupLike, 2, 1, -1, -1, MallocFamily::Malloc}}, + {LibFunc_dunder_strndup, {StrDupLike, 2, 1, -1, -1, MallocFamily::Malloc}}, + {LibFunc___kmpc_alloc_shared, {MallocLike, 1, 0, -1, -1, MallocFamily::KmpcAllocShared}}, }; +// clang-format on static const Function *getCalledFunction(const Value *V, bool &IsNoBuiltin) { @@ -217,7 +260,7 @@ static Optional<AllocFnsTy> getAllocationSize(const Value *V, Result.AllocTy = MallocLike; Result.NumParams = Callee->getNumOperands(); Result.FstParam = Args.first; - Result.SndParam = Args.second.getValueOr(-1); + Result.SndParam = Args.second.value_or(-1); // Allocsize has no way to specify an alignment argument Result.AlignParam = -1; return Result; @@ -227,54 +270,53 @@ static Optional<AllocFnsTy> getAllocationSize(const Value *V, /// allocates or reallocates memory (either malloc, calloc, realloc, or strdup /// like). bool llvm::isAllocationFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, AnyAlloc, TLI).hasValue(); + return getAllocationData(V, AnyAlloc, TLI).has_value(); } bool llvm::isAllocationFn( const Value *V, function_ref<const TargetLibraryInfo &(Function &)> GetTLI) { - return getAllocationData(V, AnyAlloc, GetTLI).hasValue(); + return getAllocationData(V, AnyAlloc, GetTLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// allocates uninitialized memory (such as malloc). static bool isMallocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, MallocOrOpNewLike, TLI).hasValue(); + return getAllocationData(V, MallocOrOpNewLike, TLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// allocates uninitialized memory with alignment (such as aligned_alloc). static bool isAlignedAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, AlignedAllocLike, TLI) - .hasValue(); + return getAllocationData(V, AlignedAllocLike, TLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// allocates zero-filled memory (such as calloc). static bool isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, CallocLike, TLI).hasValue(); + return getAllocationData(V, CallocLike, TLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// allocates memory similar to malloc or calloc. bool llvm::isMallocOrCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, MallocOrCallocLike, TLI).hasValue(); + return getAllocationData(V, MallocOrCallocLike, TLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// allocates memory (either malloc, calloc, or strdup like). bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, AllocLike, TLI).hasValue(); + return getAllocationData(V, AllocLike, TLI).has_value(); } /// Tests if a value is a call or invoke to a library function that /// reallocates memory (e.g., realloc). bool llvm::isReallocLikeFn(const Value *V, const TargetLibraryInfo *TLI) { - return getAllocationData(V, ReallocLike, TLI).hasValue(); + return getAllocationData(V, ReallocLike, TLI).has_value(); } /// Tests if a functions is a call or invoke to a library function that /// reallocates memory (e.g., realloc). bool llvm::isReallocLikeFn(const Function *F, const TargetLibraryInfo *TLI) { - return getAllocationDataForFunction(F, ReallocLike, TLI).hasValue(); + return getAllocationDataForFunction(F, ReallocLike, TLI).has_value(); } bool llvm::isAllocRemovable(const CallBase *CB, const TargetLibraryInfo *TLI) { @@ -291,13 +333,11 @@ bool llvm::isAllocRemovable(const CallBase *CB, const TargetLibraryInfo *TLI) { Value *llvm::getAllocAlignment(const CallBase *V, const TargetLibraryInfo *TLI) { - assert(isAllocationFn(V, TLI)); - const Optional<AllocFnsTy> FnData = getAllocationData(V, AnyAlloc, TLI); - if (!FnData.hasValue() || FnData->AlignParam < 0) { - return nullptr; + if (FnData && FnData->AlignParam >= 0) { + return V->getOperand(FnData->AlignParam); } - return V->getOperand(FnData->AlignParam); + return V->getArgOperandWithAttribute(Attribute::AllocAlign); } /// When we're compiling N-bit code, and the user uses parameters that are @@ -344,7 +384,7 @@ llvm::getAllocSize(const CallBase *CB, if (!Arg) return None; - APInt MaxSize = Arg->getValue().zextOrSelf(IntTyBits); + APInt MaxSize = Arg->getValue().zext(IntTyBits); if (Size.ugt(MaxSize)) Size = MaxSize + 1; } @@ -379,10 +419,12 @@ llvm::getAllocSize(const CallBase *CB, return Size; } -Constant *llvm::getInitialValueOfAllocation(const CallBase *Alloc, +Constant *llvm::getInitialValueOfAllocation(const Value *V, const TargetLibraryInfo *TLI, Type *Ty) { - assert(isAllocationFn(Alloc, TLI)); + auto *Alloc = dyn_cast<CallBase>(V); + if (!Alloc) + return nullptr; // malloc and aligned_alloc are uninitialized (undef) if (isMallocLikeFn(Alloc, TLI) || isAlignedAllocLikeFn(Alloc, TLI)) @@ -395,43 +437,81 @@ Constant *llvm::getInitialValueOfAllocation(const CallBase *Alloc, return nullptr; } +struct FreeFnsTy { + unsigned NumParams; + // Name of default allocator function to group malloc/free calls by family + MallocFamily Family; +}; + +// clang-format off +static const std::pair<LibFunc, FreeFnsTy> FreeFnData[] = { + {LibFunc_free, {1, MallocFamily::Malloc}}, + {LibFunc_vec_free, {1, MallocFamily::VecMalloc}}, + {LibFunc_ZdlPv, {1, MallocFamily::CPPNew}}, // operator delete(void*) + {LibFunc_ZdaPv, {1, MallocFamily::CPPNewArray}}, // operator delete[](void*) + {LibFunc_msvc_delete_ptr32, {1, MallocFamily::MSVCNew}}, // operator delete(void*) + {LibFunc_msvc_delete_ptr64, {1, MallocFamily::MSVCNew}}, // operator delete(void*) + {LibFunc_msvc_delete_array_ptr32, {1, MallocFamily::MSVCArrayNew}}, // operator delete[](void*) + {LibFunc_msvc_delete_array_ptr64, {1, MallocFamily::MSVCArrayNew}}, // operator delete[](void*) + {LibFunc_ZdlPvj, {2, MallocFamily::CPPNew}}, // delete(void*, uint) + {LibFunc_ZdlPvm, {2, MallocFamily::CPPNew}}, // delete(void*, ulong) + {LibFunc_ZdlPvRKSt9nothrow_t, {2, MallocFamily::CPPNew}}, // delete(void*, nothrow) + {LibFunc_ZdlPvSt11align_val_t, {2, MallocFamily::CPPNewAligned}}, // delete(void*, align_val_t) + {LibFunc_ZdaPvj, {2, MallocFamily::CPPNewArray}}, // delete[](void*, uint) + {LibFunc_ZdaPvm, {2, MallocFamily::CPPNewArray}}, // delete[](void*, ulong) + {LibFunc_ZdaPvRKSt9nothrow_t, {2, MallocFamily::CPPNewArray}}, // delete[](void*, nothrow) + {LibFunc_ZdaPvSt11align_val_t, {2, MallocFamily::CPPNewArrayAligned}}, // delete[](void*, align_val_t) + {LibFunc_msvc_delete_ptr32_int, {2, MallocFamily::MSVCNew}}, // delete(void*, uint) + {LibFunc_msvc_delete_ptr64_longlong, {2, MallocFamily::MSVCNew}}, // delete(void*, ulonglong) + {LibFunc_msvc_delete_ptr32_nothrow, {2, MallocFamily::MSVCNew}}, // delete(void*, nothrow) + {LibFunc_msvc_delete_ptr64_nothrow, {2, MallocFamily::MSVCNew}}, // delete(void*, nothrow) + {LibFunc_msvc_delete_array_ptr32_int, {2, MallocFamily::MSVCArrayNew}}, // delete[](void*, uint) + {LibFunc_msvc_delete_array_ptr64_longlong, {2, MallocFamily::MSVCArrayNew}}, // delete[](void*, ulonglong) + {LibFunc_msvc_delete_array_ptr32_nothrow, {2, MallocFamily::MSVCArrayNew}}, // delete[](void*, nothrow) + {LibFunc_msvc_delete_array_ptr64_nothrow, {2, MallocFamily::MSVCArrayNew}}, // delete[](void*, nothrow) + {LibFunc___kmpc_free_shared, {2, MallocFamily::KmpcAllocShared}}, // OpenMP Offloading RTL free + {LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t, {3, MallocFamily::CPPNewAligned}}, // delete(void*, align_val_t, nothrow) + {LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t, {3, MallocFamily::CPPNewArrayAligned}}, // delete[](void*, align_val_t, nothrow) + {LibFunc_ZdlPvjSt11align_val_t, {3, MallocFamily::CPPNewAligned}}, // delete(void*, unsigned int, align_val_t) + {LibFunc_ZdlPvmSt11align_val_t, {3, MallocFamily::CPPNewAligned}}, // delete(void*, unsigned long, align_val_t) + {LibFunc_ZdaPvjSt11align_val_t, {3, MallocFamily::CPPNewArrayAligned}}, // delete[](void*, unsigned int, align_val_t) + {LibFunc_ZdaPvmSt11align_val_t, {3, MallocFamily::CPPNewArrayAligned}}, // delete[](void*, unsigned long, align_val_t) +}; +// clang-format on + +Optional<FreeFnsTy> getFreeFunctionDataForFunction(const Function *Callee, + const LibFunc TLIFn) { + const auto *Iter = + find_if(FreeFnData, [TLIFn](const std::pair<LibFunc, FreeFnsTy> &P) { + return P.first == TLIFn; + }); + if (Iter == std::end(FreeFnData)) + return None; + return Iter->second; +} + +Optional<StringRef> llvm::getAllocationFamily(const Value *I, + const TargetLibraryInfo *TLI) { + bool IsNoBuiltin; + const Function *Callee = getCalledFunction(I, IsNoBuiltin); + if (Callee == nullptr || IsNoBuiltin) + return None; + LibFunc TLIFn; + if (!TLI || !TLI->getLibFunc(*Callee, TLIFn) || !TLI->has(TLIFn)) + return None; + const auto AllocData = getAllocationDataForFunction(Callee, AnyAlloc, TLI); + if (AllocData) + return mangledNameForMallocFamily(AllocData.getValue().Family); + const auto FreeData = getFreeFunctionDataForFunction(Callee, TLIFn); + if (FreeData) + return mangledNameForMallocFamily(FreeData.getValue().Family); + return None; +} + /// isLibFreeFunction - Returns true if the function is a builtin free() bool llvm::isLibFreeFunction(const Function *F, const LibFunc TLIFn) { - unsigned ExpectedNumParams; - if (TLIFn == LibFunc_free || - TLIFn == LibFunc_ZdlPv || // operator delete(void*) - TLIFn == LibFunc_ZdaPv || // operator delete[](void*) - TLIFn == LibFunc_msvc_delete_ptr32 || // operator delete(void*) - TLIFn == LibFunc_msvc_delete_ptr64 || // operator delete(void*) - TLIFn == LibFunc_msvc_delete_array_ptr32 || // operator delete[](void*) - TLIFn == LibFunc_msvc_delete_array_ptr64) // operator delete[](void*) - ExpectedNumParams = 1; - else if (TLIFn == LibFunc_ZdlPvj || // delete(void*, uint) - TLIFn == LibFunc_ZdlPvm || // delete(void*, ulong) - TLIFn == LibFunc_ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) - TLIFn == LibFunc_ZdlPvSt11align_val_t || // delete(void*, align_val_t) - TLIFn == LibFunc_ZdaPvj || // delete[](void*, uint) - TLIFn == LibFunc_ZdaPvm || // delete[](void*, ulong) - TLIFn == LibFunc_ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow) - TLIFn == LibFunc_ZdaPvSt11align_val_t || // delete[](void*, align_val_t) - TLIFn == LibFunc_msvc_delete_ptr32_int || // delete(void*, uint) - TLIFn == LibFunc_msvc_delete_ptr64_longlong || // delete(void*, ulonglong) - TLIFn == LibFunc_msvc_delete_ptr32_nothrow || // delete(void*, nothrow) - TLIFn == LibFunc_msvc_delete_ptr64_nothrow || // delete(void*, nothrow) - TLIFn == LibFunc_msvc_delete_array_ptr32_int || // delete[](void*, uint) - TLIFn == LibFunc_msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong) - TLIFn == LibFunc_msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow) - TLIFn == LibFunc_msvc_delete_array_ptr64_nothrow || // delete[](void*, nothrow) - TLIFn == LibFunc___kmpc_free_shared) // OpenMP Offloading RTL free - ExpectedNumParams = 2; - else if (TLIFn == LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t || // delete(void*, align_val_t, nothrow) - TLIFn == LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t || // delete[](void*, align_val_t, nothrow) - TLIFn == LibFunc_ZdlPvjSt11align_val_t || // delete(void*, unsigned long, align_val_t) - TLIFn == LibFunc_ZdlPvmSt11align_val_t || // delete(void*, unsigned long, align_val_t) - TLIFn == LibFunc_ZdaPvjSt11align_val_t || // delete[](void*, unsigned int, align_val_t) - TLIFn == LibFunc_ZdaPvmSt11align_val_t) // delete[](void*, unsigned long, align_val_t) - ExpectedNumParams = 3; - else + Optional<FreeFnsTy> FnData = getFreeFunctionDataForFunction(F, TLIFn); + if (!FnData) return false; // Check free prototype. @@ -440,7 +520,7 @@ bool llvm::isLibFreeFunction(const Function *F, const LibFunc TLIFn) { FunctionType *FTy = F->getFunctionType(); if (!FTy->getReturnType()->isVoidTy()) return false; - if (FTy->getNumParams() != ExpectedNumParams) + if (FTy->getNumParams() != FnData->NumParams) return false; if (FTy->getParamType(0) != Type::getInt8PtrTy(F->getContext())) return false; @@ -491,11 +571,21 @@ Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, const DataLayout &DL, const TargetLibraryInfo *TLI, bool MustSucceed) { + return lowerObjectSizeCall(ObjectSize, DL, TLI, /*AAResults=*/nullptr, + MustSucceed); +} + +Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, + const DataLayout &DL, + const TargetLibraryInfo *TLI, AAResults *AA, + bool MustSucceed) { assert(ObjectSize->getIntrinsicID() == Intrinsic::objectsize && "ObjectSize must be a call to llvm.objectsize!"); bool MaxVal = cast<ConstantInt>(ObjectSize->getArgOperand(1))->isZero(); ObjectSizeOpts EvalOptions; + EvalOptions.AA = AA; + // Unless we have to fold this to something, try to be as accurate as // possible. if (MustSucceed) @@ -559,7 +649,7 @@ STATISTIC(ObjectVisitorLoad, APInt ObjectSizeOffsetVisitor::align(APInt Size, MaybeAlign Alignment) { if (Options.RoundToAlign && Alignment) - return APInt(IntTyBits, alignTo(Size.getZExtValue(), Alignment)); + return APInt(IntTyBits, alignTo(Size.getZExtValue(), *Alignment)); return Size; } @@ -573,18 +663,48 @@ ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL, } SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { + unsigned InitialIntTyBits = DL.getIndexTypeSizeInBits(V->getType()); + + // Stripping pointer casts can strip address space casts which can change the + // index type size. The invariant is that we use the value type to determine + // the index type size and if we stripped address space casts we have to + // readjust the APInt as we pass it upwards in order for the APInt to match + // the type the caller passed in. + APInt Offset(InitialIntTyBits, 0); + V = V->stripAndAccumulateConstantOffsets( + DL, Offset, /* AllowNonInbounds */ true, /* AllowInvariantGroup */ true); + + // Later we use the index type size and zero but it will match the type of the + // value that is passed to computeImpl. IntTyBits = DL.getIndexTypeSizeInBits(V->getType()); Zero = APInt::getZero(IntTyBits); - V = V->stripPointerCasts(); + bool IndexTypeSizeChanged = InitialIntTyBits != IntTyBits; + if (!IndexTypeSizeChanged && Offset.isZero()) + return computeImpl(V); + + // We stripped an address space cast that changed the index type size or we + // accumulated some constant offset (or both). Readjust the bit width to match + // the argument index type size and apply the offset, as required. + SizeOffsetType SOT = computeImpl(V); + if (IndexTypeSizeChanged) { + if (knownSize(SOT) && !::CheckedZextOrTrunc(SOT.first, InitialIntTyBits)) + SOT.first = APInt(); + if (knownOffset(SOT) && !::CheckedZextOrTrunc(SOT.second, InitialIntTyBits)) + SOT.second = APInt(); + } + // If the computed offset is "unknown" we cannot add the stripped offset. + return {SOT.first, + SOT.second.getBitWidth() > 1 ? SOT.second + Offset : SOT.second}; +} + +SizeOffsetType ObjectSizeOffsetVisitor::computeImpl(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) { // If we have already seen this instruction, bail out. Cycles can happen in // unreachable code after constant propagation. if (!SeenInsts.insert(I).second) return unknown(); - if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) - return visitGEPOperator(*GEP); return visit(*I); } if (Argument *A = dyn_cast<Argument>(V)) @@ -597,12 +717,6 @@ SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { return visitGlobalVariable(*GV); if (UndefValue *UV = dyn_cast<UndefValue>(V)) return visitUndefValue(*UV); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { - if (CE->getOpcode() == Instruction::IntToPtr) - return unknown(); // clueless - if (CE->getOpcode() == Instruction::GetElementPtr) - return visitGEPOperator(cast<GEPOperator>(*CE)); - } LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor::compute() unhandled value: " << *V << '\n'); @@ -617,10 +731,10 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) { if (!I.getAllocatedType()->isSized()) return unknown(); - if (isa<ScalableVectorType>(I.getAllocatedType())) + TypeSize ElemSize = DL.getTypeAllocSize(I.getAllocatedType()); + if (ElemSize.isScalable() && Options.EvalMode != ObjectSizeOpts::Mode::Min) return unknown(); - - APInt Size(IntTyBits, DL.getTypeAllocSize(I.getAllocatedType())); + APInt Size(IntTyBits, ElemSize.getKnownMinSize()); if (!I.isArrayAllocation()) return std::make_pair(align(Size, I.getAlign()), Zero); @@ -682,15 +796,6 @@ ObjectSizeOffsetVisitor::visitExtractValueInst(ExtractValueInst&) { return unknown(); } -SizeOffsetType ObjectSizeOffsetVisitor::visitGEPOperator(GEPOperator &GEP) { - SizeOffsetType PtrData = compute(GEP.getPointerOperand()); - APInt Offset(DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()), 0); - if (!bothKnown(PtrData) || !GEP.accumulateConstantOffset(DL, Offset)) - return unknown(); - - return std::make_pair(PtrData.first, PtrData.second + Offset); -} - SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) { if (GA.isInterposable()) return unknown(); @@ -710,42 +815,161 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitIntToPtrInst(IntToPtrInst&) { return unknown(); } -SizeOffsetType ObjectSizeOffsetVisitor::visitLoadInst(LoadInst&) { - ++ObjectVisitorLoad; - return unknown(); -} +SizeOffsetType ObjectSizeOffsetVisitor::findLoadSizeOffset( + LoadInst &Load, BasicBlock &BB, BasicBlock::iterator From, + SmallDenseMap<BasicBlock *, SizeOffsetType, 8> &VisitedBlocks, + unsigned &ScannedInstCount) { + constexpr unsigned MaxInstsToScan = 128; + + auto Where = VisitedBlocks.find(&BB); + if (Where != VisitedBlocks.end()) + return Where->second; + + auto Unknown = [this, &BB, &VisitedBlocks]() { + return VisitedBlocks[&BB] = unknown(); + }; + auto Known = [&BB, &VisitedBlocks](SizeOffsetType SO) { + return VisitedBlocks[&BB] = SO; + }; + + do { + Instruction &I = *From; + + if (I.isDebugOrPseudoInst()) + continue; + + if (++ScannedInstCount > MaxInstsToScan) + return Unknown(); + + if (!I.mayWriteToMemory()) + continue; + + if (auto *SI = dyn_cast<StoreInst>(&I)) { + AliasResult AR = + Options.AA->alias(SI->getPointerOperand(), Load.getPointerOperand()); + switch ((AliasResult::Kind)AR) { + case AliasResult::NoAlias: + continue; + case AliasResult::MustAlias: + if (SI->getValueOperand()->getType()->isPointerTy()) + return Known(compute(SI->getValueOperand())); + else + return Unknown(); // No handling of non-pointer values by `compute`. + default: + return Unknown(); + } + } -SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode&) { - // too complex to analyze statically. - return unknown(); + if (auto *CB = dyn_cast<CallBase>(&I)) { + Function *Callee = CB->getCalledFunction(); + // Bail out on indirect call. + if (!Callee) + return Unknown(); + + LibFunc TLIFn; + if (!TLI || !TLI->getLibFunc(*CB->getCalledFunction(), TLIFn) || + !TLI->has(TLIFn)) + return Unknown(); + + // TODO: There's probably more interesting case to support here. + if (TLIFn != LibFunc_posix_memalign) + return Unknown(); + + AliasResult AR = + Options.AA->alias(CB->getOperand(0), Load.getPointerOperand()); + switch ((AliasResult::Kind)AR) { + case AliasResult::NoAlias: + continue; + case AliasResult::MustAlias: + break; + default: + return Unknown(); + } + + // Is the error status of posix_memalign correctly checked? If not it + // would be incorrect to assume it succeeds and load doesn't see the + // previous value. + Optional<bool> Checked = isImpliedByDomCondition( + ICmpInst::ICMP_EQ, CB, ConstantInt::get(CB->getType(), 0), &Load, DL); + if (!Checked || !*Checked) + return Unknown(); + + Value *Size = CB->getOperand(2); + auto *C = dyn_cast<ConstantInt>(Size); + if (!C) + return Unknown(); + + return Known({C->getValue(), APInt(C->getValue().getBitWidth(), 0)}); + } + + return Unknown(); + } while (From-- != BB.begin()); + + SmallVector<SizeOffsetType> PredecessorSizeOffsets; + for (auto *PredBB : predecessors(&BB)) { + PredecessorSizeOffsets.push_back(findLoadSizeOffset( + Load, *PredBB, BasicBlock::iterator(PredBB->getTerminator()), + VisitedBlocks, ScannedInstCount)); + if (!bothKnown(PredecessorSizeOffsets.back())) + return Unknown(); + } + + if (PredecessorSizeOffsets.empty()) + return Unknown(); + + return Known(std::accumulate(PredecessorSizeOffsets.begin() + 1, + PredecessorSizeOffsets.end(), + PredecessorSizeOffsets.front(), + [this](SizeOffsetType LHS, SizeOffsetType RHS) { + return combineSizeOffset(LHS, RHS); + })); } -SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { - SizeOffsetType TrueSide = compute(I.getTrueValue()); - SizeOffsetType FalseSide = compute(I.getFalseValue()); - if (bothKnown(TrueSide) && bothKnown(FalseSide)) { - if (TrueSide == FalseSide) { - return TrueSide; - } +SizeOffsetType ObjectSizeOffsetVisitor::visitLoadInst(LoadInst &LI) { + if (!Options.AA) { + ++ObjectVisitorLoad; + return unknown(); + } - APInt TrueResult = getSizeWithOverflow(TrueSide); - APInt FalseResult = getSizeWithOverflow(FalseSide); + SmallDenseMap<BasicBlock *, SizeOffsetType, 8> VisitedBlocks; + unsigned ScannedInstCount = 0; + SizeOffsetType SO = + findLoadSizeOffset(LI, *LI.getParent(), BasicBlock::iterator(LI), + VisitedBlocks, ScannedInstCount); + if (!bothKnown(SO)) + ++ObjectVisitorLoad; + return SO; +} - if (TrueResult == FalseResult) { - return TrueSide; - } - if (Options.EvalMode == ObjectSizeOpts::Mode::Min) { - if (TrueResult.slt(FalseResult)) - return TrueSide; - return FalseSide; - } - if (Options.EvalMode == ObjectSizeOpts::Mode::Max) { - if (TrueResult.sgt(FalseResult)) - return TrueSide; - return FalseSide; - } +SizeOffsetType ObjectSizeOffsetVisitor::combineSizeOffset(SizeOffsetType LHS, + SizeOffsetType RHS) { + if (!bothKnown(LHS) || !bothKnown(RHS)) + return unknown(); + + switch (Options.EvalMode) { + case ObjectSizeOpts::Mode::Min: + return (getSizeWithOverflow(LHS).slt(getSizeWithOverflow(RHS))) ? LHS : RHS; + case ObjectSizeOpts::Mode::Max: + return (getSizeWithOverflow(LHS).sgt(getSizeWithOverflow(RHS))) ? LHS : RHS; + case ObjectSizeOpts::Mode::Exact: + return (getSizeWithOverflow(LHS).eq(getSizeWithOverflow(RHS))) ? LHS + : unknown(); } - return unknown(); + llvm_unreachable("missing an eval mode"); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode &PN) { + auto IncomingValues = PN.incoming_values(); + return std::accumulate(IncomingValues.begin() + 1, IncomingValues.end(), + compute(*IncomingValues.begin()), + [this](SizeOffsetType LHS, Value *VRHS) { + return combineSizeOffset(LHS, compute(VRHS)); + }); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { + return combineSizeOffset(compute(I.getTrueValue()), + compute(I.getFalseValue())); } SizeOffsetType ObjectSizeOffsetVisitor::visitUndefValue(UndefValue&) { @@ -790,7 +1014,7 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) { // Erase any instructions we inserted as part of the traversal. for (Instruction *I : InsertedInstructions) { - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); } } @@ -919,7 +1143,7 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitIntToPtrInst(IntToPtrInst&) { return unknown(); } -SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst&) { +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst &LI) { return unknown(); } @@ -937,10 +1161,10 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitPHINode(PHINode &PHI) { SizeOffsetEvalType EdgeData = compute_(PHI.getIncomingValue(i)); if (!bothKnown(EdgeData)) { - OffsetPHI->replaceAllUsesWith(UndefValue::get(IntTy)); + OffsetPHI->replaceAllUsesWith(PoisonValue::get(IntTy)); OffsetPHI->eraseFromParent(); InsertedInstructions.erase(OffsetPHI); - SizePHI->replaceAllUsesWith(UndefValue::get(IntTy)); + SizePHI->replaceAllUsesWith(PoisonValue::get(IntTy)); SizePHI->eraseFromParent(); InsertedInstructions.erase(SizePHI); return unknown(); diff --git a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp index 36df462c7a66..690d575ef979 100644 --- a/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -27,11 +27,7 @@ #include "llvm/Analysis/PhiValues.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -44,7 +40,6 @@ #include "llvm/IR/PredIteratorCache.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" -#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -53,10 +48,8 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/MathExtras.h" #include <algorithm> #include <cassert> -#include <cstdint> #include <iterator> #include <utility> @@ -414,20 +407,17 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( isInvariantLoad = true; } - // Return "true" if and only if the instruction I is either a non-simple - // load or a non-simple store. - auto isNonSimpleLoadOrStore = [](Instruction *I) -> bool { + // True for volatile instruction. + // For Load/Store return true if atomic ordering is stronger than AO, + // for other instruction just true if it can read or write to memory. + auto isComplexForReordering = [](Instruction * I, AtomicOrdering AO)->bool { + if (I->isVolatile()) + return true; if (auto *LI = dyn_cast<LoadInst>(I)) - return !LI->isSimple(); + return isStrongerThan(LI->getOrdering(), AO); if (auto *SI = dyn_cast<StoreInst>(I)) - return !SI->isSimple(); - return false; - }; - - // Return "true" if I is not a load and not a store, but it does access - // memory. - auto isOtherMemAccess = [](Instruction *I) -> bool { - return !isa<LoadInst>(I) && !isa<StoreInst>(I) && I->mayReadOrWriteMemory(); + return isStrongerThan(SI->getOrdering(), AO); + return I->mayReadOrWriteMemory(); }; // Walk backwards through the basic block, looking for dependencies. @@ -500,8 +490,8 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( // atomic. // FIXME: This is overly conservative. if (LI->isAtomic() && isStrongerThanUnordered(LI->getOrdering())) { - if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || - isOtherMemAccess(QueryInst)) + if (!QueryInst || + isComplexForReordering(QueryInst, AtomicOrdering::NotAtomic)) return MemDepResult::getClobber(LI); if (LI->getOrdering() != AtomicOrdering::Monotonic) return MemDepResult::getClobber(LI); @@ -512,10 +502,10 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( // If we found a pointer, check if it could be the same as our pointer. AliasResult R = BatchAA.alias(LoadLoc, MemLoc); - if (isLoad) { - if (R == AliasResult::NoAlias) - continue; + if (R == AliasResult::NoAlias) + continue; + if (isLoad) { // Must aliased loads are defs of each other. if (R == AliasResult::MustAlias) return MemDepResult::getDef(Inst); @@ -532,10 +522,6 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( continue; } - // Stores don't depend on other no-aliased accesses. - if (R == AliasResult::NoAlias) - continue; - // Stores don't alias loads from read-only memory. if (BatchAA.pointsToConstantMemory(LoadLoc)) continue; @@ -549,20 +535,25 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( // A Monotonic store is OK if the query inst is itself not atomic. // FIXME: This is overly conservative. if (!SI->isUnordered() && SI->isAtomic()) { - if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || - isOtherMemAccess(QueryInst)) - return MemDepResult::getClobber(SI); - if (SI->getOrdering() != AtomicOrdering::Monotonic) + if (!QueryInst || + isComplexForReordering(QueryInst, AtomicOrdering::Unordered)) return MemDepResult::getClobber(SI); + // Ok, if we are here the guard above guarantee us that + // QueryInst is a non-atomic or unordered load/store. + // SI is atomic with monotonic or release semantic (seq_cst for store + // is actually a release semantic plus total order over other seq_cst + // instructions, as soon as QueryInst is not seq_cst we can consider it + // as simple release semantic). + // Monotonic and Release semantic allows re-ordering before store + // so we are safe to go further and check the aliasing. It will prohibit + // re-ordering in case locations are may or must alias. } - // FIXME: this is overly conservative. // While volatile access cannot be eliminated, they do not have to clobber // non-aliasing locations, as normal accesses can for example be reordered // with volatile accesses. if (SI->isVolatile()) - if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || - isOtherMemAccess(QueryInst)) + if (!QueryInst || QueryInst->isVolatile()) return MemDepResult::getClobber(SI); // If alias analysis can tell that this store is guaranteed to not modify @@ -743,8 +734,6 @@ MemoryDependenceResults::getNonLocalCallDependency(CallBase *QueryCall) { llvm::sort(Cache); ++NumCacheDirtyNonLocal; - // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " - // << Cache.size() << " cached: " << *QueryInst; } else { // Seed DirtyBlocks with each of the preds of QueryInst's block. BasicBlock *QueryBB = QueryCall->getParent(); @@ -1204,7 +1193,6 @@ bool MemoryDependenceResults::getNonLocalPointerDepFromBB( // If we do process a large number of blocks it becomes very expensive and // likely it isn't worth worrying about if (Result.size() > NumResultsLimit) { - Worklist.clear(); // Sort it now (if needed) so that recursive invocations of // getNonLocalPointerDepFromBB and other routines that could reuse the // cache value will only see properly sorted cache arrays. diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp index a877b19df866..2ed32227bd9e 100644 --- a/llvm/lib/Analysis/MemoryLocation.cpp +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -8,12 +8,10 @@ #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsARM.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" using namespace llvm; diff --git a/llvm/lib/Analysis/MemorySSA.cpp b/llvm/lib/Analysis/MemorySSA.cpp index 57f431ec21f5..76371b88812e 100644 --- a/llvm/lib/Analysis/MemorySSA.cpp +++ b/llvm/lib/Analysis/MemorySSA.cpp @@ -36,8 +36,8 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Use.h" #include "llvm/InitializePasses.h" @@ -49,10 +49,10 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormattedStream.h" +#include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> -#include <cstdlib> #include <iterator> #include <memory> #include <utility> @@ -130,6 +130,12 @@ public: MemorySSAWalkerAnnotatedWriter(MemorySSA *M) : MSSA(M), Walker(M->getWalker()) {} + void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) override { + if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) + OS << "; " << *MA << "\n"; + } + void emitInstructionAnnot(const Instruction *I, formatted_raw_ostream &OS) override { if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) { @@ -732,7 +738,7 @@ template <class AliasAnalysisType> class ClobberWalker { struct generic_def_path_iterator : public iterator_facade_base<generic_def_path_iterator<T, Walker>, std::forward_iterator_tag, T *> { - generic_def_path_iterator() {} + generic_def_path_iterator() = default; generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} T &operator*() const { return curNode(); } @@ -743,9 +749,9 @@ template <class AliasAnalysisType> class ClobberWalker { } bool operator==(const generic_def_path_iterator &O) const { - if (N.hasValue() != O.N.hasValue()) + if (N.has_value() != O.N.has_value()) return false; - return !N.hasValue() || *N == *O.N; + return !N || *N == *O.N; } private: @@ -1397,6 +1403,9 @@ void MemorySSA::OptimizeUses::optimizeUsesInBlock( continue; } + if (MU->isOptimized()) + continue; + if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true, None); continue; @@ -1585,10 +1594,6 @@ void MemorySSA::buildMemorySSA(BatchAAResults &BAA) { SmallPtrSet<BasicBlock *, 16> Visited; renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); - ClobberWalkerBase<BatchAAResults> WalkerBase(this, &BAA, DT); - CachingWalker<BatchAAResults> WalkerLocal(this, &WalkerBase); - OptimizeUses(this, &WalkerLocal, &BAA, DT).optimizeUses(); - // Mark the uses in unreachable blocks as live on entry, so that they go // somewhere. for (auto &BB : F) @@ -2178,6 +2183,17 @@ bool MemorySSA::dominates(const MemoryAccess *Dominator, return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); } +void MemorySSA::ensureOptimizedUses() { + if (IsOptimized) + return; + + BatchAAResults BatchAA(*AA); + ClobberWalkerBase<BatchAAResults> WalkerBase(this, &BatchAA, DT); + CachingWalker<BatchAAResults> WalkerLocal(this, &WalkerBase); + OptimizeUses(this, &WalkerLocal, &BatchAA, DT).optimizeUses(); + IsOptimized = true; +} + void MemoryAccess::print(raw_ostream &OS) const { switch (getValueID()) { case MemoryPhiVal: return static_cast<const MemoryPhi *>(this)->print(OS); @@ -2350,6 +2366,7 @@ struct DOTGraphTraits<DOTFuncMSSAInfo *> : public DefaultDOTGraphTraits { bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSA.ensureOptimizedUses(); if (DotCFGMSSA != "") { DOTFuncMSSAInfo CFGInfo(F, MSSA); WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); @@ -2382,6 +2399,7 @@ bool MemorySSAAnalysis::Result::invalidate( PreservedAnalyses MemorySSAPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + MSSA.ensureOptimizedUses(); if (DotCFGMSSA != "") { DOTFuncMSSAInfo CFGInfo(F, MSSA); WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); diff --git a/llvm/lib/Analysis/MemorySSAUpdater.cpp b/llvm/lib/Analysis/MemorySSAUpdater.cpp index 9c841883de6d..eb75118210b9 100644 --- a/llvm/lib/Analysis/MemorySSAUpdater.cpp +++ b/llvm/lib/Analysis/MemorySSAUpdater.cpp @@ -10,22 +10,15 @@ // //===----------------------------------------------------------------===// #include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/LoopIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/FormattedStream.h" #include <algorithm> #define DEBUG_TYPE "memoryssa" @@ -243,6 +236,7 @@ MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi, } void MemorySSAUpdater::insertUse(MemoryUse *MU, bool RenameUses) { + VisitedBlocks.clear(); InsertedPHIs.clear(); MU->setDefiningAccess(getPreviousDef(MU)); @@ -311,6 +305,13 @@ static void setMemoryPhiValueForBlock(MemoryPhi *MP, const BasicBlock *BB, // point to the correct new defs, to ensure we only have one variable, and no // disconnected stores. void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { + // Don't bother updating dead code. + if (!MSSA->DT->isReachableFromEntry(MD->getBlock())) { + MD->setDefiningAccess(MSSA->getLiveOnEntryDef()); + return; + } + + VisitedBlocks.clear(); InsertedPHIs.clear(); // See if we had a local def, and if not, go hunting. @@ -427,10 +428,10 @@ void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { if (NewPhiSize) tryRemoveTrivialPhis(ArrayRef<WeakVH>(&InsertedPHIs[NewPhiIndex], NewPhiSize)); - // Now that all fixups are done, rename all uses if we are asked. Skip - // renaming for defs in unreachable blocks. + // Now that all fixups are done, rename all uses if we are asked. The defs are + // guaranteed to be in reachable code due to the check at the method entry. BasicBlock *StartBlock = MD->getBlock(); - if (RenameUses && MSSA->getDomTree().getNode(StartBlock)) { + if (RenameUses) { SmallPtrSet<BasicBlock *, 16> Visited; // We are guaranteed there is a def in the block, because we just got it // handed to us in this function. diff --git a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp index fab51d6a7aaf..dc149f326271 100644 --- a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp +++ b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp @@ -22,7 +22,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( LLVMContext &Ctx, const std::string &ModelPath, const std::vector<TensorSpec> &InputSpecs, const std::vector<LoggedFeatureSpec> &OutputSpecs) - : MLModelRunner(Ctx, MLModelRunner::Kind::Development), + : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), OutputSpecs(OutputSpecs) { Evaluator = std::make_unique<TFModelEvaluator>( ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, @@ -32,6 +32,10 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( Evaluator.reset(); return; } + + for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { + setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); + } } void *ModelUnderTrainingRunner::evaluateUntyped() { @@ -43,24 +47,31 @@ void *ModelUnderTrainingRunner::evaluateUntyped() { return LastEvaluationResult->getUntypedTensorValue(0); } -void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { - return Evaluator->getUntypedInput(Index); -} - std::unique_ptr<ModelUnderTrainingRunner> ModelUnderTrainingRunner::createAndEnsureValid( LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, const std::vector<TensorSpec> &InputSpecs, StringRef OutputSpecsPathOverride) { - std::unique_ptr<ModelUnderTrainingRunner> MUTR; if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, OutputSpecsPathOverride)) - MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, - *MaybeOutputSpecs)); + return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, + *MaybeOutputSpecs); + Ctx.emitError("Could not load the policy model from the provided path"); + return nullptr; +} + +std::unique_ptr<ModelUnderTrainingRunner> +ModelUnderTrainingRunner::createAndEnsureValid( + LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, + const std::vector<TensorSpec> &InputSpecs, + const std::vector<LoggedFeatureSpec> &OutputSpecs) { + std::unique_ptr<ModelUnderTrainingRunner> MUTR; + MUTR.reset( + new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); if (MUTR && MUTR->isValid()) return MUTR; - Ctx.emitError("Could not load the policy model from the provided path"); + Ctx.emitError("Could not load or create model evaluator."); return nullptr; } diff --git a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp index 64fd5eb1acd4..373aaa48b1d1 100644 --- a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp +++ b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp @@ -15,8 +15,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ModuleDebugInfoPrinter.h" -#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp index 2880ca62a7f8..2b98634ef7bf 100644 --- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -38,7 +38,6 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndex.h" @@ -368,7 +367,7 @@ static void computeFunctionSummary( // We should have named any anonymous globals assert(CalledFunction->hasName()); auto ScaledCount = PSI->getProfileCount(*CB, BFI); - auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI) + auto Hotness = ScaledCount ? getHotness(*ScaledCount, PSI) : CalleeInfo::HotnessType::Unknown; if (ForceSummaryEdgesCold != FunctionSummary::FSHT_None) Hotness = CalleeInfo::HotnessType::Cold; @@ -490,8 +489,7 @@ static void computeFunctionSummary( HasIndirBranchToBlockAddress; GlobalValueSummary::GVFlags Flags( F.getLinkage(), F.getVisibility(), NotEligibleForImport, - /* Live = */ false, F.isDSOLocal(), - F.hasLinkOnceODRLinkage() && F.hasGlobalUnnamedAddr()); + /* Live = */ false, F.isDSOLocal(), F.canBeOmittedFromSymbolTable()); FunctionSummary::FFlags FunFlags{ F.hasFnAttribute(Attribute::ReadNone), F.hasFnAttribute(Attribute::ReadOnly), @@ -612,8 +610,7 @@ static void computeVariableSummary(ModuleSummaryIndex &Index, bool NonRenamableLocal = isNonRenamableLocal(V); GlobalValueSummary::GVFlags Flags( V.getLinkage(), V.getVisibility(), NonRenamableLocal, - /* Live = */ false, V.isDSOLocal(), - V.hasLinkOnceODRLinkage() && V.hasGlobalUnnamedAddr()); + /* Live = */ false, V.isDSOLocal(), V.canBeOmittedFromSymbolTable()); VTableFuncList VTableFuncs; // If splitting is not enabled, then we compute the summary information @@ -655,8 +652,7 @@ computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A, bool NonRenamableLocal = isNonRenamableLocal(A); GlobalValueSummary::GVFlags Flags( A.getLinkage(), A.getVisibility(), NonRenamableLocal, - /* Live = */ false, A.isDSOLocal(), - A.hasLinkOnceODRLinkage() && A.hasGlobalUnnamedAddr()); + /* Live = */ false, A.isDSOLocal(), A.canBeOmittedFromSymbolTable()); auto AS = std::make_unique<AliasSummary>(Flags); auto *Aliasee = A.getAliaseeObject(); auto AliaseeVI = Index.getValueInfo(Aliasee->getGUID()); @@ -733,8 +729,7 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( GlobalValue::InternalLinkage, GlobalValue::DefaultVisibility, /* NotEligibleToImport = */ true, /* Live = */ true, - /* Local */ GV->isDSOLocal(), - GV->hasLinkOnceODRLinkage() && GV->hasGlobalUnnamedAddr()); + /* Local */ GV->isDSOLocal(), GV->canBeOmittedFromSymbolTable()); CantBePromoted.insert(GV->getGUID()); // Create the appropriate summary type. if (Function *F = dyn_cast<Function>(GV)) { diff --git a/llvm/lib/Analysis/MustExecute.cpp b/llvm/lib/Analysis/MustExecute.cpp index 5ca72f5f3623..5cff986245b9 100644 --- a/llvm/lib/Analysis/MustExecute.cpp +++ b/llvm/lib/Analysis/MustExecute.cpp @@ -16,14 +16,11 @@ #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/AssemblyAnnotationWriter.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Support/raw_ostream.h" @@ -143,7 +140,7 @@ static bool CanProveNotTakenFirstIteration(const BasicBlock *ExitBlock, return false; auto DL = ExitBlock->getModule()->getDataLayout(); auto *IVStart = LHS->getIncomingValueForBlock(CurLoop->getLoopPreheader()); - auto *SimpleValOrNull = SimplifyCmpInst(Cond->getPredicate(), + auto *SimpleValOrNull = simplifyCmpInst(Cond->getPredicate(), IVStart, RHS, {DL, /*TLI*/ nullptr, DT, /*AC*/ nullptr, BI}); @@ -494,7 +491,7 @@ template <typename K, typename V, typename FnTy, typename... ArgsTy> static V getOrCreateCachedOptional(K Key, DenseMap<K, Optional<V>> &Map, FnTy &&Fn, ArgsTy&&... args) { Optional<V> &OptVal = Map[Key]; - if (!OptVal.hasValue()) + if (!OptVal) OptVal = Fn(std::forward<ArgsTy>(args)...); return OptVal.getValue(); } diff --git a/llvm/lib/Analysis/NoInferenceModelRunner.cpp b/llvm/lib/Analysis/NoInferenceModelRunner.cpp index 7178120ebe4f..1914b22f5d71 100644 --- a/llvm/lib/Analysis/NoInferenceModelRunner.cpp +++ b/llvm/lib/Analysis/NoInferenceModelRunner.cpp @@ -10,24 +10,14 @@ // logs for the default policy, in 'development' mode, but never ask it to // 'run'. //===----------------------------------------------------------------------===// -#include "llvm/Config/config.h" -#if defined(LLVM_HAVE_TF_API) - #include "llvm/Analysis/NoInferenceModelRunner.h" -#include "llvm/Analysis/Utils/TFUtils.h" using namespace llvm; NoInferenceModelRunner::NoInferenceModelRunner( LLVMContext &Ctx, const std::vector<TensorSpec> &Inputs) - : MLModelRunner(Ctx, MLModelRunner::Kind::NoOp) { - ValuesBuffer.reserve(Inputs.size()); + : MLModelRunner(Ctx, MLModelRunner::Kind::NoOp, Inputs.size()) { + size_t Index = 0; for (const auto &TS : Inputs) - ValuesBuffer.push_back(std::make_unique<char[]>(TS.getElementCount() * - TS.getElementByteSize())); -} - -void *NoInferenceModelRunner::getTensorUntyped(size_t Index) { - return ValuesBuffer[Index].get(); + setUpBufferForTensor(Index++, TS, nullptr); } -#endif // defined(LLVM_HAVE_TF_API) diff --git a/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp index 0826b3078672..6fe056d36668 100644 --- a/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp +++ b/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp @@ -26,8 +26,6 @@ #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/Passes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp index 6f3d4d536c40..17b40f03a5a5 100644 --- a/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp +++ b/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp @@ -47,7 +47,7 @@ OptimizationRemarkEmitter::OptimizationRemarkEmitter(const Function *F) bool OptimizationRemarkEmitter::invalidate( Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv) { - if (OwnedBFI.get()) { + if (OwnedBFI) { OwnedBFI.reset(); BFI = nullptr; } @@ -80,7 +80,7 @@ void OptimizationRemarkEmitter::emit( computeHotness(OptDiag); // Only emit it if its hotness meets the threshold. - if (OptDiag.getHotness().getValueOr(0) < + if (OptDiag.getHotness().value_or(0) < F->getContext().getDiagnosticsHotnessThreshold()) { return; } diff --git a/llvm/lib/Analysis/OverflowInstAnalysis.cpp b/llvm/lib/Analysis/OverflowInstAnalysis.cpp index 87a85e6a7364..8bfd6642f760 100644 --- a/llvm/lib/Analysis/OverflowInstAnalysis.cpp +++ b/llvm/lib/Analysis/OverflowInstAnalysis.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/OverflowInstAnalysis.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PatternMatch.h" diff --git a/llvm/lib/Analysis/PHITransAddr.cpp b/llvm/lib/Analysis/PHITransAddr.cpp index 02d084937ccb..7571bd0059cc 100644 --- a/llvm/lib/Analysis/PHITransAddr.cpp +++ b/llvm/lib/Analysis/PHITransAddr.cpp @@ -17,7 +17,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -35,9 +34,6 @@ static bool CanPHITrans(Instruction *Inst) { isa<ConstantInt>(Inst->getOperand(1))) return true; - // cerr << "MEMDEP: Could not PHI translate: " << *Pointer; - // if (isa<BitCastInst>(PtrInst) || isa<GetElementPtrInst>(PtrInst)) - // cerr << "OP:\t\t\t\t" << *PtrInst->getOperand(0); return false; } @@ -226,7 +222,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return GEP; // Simplify the GEP to handle 'gep x, 0' -> x etc. - if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(), GEPOps[0], + if (Value *V = simplifyGEPInst(GEP->getSourceElementType(), GEPOps[0], ArrayRef<Value *>(GEPOps).slice(1), GEP->isInBounds(), {DL, TLI, DT, AC})) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) @@ -240,6 +236,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, for (User *U : APHIOp->users()) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) if (GEPI->getType() == GEP->getType() && + GEPI->getSourceElementType() == GEP->getSourceElementType() && GEPI->getNumOperands() == GEPOps.size() && GEPI->getParent()->getParent() == CurBB->getParent() && (!DT || DT->dominates(GEPI->getParent(), PredBB))) { @@ -277,7 +274,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, } // See if the add simplifies away. - if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) { + if (Value *Res = simplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) { // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp index 268ed9d04741..9d5fa6d0a41b 100644 --- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp +++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -15,7 +15,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ProfileSummary.h" #include "llvm/InitializePasses.h" @@ -125,7 +124,7 @@ bool ProfileSummaryInfo::isFunctionHotInCallGraph( for (const auto &I : BB) if (isa<CallInst>(I) || isa<InvokeInst>(I)) if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += CallCount.getValue(); + TotalCallCount += *CallCount; if (isHotCount(TotalCallCount)) return true; } @@ -154,7 +153,7 @@ bool ProfileSummaryInfo::isFunctionColdInCallGraph( for (const auto &I : BB) if (isa<CallInst>(I) || isa<InvokeInst>(I)) if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += CallCount.getValue(); + TotalCallCount += *CallCount; if (!isColdCount(TotalCallCount)) return false; } @@ -166,7 +165,7 @@ bool ProfileSummaryInfo::isFunctionColdInCallGraph( bool ProfileSummaryInfo::isFunctionHotnessUnknown(const Function &F) const { assert(hasPartialSampleProfile() && "Expect partial sample profile"); - return !F.getEntryCount().hasValue(); + return !F.getEntryCount(); } template <bool isHot> @@ -188,7 +187,7 @@ bool ProfileSummaryInfo::isFunctionHotOrColdInCallGraphNthPercentile( for (const auto &I : BB) if (isa<CallInst>(I) || isa<InvokeInst>(I)) if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += CallCount.getValue(); + TotalCallCount += *CallCount; if (isHot && isHotCountNthPercentile(PercentileCutoff, TotalCallCount)) return true; if (!isHot && !isColdCountNthPercentile(PercentileCutoff, TotalCallCount)) @@ -316,11 +315,11 @@ bool ProfileSummaryInfo::isColdCountNthPercentile(int PercentileCutoff, } uint64_t ProfileSummaryInfo::getOrCompHotCountThreshold() const { - return HotCountThreshold.getValueOr(UINT64_MAX); + return HotCountThreshold.value_or(UINT64_MAX); } uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() const { - return ColdCountThreshold.getValueOr(0); + return ColdCountThreshold.value_or(0); } bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB, diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp index 9a834ba4866a..49304818d7ef 100644 --- a/llvm/lib/Analysis/PtrUseVisitor.cpp +++ b/llvm/lib/Analysis/PtrUseVisitor.cpp @@ -14,7 +14,6 @@ #include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include <algorithm> using namespace llvm; diff --git a/llvm/lib/Analysis/RegionInfo.cpp b/llvm/lib/Analysis/RegionInfo.cpp index 3ba0bb9eaf2c..9be23a374eca 100644 --- a/llvm/lib/Analysis/RegionInfo.cpp +++ b/llvm/lib/Analysis/RegionInfo.cpp @@ -10,6 +10,7 @@ #include "llvm/Analysis/RegionInfo.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DominanceFrontier.h" #include "llvm/InitializePasses.h" #ifndef NDEBUG #include "llvm/Analysis/RegionPrinter.h" diff --git a/llvm/lib/Analysis/RegionPass.cpp b/llvm/lib/Analysis/RegionPass.cpp index 10c8569096c6..ddef3be8df37 100644 --- a/llvm/lib/Analysis/RegionPass.cpp +++ b/llvm/lib/Analysis/RegionPass.cpp @@ -12,14 +12,16 @@ // Most of this code has been COPIED from LoopPass.cpp // //===----------------------------------------------------------------------===// + #include "llvm/Analysis/RegionPass.h" +#include "llvm/Analysis/RegionInfo.h" #include "llvm/IR/OptBisect.h" #include "llvm/IR/PassTimingInfo.h" #include "llvm/IR/PrintPasses.h" -#include "llvm/IR/StructuralHash.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" + using namespace llvm; #define DEBUG_TYPE "regionpassmgr" @@ -93,12 +95,12 @@ bool RGPassManager::runOnFunction(Function &F) { TimeRegion PassTimer(getPassTimer(P)); #ifdef EXPENSIVE_CHECKS - uint64_t RefHash = StructuralHash(F); + uint64_t RefHash = P->structuralHash(F); #endif LocalChanged = P->runOnRegion(CurrentRegion, *this); #ifdef EXPENSIVE_CHECKS - if (!LocalChanged && (RefHash != StructuralHash(F))) { + if (!LocalChanged && (RefHash != P->structuralHash(F))) { llvm::errs() << "Pass modifies its input and doesn't report it: " << P->getPassName() << "\n"; llvm_unreachable("Pass modifies its input and doesn't report it"); diff --git a/llvm/lib/Analysis/RegionPrinter.cpp b/llvm/lib/Analysis/RegionPrinter.cpp index 1fb5faaa6a71..fbd3d17febff 100644 --- a/llvm/lib/Analysis/RegionPrinter.cpp +++ b/llvm/lib/Analysis/RegionPrinter.cpp @@ -10,15 +10,11 @@ #include "llvm/Analysis/RegionPrinter.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DOTGraphTraitsPass.h" -#include "llvm/Analysis/Passes.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #ifndef NDEBUG #include "llvm/IR/LegacyPassManager.h" @@ -35,28 +31,20 @@ onlySimpleRegions("only-simple-regions", cl::init(false)); namespace llvm { -template<> -struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits { - DOTGraphTraits (bool isSimple=false) - : DefaultDOTGraphTraits(isSimple) {} +std::string DOTGraphTraits<RegionNode *>::getNodeLabel(RegionNode *Node, + RegionNode *Graph) { + if (!Node->isSubRegion()) { + BasicBlock *BB = Node->getNodeAs<BasicBlock>(); - std::string getNodeLabel(RegionNode *Node, RegionNode *Graph) { - - if (!Node->isSubRegion()) { - BasicBlock *BB = Node->getNodeAs<BasicBlock>(); - - if (isSimple()) - return DOTGraphTraits<DOTFuncInfo *> - ::getSimpleNodeLabel(BB, nullptr); - else - return DOTGraphTraits<DOTFuncInfo *> - ::getCompleteNodeLabel(BB, nullptr); - } - - return "Not implemented"; + if (isSimple()) + return DOTGraphTraits<DOTFuncInfo *>::getSimpleNodeLabel(BB, nullptr); + else + return DOTGraphTraits<DOTFuncInfo *>::getCompleteNodeLabel(BB, nullptr); } -}; + + return "Not implemented"; +} template <> struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { @@ -138,7 +126,7 @@ struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { printRegionCluster(*G->getTopLevelRegion(), GW, 4); } }; -} //end namespace llvm +} // end namespace llvm namespace { @@ -149,48 +137,49 @@ struct RegionInfoPassGraphTraits { }; struct RegionPrinter - : public DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, - RegionInfoPassGraphTraits> { + : public DOTGraphTraitsPrinterWrapperPass< + RegionInfoPass, false, RegionInfo *, RegionInfoPassGraphTraits> { static char ID; RegionPrinter() - : DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, - RegionInfoPassGraphTraits>("reg", ID) { + : DOTGraphTraitsPrinterWrapperPass<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { initializeRegionPrinterPass(*PassRegistry::getPassRegistry()); } }; char RegionPrinter::ID = 0; struct RegionOnlyPrinter - : public DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, - RegionInfoPassGraphTraits> { + : public DOTGraphTraitsPrinterWrapperPass< + RegionInfoPass, true, RegionInfo *, RegionInfoPassGraphTraits> { static char ID; RegionOnlyPrinter() - : DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, - RegionInfoPassGraphTraits>("reg", ID) { + : DOTGraphTraitsPrinterWrapperPass<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { initializeRegionOnlyPrinterPass(*PassRegistry::getPassRegistry()); } }; char RegionOnlyPrinter::ID = 0; struct RegionViewer - : public DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, - RegionInfoPassGraphTraits> { + : public DOTGraphTraitsViewerWrapperPass< + RegionInfoPass, false, RegionInfo *, RegionInfoPassGraphTraits> { static char ID; RegionViewer() - : DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, - RegionInfoPassGraphTraits>("reg", ID) { + : DOTGraphTraitsViewerWrapperPass<RegionInfoPass, false, RegionInfo *, + RegionInfoPassGraphTraits>("reg", ID) { initializeRegionViewerPass(*PassRegistry::getPassRegistry()); } }; char RegionViewer::ID = 0; struct RegionOnlyViewer - : public DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, - RegionInfoPassGraphTraits> { + : public DOTGraphTraitsViewerWrapperPass<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits> { static char ID; RegionOnlyViewer() - : DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, - RegionInfoPassGraphTraits>("regonly", ID) { + : DOTGraphTraitsViewerWrapperPass<RegionInfoPass, true, RegionInfo *, + RegionInfoPassGraphTraits>("regonly", + ID) { initializeRegionOnlyViewerPass(*PassRegistry::getPassRegistry()); } }; diff --git a/llvm/lib/Analysis/ReplayInlineAdvisor.cpp b/llvm/lib/Analysis/ReplayInlineAdvisor.cpp index 294bc38c17ad..afc3d7fc4c35 100644 --- a/llvm/lib/Analysis/ReplayInlineAdvisor.cpp +++ b/llvm/lib/Analysis/ReplayInlineAdvisor.cpp @@ -14,9 +14,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ReplayInlineAdvisor.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Instructions.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Support/LineIterator.h" +#include "llvm/Support/MemoryBuffer.h" #include <memory> using namespace llvm; @@ -26,8 +26,9 @@ using namespace llvm; ReplayInlineAdvisor::ReplayInlineAdvisor( Module &M, FunctionAnalysisManager &FAM, LLVMContext &Context, std::unique_ptr<InlineAdvisor> OriginalAdvisor, - const ReplayInlinerSettings &ReplaySettings, bool EmitRemarks) - : InlineAdvisor(M, FAM), OriginalAdvisor(std::move(OriginalAdvisor)), + const ReplayInlinerSettings &ReplaySettings, bool EmitRemarks, + InlineContext IC) + : InlineAdvisor(M, FAM, IC), OriginalAdvisor(std::move(OriginalAdvisor)), ReplaySettings(ReplaySettings), EmitRemarks(EmitRemarks) { auto BufferOrErr = MemoryBuffer::getFileOrSTDIN(ReplaySettings.ReplayFile); @@ -75,12 +76,15 @@ ReplayInlineAdvisor::ReplayInlineAdvisor( HasReplayRemarks = true; } -std::unique_ptr<InlineAdvisor> llvm::getReplayInlineAdvisor( - Module &M, FunctionAnalysisManager &FAM, LLVMContext &Context, - std::unique_ptr<InlineAdvisor> OriginalAdvisor, - const ReplayInlinerSettings &ReplaySettings, bool EmitRemarks) { +std::unique_ptr<InlineAdvisor> +llvm::getReplayInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, + LLVMContext &Context, + std::unique_ptr<InlineAdvisor> OriginalAdvisor, + const ReplayInlinerSettings &ReplaySettings, + bool EmitRemarks, InlineContext IC) { auto Advisor = std::make_unique<ReplayInlineAdvisor>( - M, FAM, Context, std::move(OriginalAdvisor), ReplaySettings, EmitRemarks); + M, FAM, Context, std::move(OriginalAdvisor), ReplaySettings, EmitRemarks, + IC); if (!Advisor->areReplayRemarksLoaded()) Advisor.reset(); return Advisor; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 977fc0911355..207f4df79e45 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -79,7 +79,6 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolutionDivision.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -96,7 +95,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -104,7 +102,6 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" @@ -125,7 +122,6 @@ #include <algorithm> #include <cassert> #include <climits> -#include <cstddef> #include <cstdint> #include <cstdlib> #include <map> @@ -146,17 +142,21 @@ STATISTIC(NumTripCountsNotComputed, STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); +#ifdef EXPENSIVE_CHECKS +bool llvm::VerifySCEV = true; +#else +bool llvm::VerifySCEV = false; +#endif + static cl::opt<unsigned> -MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, - cl::ZeroOrMore, - cl::desc("Maximum number of iterations SCEV will " - "symbolically execute a constant " - "derived loop"), - cl::init(100)); - -// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. -static cl::opt<bool> VerifySCEV( - "verify-scev", cl::Hidden, + MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, + cl::desc("Maximum number of iterations SCEV will " + "symbolically execute a constant " + "derived loop"), + cl::init(100)); + +static cl::opt<bool, true> VerifySCEVOpt( + "verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); static cl::opt<bool> VerifySCEVStrict( "verify-scev-strict", cl::Hidden, @@ -231,6 +231,17 @@ static cl::opt<bool> UseExpensiveRangeSharpening( cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time")); +static cl::opt<unsigned> MaxPhiSCCAnalysisSize( + "scalar-evolution-max-scc-analysis-depth", cl::Hidden, + cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " + "Phi strongly connected components"), + cl::init(8)); + +static cl::opt<bool> + EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, + cl::desc("Handle <= and >= in finite loops"), + cl::init(true)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -519,12 +530,13 @@ void SCEVUnknown::deleted() { } void SCEVUnknown::allUsesReplacedWith(Value *New) { + // Clear this SCEVUnknown from various maps. + SE->forgetMemoizedResults(this); + // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); - // Update this SCEVUnknown to point to the new value. This is needed - // because there may still be outstanding SCEVs which still point to - // this SCEVUnknown. + // Replace the value pointer in case someone is still using this SCEVUnknown. setValPtr(New); } @@ -1643,10 +1655,12 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoUnsignedWrap()) - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), - getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + if (AR->hasNoUnsignedWrap()) { + Start = + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); + Step = getZeroExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); + } // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1688,11 +1702,10 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // Cache knowledge of AR NUW, which is propagated to this AddRec. setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, - Depth + 1), - getZeroExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1); + Step = getZeroExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1707,11 +1720,10 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // Negative step causes unsigned wrap, but it still can't self-wrap. setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, - Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } } } @@ -1733,11 +1745,10 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // issue. It's not clear that the order of checks does matter, but // it's one of two issue possible causes for a change which was // reverted. Be conservative for the moment. - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, - Depth + 1), - getZeroExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); + Step = getZeroExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } // For a negative step, we can extend the operands iff doing so only @@ -1752,11 +1763,10 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // still can't self-wrap. setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, - Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, + Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } } } @@ -1780,9 +1790,10 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); - return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), - getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + Start = + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1); + Step = getZeroExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } } @@ -1984,10 +1995,12 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoSignedWrap()) - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW); + if (AR->hasNoSignedWrap()) { + Start = + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, SCEV::FlagNSW); + } // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -2030,11 +2043,10 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // Cache knowledge of AR NSW, which is propagated to this AddRec. setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, - Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -2056,11 +2068,10 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, - Depth + 1), - getZeroExtendExpr(Step, Ty, Depth + 1), L, - AR->getNoWrapFlags()); + Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, + Depth + 1); + Step = getZeroExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } } } @@ -2072,9 +2083,10 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // issue. It's not clear that the order of checks does matter, but // it's one of two issue possible causes for a change which was // reverted. Be conservative for the moment. - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + Start = + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> @@ -2096,9 +2108,10 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + Start = + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1); + Step = getSignExtendExpr(Step, Ty, Depth + 1); + return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags()); } } @@ -2300,9 +2313,9 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *A = (this->*Extension)( (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0); - const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0), - (this->*Extension)(RHS, WideTy, 0), - SCEV::FlagAnyWrap, 0); + const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0); + const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0); + const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); return A == B; } @@ -3106,12 +3119,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // TODO: There are some cases where this transformation is not // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of // this transformation should be narrowed down. - if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1), - getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { + const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), + SCEV::FlagAnyWrap, Depth + 1); + const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), + SCEV::FlagAnyWrap, Depth + 1); + return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); + } if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the @@ -3466,12 +3480,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // Fold if both operands are constant. - if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { - Constant *LHSCV = LHSC->getValue(); - Constant *RHSCV = RHSC->getValue(); - return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV, - RHSCV))); - } + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) + return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt())); } } @@ -4002,6 +4012,59 @@ public: } // namespace +/// Return true if V is poison given that AssumedPoison is already poison. +static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { + // The only way poison may be introduced in a SCEV expression is from a + // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown, + // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not* + // introduce poison -- they encode guaranteed, non-speculated knowledge. + // + // Additionally, all SCEV nodes propagate poison from inputs to outputs, + // with the notable exception of umin_seq, where only poison from the first + // operand is (unconditionally) propagated. + struct SCEVPoisonCollector { + bool LookThroughSeq; + SmallPtrSet<const SCEV *, 4> MaybePoison; + SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {} + + bool follow(const SCEV *S) { + // TODO: We can always follow the first operand, but the SCEVTraversal + // API doesn't support this. + if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S)) + return false; + + if (auto *SU = dyn_cast<SCEVUnknown>(S)) { + if (!isGuaranteedNotToBePoison(SU->getValue())) + MaybePoison.insert(S); + } + return true; + } + bool isDone() const { return false; } + }; + + // First collect all SCEVs that might result in AssumedPoison to be poison. + // We need to look through umin_seq here, because we want to find all SCEVs + // that *might* result in poison, not only those that are *required* to. + SCEVPoisonCollector PC1(/* LookThroughSeq */ true); + visitAll(AssumedPoison, PC1); + + // AssumedPoison is never poison. As the assumption is false, the implication + // is true. Don't bother walking the other SCEV in this case. + if (PC1.MaybePoison.empty()) + return true; + + // Collect all SCEVs in S that, if poison, *will* result in S being poison + // as well. We cannot look through umin_seq here, as its argument only *may* + // make the result poison. + SCEVPoisonCollector PC2(/* LookThroughSeq */ false); + visitAll(S, PC2); + + // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison, + // it will also make S poison by being part of PC2.MaybePoison. + return all_of(PC1.MaybePoison, + [&](const SCEV *S) { return PC2.MaybePoison.contains(S); }); +} + const SCEV * ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl<const SCEV *> &Ops) { @@ -4010,11 +4073,6 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); if (Ops.size() == 1) return Ops[0]; - if (Ops.size() == 2 && - any_of(Ops, [](const SCEV *Op) { return isa<SCEVConstant>(Op); })) - return getMinMaxExpr( - SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind), - Ops); #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) { @@ -4063,6 +4121,39 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, return getSequentialMinMaxExpr(Kind, Ops); } + const SCEV *SaturationPoint; + ICmpInst::Predicate Pred; + switch (Kind) { + case scSequentialUMinExpr: + SaturationPoint = getZero(Ops[0]->getType()); + Pred = ICmpInst::ICMP_ULE; + break; + default: + llvm_unreachable("Not a sequential min/max type."); + } + + for (unsigned i = 1, e = Ops.size(); i != e; ++i) { + // We can replace %x umin_seq %y with %x umin %y if either: + // * %y being poison implies %x is also poison. + // * %x cannot be the saturating value (e.g. zero for umin). + if (::impliesPoison(Ops[i], Ops[i - 1]) || + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1], + SaturationPoint)) { + SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]}; + Ops[i - 1] = getMinMaxExpr( + SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind), + SeqOps); + Ops.erase(Ops.begin() + i); + return getSequentialMinMaxExpr(Kind, Ops); + } + // Fold %x umin_seq %y to %x if %x ule %y. + // TODO: We might be able to prove the predicate for a later operand. + if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) { + Ops.erase(Ops.begin() + i); + return getSequentialMinMaxExpr(Kind, Ops); + } + } + // Okay, it looks like we really DO need an expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; @@ -4265,39 +4356,20 @@ bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { return FoundAddRec; } -/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. -/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an -/// offset I, then return {S', I}, else return {\p S, nullptr}. -static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) { - const auto *Add = dyn_cast<SCEVAddExpr>(S); - if (!Add) - return {S, nullptr}; - - if (Add->getNumOperands() != 2) - return {S, nullptr}; - - auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0)); - if (!ConstOp) - return {S, nullptr}; - - return {Add->getOperand(1), ConstOp->getValue()}; -} - /// Return the ValueOffsetPair set for \p S. \p S can be represented /// by the value and offset from any ValueOffsetPair in the set. -ScalarEvolution::ValueOffsetPairSetVector * -ScalarEvolution::getSCEVValues(const SCEV *S) { +ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) { ExprValueMapType::iterator SI = ExprValueMap.find_as(S); if (SI == ExprValueMap.end()) - return nullptr; + return None; #ifndef NDEBUG if (VerifySCEVMap) { // Check there is no dangling Value in the set returned. - for (const auto &VE : SI->second) - assert(ValueExprMap.count(VE.first)); + for (Value *V : SI->second) + assert(ValueExprMap.count(V)); } #endif - return &SI->second; + return SI->second.getArrayRef(); } /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) @@ -4306,20 +4378,11 @@ ScalarEvolution::getSCEVValues(const SCEV *S) { void ScalarEvolution::eraseValueFromMap(Value *V) { ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { - const SCEV *S = I->second; - // Remove {V, 0} from the set of ExprValueMap[S] - if (auto *SV = getSCEVValues(S)) - SV->remove({V, nullptr}); - - // Remove {V, Offset} from the set of ExprValueMap[Stripped] - const SCEV *Stripped; - ConstantInt *Offset; - std::tie(Stripped, Offset) = splitAddExpr(S); - if (Offset != nullptr) { - if (auto *SV = getSCEVValues(Stripped)) - SV->remove({V, Offset}); - } - ValueExprMap.erase(V); + auto EVIt = ExprValueMap.find(I->second); + bool Removed = EVIt->second.remove(V); + (void) Removed; + assert(Removed && "Value not in ExprValueMap?"); + ValueExprMap.erase(I); } } @@ -4330,7 +4393,7 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { auto It = ValueExprMap.find_as(V); if (It == ValueExprMap.end()) { ValueExprMap.insert({SCEVCallbackVH(V, this), S}); - ExprValueMap[S].insert({V, nullptr}); + ExprValueMap[S].insert(V); } } @@ -4339,33 +4402,9 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - const SCEV *S = getExistingSCEV(V); - if (S == nullptr) { - S = createSCEV(V); - // During PHI resolution, it is possible to create two SCEVs for the same - // V, so it is needed to double check whether V->S is inserted into - // ValueExprMap before insert S->{V, 0} into ExprValueMap. - std::pair<ValueExprMapType::iterator, bool> Pair = - ValueExprMap.insert({SCEVCallbackVH(V, this), S}); - if (Pair.second) { - ExprValueMap[S].insert({V, nullptr}); - - // If S == Stripped + Offset, add Stripped -> {V, Offset} into - // ExprValueMap. - const SCEV *Stripped = S; - ConstantInt *Offset = nullptr; - std::tie(Stripped, Offset) = splitAddExpr(S); - // If stripped is SCEVUnknown, don't bother to save - // Stripped -> {V, offset}. It doesn't simplify and sometimes even - // increase the complexity of the expansion code. - // If V is GetElementPtrInst, don't save Stripped -> {V, offset} - // because it may generate add/sub instead of GEP in SCEV expansion. - if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) && - !isa<GetElementPtrInst>(V)) - ExprValueMap[Stripped].insert({V, Offset}); - } - } - return S; + if (const SCEV *S = getExistingSCEV(V)) + return S; + return createSCEVIter(V); } const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { @@ -4795,7 +4834,7 @@ public: SelectInst *SI = cast<SelectInst>(I); Optional<const SCEV *> Res = compareWithBackedgeCondition(SI->getCondition()); - if (Res.hasValue()) { + if (Res) { bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); } @@ -4803,7 +4842,7 @@ public: } default: { Optional<const SCEV *> Res = compareWithBackedgeCondition(I); - if (Res.hasValue()) + if (Res) Result = Res.getValue(); break; } @@ -5067,6 +5106,9 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { // Instcombine turns add of signmask into xor as a strength reduction step. if (RHSC->getValue().isSignMask()) return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); + // Binary `xor` is a bit-wise `add`. + if (V->getType()->isIntegerTy(1)) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); return BinaryOp(Op); case Instruction::LShr: @@ -5489,8 +5531,8 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds.implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && + !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) return false; return true; }; @@ -5872,31 +5914,53 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (const SCEV *S = createNodeFromSelectLikePHI(PN)) return S; - // If the PHI has a single incoming value, follow that value, unless the - // PHI's incoming blocks are in a different loop, in which case doing so - // risks breaking LCSSA form. Instcombine would normally zap these, but - // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) - if (LI.replacementPreservesLCSSAForm(PN, V)) - return getSCEV(V); + if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) + return getSCEV(V); // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } -const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, - Value *Cond, - Value *TrueVal, - Value *FalseVal) { - // Handle "constant" branch or select. This can occur for instance when a - // loop pass transforms an inner loop and moves on to process the outer loop. - if (auto *CI = dyn_cast<ConstantInt>(Cond)) - return getSCEV(CI->isOne() ? TrueVal : FalseVal); +bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, + SCEVTypes RootKind) { + struct FindClosure { + const SCEV *OperandToFind; + const SCEVTypes RootKind; // Must be a sequential min/max expression. + const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind. + + bool Found = false; + + bool canRecurseInto(SCEVTypes Kind) const { + // We can only recurse into the SCEV expression of the same effective type + // as the type of our root SCEV expression, and into zero-extensions. + return RootKind == Kind || NonSequentialRootKind == Kind || + scZeroExtend == Kind; + }; + + FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind) + : OperandToFind(OperandToFind), RootKind(RootKind), + NonSequentialRootKind( + SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( + RootKind)) {} + bool follow(const SCEV *S) { + Found = S == OperandToFind; + + return !isDone() && canRecurseInto(S->getSCEVType()); + } + + bool isDone() const { return Found; } + }; + + FindClosure FC(OperandToFind, RootKind); + visitAll(Root, FC); + return FC.Found; +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond( + Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) { // Try to match some simple smax or umax patterns. - auto *ICI = dyn_cast<ICmpInst>(Cond); - if (!ICI) - return getUnknown(I); + auto *ICI = Cond; Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); @@ -5958,31 +6022,36 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, } break; case ICmpInst::ICMP_NE: - // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && - isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { - const SCEV *One = getOne(I->getType()); - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); - const SCEV *LA = getSCEV(TrueVal); - const SCEV *RA = getSCEV(FalseVal); - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, One); - if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(One, LS), LDiff); - } - break; + // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y + std::swap(TrueVal, FalseVal); + LLVM_FALLTHROUGH; case ICmpInst::ICMP_EQ: - // n == 0 ? 1+x : n+x -> umax(n, 1)+x + // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { - const SCEV *One = getOne(I->getType()); - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); - const SCEV *LA = getSCEV(TrueVal); - const SCEV *RA = getSCEV(FalseVal); - const SCEV *LDiff = getMinusSCEV(LA, One); - const SCEV *RDiff = getMinusSCEV(RA, LS); - if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(One, LS), LDiff); + const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y + const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y + const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x + const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y + if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1)) + return getAddExpr(getUMaxExpr(X, C), Y); + } + // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...)) + // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...)) + // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...) + // -> umin_seq(x, umin (..., umin_seq(...), ...)) + if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() && + isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) { + const SCEV *X = getSCEV(LHS); + while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X)) + X = ZExt->getOperand(); + if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(I->getType())) { + const SCEV *FalseValExpr = getSCEV(FalseVal); + if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr)) + return getUMinExpr(getNoopOrZeroExtend(X, I->getType()), FalseValExpr, + /*Sequential=*/true); + } } break; default: @@ -5992,12 +6061,95 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, return getUnknown(I); } +static Optional<const SCEV *> +createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, + const SCEV *TrueExpr, const SCEV *FalseExpr) { + assert(CondExpr->getType()->isIntegerTy(1) && + TrueExpr->getType() == FalseExpr->getType() && + TrueExpr->getType()->isIntegerTy(1) && + "Unexpected operands of a select."); + + // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0) + // --> C + (umin_seq cond, x - C) + // + // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C)) + // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0) + // --> C + (umin_seq ~cond, x - C) + + // FIXME: while we can't legally model the case where both of the hands + // are fully variable, we only require that the *difference* is constant. + if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr)) + return None; + + const SCEV *X, *C; + if (isa<SCEVConstant>(TrueExpr)) { + CondExpr = SE->getNotSCEV(CondExpr); + X = FalseExpr; + C = TrueExpr; + } else { + X = TrueExpr; + C = FalseExpr; + } + return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C), + /*Sequential=*/true)); +} + +static Optional<const SCEV *> createNodeForSelectViaUMinSeq(ScalarEvolution *SE, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { + if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal)) + return None; + + const auto *SECond = SE->getSCEV(Cond); + const auto *SETrue = SE->getSCEV(TrueVal); + const auto *SEFalse = SE->getSCEV(FalseVal); + return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse); +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( + Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) { + assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?"); + assert(TrueVal->getType() == FalseVal->getType() && + V->getType() == TrueVal->getType() && + "Types of select hands and of the result must match."); + + // For now, only deal with i1-typed `select`s. + if (!V->getType()->isIntegerTy(1)) + return getUnknown(V); + + if (Optional<const SCEV *> S = + createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal)) + return *S; + + return getUnknown(V); +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, + Value *TrueVal, + Value *FalseVal) { + // Handle "constant" branch or select. This can occur for instance when a + // loop pass transforms an inner loop and moves on to process the outer loop. + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + return getSCEV(CI->isOne() ? TrueVal : FalseVal); + + if (auto *I = dyn_cast<Instruction>(V)) { + if (auto *ICI = dyn_cast<ICmpInst>(Cond)) { + const SCEV *S = createNodeForSelectOrPHIInstWithICmpInstCond( + I, ICI, TrueVal, FalseVal); + if (!isa<SCEVUnknown>(S)) + return S; + } + } + + return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal); +} + /// Expand GEP instructions into add and multiply operations. This allows them /// to be analyzed by regular SCEV code. const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - // Don't attempt to analyze GEPs over unsized objects. - if (!GEP->getSourceElementType()->isSized()) - return getUnknown(GEP); + assert(GEP->getSourceElementType()->isSized() && + "GEP source element type must be sized"); SmallVector<const SCEV *, 4> IndexExprs; for (Value *Index : GEP->indices()) @@ -6430,7 +6582,7 @@ ScalarEvolution::getRangeRef(const SCEV *S, // Check if the IR explicitly contains !range metadata. Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); - if (MDRange.hasValue()) + if (MDRange) ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(), RangeType); @@ -6719,7 +6871,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, FalseValue = *FalseVal; // Re-apply the cast we peeled off earlier - if (CastOp.hasValue()) + if (CastOp) switch (*CastOp) { default: llvm_unreachable("Unknown SCEV cast type!"); @@ -7020,6 +7172,211 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } +const SCEV *ScalarEvolution::createSCEVIter(Value *V) { + // Worklist item with a Value and a bool indicating whether all operands have + // been visited already. + using PointerTy = PointerIntPair<Value *, 1, bool>; + SmallVector<PointerTy> Stack; + + Stack.emplace_back(V, true); + Stack.emplace_back(V, false); + while (!Stack.empty()) { + auto E = Stack.pop_back_val(); + Value *CurV = E.getPointer(); + + if (getExistingSCEV(CurV)) + continue; + + SmallVector<Value *> Ops; + const SCEV *CreatedSCEV = nullptr; + // If all operands have been visited already, create the SCEV. + if (E.getInt()) { + CreatedSCEV = createSCEV(CurV); + } else { + // Otherwise get the operands we need to create SCEV's for before creating + // the SCEV for CurV. If the SCEV for CurV can be constructed trivially, + // just use it. + CreatedSCEV = getOperandsToCreate(CurV, Ops); + } + + if (CreatedSCEV) { + insertValueToMap(CurV, CreatedSCEV); + } else { + // Queue CurV for SCEV creation, followed by its's operands which need to + // be constructed first. + Stack.emplace_back(CurV, true); + for (Value *Op : Ops) + Stack.emplace_back(Op, false); + } + } + + return getExistingSCEV(V); +} + +const SCEV * +ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) { + if (!isSCEVable(V->getType())) + return getUnknown(V); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + // Don't attempt to analyze instructions in blocks that aren't + // reachable. Such instructions don't matter, and they aren't required + // to obey basic rules for definitions dominating uses which this + // analysis depends on. + if (!DT.isReachableFromEntry(I->getParent())) + return getUnknown(PoisonValue::get(V->getType())); + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + return getConstant(CI); + else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (!GA->isInterposable()) { + Ops.push_back(GA->getAliasee()); + return nullptr; + } + return getUnknown(V); + } else if (!isa<ConstantExpr>(V)) + return getUnknown(V); + + Operator *U = cast<Operator>(V); + if (auto BO = MatchBinaryOp(U, DT)) { + bool IsConstArg = isa<ConstantInt>(BO->RHS); + switch (U->getOpcode()) { + case Instruction::Add: { + // For additions and multiplications, traverse add/mul chains for which we + // can potentially create a single SCEV, to reduce the number of + // get{Add,Mul}Expr calls. + do { + if (BO->Op) { + if (BO->Op != V && getExistingSCEV(BO->Op)) { + Ops.push_back(BO->Op); + break; + } + } + Ops.push_back(BO->RHS); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || (NewBO->Opcode != Instruction::Add && + NewBO->Opcode != Instruction::Sub)) { + Ops.push_back(BO->LHS); + break; + } + BO = NewBO; + } while (true); + return nullptr; + } + + case Instruction::Mul: { + do { + if (BO->Op) { + if (BO->Op != V && getExistingSCEV(BO->Op)) { + Ops.push_back(BO->Op); + break; + } + } + Ops.push_back(BO->RHS); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || NewBO->Opcode != Instruction::Mul) { + Ops.push_back(BO->LHS); + break; + } + BO = NewBO; + } while (true); + return nullptr; + } + + case Instruction::AShr: + case Instruction::Shl: + case Instruction::Xor: + if (!IsConstArg) + return nullptr; + break; + case Instruction::And: + case Instruction::Or: + if (!IsConstArg && BO->LHS->getType()->isIntegerTy(1)) + return nullptr; + break; + default: + break; + } + + Ops.push_back(BO->LHS); + Ops.push_back(BO->RHS); + return nullptr; + } + + switch (U->getOpcode()) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::PtrToInt: + Ops.push_back(U->getOperand(0)); + return nullptr; + + case Instruction::BitCast: + if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) { + Ops.push_back(U->getOperand(0)); + return nullptr; + } + return getUnknown(V); + + case Instruction::SDiv: + case Instruction::SRem: + Ops.push_back(U->getOperand(0)); + Ops.push_back(U->getOperand(1)); + return nullptr; + + case Instruction::GetElementPtr: + assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() && + "GEP source element type must be sized"); + for (Value *Index : U->operands()) + Ops.push_back(Index); + return nullptr; + + case Instruction::IntToPtr: + return getUnknown(V); + + case Instruction::PHI: + // Keep constructing SCEVs' for phis recursively for now. + return nullptr; + + case Instruction::Select: + for (Value *Inc : U->operands()) + Ops.push_back(Inc); + return nullptr; + break; + + case Instruction::Call: + case Instruction::Invoke: + if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) { + Ops.push_back(RV); + return nullptr; + } + + if (auto *II = dyn_cast<IntrinsicInst>(U)) { + switch (II->getIntrinsicID()) { + case Intrinsic::abs: + Ops.push_back(II->getArgOperand(0)); + return nullptr; + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::usub_sat: + case Intrinsic::uadd_sat: + Ops.push_back(II->getArgOperand(0)); + Ops.push_back(II->getArgOperand(1)); + return nullptr; + case Intrinsic::start_loop_iterations: + Ops.push_back(II->getArgOperand(0)); + return nullptr; + default: + break; + } + } + break; + } + + return nullptr; +} + const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -7030,7 +7387,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // to obey basic rules for definitions dominating uses which this // analysis depends on. if (!DT.isReachableFromEntry(I->getParent())) - return getUnknown(UndefValue::get(V->getType())); + return getUnknown(PoisonValue::get(V->getType())); } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) @@ -7038,6 +7395,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { else if (!isa<ConstantExpr>(V)) return getUnknown(V); + const SCEV *LHS; + const SCEV *RHS; + Operator *U = cast<Operator>(V); if (auto BO = MatchBinaryOp(U, DT)) { switch (BO->Opcode) { @@ -7103,8 +7463,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); if (Flags != SCEV::FlagAnyWrap) { - MulOps.push_back( - getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + MulOps.push_back(getMulExpr(LHS, RHS, Flags)); break; } } @@ -7121,14 +7482,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getMulExpr(MulOps); } case Instruction::UDiv: - return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + return getUDivExpr(LHS, RHS); case Instruction::URem: - return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + return getURemExpr(LHS, RHS); case Instruction::Sub: { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BO->Op) Flags = getNoWrapFlagsFromUB(BO->Op); - return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + return getMinusSCEV(LHS, RHS, Flags); } case Instruction::And: // For an expression like x&255 that merely masks off the high bits, @@ -7180,6 +7547,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { MulCount); } } + // Binary `and` is a bit-wise `umin`. + if (BO->LHS->getType()->isIntegerTy(1)) { + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + return getUMinExpr(LHS, RHS); + } break; case Instruction::Or: @@ -7199,6 +7572,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW)); } } + // Binary `or` is a bit-wise `umax`. + if (BO->LHS->getType()->isIntegerTy(1)) { + LHS = getSCEV(BO->LHS); + RHS = getSCEV(BO->RHS); + return getUMaxExpr(LHS, RHS); + } break; case Instruction::Xor: @@ -7266,9 +7645,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); } - Constant *X = ConstantInt::get( + ConstantInt *X = ConstantInt::get( getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); + return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags); } break; @@ -7394,14 +7773,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return createNodeForPHI(cast<PHINode>(U)); case Instruction::Select: - // U can also be a select constant expr, which let fall through. Since - // createNodeForSelect only works for a condition that is an `ICmpInst`, and - // constant expressions cannot have instructions as operands, we'd have - // returned getUnknown for a select constant expressions anyway. - if (isa<Instruction>(U)) - return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), - U->getOperand(1), U->getOperand(2)); - break; + return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1), + U->getOperand(2)); case Instruction::Call: case Instruction::Invoke: @@ -7415,17 +7788,21 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { getSCEV(II->getArgOperand(0)), /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne()); case Intrinsic::umax: - return getUMaxExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); + LHS = getSCEV(II->getArgOperand(0)); + RHS = getSCEV(II->getArgOperand(1)); + return getUMaxExpr(LHS, RHS); case Intrinsic::umin: - return getUMinExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); + LHS = getSCEV(II->getArgOperand(0)); + RHS = getSCEV(II->getArgOperand(1)); + return getUMinExpr(LHS, RHS); case Intrinsic::smax: - return getSMaxExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); + LHS = getSCEV(II->getArgOperand(0)); + RHS = getSCEV(II->getArgOperand(1)); + return getSMaxExpr(LHS, RHS); case Intrinsic::smin: - return getSMinExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); + LHS = getSCEV(II->getArgOperand(0)); + RHS = getSCEV(II->getArgOperand(1)); + return getSMinExpr(LHS, RHS); case Intrinsic::usub_sat: { const SCEV *X = getSCEV(II->getArgOperand(0)); const SCEV *Y = getSCEV(II->getArgOperand(1)); @@ -7640,7 +8017,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { Res = Multiple; Res = (unsigned)GreatestCommonDivisor64(*Res, Multiple); } - return Res.getValueOr(1); + return Res.value_or(1); } unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, @@ -7708,7 +8085,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, const SCEV * ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, - SCEVUnionPredicate &Preds) { + SmallVector<const SCEVPredicate *, 4> &Preds) { return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); } @@ -7870,7 +8247,6 @@ void ScalarEvolution::forgetLoop(const Loop *L) { if (LoopUsersItr != LoopUsers.end()) { ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(), LoopUsersItr->second.end()); - LoopUsers.erase(LoopUsersItr); } // Drop information about expressions based on loop-header PHIs. @@ -7900,9 +8276,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { } void ScalarEvolution::forgetTopmostLoop(const Loop *L) { - while (Loop *Parent = L->getParentLoop()) - L = Parent; - forgetLoop(L); + forgetLoop(L->getOutermostLoop()); } void ScalarEvolution::forgetValue(Value *V) { @@ -7944,7 +8318,7 @@ void ScalarEvolution::forgetLoopDispositions(const Loop *L) { /// the relevant loop exiting block using getExact(ExitingBlock, SE). const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, - SCEVUnionPredicate *Preds) const { + SmallVector<const SCEVPredicate *, 4> *Preds) const { // If any exits were not computable, the loop is not computable. if (!isComplete() || ExitNotTaken.empty()) return SE->getCouldNotCompute(); @@ -7966,14 +8340,18 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, Ops.push_back(BECount); - if (Preds && !ENT.hasAlwaysTruePredicate()) - Preds->add(ENT.Predicate.get()); + if (Preds) + for (auto *P : ENT.Predicates) + Preds->push_back(P); assert((Preds || ENT.hasAlwaysTruePredicate()) && "Predicate should be always true!"); } - return SE->getUMinFromMismatchedTypes(Ops); + // If an earlier exit exits on the first iteration (exit count zero), then + // a later poison exit count should not propagate into the result. This are + // exactly the semantics provided by umin_seq. + return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true); } /// Get the exact not taken count for this loop exit. @@ -8082,16 +8460,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( [&](const EdgeExitInfo &EEI) { BasicBlock *ExitBB = EEI.first; const ExitLimit &EL = EEI.second; - if (EL.Predicates.empty()) - return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken, - nullptr); - - std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); - for (auto *Pred : EL.Predicates) - Predicate->add(Pred); - return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken, - std::move(Predicate)); + EL.Predicates); }); assert((isa<SCEVCouldNotCompute>(ConstantMax) || isa<SCEVConstant>(ConstantMax)) && @@ -8385,11 +8755,6 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( BECount = getUMinFromMismatchedTypes( EL0.ExactNotTaken, EL1.ExactNotTaken, /*Sequential=*/!isa<BinaryOperator>(ExitCond)); - - // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form, - // it should have been simplified to zero (see the condition (3) above) - assert(!isa<BinaryOperator>(ExitCond) || !EL0.ExactNotTaken->isZero() || - BECount->isZero()); } if (EL0.MaxNotTaken == getCouldNotCompute()) MaxBECount = EL1.MaxNotTaken; @@ -8470,7 +8835,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L); // Simplify the operands before analyzing them. (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0, - ControllingFiniteLoop); + (EnableFiniteLoopControl ? ControllingFiniteLoop + : false)); // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. @@ -8683,7 +9049,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( // and the kind of shift should be match the kind of shift we peeled // off, if any. - (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); + (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut); }; PHINode *PN; @@ -8871,13 +9237,6 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, Operands[i] = C; } - if (CmpInst *CI = dyn_cast<CmpInst>(I)) - return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], - Operands[1], DL, TLI); - if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); - } return ConstantFoldInstOperands(I, Operands, DL, TLI); } @@ -9121,58 +9480,42 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } case scAddExpr: { const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); - if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { - if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { - unsigned AS = PTy->getAddressSpace(); - Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); - C = ConstantExpr::getBitCast(C, DestPtrTy); + Constant *C = nullptr; + for (const SCEV *Op : SA->operands()) { + Constant *OpC = BuildConstantFromSCEV(Op); + if (!OpC) + return nullptr; + if (!C) { + C = OpC; + continue; } - for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); - if (!C2) - return nullptr; - - // First pointer! - if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { - unsigned AS = C2->getType()->getPointerAddressSpace(); - std::swap(C, C2); - Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); - // The offsets have been converted to bytes. We can add bytes to an - // i8* by GEP with the byte count in the first index. - C = ConstantExpr::getBitCast(C, DestPtrTy); - } - - // Don't bother trying to sum two pointers. We probably can't - // statically compute a load that results from it anyway. - if (C2->getType()->isPointerTy()) - return nullptr; - - if (C->getType()->isPointerTy()) { - C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()), - C, C2); - } else { - C = ConstantExpr::getAdd(C, C2); - } + assert(!C->getType()->isPointerTy() && + "Can only have one pointer, and it must be last"); + if (auto *PT = dyn_cast<PointerType>(OpC->getType())) { + // The offsets have been converted to bytes. We can add bytes to an + // i8* by GEP with the byte count in the first index. + Type *DestPtrTy = + Type::getInt8PtrTy(PT->getContext(), PT->getAddressSpace()); + OpC = ConstantExpr::getBitCast(OpC, DestPtrTy); + C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()), + OpC, C); + } else { + C = ConstantExpr::getAdd(C, OpC); } - return C; } - return nullptr; + return C; } case scMulExpr: { const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); - if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { - // Don't bother with pointers at all. - if (C->getType()->isPointerTy()) + Constant *C = nullptr; + for (const SCEV *Op : SM->operands()) { + assert(!Op->getType()->isPointerTy() && "Can't multiply pointers"); + Constant *OpC = BuildConstantFromSCEV(Op); + if (!OpC) return nullptr; - for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); - if (!C2 || C2->getType()->isPointerTy()) - return nullptr; - C = ConstantExpr::getMul(C, C2); - } - return C; + C = C ? ConstantExpr::getMul(C, OpC) : OpC; } - return nullptr; + return C; } case scUDivExpr: { const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); @@ -9297,15 +9640,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (MadeImprovement) { Constant *C = nullptr; const DataLayout &DL = getDataLayout(); - if (const CmpInst *CI = dyn_cast<CmpInst>(I)) - C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], - Operands[1], DL, &TLI); - else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) { - if (!Load->isVolatile()) - C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(), - DL); - } else - C = ConstantFoldInstOperands(I, Operands, DL, &TLI); + C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; return getSCEV(C); } @@ -9535,15 +9870,15 @@ GetQuadraticEquation(const SCEVAddRecExpr *AddRec) { /// (b) if neither X nor Y exist, return None, /// (c) if exactly one of X and Y exists, return that value. static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) { - if (X.hasValue() && Y.hasValue()) { + if (X && Y) { unsigned W = std::max(X->getBitWidth(), Y->getBitWidth()); - APInt XW = X->sextOrSelf(W); - APInt YW = Y->sextOrSelf(W); + APInt XW = X->sext(W); + APInt YW = Y->sext(W); return XW.slt(YW) ? *X : *Y; } - if (!X.hasValue() && !Y.hasValue()) + if (!X && !Y) return None; - return X.hasValue() ? *X : *Y; + return X ? *X : *Y; } /// Helper function to truncate an optional APInt to a given BitWidth. @@ -9558,7 +9893,7 @@ static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) { /// equation are BW+1 bits wide (to avoid truncation when converting from /// the addrec to the equation). static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) { - if (!X.hasValue()) + if (!X) return None; unsigned W = X->getBitWidth(); if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth)) @@ -9585,13 +9920,13 @@ SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { APInt A, B, C, M; unsigned BitWidth; auto T = GetQuadraticEquation(AddRec); - if (!T.hasValue()) + if (!T) return None; std::tie(A, B, C, M, BitWidth) = *T; LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n"); Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1); - if (!X.hasValue()) + if (!X) return None; ConstantInt *CX = ConstantInt::get(SE.getContext(), *X); @@ -9627,7 +9962,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, APInt A, B, C, M; unsigned BitWidth; auto T = GetQuadraticEquation(AddRec); - if (!T.hasValue()) + if (!T) return None; // Be careful about the return value: there can be two reasons for not @@ -9672,7 +10007,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, // If SolveQuadraticEquationWrap returns None, it means that there can // be a solution, but the function failed to find it. We cannot treat it // as "no solution". - if (!SO.hasValue() || !UO.hasValue()) + if (!SO || !UO) return { None, false }; // Check the smaller value first to see if it leaves the range. @@ -9690,8 +10025,8 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, std::tie(A, B, C, M, BitWidth) = *T; // Lower bound is inclusive, subtract 1 to represent the exiting value. - APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1; - APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth()); + APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1; + APInt Upper = Range.getUpper().sext(A.getBitWidth()); auto SL = SolveForBoundary(Lower); auto SU = SolveForBoundary(Upper); // If any of the solutions was unknown, no meaninigful conclusions can @@ -9776,7 +10111,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { - const auto *R = cast<SCEVConstant>(getConstant(S.getValue())); + const auto *R = cast<SCEVConstant>(getConstant(*S)); return ExitLimit(R, R, false, Predicates); } return getCouldNotCompute(); @@ -10296,7 +10631,7 @@ ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS, auto ResultSwapped = getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred)); - assert(ResultSwapped.hasValue() && "should be able to analyze both!"); + assert(ResultSwapped && "should be able to analyze both!"); assert(ResultSwapped.getValue() != Result.getValue() && "monotonicity should flip as we flip the predicate"); } @@ -10479,17 +10814,27 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( return false; if (Pred == CmpInst::ICMP_NE) { - if (CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || - CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS))) + auto SL = getSignedRange(LHS); + auto SR = getSignedRange(RHS); + if (CheckRanges(SL, SR)) + return true; + auto UL = getUnsignedRange(LHS); + auto UR = getUnsignedRange(RHS); + if (CheckRanges(UL, UR)) return true; auto *Diff = getMinusSCEV(LHS, RHS); return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff); } - if (CmpInst::isSigned(Pred)) - return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + if (CmpInst::isSigned(Pred)) { + auto SL = getSignedRange(LHS); + auto SR = getSignedRange(RHS); + return CheckRanges(SL, SR); + } - return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); + auto UL = getUnsignedRange(LHS); + auto UR = getUnsignedRange(RHS); + return CheckRanges(UL, UR); } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, @@ -12596,7 +12941,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, if (isQuadratic()) { if (auto S = SolveQuadraticAddRecRange(this, Range, SE)) - return SE.getConstant(S.getValue()); + return SE.getConstant(*S); } return SE.getCouldNotCompute(); @@ -12636,6 +12981,15 @@ bool ScalarEvolution::containsUndefs(const SCEV *S) const { }); } +// Return true when S contains a value that is a nullptr. +bool ScalarEvolution::containsErasedValue(const SCEV *S) const { + return SCEVExprContains(S, [](const SCEV *S) { + if (const auto *SU = dyn_cast<SCEVUnknown>(S)) + return SU->getValue() == nullptr; + return false; + }); +} + /// Return the size of an element read or written by Inst. const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { Type *Ty; @@ -12820,12 +13174,13 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; - SCEVUnionPredicate Pred; - auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); + SmallVector<const SCEVPredicate *, 4> Preds; + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); if (!isa<SCEVCouldNotCompute>(PBT)) { OS << "Predicated backedge-taken count is " << *PBT << "\n"; OS << " Predicates:\n"; - Pred.print(OS, 4); + for (auto *P : Preds) + P->print(OS, 4); } else { OS << "Unpredictable predicated backedge-taken count. "; } @@ -13202,12 +13557,10 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { auto ExprIt = ExprValueMap.find(S); if (ExprIt != ExprValueMap.end()) { - for (auto &ValueAndOffset : ExprIt->second) { - if (ValueAndOffset.second == nullptr) { - auto ValueIt = ValueExprMap.find_as(ValueAndOffset.first); - if (ValueIt != ValueExprMap.end()) - ValueExprMap.erase(ValueIt); - } + for (Value *V : ExprIt->second) { + auto ValueIt = ValueExprMap.find_as(V); + if (ValueIt != ValueExprMap.end()) + ValueExprMap.erase(ValueIt); } ExprValueMap.erase(ExprIt); } @@ -13258,6 +13611,43 @@ ScalarEvolution::getUsedLoops(const SCEV *S, SCEVTraversal<FindUsedLoops>(F).visitAll(S); } +void ScalarEvolution::getReachableBlocks( + SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) { + SmallVector<BasicBlock *> Worklist; + Worklist.push_back(&F.getEntryBlock()); + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + if (!Reachable.insert(BB).second) + continue; + + Value *Cond; + BasicBlock *TrueBB, *FalseBB; + if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB), + m_BasicBlock(FalseBB)))) { + if (auto *C = dyn_cast<ConstantInt>(Cond)) { + Worklist.push_back(C->isOne() ? TrueBB : FalseBB); + continue; + } + + if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) { + const SCEV *L = getSCEV(Cmp->getOperand(0)); + const SCEV *R = getSCEV(Cmp->getOperand(1)); + if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) { + Worklist.push_back(TrueBB); + continue; + } + if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L, + R)) { + Worklist.push_back(FalseBB); + continue; + } + } + } + + append_range(Worklist, successors(BB)); + } +} + void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); ScalarEvolution SE2(F, TLI, AC, DT, LI); @@ -13282,13 +13672,44 @@ void ScalarEvolution::verify() const { }; SCEVMapper SCM(SE2); + SmallPtrSet<BasicBlock *, 16> ReachableBlocks; + SE2.getReachableBlocks(ReachableBlocks, F); + + auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * { + if (containsUndefs(Old) || containsUndefs(New)) { + // SCEV treats "undef" as an unknown but consistent value (i.e. it does + // not propagate undef aggressively). This means we can (and do) fail + // verification in cases where a transform makes a value go from "undef" + // to "undef+1" (say). The transform is fine, since in both cases the + // result is "undef", but SCEV thinks the value increased by 1. + return nullptr; + } + + // Unless VerifySCEVStrict is set, we only compare constant deltas. + const SCEV *Delta = SE2.getMinusSCEV(Old, New); + if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta)) + return nullptr; + + return Delta; + }; while (!LoopStack.empty()) { auto *L = LoopStack.pop_back_val(); llvm::append_range(LoopStack, *L); - auto *CurBECount = SCM.visit( - const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); + // Only verify BECounts in reachable loops. For an unreachable loop, + // any BECount is legal. + if (!ReachableBlocks.contains(L->getHeader())) + continue; + + // Only verify cached BECounts. Computing new BECounts may change the + // results of subsequent SCEV uses. + auto It = BackedgeTakenCounts.find(L); + if (It == BackedgeTakenCounts.end()) + continue; + + auto *CurBECount = + SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this))); auto *NewBECount = SE2.getBackedgeTakenCount(L); if (CurBECount == SE2.getCouldNotCompute() || @@ -13301,16 +13722,6 @@ void ScalarEvolution::verify() const { continue; } - if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) { - // SCEV treats "undef" as an unknown but consistent value (i.e. it does - // not propagate undef aggressively). This means we can (and do) fail - // verification in cases where a transform makes the trip count of a loop - // go from "undef" to "undef+1" (say). The transform is fine, since in - // both cases the loop iterates "undef" times, but SCEV thinks we - // increased the trip count of the loop by 1 incorrectly. - continue; - } - if (SE.getTypeSizeInBits(CurBECount->getType()) > SE.getTypeSizeInBits(NewBECount->getType())) NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); @@ -13318,10 +13729,8 @@ void ScalarEvolution::verify() const { SE.getTypeSizeInBits(NewBECount->getType())) CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); - const SCEV *Delta = SE2.getMinusSCEV(CurBECount, NewBECount); - - // Unless VerifySCEVStrict is set, we only compare constant deltas. - if ((VerifySCEVStrict || isa<SCEVConstant>(Delta)) && !Delta->isZero()) { + const SCEV *Delta = GetDelta(CurBECount, NewBECount); + if (Delta && !Delta->isZero()) { dbgs() << "Trip Count for " << *L << " Changed!\n"; dbgs() << "Old: " << *CurBECount << "\n"; dbgs() << "New: " << *NewBECount << "\n"; @@ -13335,10 +13744,8 @@ void ScalarEvolution::verify() const { SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end()); while (!Worklist.empty()) { Loop *L = Worklist.pop_back_val(); - if (ValidLoops.contains(L)) - continue; - ValidLoops.insert(L); - Worklist.append(L->begin(), L->end()); + if (ValidLoops.insert(L).second) + Worklist.append(L->begin(), L->end()); } for (auto &KV : ValueExprMap) { #ifndef NDEBUG @@ -13351,27 +13758,38 @@ void ScalarEvolution::verify() const { // Check that the value is also part of the reverse map. auto It = ExprValueMap.find(KV.second); - if (It == ExprValueMap.end() || !It->second.contains({KV.first, nullptr})) { + if (It == ExprValueMap.end() || !It->second.contains(KV.first)) { dbgs() << "Value " << *KV.first << " is in ValueExprMap but not in ExprValueMap\n"; std::abort(); } - } - for (const auto &KV : ExprValueMap) { - for (const auto &ValueAndOffset : KV.second) { - if (ValueAndOffset.second != nullptr) + if (auto *I = dyn_cast<Instruction>(&*KV.first)) { + if (!ReachableBlocks.contains(I->getParent())) continue; + const SCEV *OldSCEV = SCM.visit(KV.second); + const SCEV *NewSCEV = SE2.getSCEV(I); + const SCEV *Delta = GetDelta(OldSCEV, NewSCEV); + if (Delta && !Delta->isZero()) { + dbgs() << "SCEV for value " << *I << " changed!\n" + << "Old: " << *OldSCEV << "\n" + << "New: " << *NewSCEV << "\n" + << "Delta: " << *Delta << "\n"; + std::abort(); + } + } + } - auto It = ValueExprMap.find_as(ValueAndOffset.first); + for (const auto &KV : ExprValueMap) { + for (Value *V : KV.second) { + auto It = ValueExprMap.find_as(V); if (It == ValueExprMap.end()) { - dbgs() << "Value " << *ValueAndOffset.first + dbgs() << "Value " << *V << " is in ExprValueMap but not in ValueExprMap\n"; std::abort(); } if (It->second != KV.first) { - dbgs() << "Value " << *ValueAndOffset.first - << " mapped to " << *It->second + dbgs() << "Value " << *V << " mapped to " << *It->second << " rather than " << *KV.first << "\n"; std::abort(); } @@ -13537,18 +13955,25 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, const SCEV *RHS) { + return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS); +} + +const SCEVPredicate * +ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { FoldingSetNodeID ID; assert(LHS->getType() == RHS->getType() && "Type mismatch between LHS and RHS"); // Unique this node based on the arguments - ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddInteger(SCEVPredicate::P_Compare); + ID.AddInteger(Pred); ID.AddPointer(LHS); ID.AddPointer(RHS); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; - SCEVEqualPredicate *Eq = new (SCEVAllocator) - SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + SCEVComparePredicate *Eq = new (SCEVAllocator) + SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS); UniquePreds.InsertNode(Eq, IP); return Eq; } @@ -13585,18 +14010,24 @@ public: /// \p NewPreds such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, - SCEVUnionPredicate *Pred) { + const SCEVPredicate *Pred) { SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (Pred) { - auto ExprPreds = Pred->getPredicatesForExpr(Expr); - for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) - if (IPred->getLHS() == Expr) - return IPred->getRHS(); + if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) { + for (auto *Pred : U->getPredicates()) + if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) + if (IPred->getLHS() == Expr && + IPred->getPredicate() == ICmpInst::ICMP_EQ) + return IPred->getRHS(); + } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) { + if (IPred->getLHS() == Expr && + IPred->getPredicate() == ICmpInst::ICMP_EQ) + return IPred->getRHS(); + } } return convertToAddRecWithPreds(Expr); } @@ -13636,7 +14067,7 @@ public: private: explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, - SCEVUnionPredicate *Pred) + const SCEVPredicate *Pred) : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} bool addOverflowAssumption(const SCEVPredicate *P) { @@ -13670,8 +14101,7 @@ private: for (auto *P : PredicatedRewrite->second){ // Wrap predicates from outer loops are not supported. if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) { - auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr()); - if (L != AR->getLoop()) + if (L != WP->getExpr()->getLoop()) return Expr; } if (!addOverflowAssumption(P)) @@ -13681,14 +14111,15 @@ private: } SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; - SCEVUnionPredicate *Pred; + const SCEVPredicate *Pred; const Loop *L; }; } // end anonymous namespace -const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, - SCEVUnionPredicate &Preds) { +const SCEV * +ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, + const SCEVPredicate &Preds) { return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); } @@ -13715,28 +14146,36 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind) : FastID(ID), Kind(Kind) {} -SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, - const SCEV *LHS, const SCEV *RHS) - : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) { +SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, + const ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) + : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast<SCEVEqualPredicate>(N); +bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVComparePredicate>(N); if (!Op) return false; + if (Pred != ICmpInst::ICMP_EQ) + return false; + return Op->LHS == LHS && Op->RHS == RHS; } -bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } +bool SCEVComparePredicate::isAlwaysTrue() const { return false; } -const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } +void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const { + if (Pred == ICmpInst::ICMP_EQ) + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; + else + OS.indent(Depth) << "Compare predicate: " << *LHS + << " " << CmpInst::getPredicateName(Pred) << ") " + << *RHS << "\n"; -void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { - OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, @@ -13744,7 +14183,7 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, IncrementWrapFlags Flags) : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} -const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } +const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { const auto *Op = dyn_cast<SCEVWrapPredicate>(N); @@ -13793,38 +14232,26 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, } /// Union predicates don't get cached so create a dummy set ID for it. -SCEVUnionPredicate::SCEVUnionPredicate() - : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} +SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds) + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { + for (auto *P : Preds) + add(P); +} bool SCEVUnionPredicate::isAlwaysTrue() const { return all_of(Preds, [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -ArrayRef<const SCEVPredicate *> -SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { - auto I = SCEVToPreds.find(Expr); - if (I == SCEVToPreds.end()) - return ArrayRef<const SCEVPredicate *>(); - return I->second; -} - bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) return all_of(Set->Preds, [this](const SCEVPredicate *I) { return this->implies(I); }); - auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); - if (ScevPredsIt == SCEVToPreds.end()) - return false; - auto &SCEVPreds = ScevPredsIt->second; - - return any_of(SCEVPreds, + return any_of(Preds, [N](const SCEVPredicate *I) { return I->implies(N); }); } -const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } - void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { for (auto Pred : Preds) Pred->print(OS, Depth); @@ -13837,20 +14264,15 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { return; } - if (implies(N)) - return; - - const SCEV *Key = N->getExpr(); - assert(Key && "Only SCEVUnionPredicate doesn't have an " - " associated expression!"); - - SCEVToPreds[Key].push_back(N); Preds.push_back(N); } PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) - : SE(SE), L(L) {} + : SE(SE), L(L) { + SmallVector<const SCEVPredicate*, 4> Empty; + Preds = std::make_unique<SCEVUnionPredicate>(Empty); +} void ScalarEvolution::registerUser(const SCEV *User, ArrayRef<const SCEV *> Ops) { @@ -13875,7 +14297,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); Entry = {Generation, NewSCEV}; return NewSCEV; @@ -13883,22 +14305,27 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { if (!BackedgeCount) { - SCEVUnionPredicate BackedgePred; - BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); - addPredicate(BackedgePred); + SmallVector<const SCEVPredicate *, 4> Preds; + BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds); + for (auto *P : Preds) + addPredicate(*P); } return BackedgeCount; } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds.implies(&Pred)) + if (Preds->implies(&Pred)) return; - Preds.add(&Pred); + + auto &OldPreds = Preds->getPredicates(); + SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end()); + NewPreds.push_back(&Pred); + Preds = std::make_unique<SCEVUnionPredicate>(NewPreds); updateGeneration(); } -const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { - return Preds; +const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const { + return *Preds; } void PredicatedScalarEvolution::updateGeneration() { @@ -13906,7 +14333,7 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)}; } } } @@ -13952,17 +14379,17 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { return nullptr; for (auto *P : NewPreds) - Preds.add(P); + addPredicate(*P); - updateGeneration(); RewriteMap[SE.getSCEV(V)] = {Generation, New}; return New; } PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) - : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), - Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), + Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (auto I : Init.FlagsMap) FlagsMap.insert(I); } @@ -14243,12 +14670,23 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { ExprsToRewrite.push_back(LHS); } }; - // First, collect conditions from dominating branches. Starting at the loop + + SmallVector<std::pair<Value *, bool>> Terms; + // First, collect information from assumptions dominating the loop. + for (auto &AssumeVH : AC.assumptions()) { + if (!AssumeVH) + continue; + auto *AssumeI = cast<CallInst>(AssumeVH); + if (!DT.dominates(AssumeI, L->getHeader())) + continue; + Terms.emplace_back(AssumeI->getOperand(0), true); + } + + // Second, collect conditions from dominating branches. Starting at the loop // predecessor, climb up the predecessor chain, as long as there are // predecessors that can be found that have unique successors leading to the // original header. // TODO: share this logic with isLoopEntryGuardedByCond. - SmallVector<std::pair<Value *, bool>> Terms; for (std::pair<const BasicBlock *, const BasicBlock *> Pair( L->getLoopPredecessor(), L->getHeader()); Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { @@ -14280,8 +14718,9 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) { auto Predicate = EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); - CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)), - getSCEV(Cmp->getOperand(1)), RewriteMap); + const auto *LHS = getSCEV(Cmp->getOperand(0)); + const auto *RHS = getSCEV(Cmp->getOperand(1)); + CollectCondition(Predicate, LHS, RHS, RewriteMap); continue; } @@ -14294,18 +14733,6 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { } } - // Also collect information from assumptions dominating the loop. - for (auto &AssumeVH : AC.assumptions()) { - if (!AssumeVH) - continue; - auto *AssumeI = cast<CallInst>(AssumeVH); - auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0)); - if (!Cmp || !DT.dominates(AssumeI, L->getHeader())) - continue; - CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)), - getSCEV(Cmp->getOperand(1)), RewriteMap); - } - if (RewriteMap.empty()) return Expr; diff --git a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp index f4fa159d1ec7..3d47dc6b30df 100644 --- a/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/InitializePasses.h" using namespace llvm; diff --git a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp index 64e908bdf342..0619569bf816 100644 --- a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -15,9 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/IR/Constants.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" #include <cassert> #include <cstdint> diff --git a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp index 209ae66ca53e..22dff5efec5c 100644 --- a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -13,6 +13,7 @@ #include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" using namespace llvm; diff --git a/llvm/lib/Analysis/ScopedNoAliasAA.cpp b/llvm/lib/Analysis/ScopedNoAliasAA.cpp index e847bf8f0f6b..f510991b4463 100644 --- a/llvm/lib/Analysis/ScopedNoAliasAA.cpp +++ b/llvm/lib/Analysis/ScopedNoAliasAA.cpp @@ -36,7 +36,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/InitializePasses.h" diff --git a/llvm/lib/Analysis/StackLifetime.cpp b/llvm/lib/Analysis/StackLifetime.cpp index 9056cc01484d..52e8566aca3c 100644 --- a/llvm/lib/Analysis/StackLifetime.cpp +++ b/llvm/lib/Analysis/StackLifetime.cpp @@ -19,17 +19,12 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormattedStream.h" #include <algorithm> -#include <memory> #include <tuple> using namespace llvm; @@ -75,7 +70,7 @@ static const AllocaInst *findMatchingAlloca(const IntrinsicInst &II, auto AllocaSizeInBits = AI->getAllocationSizeInBits(DL); if (!AllocaSizeInBits) return nullptr; - int64_t AllocaSize = AllocaSizeInBits.getValue() / 8; + int64_t AllocaSize = *AllocaSizeInBits / 8; auto *Size = dyn_cast<ConstantInt>(II.getArgOperand(0)); if (!Size) diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp index 54f3605ee033..94b646ab7c06 100644 --- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp +++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ModuleSummaryAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/StackLifetime.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DerivedTypes.h" @@ -384,9 +383,9 @@ bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI, const SCEV *Max = SE.getMinusSCEV(ToDiffTy(SE.getConstant(Size.getUpper())), ToDiffTy(AccessSize)); return SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SGE, Diff, Min, I) - .getValueOr(false) && + .value_or(false) && SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SLE, Diff, Max, I) - .getValueOr(false); + .value_or(false); } /// The function analyzes all local uses of Ptr (alloca or argument) and diff --git a/llvm/lib/Analysis/StratifiedSets.h b/llvm/lib/Analysis/StratifiedSets.h index 60ea2451b0ef..883ebd24efdc 100644 --- a/llvm/lib/Analysis/StratifiedSets.h +++ b/llvm/lib/Analysis/StratifiedSets.h @@ -340,10 +340,10 @@ public: return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); } - bool has(const T &Elem) const { return get(Elem).hasValue(); } + bool has(const T &Elem) const { return get(Elem).has_value(); } bool add(const T &Main) { - if (get(Main).hasValue()) + if (get(Main)) return false; auto NewIndex = getNewUnlinkedIndex(); @@ -560,7 +560,7 @@ private: Optional<StratifiedIndex> indexOf(const T &Val) { auto MaybeVal = get(Val); - if (!MaybeVal.hasValue()) + if (!MaybeVal) return None; auto *Info = *MaybeVal; auto &Link = linksAt(Info->Index); diff --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp index ff833b55bbce..3446e50a4344 100644 --- a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp +++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp @@ -116,18 +116,16 @@ // around from the latch. // //===----------------------------------------------------------------------===// + #include "llvm/Analysis/SyncDependenceAnalysis.h" -#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include <functional> -#include <stack> -#include <unordered_set> #define DEBUG_TYPE "sync-dependence" @@ -257,7 +255,7 @@ SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); } -SyncDependenceAnalysis::~SyncDependenceAnalysis() {} +SyncDependenceAnalysis::~SyncDependenceAnalysis() = default; // divergence propagator for reducible CFGs struct DivergencePropagator { diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp index a3edce76cd88..29c41fda5e28 100644 --- a/llvm/lib/Analysis/SyntheticCountsUtils.cpp +++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp @@ -14,9 +14,6 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/ModuleSummaryIndex.h" using namespace llvm; @@ -57,7 +54,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( if (!OptProfCount) continue; auto Callee = CGT::edge_dest(E.second); - AdditionalCounts[Callee] += OptProfCount.getValue(); + AdditionalCounts[Callee] += *OptProfCount; } // Update the counts for the nodes in the SCC. @@ -70,7 +67,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( if (!OptProfCount) continue; auto Callee = CGT::edge_dest(E.second); - AddCount(Callee, OptProfCount.getValue()); + AddCount(Callee, *OptProfCount); } } diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp index 26bc63983b4e..203858c1cf06 100644 --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -82,6 +82,33 @@ void serialize(const Message &SE, std::string *OutStr) { *OutStr = SE.SerializeAsString(); } } + +int getTFTypeIndex(TensorType TType) { + switch (TType) { + case TensorType::Double: + return TF_DOUBLE; + case TensorType::Float: + return TF_FLOAT; + case TensorType::Int8: + return TF_INT8; + case TensorType::UInt8: + return TF_UINT8; + case TensorType::Int16: + return TF_INT16; + case TensorType::UInt16: + return TF_UINT16; + case TensorType::Int32: + return TF_INT32; + case TensorType::UInt32: + return TF_UINT32; + case TensorType::Int64: + return TF_INT64; + case TensorType::UInt64: + return TF_UINT64; + case TensorType::Invalid: + llvm_unreachable("Unknown tensor type"); + } +} } // namespace namespace llvm { @@ -105,116 +132,6 @@ private: std::vector<TF_Tensor *> Output; }; -size_t TensorSpec::getElementByteSize() const { - return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex)); -} - -TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex, - const std::vector<int64_t> &Shape) - : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape), - ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, - std::multiplies<int64_t>())) {} - -Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, - const json::Value &Value) { - auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> { - std::string S; - llvm::raw_string_ostream OS(S); - OS << Value; - Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); - return None; - }; - // FIXME: accept a Path as a parameter, and use it for error reporting. - json::Path::Root Root("tensor_spec"); - json::ObjectMapper Mapper(Value, Root); - if (!Mapper) - return EmitError("Value is not a dict"); - - std::string TensorName; - int TensorPort = -1; - std::string TensorType; - std::vector<int64_t> TensorShape; - - if (!Mapper.map<std::string>("name", TensorName)) - return EmitError("'name' property not present or not a string"); - if (!Mapper.map<std::string>("type", TensorType)) - return EmitError("'type' property not present or not a string"); - if (!Mapper.map<int>("port", TensorPort)) - return EmitError("'port' property not present or not an int"); - if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) - return EmitError("'shape' property not present or not an int array"); - -#define PARSE_TYPE(T, E) \ - if (TensorType == #T) \ - return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); - TFUTILS_SUPPORTED_TYPES(PARSE_TYPE) -#undef PARSE_TYPE - return None; -} - -Optional<std::vector<LoggedFeatureSpec>> -loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, - StringRef ModelPath, StringRef SpecFileOverride) { - SmallVector<char, 128> OutputSpecsPath; - StringRef FileName = SpecFileOverride; - if (FileName.empty()) { - llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); - FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; - } - - auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); - if (!BufferOrError) { - Ctx.emitError("Error opening output specs file: " + FileName + " : " + - BufferOrError.getError().message()); - return None; - } - auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); - if (!ParsedJSONValues) { - Ctx.emitError("Could not parse specs file: " + FileName); - return None; - } - auto ValuesArray = ParsedJSONValues->getAsArray(); - if (!ValuesArray) { - Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, " - "logging_name:<name>} dictionaries"); - return None; - } - std::vector<LoggedFeatureSpec> Ret; - for (const auto &Value : *ValuesArray) - if (const auto *Obj = Value.getAsObject()) - if (const auto *SpecPart = Obj->get("tensor_spec")) - if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) - if (auto LoggingName = Obj->getString("logging_name")) { - if (!TensorSpec->isElementType<int64_t>() && - !TensorSpec->isElementType<int32_t>() && - !TensorSpec->isElementType<float>()) { - Ctx.emitError( - "Only int64, int32, and float tensors are supported. " - "Found unsupported type for tensor named " + - TensorSpec->name()); - return None; - } - Ret.push_back({*TensorSpec, LoggingName->str()}); - } - - if (ValuesArray->size() != Ret.size()) { - Ctx.emitError( - "Unable to parse output spec. It should be a json file containing an " - "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " - "with a json object describing a TensorSpec; and a 'logging_name' key, " - "which is a string to use as name when logging this tensor in the " - "training log."); - return None; - } - if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { - Ctx.emitError("The first output spec must describe the decision tensor, " - "and must have the logging_name " + - StringRef(ExpectedDecisionName)); - return None; - } - return Ret; -} - class TFModelEvaluatorImpl { public: TFModelEvaluatorImpl(StringRef SavedModelPath, @@ -383,16 +300,29 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl( errs() << TF_Message(Status.get()); invalidate(); } + size_t NrSupported = 0; for (size_t I = 0; I < InputSpecs.size(); ++I) { auto &InputSpec = InputSpecs[I]; InputFeed[I] = { TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()), InputSpec.port()}; + if (!InputFeed[I].oper) { + continue; + } + if (NrSupported++ != I) { + errs() + << "Unsupported features must be placed at the end of the InputSpecs"; + invalidate(); + return; + } if (!checkReportAndInvalidate(InputFeed[I], InputSpec)) return; - initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()), + initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())), InputSpec.shape()); } + InputFeed.resize(NrSupported); + Input.resize(NrSupported); + for (size_t I = 0; I < OutputSpecsSize; ++I) { auto OutputSpec = GetOutputSpecs(I); OutputFeed[I] = { @@ -470,7 +400,9 @@ void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type, } void *TFModelEvaluator::getUntypedInput(size_t Index) { - return TF_TensorData(Impl->getInput()[Index]); + if (Index < Impl->getInput().size()) + return TF_TensorData(Impl->getInput()[Index]); + return nullptr; } TFModelEvaluator::EvaluationResult::EvaluationResult( @@ -495,13 +427,6 @@ TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const { return TF_TensorData(Impl->getOutput()[Index]); } -#define TFUTILS_GETDATATYPE_IMPL(T, E) \ - template <> int TensorSpec::getDataType<T>() { return E; } - -TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL) - -#undef TFUTILS_GETDATATYPE_IMPL - TFModelEvaluator::EvaluationResult::~EvaluationResult() {} TFModelEvaluator::~TFModelEvaluator() {} diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp index 02923c2c7eb1..8ebdb65e88dc 100644 --- a/llvm/lib/Analysis/TargetLibraryInfo.cpp +++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp @@ -659,12 +659,12 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc_stpncpy); } - if (T.isPS4()) { - // PS4 does have memalign. + if (T.isPS()) { + // PS4/PS5 do have memalign. TLI.setAvailable(LibFunc_memalign); - // PS4 does not have new/delete with "unsigned int" size parameter; - // it only has the "unsigned long" versions. + // PS4/PS5 do not have new/delete with "unsigned int" size parameter; + // they only have the "unsigned long" versions. TLI.setUnavailable(LibFunc_ZdaPvj); TLI.setUnavailable(LibFunc_ZdaPvjSt11align_val_t); TLI.setUnavailable(LibFunc_ZdlPvj); @@ -1110,9 +1110,11 @@ bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, case LibFunc_system: return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); case LibFunc___kmpc_alloc_shared: + return NumParams == 1 && FTy.getReturnType()->isPointerTy(); case LibFunc_malloc: case LibFunc_vec_malloc: - return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + return NumParams == 1 && FTy.getParamType(0)->isIntegerTy(SizeTBits) && + FTy.getReturnType()->isPointerTy(); case LibFunc_memcmp: return NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && FTy.getParamType(0)->isPointerTy() && diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 25e9dee98e13..66f61961d01b 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -11,7 +11,6 @@ #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -21,7 +20,6 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/ErrorHandling.h" #include <utility> using namespace llvm; @@ -33,6 +31,11 @@ static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), cl::Hidden, cl::desc("Recognize reduction patterns.")); +static cl::opt<unsigned> CacheLineSize( + "cache-line-size", cl::init(0), cl::Hidden, + cl::desc("Use this to override the target cache line size when " + "specified by the user.")); + namespace { /// No-op implementation of the TTI interface using the utility base /// classes. @@ -179,7 +182,7 @@ bool HardwareLoopInfo::isHardwareLoopCandidate(ScalarEvolution &SE, TargetTransformInfo::TargetTransformInfo(const DataLayout &DL) : TTIImpl(new Model<NoTTIImpl>(NoTTIImpl(DL))) {} -TargetTransformInfo::~TargetTransformInfo() {} +TargetTransformInfo::~TargetTransformInfo() = default; TargetTransformInfo::TargetTransformInfo(TargetTransformInfo &&Arg) : TTIImpl(std::move(Arg.TTIImpl)) {} @@ -350,7 +353,8 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, Scale, AddrSpace, I); } -bool TargetTransformInfo::isLSRCostLess(LSRCost &C1, LSRCost &C2) const { +bool TargetTransformInfo::isLSRCostLess(const LSRCost &C1, + const LSRCost &C2) const { return TTIImpl->isLSRCostLess(C1, C2); } @@ -398,11 +402,22 @@ bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const { return TTIImpl->isLegalNTLoad(DataType, Alignment); } +bool TargetTransformInfo::isLegalBroadcastLoad(Type *ElementTy, + ElementCount NumElements) const { + return TTIImpl->isLegalBroadcastLoad(ElementTy, NumElements); +} + bool TargetTransformInfo::isLegalMaskedGather(Type *DataType, Align Alignment) const { return TTIImpl->isLegalMaskedGather(DataType, Alignment); } +bool TargetTransformInfo::isLegalAltInstr( + VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const { + return TTIImpl->isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask); +} + bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType, Align Alignment) const { return TTIImpl->isLegalMaskedScatter(DataType, Alignment); @@ -470,7 +485,7 @@ bool TargetTransformInfo::isTypeLegal(Type *Ty) const { return TTIImpl->isTypeLegal(Ty); } -InstructionCost TargetTransformInfo::getRegUsageForType(Type *Ty) const { +unsigned TargetTransformInfo::getRegUsageForType(Type *Ty) const { return TTIImpl->getRegUsageForType(Ty); } @@ -507,6 +522,10 @@ bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const { return TTIImpl->supportsEfficientVectorElementLoadStore(); } +bool TargetTransformInfo::supportsTailCalls() const { + return TTIImpl->supportsTailCalls(); +} + bool TargetTransformInfo::enableAggressiveInterleaving( bool LoopHasReductions) const { return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); @@ -623,8 +642,9 @@ Optional<unsigned> TargetTransformInfo::getVScaleForTuning() const { return TTIImpl->getVScaleForTuning(); } -bool TargetTransformInfo::shouldMaximizeVectorBandwidth() const { - return TTIImpl->shouldMaximizeVectorBandwidth(); +bool TargetTransformInfo::shouldMaximizeVectorBandwidth( + TargetTransformInfo::RegisterKind K) const { + return TTIImpl->shouldMaximizeVectorBandwidth(K); } ElementCount TargetTransformInfo::getMinimumVF(unsigned ElemWidth, @@ -637,6 +657,11 @@ unsigned TargetTransformInfo::getMaximumVF(unsigned ElemWidth, return TTIImpl->getMaximumVF(ElemWidth, Opcode); } +unsigned TargetTransformInfo::getStoreMinimumVF(unsigned VF, Type *ScalarMemTy, + Type *ScalarValTy) const { + return TTIImpl->getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy); +} + bool TargetTransformInfo::shouldConsiderAddressTypePromotion( const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const { return TTIImpl->shouldConsiderAddressTypePromotion( @@ -644,7 +669,8 @@ bool TargetTransformInfo::shouldConsiderAddressTypePromotion( } unsigned TargetTransformInfo::getCacheLineSize() const { - return TTIImpl->getCacheLineSize(); + return CacheLineSize.getNumOccurrences() > 0 ? CacheLineSize + : TTIImpl->getCacheLineSize(); } llvm::Optional<unsigned> @@ -742,12 +768,11 @@ InstructionCost TargetTransformInfo::getArithmeticInstrCost( return Cost; } -InstructionCost TargetTransformInfo::getShuffleCost(ShuffleKind Kind, - VectorType *Ty, - ArrayRef<int> Mask, - int Index, - VectorType *SubTp) const { - InstructionCost Cost = TTIImpl->getShuffleCost(Kind, Ty, Mask, Index, SubTp); +InstructionCost TargetTransformInfo::getShuffleCost( + ShuffleKind Kind, VectorType *Ty, ArrayRef<int> Mask, int Index, + VectorType *SubTp, ArrayRef<const Value *> Args) const { + InstructionCost Cost = + TTIImpl->getShuffleCost(Kind, Ty, Mask, Index, SubTp, Args); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -973,18 +998,21 @@ Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic( Type *TargetTransformInfo::getMemcpyLoopLoweringType( LLVMContext &Context, Value *Length, unsigned SrcAddrSpace, - unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const { + unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign, + Optional<uint32_t> AtomicElementSize) const { return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace, - DestAddrSpace, SrcAlign, DestAlign); + DestAddrSpace, SrcAlign, DestAlign, + AtomicElementSize); } void TargetTransformInfo::getMemcpyLoopResidualLoweringType( SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context, unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace, - unsigned SrcAlign, unsigned DestAlign) const { - TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes, - SrcAddrSpace, DestAddrSpace, - SrcAlign, DestAlign); + unsigned SrcAlign, unsigned DestAlign, + Optional<uint32_t> AtomicCpySize) const { + TTIImpl->getMemcpyLoopResidualLoweringType( + OpsOut, Context, RemainingBytes, SrcAddrSpace, DestAddrSpace, SrcAlign, + DestAlign, AtomicCpySize); } bool TargetTransformInfo::areInlineCompatible(const Function *Caller, @@ -1155,7 +1183,7 @@ TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { } } -TargetTransformInfo::Concept::~Concept() {} +TargetTransformInfo::Concept::~Concept() = default; TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} diff --git a/llvm/lib/Analysis/TensorSpec.cpp b/llvm/lib/Analysis/TensorSpec.cpp new file mode 100644 index 000000000000..f6a5882371a7 --- /dev/null +++ b/llvm/lib/Analysis/TensorSpec.cpp @@ -0,0 +1,144 @@ +//===- TensorSpec.cpp - tensor type abstraction ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation file for the abstraction of a tensor type, and JSON loading +// utils. +// +//===----------------------------------------------------------------------===// +#include "llvm/Config/config.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TensorSpec.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <numeric> + +using namespace llvm; + +namespace llvm { + +#define TFUTILS_GETDATATYPE_IMPL(T, E) \ + template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; } + +SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) + +#undef TFUTILS_GETDATATYPE_IMPL + +TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, + size_t ElementSize, const std::vector<int64_t> &Shape) + : Name(Name), Port(Port), Type(Type), Shape(Shape), + ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, + std::multiplies<int64_t>())), + ElementSize(ElementSize) {} + +Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value) { + auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> { + std::string S; + llvm::raw_string_ostream OS(S); + OS << Value; + Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); + return None; + }; + // FIXME: accept a Path as a parameter, and use it for error reporting. + json::Path::Root Root("tensor_spec"); + json::ObjectMapper Mapper(Value, Root); + if (!Mapper) + return EmitError("Value is not a dict"); + + std::string TensorName; + int TensorPort = -1; + std::string TensorType; + std::vector<int64_t> TensorShape; + + if (!Mapper.map<std::string>("name", TensorName)) + return EmitError("'name' property not present or not a string"); + if (!Mapper.map<std::string>("type", TensorType)) + return EmitError("'type' property not present or not a string"); + if (!Mapper.map<int>("port", TensorPort)) + return EmitError("'port' property not present or not an int"); + if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) + return EmitError("'shape' property not present or not an int array"); + +#define PARSE_TYPE(T, E) \ + if (TensorType == #T) \ + return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); + SUPPORTED_TENSOR_TYPES(PARSE_TYPE) +#undef PARSE_TYPE + return None; +} + +Optional<std::vector<LoggedFeatureSpec>> +loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, + StringRef ModelPath, StringRef SpecFileOverride) { + SmallVector<char, 128> OutputSpecsPath; + StringRef FileName = SpecFileOverride; + if (FileName.empty()) { + llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); + FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; + } + + auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); + if (!BufferOrError) { + Ctx.emitError("Error opening output specs file: " + FileName + " : " + + BufferOrError.getError().message()); + return None; + } + auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); + if (!ParsedJSONValues) { + Ctx.emitError("Could not parse specs file: " + FileName); + return None; + } + auto ValuesArray = ParsedJSONValues->getAsArray(); + if (!ValuesArray) { + Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, " + "logging_name:<name>} dictionaries"); + return None; + } + std::vector<LoggedFeatureSpec> Ret; + for (const auto &Value : *ValuesArray) + if (const auto *Obj = Value.getAsObject()) + if (const auto *SpecPart = Obj->get("tensor_spec")) + if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) + if (auto LoggingName = Obj->getString("logging_name")) { + if (!TensorSpec->isElementType<int64_t>() && + !TensorSpec->isElementType<int32_t>() && + !TensorSpec->isElementType<float>()) { + Ctx.emitError( + "Only int64, int32, and float tensors are supported. " + "Found unsupported type for tensor named " + + TensorSpec->name()); + return None; + } + Ret.push_back({*TensorSpec, LoggingName->str()}); + } + + if (ValuesArray->size() != Ret.size()) { + Ctx.emitError( + "Unable to parse output spec. It should be a json file containing an " + "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " + "with a json object describing a TensorSpec; and a 'logging_name' key, " + "which is a string to use as name when logging this tensor in the " + "training log."); + return None; + } + if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { + Ctx.emitError("The first output spec must describe the decision tensor, " + "and must have the logging_name " + + StringRef(ExpectedDecisionName)); + return None; + } + return Ret; +} +} // namespace llvm diff --git a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp index 627a78a2a2fd..9bcbe4a4cc1e 100644 --- a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -112,7 +112,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/InitializePasses.h" diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp index 80051fd5f7c1..201e64770766 100644 --- a/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -16,7 +16,6 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" using namespace llvm; diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp index 7573975a3dd3..e6d297877b62 100644 --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -6,8 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/VectorUtils.h" using namespace llvm; diff --git a/llvm/lib/Analysis/ValueLatticeUtils.cpp b/llvm/lib/Analysis/ValueLatticeUtils.cpp index 53638c351f72..2bcb4d5b0e6b 100644 --- a/llvm/lib/Analysis/ValueLatticeUtils.cpp +++ b/llvm/lib/Analysis/ValueLatticeUtils.cpp @@ -29,12 +29,13 @@ bool llvm::canTrackGlobalVariableInterprocedurally(GlobalVariable *GV) { !GV->hasDefinitiveInitializer()) return false; return all_of(GV->users(), [&](User *U) { - // Currently all users of a global variable have to be none-volatile loads - // or stores and the global cannot be stored itself. + // Currently all users of a global variable have to be non-volatile loads + // or stores of the global type, and the global cannot be stored itself. if (auto *Store = dyn_cast<StoreInst>(U)) - return Store->getValueOperand() != GV && !Store->isVolatile(); + return Store->getValueOperand() != GV && !Store->isVolatile() && + Store->getValueOperand()->getType() == GV->getValueType(); if (auto *Load = dyn_cast<LoadInst>(U)) - return !Load->isVolatile(); + return !Load->isVolatile() && Load->getType() == GV->getValueType(); return false; }); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index c14bdb8bc262..05d5e47bb8d7 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -70,10 +71,8 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include <algorithm> -#include <array> #include <cassert> #include <cstdint> -#include <iterator> #include <utility> using namespace llvm; @@ -86,13 +85,12 @@ static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses", // According to the LangRef, branching on a poison condition is absolutely // immediate full UB. However, historically we haven't implemented that -// consistently as we have an important transformation (non-trivial unswitch) -// which introduces instances of branch on poison/undef to otherwise well -// defined programs. This flag exists to let us test optimization benefit -// of exploiting the specified behavior (in combination with enabling the -// unswitch fix.) +// consistently as we had an important transformation (non-trivial unswitch) +// which introduced instances of branch on poison/undef to otherwise well +// defined programs. This issue has since been fixed, but the flag is +// temporarily retained to easily diagnose potential regressions. static cl::opt<bool> BranchOnPoisonAsUB("branch-on-poison-as-ub", - cl::Hidden, cl::init(false)); + cl::Hidden, cl::init(true)); /// Returns the bitwidth of the given scalar or pointer type. For vector types, @@ -275,13 +273,39 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, assert(LHS->getType()->isIntOrIntVectorTy() && "LHS and RHS should be integers"); // Look for an inverted mask: (X & ~M) op (Y & M). - Value *M; - if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) && - match(RHS, m_c_And(m_Specific(M), m_Value()))) + { + Value *M; + if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(RHS, m_c_And(m_Specific(M), m_Value()))) + return true; + if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(LHS, m_c_And(m_Specific(M), m_Value()))) + return true; + } + + // X op (Y & ~X) + if (match(RHS, m_c_And(m_Not(m_Specific(LHS)), m_Value())) || + match(LHS, m_c_And(m_Not(m_Specific(RHS)), m_Value()))) return true; - if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) && - match(LHS, m_c_And(m_Specific(M), m_Value()))) + + // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern + // for constant Y. + Value *Y; + if (match(RHS, + m_c_Xor(m_c_And(m_Specific(LHS), m_Value(Y)), m_Deferred(Y))) || + match(LHS, m_c_Xor(m_c_And(m_Specific(RHS), m_Value(Y)), m_Deferred(Y)))) return true; + + // Look for: (A & B) op ~(A | B) + { + Value *A, *B; + if (match(LHS, m_And(m_Value(A), m_Value(B))) && + match(RHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return true; + if (match(RHS, m_And(m_Value(A), m_Value(B))) && + match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return true; + } IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); @@ -451,7 +475,12 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, } } - Known = KnownBits::mul(Known, Known2); + bool SelfMultiply = Op0 == Op1; + // TODO: SelfMultiply can be poison, but not undef. + if (SelfMultiply) + SelfMultiply &= + isGuaranteedNotToBeUndefOrPoison(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1); + Known = KnownBits::mul(Known, Known2, SelfMultiply); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in @@ -656,7 +685,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (V->getType()->isPointerTy()) { if (RetainedKnowledge RK = getKnowledgeValidInContext( V, {Attribute::Alignment}, Q.CxtI, Q.DT, Q.AC)) { - Known.Zero.setLowBits(Log2_64(RK.ArgValue)); + if (isPowerOf2_64(RK.ArgValue)) + Known.Zero.setLowBits(Log2_64(RK.ArgValue)); } } @@ -1041,7 +1071,7 @@ static void computeKnownBitsFromShiftOperator( // bits. This check is sunk down as far as possible to avoid the expensive // call to isKnownNonZero if the cheaper checks above fail. if (ShiftAmt == 0) { - if (!ShifterOperandIsNonZero.hasValue()) + if (!ShifterOperandIsNonZero) ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q); if (*ShifterOperandIsNonZero) @@ -1726,8 +1756,7 @@ static void computeKnownBitsFromOperator(const Operator *I, break; } - unsigned FirstZeroHighBit = - 32 - countLeadingZeros(VScaleMax.getValue()); + unsigned FirstZeroHighBit = 32 - countLeadingZeros(*VScaleMax); if (FirstZeroHighBit < BitWidth) Known.Zero.setBitsFrom(FirstZeroHighBit); @@ -2007,6 +2036,63 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); } +/// Try to detect a recurrence that the value of the induction variable is +/// always a power of two (or zero). +static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero, + unsigned Depth, Query &Q) { + BinaryOperator *BO = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (!matchSimpleRecurrence(PN, BO, Start, Step)) + return false; + + // Initial value must be a power of two. + for (const Use &U : PN->operands()) { + if (U.get() == Start) { + // Initial value comes from a different BB, need to adjust context + // instruction for analysis. + Q.CxtI = PN->getIncomingBlock(U)->getTerminator(); + if (!isKnownToBeAPowerOfTwo(Start, OrZero, Depth, Q)) + return false; + } + } + + // Except for Mul, the induction variable must be on the left side of the + // increment expression, otherwise its value can be arbitrary. + if (BO->getOpcode() != Instruction::Mul && BO->getOperand(1) != Step) + return false; + + Q.CxtI = BO->getParent()->getTerminator(); + switch (BO->getOpcode()) { + case Instruction::Mul: + // Power of two is closed under multiplication. + return (OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || + Q.IIQ.hasNoSignedWrap(BO)) && + isKnownToBeAPowerOfTwo(Step, OrZero, Depth, Q); + case Instruction::SDiv: + // Start value must not be signmask for signed division, so simply being a + // power of two is not sufficient, and it has to be a constant. + if (!match(Start, m_Power2()) || match(Start, m_SignMask())) + return false; + LLVM_FALLTHROUGH; + case Instruction::UDiv: + // Divisor must be a power of two. + // If OrZero is false, cannot guarantee induction variable is non-zero after + // division, same for Shr, unless it is exact division. + return (OrZero || Q.IIQ.isExact(BO)) && + isKnownToBeAPowerOfTwo(Step, false, Depth, Q); + case Instruction::Shl: + return OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO); + case Instruction::AShr: + if (!match(Start, m_Power2()) || match(Start, m_SignMask())) + return false; + LLVM_FALLTHROUGH; + case Instruction::LShr: + return OrZero || Q.IIQ.isExact(BO); + default: + return false; + } +} + /// Return true if the given value is known to have exactly one /// bit set when defined. For vectors return true if every element is known to /// be a power of two when defined. Supports values with integer or pointer @@ -2098,6 +2184,30 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, } } + // A PHI node is power of two if all incoming values are power of two, or if + // it is an induction variable where in each step its value is a power of two. + if (const PHINode *PN = dyn_cast<PHINode>(V)) { + Query RecQ = Q; + + // Check if it is an induction variable and always power of two. + if (isPowerOfTwoRecurrence(PN, OrZero, Depth, RecQ)) + return true; + + // Recursively check all incoming values. Limit recursion to 2 levels, so + // that search complexity is limited to number of operands^2. + unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1); + return llvm::all_of(PN->operands(), [&](const Use &U) { + // Value is power of 2 if it is coming from PHI node itself by induction. + if (U.get() == PN) + return true; + + // Change the context instruction to the incoming block where it is + // evaluated. + RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator(); + return isKnownToBeAPowerOfTwo(U.get(), OrZero, NewDepth, RecQ); + }); + } + // An exact divide or right shift can only shift off zero bits, so the result // is a power of two only if the first operand is a power of two and not // copying a sign bit (sdiv int_min, 2). @@ -2588,6 +2698,9 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, if (isKnownNonZero(Op, Depth, Q) && isGuaranteedNotToBePoison(Op, Q.AC, Q.CxtI, Q.DT, Depth)) return true; + } else if (const auto *II = dyn_cast<IntrinsicInst>(V)) { + if (II->getIntrinsicID() == Intrinsic::vscale) + return true; } KnownBits Known(BitWidth); @@ -2885,6 +2998,24 @@ static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, return CLow->sle(*CHigh); } +static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II, + const APInt *&CLow, + const APInt *&CHigh) { + assert((II->getIntrinsicID() == Intrinsic::smin || + II->getIntrinsicID() == Intrinsic::smax) && "Must be smin/smax"); + + Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + if (!InnerII || InnerII->getIntrinsicID() != InverseID || + !match(II->getArgOperand(1), m_APInt(CLow)) || + !match(InnerII->getArgOperand(1), m_APInt(CHigh))) + return false; + + if (II->getIntrinsicID() == Intrinsic::smin) + std::swap(CLow, CHigh); + return CLow->sle(*CHigh); +} + /// For vector constants, loop over the elements and find the constant with the /// minimum number of sign bits. Return 0 if the value is not a vector constant /// or if any element was not analyzed; otherwise, return the count for the @@ -3225,6 +3356,12 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, // Absolute value reduces number of sign bits by at most 1. return Tmp - 1; + case Intrinsic::smin: + case Intrinsic::smax: { + const APInt *CLow, *CHigh; + if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh)) + return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits()); + } } } } @@ -3358,9 +3495,6 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB, /// NOTE: Do not check 'nsz' here because that fast-math-flag does not guarantee /// that a value is not -0.0. It only guarantees that -0.0 may be treated /// the same as +0.0 in floating-point ops. -/// -/// NOTE: this function will need to be revisited when we support non-default -/// rounding modes! bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, unsigned Depth) { if (auto *CFP = dyn_cast<ConstantFP>(V)) @@ -3390,9 +3524,21 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, case Intrinsic::sqrt: case Intrinsic::canonicalize: return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); + case Intrinsic::experimental_constrained_sqrt: { + // NOTE: This rounding mode restriction may be too strict. + const auto *CI = cast<ConstrainedFPIntrinsic>(Call); + if (CI->getRoundingMode() == RoundingMode::NearestTiesToEven) + return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); + else + return false; + } // fabs(x) != -0.0 case Intrinsic::fabs: return true; + // sitofp and uitofp turn into +0.0 for zero. + case Intrinsic::experimental_constrained_sitofp: + case Intrinsic::experimental_constrained_uitofp: + return true; } } @@ -4032,69 +4178,83 @@ bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP, return true; } +// If V refers to an initialized global constant, set Slice either to +// its initializer if the size of its elements equals ElementSize, or, +// for ElementSize == 8, to its representation as an array of unsiged +// char. Return true on success. bool llvm::getConstantDataArrayInfo(const Value *V, ConstantDataArraySlice &Slice, unsigned ElementSize, uint64_t Offset) { assert(V); - // Look through bitcast instructions and geps. - V = V->stripPointerCasts(); + // Drill down into the pointer expression V, ignoring any intervening + // casts, and determine the identity of the object it references along + // with the cumulative byte offset into it. + const GlobalVariable *GV = + dyn_cast<GlobalVariable>(getUnderlyingObject(V)); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + // Fail if V is not based on constant global object. + return false; - // If the value is a GEP instruction or constant expression, treat it as an - // offset. - if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - // The GEP operator should be based on a pointer to string constant, and is - // indexing into the string constant. - if (!isGEPBasedOnPointerToString(GEP, ElementSize)) - return false; + const DataLayout &DL = GV->getParent()->getDataLayout(); + APInt Off(DL.getIndexTypeSizeInBits(V->getType()), 0); - // If the second index isn't a ConstantInt, then this is a variable index - // into the array. If this occurs, we can't say anything meaningful about - // the string. - uint64_t StartIdx = 0; - if (const ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2))) - StartIdx = CI->getZExtValue(); - else - return false; - return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize, - StartIdx + Offset); - } + if (GV != V->stripAndAccumulateConstantOffsets(DL, Off, + /*AllowNonInbounds*/ true)) + // Fail if a constant offset could not be determined. + return false; - // The GEP instruction, constant or instruction, must reference a global - // variable that is a constant and is initialized. The referenced constant - // initializer is the array that we'll use for optimization. - const GlobalVariable *GV = dyn_cast<GlobalVariable>(V); - if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + uint64_t StartIdx = Off.getLimitedValue(); + if (StartIdx == UINT64_MAX) + // Fail if the constant offset is excessive. return false; - const ConstantDataArray *Array; - ArrayType *ArrayTy; + Offset += StartIdx; + + ConstantDataArray *Array = nullptr; + ArrayType *ArrayTy = nullptr; + if (GV->getInitializer()->isNullValue()) { Type *GVTy = GV->getValueType(); - if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) { - // A zeroinitializer for the array; there is no ConstantDataArray. - Array = nullptr; - } else { - const DataLayout &DL = GV->getParent()->getDataLayout(); - uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy).getFixedSize(); - uint64_t Length = SizeInBytes / (ElementSize / 8); - if (Length <= Offset) - return false; + uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy).getFixedSize(); + uint64_t Length = SizeInBytes / (ElementSize / 8); + + Slice.Array = nullptr; + Slice.Offset = 0; + // Return an empty Slice for undersized constants to let callers + // transform even undefined library calls into simpler, well-defined + // expressions. This is preferable to making the calls although it + // prevents sanitizers from detecting such calls. + Slice.Length = Length < Offset ? 0 : Length - Offset; + return true; + } - Slice.Array = nullptr; - Slice.Offset = 0; - Slice.Length = Length - Offset; - return true; + auto *Init = const_cast<Constant *>(GV->getInitializer()); + if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Init)) { + Type *InitElTy = ArrayInit->getElementType(); + if (InitElTy->isIntegerTy(ElementSize)) { + // If Init is an initializer for an array of the expected type + // and size, use it as is. + Array = ArrayInit; + ArrayTy = ArrayInit->getType(); } - } else { - // This must be a ConstantDataArray. - Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); - if (!Array) + } + + if (!Array) { + if (ElementSize != 8) + // TODO: Handle conversions to larger integral types. return false; - ArrayTy = Array->getType(); + + // Otherwise extract the portion of the initializer starting + // at Offset as an array of bytes, and reset Offset. + Init = ReadByteArrayFromGlobal(GV, Offset); + if (!Init) + return false; + + Offset = 0; + Array = dyn_cast<ConstantDataArray>(Init); + ArrayTy = dyn_cast<ArrayType>(Init->getType()); } - if (!ArrayTy->getElementType()->isIntegerTy(ElementSize)) - return false; uint64_t NumElts = ArrayTy->getArrayNumElements(); if (Offset > NumElts) @@ -4117,6 +4277,12 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, if (Slice.Array == nullptr) { if (TrimAtNul) { + // Return a nul-terminated string even for an empty Slice. This is + // safe because all existing SimplifyLibcalls callers require string + // arguments and the behavior of the functions they fold is undefined + // otherwise. Folding the calls this way is preferable to making + // the undefined library calls, even though it prevents sanitizers + // from reporting such calls. Str = StringRef(); return true; } @@ -4196,9 +4362,13 @@ static uint64_t GetStringLengthH(const Value *V, return 0; if (Slice.Array == nullptr) + // Zeroinitializer (including an empty one). return 1; - // Search for nul characters + // Search for the first nul character. Return a conservative result even + // when there is no nul. This is safe since otherwise the string function + // being folded such as strlen is undefined, and can be preferable to + // making the undefined library call. unsigned NullIndex = 0; for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) { if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0) @@ -4517,13 +4687,40 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, const Operator *Inst = dyn_cast<Operator>(V); if (!Inst) return false; + return isSafeToSpeculativelyExecuteWithOpcode(Inst->getOpcode(), Inst, CtxI, DT, TLI); +} + +bool llvm::isSafeToSpeculativelyExecuteWithOpcode(unsigned Opcode, + const Operator *Inst, + const Instruction *CtxI, + const DominatorTree *DT, + const TargetLibraryInfo *TLI) { +#ifndef NDEBUG + if (Inst->getOpcode() != Opcode) { + // Check that the operands are actually compatible with the Opcode override. + auto hasEqualReturnAndLeadingOperandTypes = + [](const Operator *Inst, unsigned NumLeadingOperands) { + if (Inst->getNumOperands() < NumLeadingOperands) + return false; + const Type *ExpectedType = Inst->getType(); + for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp) + if (Inst->getOperand(ItOp)->getType() != ExpectedType) + return false; + return true; + }; + assert(!Instruction::isBinaryOp(Opcode) || + hasEqualReturnAndLeadingOperandTypes(Inst, 2)); + assert(!Instruction::isUnaryOp(Opcode) || + hasEqualReturnAndLeadingOperandTypes(Inst, 1)); + } +#endif for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) if (Constant *C = dyn_cast<Constant>(Inst->getOperand(i))) if (C->canTrap()) return false; - switch (Inst->getOpcode()) { + switch (Opcode) { default: return true; case Instruction::UDiv: @@ -4554,7 +4751,9 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, return false; } case Instruction::Load: { - const LoadInst *LI = cast<LoadInst>(Inst); + const LoadInst *LI = dyn_cast<LoadInst>(Inst); + if (!LI) + return false; if (mustSuppressSpeculation(*LI)) return false; const DataLayout &DL = LI->getModule()->getDataLayout(); @@ -4563,7 +4762,9 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, TLI); } case Instruction::Call: { - auto *CI = cast<const CallInst>(Inst); + auto *CI = dyn_cast<const CallInst>(Inst); + if (!CI) + return false; const Function *Callee = CI->getCalledFunction(); // The called function could have undefined behavior or side-effects, even @@ -4595,8 +4796,20 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, } } -bool llvm::mayBeMemoryDependent(const Instruction &I) { - return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); +bool llvm::mayHaveNonDefUseDependency(const Instruction &I) { + if (I.mayReadOrWriteMemory()) + // Memory dependency possible + return true; + if (!isSafeToSpeculativelyExecute(&I)) + // Can't move above a maythrow call or infinite loop. Or if an + // inalloca alloca, above a stacksave call. + return true; + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + // 1) Can't reorder two inf-loop calls, even if readonly + // 2) Also can't reorder an inf-loop call below a instruction which isn't + // safe to speculative execute. (Inverse of above) + return true; + return false; } /// Convert ConstantRange OverflowResult into ValueTracking OverflowResult. @@ -4766,6 +4979,22 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { + // X - (X % ?) + // The remainder of a value can't have greater magnitude than itself, + // so the subtraction can't overflow. + + // X - (X -nuw ?) + // In the minimal case, this would simplify to "?", so there's no subtract + // at all. But if this analysis is used to peek through casts, for example, + // then determining no-overflow may allow other transforms. + + // TODO: There are other patterns like this. + // See simplifyICmpWithBinOpOnLHS() for candidates. + if (match(RHS, m_URem(m_Specific(LHS), m_Value())) || + match(RHS, m_NUWSub(m_Specific(LHS), m_Value()))) + if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT)) + return OverflowResult::NeverOverflows; + // Checking for conditions implied by dominating conditions may be expensive. // Limit it to usub_with_overflow calls for now. if (match(CxtI, @@ -4789,6 +5018,19 @@ OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { + // X - (X % ?) + // The remainder of a value can't have greater magnitude than itself, + // so the subtraction can't overflow. + + // X - (X -nsw ?) + // In the minimal case, this would simplify to "?", so there's no subtract + // at all. But if this analysis is used to peek through casts, for example, + // then determining no-overflow may allow other transforms. + if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) || + match(RHS, m_NSWSub(m_Specific(LHS), m_Value()))) + if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT)) + return OverflowResult::NeverOverflows; + // If LHS and RHS each have at least two sign bits, the subtraction // cannot overflow. if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && @@ -5100,7 +5342,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V, } if (auto *I = dyn_cast<LoadInst>(V)) - if (I->getMetadata(LLVMContext::MD_noundef)) + if (I->hasMetadata(LLVMContext::MD_noundef) || + I->hasMetadata(LLVMContext::MD_dereferenceable) || + I->hasMetadata(LLVMContext::MD_dereferenceable_or_null)) return true; if (programUndefinedIfUndefOrPoison(V, PoisonOnly)) @@ -5125,10 +5369,10 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V, auto *TI = Dominator->getBlock()->getTerminator(); Value *Cond = nullptr; - if (auto BI = dyn_cast<BranchInst>(TI)) { + if (auto BI = dyn_cast_or_null<BranchInst>(TI)) { if (BI->isConditional()) Cond = BI->getCondition(); - } else if (auto SI = dyn_cast<SwitchInst>(TI)) { + } else if (auto SI = dyn_cast_or_null<SwitchInst>(TI)) { Cond = SI->getCondition(); } @@ -5763,20 +6007,6 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; - // Z = X -nsw Y - // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0) - // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0) - if (match(TrueVal, m_Zero()) && - match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) - return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - - // Z = X -nsw Y - // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0) - // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0) - if (match(FalseVal, m_Zero()) && - match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) - return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - const APInt *C1; if (!match(CmpRHS, m_APInt(C1))) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -6576,11 +6806,38 @@ Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, if (LHS == RHS) return LHSIsTrue; - const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); - if (RHSCmp) + if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS)) return isImpliedCondition(LHS, RHSCmp->getPredicate(), RHSCmp->getOperand(0), RHSCmp->getOperand(1), DL, LHSIsTrue, Depth); + + if (Depth == MaxAnalysisRecursionDepth) + return None; + + // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2 + // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2 + const Value *RHS1, *RHS2; + if (match(RHS, m_LogicalOr(m_Value(RHS1), m_Value(RHS2)))) { + if (Optional<bool> Imp = + isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1)) + if (*Imp == true) + return true; + if (Optional<bool> Imp = + isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1)) + if (*Imp == true) + return true; + } + if (match(RHS, m_LogicalAnd(m_Value(RHS1), m_Value(RHS2)))) { + if (Optional<bool> Imp = + isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1)) + if (*Imp == false) + return false; + if (Optional<bool> Imp = + isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1)) + if (*Imp == false) + return false; + } + return None; } @@ -7072,66 +7329,25 @@ getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) { Optional<int64_t> llvm::isPointerOffset(const Value *Ptr1, const Value *Ptr2, const DataLayout &DL) { - Ptr1 = Ptr1->stripPointerCasts(); - Ptr2 = Ptr2->stripPointerCasts(); + APInt Offset1(DL.getIndexTypeSizeInBits(Ptr1->getType()), 0); + APInt Offset2(DL.getIndexTypeSizeInBits(Ptr2->getType()), 0); + Ptr1 = Ptr1->stripAndAccumulateConstantOffsets(DL, Offset1, true); + Ptr2 = Ptr2->stripAndAccumulateConstantOffsets(DL, Offset2, true); // Handle the trivial case first. - if (Ptr1 == Ptr2) { - return 0; - } + if (Ptr1 == Ptr2) + return Offset2.getSExtValue() - Offset1.getSExtValue(); const GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1); const GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2); - // If one pointer is a GEP see if the GEP is a constant offset from the base, - // as in "P" and "gep P, 1". - // Also do this iteratively to handle the the following case: - // Ptr_t1 = GEP Ptr1, c1 - // Ptr_t2 = GEP Ptr_t1, c2 - // Ptr2 = GEP Ptr_t2, c3 - // where we will return c1+c2+c3. - // TODO: Handle the case when both Ptr1 and Ptr2 are GEPs of some common base - // -- replace getOffsetFromBase with getOffsetAndBase, check that the bases - // are the same, and return the difference between offsets. - auto getOffsetFromBase = [&DL](const GEPOperator *GEP, - const Value *Ptr) -> Optional<int64_t> { - const GEPOperator *GEP_T = GEP; - int64_t OffsetVal = 0; - bool HasSameBase = false; - while (GEP_T) { - auto Offset = getOffsetFromIndex(GEP_T, 1, DL); - if (!Offset) - return None; - OffsetVal += *Offset; - auto Op0 = GEP_T->getOperand(0)->stripPointerCasts(); - if (Op0 == Ptr) { - HasSameBase = true; - break; - } - GEP_T = dyn_cast<GEPOperator>(Op0); - } - if (!HasSameBase) - return None; - return OffsetVal; - }; - - if (GEP1) { - auto Offset = getOffsetFromBase(GEP1, Ptr2); - if (Offset) - return -*Offset; - } - if (GEP2) { - auto Offset = getOffsetFromBase(GEP2, Ptr1); - if (Offset) - return Offset; - } - // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical // base. After that base, they may have some number of common (and // potentially variable) indices. After that they handle some constant // offset, which determines their offset from each other. At this point, we // handle no other case. - if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0)) + if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0) || + GEP1->getSourceElementType() != GEP2->getSourceElementType()) return None; // Skip any common indices and track the GEP types. @@ -7140,9 +7356,10 @@ Optional<int64_t> llvm::isPointerOffset(const Value *Ptr1, const Value *Ptr2, if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) break; - auto Offset1 = getOffsetFromIndex(GEP1, Idx, DL); - auto Offset2 = getOffsetFromIndex(GEP2, Idx, DL); - if (!Offset1 || !Offset2) + auto IOffset1 = getOffsetFromIndex(GEP1, Idx, DL); + auto IOffset2 = getOffsetFromIndex(GEP2, Idx, DL); + if (!IOffset1 || !IOffset2) return None; - return *Offset2 - *Offset1; + return *IOffset2 - *IOffset1 + Offset2.getSExtValue() - + Offset1.getSExtValue(); } diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 655c248907f6..f863a1ffad3a 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -40,7 +40,7 @@ static cl::opt<unsigned> MaxInterleaveGroupFactor( /// Return true if all of the intrinsic's arguments and return type are scalars /// for the scalar form of the intrinsic, and vectors for the vector form of the /// intrinsic (except operands that are marked as always being scalar by -/// hasVectorInstrinsicScalarOpd). +/// isVectorIntrinsicWithScalarOpAtArg). bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { switch (ID) { case Intrinsic::abs: // Begin integer bit-manipulation. @@ -89,6 +89,8 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::fmuladd: case Intrinsic::powi: case Intrinsic::canonicalize: + case Intrinsic::fptosi_sat: + case Intrinsic::fptoui_sat: return true; default: return false; @@ -96,8 +98,8 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { } /// Identifies if the vector form of the intrinsic has a scalar operand. -bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, - unsigned ScalarOpdIdx) { +bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, + unsigned ScalarOpdIdx) { switch (ID) { case Intrinsic::abs: case Intrinsic::ctlz: @@ -114,11 +116,14 @@ bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, } } -bool llvm::hasVectorInstrinsicOverloadedScalarOpd(Intrinsic::ID ID, - unsigned ScalarOpdIdx) { +bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + unsigned OpdIdx) { switch (ID) { + case Intrinsic::fptosi_sat: + case Intrinsic::fptoui_sat: + return OpdIdx == 0; case Intrinsic::powi: - return (ScalarOpdIdx == 1); + return OpdIdx == 1; default: return false; } @@ -496,6 +501,116 @@ bool llvm::widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, return true; } +void llvm::processShuffleMasks( + ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, + unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, + function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, + function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction) { + SmallVector<SmallVector<SmallVector<int>>> Res(NumOfDestRegs); + // Try to perform better estimation of the permutation. + // 1. Split the source/destination vectors into real registers. + // 2. Do the mask analysis to identify which real registers are + // permuted. + int Sz = Mask.size(); + unsigned SzDest = Sz / NumOfDestRegs; + unsigned SzSrc = Sz / NumOfSrcRegs; + for (unsigned I = 0; I < NumOfDestRegs; ++I) { + auto &RegMasks = Res[I]; + RegMasks.assign(NumOfSrcRegs, {}); + // Check that the values in dest registers are in the one src + // register. + for (unsigned K = 0; K < SzDest; ++K) { + int Idx = I * SzDest + K; + if (Idx == Sz) + break; + if (Mask[Idx] >= Sz || Mask[Idx] == UndefMaskElem) + continue; + int SrcRegIdx = Mask[Idx] / SzSrc; + // Add a cost of PermuteTwoSrc for each new source register permute, + // if we have more than one source registers. + if (RegMasks[SrcRegIdx].empty()) + RegMasks[SrcRegIdx].assign(SzDest, UndefMaskElem); + RegMasks[SrcRegIdx][K] = Mask[Idx] % SzSrc; + } + } + // Process split mask. + for (unsigned I = 0; I < NumOfUsedRegs; ++I) { + auto &Dest = Res[I]; + int NumSrcRegs = + count_if(Dest, [](ArrayRef<int> Mask) { return !Mask.empty(); }); + switch (NumSrcRegs) { + case 0: + // No input vectors were used! + NoInputAction(); + break; + case 1: { + // Find the only mask with at least single undef mask elem. + auto *It = + find_if(Dest, [](ArrayRef<int> Mask) { return !Mask.empty(); }); + unsigned SrcReg = std::distance(Dest.begin(), It); + SingleInputAction(*It, SrcReg, I); + break; + } + default: { + // The first mask is a permutation of a single register. Since we have >2 + // input registers to shuffle, we merge the masks for 2 first registers + // and generate a shuffle of 2 registers rather than the reordering of the + // first register and then shuffle with the second register. Next, + // generate the shuffles of the resulting register + the remaining + // registers from the list. + auto &&CombineMasks = [](MutableArrayRef<int> FirstMask, + ArrayRef<int> SecondMask) { + for (int Idx = 0, VF = FirstMask.size(); Idx < VF; ++Idx) { + if (SecondMask[Idx] != UndefMaskElem) { + assert(FirstMask[Idx] == UndefMaskElem && + "Expected undefined mask element."); + FirstMask[Idx] = SecondMask[Idx] + VF; + } + } + }; + auto &&NormalizeMask = [](MutableArrayRef<int> Mask) { + for (int Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) { + if (Mask[Idx] != UndefMaskElem) + Mask[Idx] = Idx; + } + }; + int SecondIdx; + do { + int FirstIdx = -1; + SecondIdx = -1; + MutableArrayRef<int> FirstMask, SecondMask; + for (unsigned I = 0; I < NumOfDestRegs; ++I) { + SmallVectorImpl<int> &RegMask = Dest[I]; + if (RegMask.empty()) + continue; + + if (FirstIdx == SecondIdx) { + FirstIdx = I; + FirstMask = RegMask; + continue; + } + SecondIdx = I; + SecondMask = RegMask; + CombineMasks(FirstMask, SecondMask); + ManyInputsAction(FirstMask, FirstIdx, SecondIdx); + NormalizeMask(FirstMask); + RegMask.clear(); + SecondMask = FirstMask; + SecondIdx = FirstIdx; + } + if (FirstIdx != SecondIdx && SecondIdx >= 0) { + CombineMasks(SecondMask, FirstMask); + ManyInputsAction(SecondMask, SecondIdx, FirstIdx); + Dest[FirstIdx].clear(); + NormalizeMask(SecondMask); + } + } while (SecondIdx >= 0); + break; + } + } + } +} + MapVector<Instruction *, uint64_t> llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, const TargetTransformInfo *TTI) { @@ -543,9 +658,8 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, Value *Val = Worklist.pop_back_val(); Value *Leader = ECs.getOrInsertLeaderValue(Val); - if (Visited.count(Val)) + if (!Visited.insert(Val).second) continue; - Visited.insert(Val); // Non-instructions terminate a chain successfully. if (!isa<Instruction>(Val)) @@ -1387,7 +1501,7 @@ void VFABI::getVectorVariantNames( #ifndef NDEBUG LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << S << "'\n"); Optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, *(CI.getModule())); - assert(Info.hasValue() && "Invalid name for a VFABI variant."); + assert(Info && "Invalid name for a VFABI variant."); assert(CI.getModule()->getFunction(Info.getValue().VectorName) && "Vector function is missing."); #endif |